Compare commits

..

88 Commits

Author SHA1 Message Date
Jamison Lahman
bfdeb65bbb feat(ux): handle when chat session id cannot be found 2026-03-20 16:44:34 -07:00
Jamison Lahman
461350958a fix(fe): dim project name in sidebar color (#9519) 2026-03-20 17:47:49 +00:00
Raunak Bhagat
50dde0be1a chore: edit AGENTS.md and CLAUDE.md files (#9486) 2026-03-20 00:59:30 +00:00
acaprau
199e1df453 feat(opensearch): Add functions for keyword and semantic retrieval (#9479)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-03-20 00:48:01 +00:00
Justin Tahara
996b674840 feat(backend): Adding procps (#9509) 2026-03-19 23:26:36 +00:00
Justin Tahara
5413723ccc feat(ods): Rerun run-ci workflow (#9501) 2026-03-19 22:11:59 +00:00
Evan Lohn
9660056a51 fix: drive rate limit retry (#9498) 2026-03-19 21:32:08 +00:00
Fizza Mukhtar
3105177238 fix(llm): don't send tool_choice when no tools are provided (#9224) 2026-03-19 21:26:46 +00:00
Evan Lohn
24bb4bda8b feat: windows installer and install improvements (#9476) 2026-03-19 20:47:44 +00:00
Raunak Bhagat
9532af4ceb chore: move Hoverable story (#9495) 2026-03-19 20:40:27 +00:00
Jamison Lahman
0a913f6af5 fix(fe): fix memories immediately losing focus on click (#9493) 2026-03-19 20:15:34 +00:00
Justin Tahara
fe30c55199 fix(code interpreter): Caching files (#9484) 2026-03-19 19:32:37 +00:00
Jamison Lahman
2cf0a65dd3 chore(fe): reduce padding on elements at the bottom of modal headers (#9488) 2026-03-19 19:27:37 +00:00
Nikolas Garza
659416f363 feat(admin): groups page - list page and group cards (#9453) 2026-03-19 18:23:15 +00:00
Raunak Bhagat
40aecbc4b9 refactor(fe): move table to opal, update size API (#9438) 2026-03-19 17:23:41 +00:00
Jamison Lahman
710b39074f chore(fe): remove opal-button* class names (#9471) 2026-03-19 02:15:00 +00:00
acaprau
8fe2f67d38 chore(opensearch): Allow disabling match highlights via env var; default to disabled (#9436) 2026-03-19 00:43:17 +00:00
Justin Tahara
f00aaf9fc0 fix(agents): Agents are Private by Default (#9465) 2026-03-19 00:01:46 +00:00
Bo-Onyx
5b2426b002 chore(hooks): Define Hook Point in the backend (#9391) 2026-03-18 23:43:26 +00:00
Justin Tahara
ba6ab0245b fix(celery): add dedup guardrails to user file delete queue (#9454)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 23:38:52 +00:00
Justin Tahara
b64ebb57e1 fix(logging): extract LiteLLM error details in image summarization failures (#9458) 2026-03-18 23:29:04 +00:00
Justin Tahara
2fcfdbabde fix(celery): add task expiry to upload API send_task call (#9456)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 23:17:08 +00:00
Justin Tahara
ea1a2749c1 fix(image): add diagnostic logging to vision model selection (#9460) 2026-03-18 22:06:56 +00:00
Justin Tahara
73c4e22588 fix(image): stop dumping base64 image data into error logs (#9457) 2026-03-18 21:43:55 +00:00
Jamison Lahman
fceaac6e13 fix(fe): make indexing attempt error rows click to show trace (#9463)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-03-18 21:38:53 +00:00
Jamison Lahman
e8bf45cfd2 feat(fe): "Full Exception Trace" modal uses CodePreview rendering (#9464)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-03-18 21:04:55 +00:00
Bo-Onyx
13ff648fcd chore(hooks): Add Celery task to remove hook running records older than 30 days (#9433) 2026-03-18 21:03:01 +00:00
Jamison Lahman
ae8268afb1 fix(fe): truncate connector names in table (#9459) 2026-03-18 20:59:49 +00:00
acaprau
b338bd9e97 feat(opensearch): Can override number of shards and replicas via env var (#9431) 2026-03-18 20:16:05 +00:00
acaprau
0dcc90a042 fix(opensearch): Exclude retrieving vectors during hybrid and random search (#9430) 2026-03-18 20:13:12 +00:00
Jamison Lahman
0f6a6693d3 fix(fe): truncate project name in sidebar button (#9462) 2026-03-18 20:06:09 +00:00
Jamison Lahman
e32cc450b2 fix(fe): update connector indexing error modal (#9426)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-03-18 11:57:28 -07:00
Jamison Lahman
732fb71edf chore(tests): unit tests for pdf processing (#9452)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 18:31:37 +00:00
dependabot[bot]
ca3320c0e0 chore(deps): bump pypdf from 6.8.0 to 6.9.1 (#9450)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-18 17:52:50 +00:00
Jamison Lahman
d7c554aca7 chore(ruff): fix and enable S324 (#9451) 2026-03-18 17:26:29 +00:00
dependabot[bot]
69e5c19695 chore(deps): bump next from 16.1.5 to 16.1.7 in /examples/widget (#9425)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-18 09:25:27 -07:00
Nikolas Garza
b4ce1c7a97 chore: bump next to 16.1.7 (#9423) 2026-03-18 09:22:40 -07:00
Jamison Lahman
cd64a91154 fix(fe): display name on attachment file card hover (#9446) 2026-03-18 16:13:21 +00:00
Danelegend
c282cdc096 fix(file upload): Allow zip file upload via query param (#9432) 2026-03-18 07:32:07 +00:00
Jamison Lahman
b1de1c59b6 chore(playwright): projects screenshot is main container only (#9440) 2026-03-18 05:35:30 +00:00
acaprau
64d484039f chore(opensearch): Disable test_update_single_can_clear_user_projects_and_personas (#9434) 2026-03-18 00:40:29 +00:00
Jamison Lahman
0530095b71 fix(fe): replace users table buttons with LineItems (#9435) 2026-03-17 23:45:15 +00:00
acaprau
23280d5b91 fix(opensearch): Fix env var mismatch issue with configuring subquery results; set default to 100 (#9428) 2026-03-17 16:01:45 -07:00
Bo-Onyx
229442679c chore(hooks): Add db CRUD (#9411) 2026-03-17 22:36:50 +00:00
Jamison Lahman
95a192fb0f chore(devtools): upgrade ods: 0.7.0->0.7.1 (#9429) 2026-03-17 15:18:21 -07:00
Yuhong Sun
6bd96ec906 chore: Scripts for search quality eval (#9421) 2026-03-17 14:53:32 -07:00
Jamison Lahman
a1ec88269f chore(docker): configurable api_server resource limits (#9424)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-03-17 14:52:21 -07:00
Raunak Bhagat
b929518c34 refactor: update size variant and dimension names in opal components (#9416)
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
2026-03-17 21:43:19 +00:00
Justin Tahara
479220e774 chore(opensearch): Make Password Default Empty (#9415) 2026-03-17 21:41:08 +00:00
dependabot[bot]
d3e0acf905 chore(deps): bump pyasn1 from 0.6.2 to 0.6.3 (#9417)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-17 21:28:07 +00:00
acaprau
cbd1a344f2 feat(opensearch): Configure hybrid search subquery groups and pipelines via env var (#9407) 2026-03-17 21:25:53 +00:00
Jamison Lahman
e79264b69b chore(fe): prefer INTERNAL_URL (#9419) 2026-03-17 14:17:05 -07:00
Justin Tahara
1e0a8e9a0e fix(llm): surface masked OpenAI quota failures (#9308) 2026-03-17 21:11:43 +00:00
Jamison Lahman
b7841a513d chore(devtools): ods backend scans for available ports (#9418) 2026-03-17 14:09:23 -07:00
dependabot[bot]
c779bf722d chore(deps): bump next from 16.1.5 to 16.1.7 in /backend/onyx/server/features/build/sandbox/kubernetes/docker/templates/outputs/web (#9420)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-17 14:03:51 -07:00
Jamison Lahman
a5aff0d199 chore(fe): rm dead stackTraceModalContent (#9412) 2026-03-17 20:42:29 +00:00
Wenxi
8ed170b070 fix: make ConnectorCredentialPair name required (#9408)
Co-authored-by: Ciaran Sweet <ciaran@developmentseed.org>
2026-03-17 18:54:34 +00:00
Raunak Bhagat
c890cd4767 feat(llm-config): replace AdvancedOptions with unified ModelsAccessField (#9270)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-17 18:19:52 +00:00
Nikolas Garza
2b2df18463 fix(vespa): use weightedSet for ACL filters to prevent query failures (#9403) 2026-03-17 17:21:13 +00:00
Bo-Onyx
11cfc92f15 chore(hook): DB changes (#9337) 2026-03-17 01:04:06 +00:00
Jamison Lahman
c7da99cfd7 chore(playwright): make project name human-readable (#9394) 2026-03-16 17:26:20 -07:00
Jamison Lahman
b384c77863 chore(fe): admin navigation always goes to LLM config page (#9395) 2026-03-16 17:15:50 -07:00
Raunak Bhagat
b0f31cd46b fix(search-ui): center pagination in SearchUI (#9396) 2026-03-16 23:59:17 +00:00
Jamison Lahman
323eb9bbba chore(fe): make sidebar scrollbar flush with edge (#9383)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-03-16 23:33:40 +00:00
Raunak Bhagat
708e310849 refactor: refreshed Pagination component (#9380) 2026-03-16 23:14:59 +00:00
Wenxi
c25509e212 chore: run identify from backend (#9392) 2026-03-16 23:12:10 +00:00
Nikolas Garza
6af0da41bd test(admin): add E2E Playwright tests for Users page (#9266) 2026-03-16 21:41:24 +00:00
Evan Lohn
b94da25d7c chore: update install script (#9068) 2026-03-16 21:29:56 +00:00
Jamison Lahman
7d443c1b53 chore(ws): ignore port when determining origin in dev (#9382) 2026-03-16 21:24:34 +00:00
Justin Tahara
d6b7b3c68f fix(celery): Limiting connector_hierarchy_fetching jobs (#9381) 2026-03-16 21:14:04 +00:00
Jamison Lahman
f5073d331e chore(tests): fix flaky test_run_with_timeout_raises_on_timeout (#9377) 2026-03-16 19:02:58 +00:00
dependabot[bot]
64c9f6a0d5 chore(deps): bump docker/metadata-action from 5.10.0 to 6.0.0 (#9374)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-16 11:57:00 -07:00
dependabot[bot]
f5a494f790 chore(deps): bump actions/upload-artifact from 6.0.0 to 7.0.0 (#9375)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-16 11:56:45 -07:00
dependabot[bot]
8598e9f25d chore(deps): bump actions/checkout from 6.0.1 to 6.0.2 (#9373)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-16 11:56:26 -07:00
Justin Tahara
3ef8aecc54 test(ui): Add visual regression test for project files with long filenames (#9062) 2026-03-16 18:41:06 +00:00
Wenxi
eb311c7550 fix: use uuid as ph unique id from BE (#9371) 2026-03-16 18:06:34 +00:00
Jamison Lahman
13284d9def chore(voice): support non-default FE ports for IS_DEV (#9356) 2026-03-16 11:03:56 -07:00
Bo-Onyx
aaa99fcb60 chore(hook): Add feature control (#9320) 2026-03-16 17:48:53 +00:00
dependabot[bot]
5f628da4e8 chore(deps): bump authlib from 1.6.7 to 1.6.9 (#9370)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-16 17:21:05 +00:00
Jamison Lahman
e40f80cfe1 chore(posthog): allow no-op client in DEV_MODE (#9357) 2026-03-16 16:55:00 +00:00
Nikolas Garza
ca6ba2cca9 fix(admin): users page UI/UX polish (#9366) 2026-03-16 15:27:03 +00:00
Nikolas Garza
98ef5006ff feat(ci): add Slack @-mention support to slack-notify action (#9359) 2026-03-16 15:26:32 +00:00
Nikolas Garza
dfd168cde9 fix(fe): bump flatted to patch CVE-2026-32141 (#9350) 2026-03-14 05:46:04 +00:00
Raunak Bhagat
6c7ae243d0 feat: refresh admin sidebar with new sections, search, and disabled EE tabs (#9344) 2026-03-14 04:09:16 +00:00
Raunak Bhagat
c4a2ff2593 feat: add progress-bars opal icon (#9349) 2026-03-14 02:18:41 +00:00
Danelegend
4b74a6dc76 fix(litellm): filter embedding models (#9347) 2026-03-14 01:40:06 +00:00
dependabot[bot]
eea5f5b380 chore(deps): bump pyjwt from 2.11.0 to 2.12.0 (#9341)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-13 21:57:49 +00:00
Raunak Bhagat
ae428ba684 feat: add curate and user variant opal icons (#9343) 2026-03-13 21:51:02 +00:00
365 changed files with 17665 additions and 15422 deletions

View File

@@ -10,6 +10,9 @@ inputs:
failed-jobs:
description: "Deprecated alias for details"
required: false
mention:
description: "GitHub username to resolve to a Slack @-mention. Replaces {mention} in details."
required: false
title:
description: "Title for the notification"
required: false
@@ -26,6 +29,7 @@ runs:
SLACK_WEBHOOK_URL: ${{ inputs.webhook-url }}
DETAILS: ${{ inputs.details }}
FAILED_JOBS: ${{ inputs.failed-jobs }}
MENTION_USER: ${{ inputs.mention }}
TITLE: ${{ inputs.title }}
REF_NAME: ${{ inputs.ref-name }}
REPO: ${{ github.repository }}
@@ -52,6 +56,27 @@ runs:
DETAILS="$FAILED_JOBS"
fi
# Resolve {mention} placeholder if a GitHub username was provided.
# Looks up the username in user-mappings.json (co-located with this action)
# and replaces {mention} with <@SLACK_ID> for a Slack @-mention.
# Falls back to the plain GitHub username if not found in the mapping.
if [ -n "$MENTION_USER" ]; then
MAPPINGS_FILE="${GITHUB_ACTION_PATH}/user-mappings.json"
slack_id="$(jq -r --arg gh "$MENTION_USER" 'to_entries[] | select(.value | ascii_downcase == ($gh | ascii_downcase)) | .key' "$MAPPINGS_FILE" 2>/dev/null | head -1)"
if [ -n "$slack_id" ]; then
mention_text="<@${slack_id}>"
else
mention_text="${MENTION_USER}"
fi
DETAILS="${DETAILS//\{mention\}/$mention_text}"
TITLE="${TITLE//\{mention\}/}"
else
DETAILS="${DETAILS//\{mention\}/}"
TITLE="${TITLE//\{mention\}/}"
fi
normalize_multiline() {
printf '%s' "$1" | awk 'BEGIN { ORS=""; first=1 } { if (!first) printf "\\n"; printf "%s", $0; first=0 }'
}

View File

@@ -0,0 +1,18 @@
{
"U05SAGZPEA1": "yuhongsun96",
"U05SAH6UGUD": "Weves",
"U07PWEQB7A5": "evan-onyx",
"U07V1SM68KF": "joachim-danswer",
"U08JZ9N3QNN": "raunakab",
"U08L24NCLJE": "Subash-Mohan",
"U090B9M07B2": "wenxi-onyx",
"U094RASDP0Q": "duo-onyx",
"U096L8ZQ85B": "justin-tahara",
"U09AHV8UBQX": "jessicasingh7",
"U09KAL5T3C2": "nmgarza5",
"U09KPGVQ70R": "acaprau",
"U09QR8KTSJH": "rohoswagger",
"U09RB4NTXA4": "jmelahman",
"U0A6K9VCY6A": "Danelegend",
"U0AGC4KH71A": "Bo-Onyx"
}

View File

@@ -455,7 +455,7 @@ jobs:
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
with:
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
@@ -529,7 +529,7 @@ jobs:
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
with:
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
@@ -607,7 +607,7 @@ jobs:
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
with:
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
@@ -668,7 +668,7 @@ jobs:
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
with:
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
@@ -750,7 +750,7 @@ jobs:
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
with:
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
@@ -836,7 +836,7 @@ jobs:
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
with:
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
@@ -894,7 +894,7 @@ jobs:
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
with:
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
@@ -967,7 +967,7 @@ jobs:
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
with:
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
@@ -1044,7 +1044,7 @@ jobs:
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
with:
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
@@ -1105,7 +1105,7 @@ jobs:
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
with:
images: ${{ env.REGISTRY_IMAGE }}
flavor: |
@@ -1178,7 +1178,7 @@ jobs:
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
with:
images: ${{ env.REGISTRY_IMAGE }}
flavor: |
@@ -1256,7 +1256,7 @@ jobs:
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
with:
images: ${{ env.REGISTRY_IMAGE }}
flavor: |
@@ -1317,7 +1317,7 @@ jobs:
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
with:
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
@@ -1397,7 +1397,7 @@ jobs:
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
with:
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
@@ -1480,7 +1480,7 @@ jobs:
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
with:
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |

View File

@@ -207,7 +207,7 @@ jobs:
CHERRY_PICK_PR_URL: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_pr_url }}
run: |
source_pr_url="https://github.com/${GITHUB_REPOSITORY}/pull/${SOURCE_PR_NUMBER}"
details="*Cherry-pick PR opened successfully.*\\n• source PR: ${source_pr_url}"
details="*Cherry-pick PR opened successfully.*\\n• author: {mention}\\n• source PR: ${source_pr_url}"
if [ -n "${CHERRY_PICK_PR_URL}" ]; then
details="${details}\\n• cherry-pick PR: ${CHERRY_PICK_PR_URL}"
fi
@@ -221,6 +221,7 @@ jobs:
uses: ./.github/actions/slack-notify
with:
webhook-url: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }}
mention: ${{ needs.resolve-cherry-pick-request.outputs.merged_by }}
details: ${{ steps.success-summary.outputs.details }}
title: "✅ Automated Cherry-Pick PR Opened"
ref-name: ${{ github.event.pull_request.base.ref }}
@@ -275,20 +276,21 @@ jobs:
else
failed_job_label="cherry-pick-to-latest-release"
fi
failed_jobs="• ${failed_job_label}\\n• source PR: ${source_pr_url}\\n• reason: ${reason_text}"
details="• author: {mention}\\n• ${failed_job_label}\\n• source PR: ${source_pr_url}\\n• reason: ${reason_text}"
if [ -n "${MERGE_COMMIT_SHA}" ]; then
failed_jobs="${failed_jobs}\\n• merge SHA: ${MERGE_COMMIT_SHA}"
details="${details}\\n• merge SHA: ${MERGE_COMMIT_SHA}"
fi
if [ -n "${details_excerpt}" ]; then
failed_jobs="${failed_jobs}\\n• excerpt: ${details_excerpt}"
details="${details}\\n• excerpt: ${details_excerpt}"
fi
echo "jobs=${failed_jobs}" >> "$GITHUB_OUTPUT"
echo "details=${details}" >> "$GITHUB_OUTPUT"
- name: Notify #cherry-pick-prs about cherry-pick failure
uses: ./.github/actions/slack-notify
with:
webhook-url: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }}
details: ${{ steps.failure-summary.outputs.jobs }}
mention: ${{ needs.resolve-cherry-pick-request.outputs.merged_by }}
details: ${{ steps.failure-summary.outputs.details }}
title: "🚨 Automated Cherry-Pick Failed"
ref-name: ${{ github.event.pull_request.base.ref }}

View File

@@ -105,7 +105,7 @@ jobs:
- name: Upload build artifacts
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f
with:
name: desktop-build-${{ matrix.platform }}-${{ github.run_id }}
path: |

View File

@@ -174,7 +174,7 @@ jobs:
- name: Upload Docker logs
if: failure()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f
with:
name: docker-logs-${{ matrix.test-dir }}
path: docker-logs/

View File

@@ -25,7 +25,7 @@ jobs:
outputs:
modules: ${{ steps.set-modules.outputs.modules }}
steps:
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd
with:
persist-credentials: false
- id: set-modules
@@ -39,7 +39,7 @@ jobs:
matrix:
modules: ${{ fromJSON(needs.detect-modules.outputs.modules) }}
steps:
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # zizmor: ignore[cache-poisoning]

View File

@@ -466,7 +466,7 @@ jobs:
- name: Upload logs
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f
with:
name: docker-all-logs-${{ matrix.edition }}-${{ matrix.test-dir.name }}
path: ${{ github.workspace }}/docker-compose.log
@@ -587,7 +587,7 @@ jobs:
- name: Upload logs (onyx-lite)
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f
with:
name: docker-all-logs-onyx-lite
path: ${{ github.workspace }}/docker-compose-onyx-lite.log
@@ -725,7 +725,7 @@ jobs:
- name: Upload logs (multi-tenant)
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f
with:
name: docker-all-logs-multitenant
path: ${{ github.workspace }}/docker-compose-multitenant.log

View File

@@ -44,7 +44,7 @@ jobs:
- name: Upload coverage reports
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f
with:
name: jest-coverage-${{ github.run_id }}
path: ./web/coverage

View File

@@ -445,7 +445,7 @@ jobs:
run: |
npx playwright test --project ${PROJECT}
- uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
- uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f
if: always()
with:
# Includes test results and trace.zip files
@@ -454,7 +454,7 @@ jobs:
retention-days: 30
- name: Upload screenshots
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f
if: always()
with:
name: playwright-screenshots-${{ matrix.project }}-${{ github.run_id }}
@@ -534,7 +534,7 @@ jobs:
"s3://${PLAYWRIGHT_S3_BUCKET}/reports/pr-${PR_NUMBER}/${RUN_ID}/${PROJECT}/"
- name: Upload visual diff summary
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f
if: always()
with:
name: screenshot-diff-summary-${{ matrix.project }}
@@ -543,7 +543,7 @@ jobs:
retention-days: 5
- name: Upload visual diff report artifact
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f
if: always()
with:
name: screenshot-diff-report-${{ matrix.project }}-${{ github.run_id }}
@@ -590,7 +590,7 @@ jobs:
- name: Upload logs
if: success() || failure()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f
with:
name: docker-logs-${{ matrix.project }}-${{ github.run_id }}
path: ${{ github.workspace }}/docker-compose.log
@@ -674,7 +674,7 @@ jobs:
working-directory: ./web
run: npx playwright test --project lite
- uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
- uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f
if: always()
with:
name: playwright-test-results-lite-${{ github.run_id }}
@@ -692,7 +692,7 @@ jobs:
- name: Upload logs
if: success() || failure()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f
with:
name: docker-logs-lite-${{ github.run_id }}
path: ${{ github.workspace }}/docker-compose.log

View File

@@ -122,7 +122,7 @@ jobs:
- name: Upload logs
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f
with:
name: docker-all-logs
path: ${{ github.workspace }}/docker-compose.log

View File

@@ -319,7 +319,7 @@ jobs:
- name: Upload logs
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f
with:
name: docker-all-logs-nightly-${{ matrix.provider }}-llm-provider
path: |

View File

@@ -125,7 +125,7 @@ jobs:
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
with:
images: ${{ env.REGISTRY_IMAGE }}
flavor: |
@@ -195,7 +195,7 @@ jobs:
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
with:
images: ${{ env.REGISTRY_IMAGE }}
flavor: |
@@ -268,7 +268,7 @@ jobs:
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
with:
images: ${{ env.REGISTRY_IMAGE }}
flavor: |

279
AGENTS.md
View File

@@ -167,284 +167,7 @@ web/
## Frontend Standards
### 1. Import Standards
**Always use absolute imports with the `@` prefix.**
**Reason:** Moving files around becomes easier since you don't also have to update those import statements. This makes modifications to the codebase much nicer.
```typescript
// ✅ Good
import { Button } from "@/components/ui/button";
import { useAuth } from "@/hooks/useAuth";
import { Text } from "@/refresh-components/texts/Text";
// ❌ Bad
import { Button } from "../../../components/ui/button";
import { useAuth } from "./hooks/useAuth";
```
### 2. React Component Functions
**Prefer regular functions over arrow functions for React components.**
**Reason:** Functions just become easier to read.
```typescript
// ✅ Good
function UserProfile({ userId }: UserProfileProps) {
return <div>User Profile</div>
}
// ❌ Bad
const UserProfile = ({ userId }: UserProfileProps) => {
return <div>User Profile</div>
}
```
### 3. Props Interface Extraction
**Extract prop types into their own interface definitions.**
**Reason:** Functions just become easier to read.
```typescript
// ✅ Good
interface UserCardProps {
user: User
showActions?: boolean
onEdit?: (userId: string) => void
}
function UserCard({ user, showActions = false, onEdit }: UserCardProps) {
return <div>User Card</div>
}
// ❌ Bad
function UserCard({
user,
showActions = false,
onEdit
}: {
user: User
showActions?: boolean
onEdit?: (userId: string) => void
}) {
return <div>User Card</div>
}
```
### 4. Spacing Guidelines
**Prefer padding over margins for spacing.**
**Reason:** We want to consolidate usage to paddings instead of margins.
```typescript
// ✅ Good
<div className="p-4 space-y-2">
<div className="p-2">Content</div>
</div>
// ❌ Bad
<div className="m-4 space-y-2">
<div className="m-2">Content</div>
</div>
```
### 5. Tailwind Dark Mode
**Strictly forbid using the `dark:` modifier in Tailwind classes, except for logo icon handling.**
**Reason:** The `colors.css` file already, VERY CAREFULLY, defines what the exact opposite colour of each light-mode colour is. Overriding this behaviour is VERY bad and will lead to horrible UI breakages.
**Exception:** The `createLogoIcon` helper in `web/src/components/icons/icons.tsx` uses `dark:` modifiers (`dark:invert`, `dark:hidden`, `dark:block`) to handle third-party logo icons that cannot automatically adapt through `colors.css`. This is the ONLY acceptable use of dark mode modifiers.
```typescript
// ✅ Good - Standard components use `tailwind-themes/tailwind.config.js` / `src/app/css/colors.css`
<div className="bg-background-neutral-03 text-text-02">
Content
</div>
// ✅ Good - Logo icons with dark mode handling via createLogoIcon
export const GithubIcon = createLogoIcon(githubLightIcon, {
monochromatic: true, // Will apply dark:invert internally
});
export const GitbookIcon = createLogoIcon(gitbookLightIcon, {
darkSrc: gitbookDarkIcon, // Will use dark:hidden/dark:block internally
});
// ❌ Bad - Manual dark mode overrides
<div className="bg-white dark:bg-black text-black dark:text-white">
Content
</div>
```
### 6. Class Name Utilities
**Use the `cn` utility instead of raw string formatting for classNames.**
**Reason:** `cn`s are easier to read. They also allow for more complex types (i.e., string-arrays) to get formatted properly (it flattens each element in that string array down). As a result, it can allow things such as conditionals (i.e., `myCondition && "some-tailwind-class"`, which evaluates to `false` when `myCondition` is `false`) to get filtered out.
```typescript
import { cn } from '@/lib/utils'
// ✅ Good
<div className={cn(
'base-class',
isActive && 'active-class',
className
)}>
Content
</div>
// ❌ Bad
<div className={`base-class ${isActive ? 'active-class' : ''} ${className}`}>
Content
</div>
```
### 7. Custom Hooks Organization
**Follow a "hook-per-file" layout. Each hook should live in its own file within `web/src/hooks`.**
**Reason:** This is just a layout preference. Keeps code clean.
```typescript
// web/src/hooks/useUserData.ts
export function useUserData(userId: string) {
// hook implementation
}
// web/src/hooks/useLocalStorage.ts
export function useLocalStorage<T>(key: string, initialValue: T) {
// hook implementation
}
```
### 8. Icon Usage
**ONLY use icons from the `web/src/icons` directory. Do NOT use icons from `react-icons`, `lucide`, or other external libraries.**
**Reason:** We have a very carefully curated selection of icons that match our Onyx guidelines. We do NOT want to muddy those up with different aesthetic stylings.
```typescript
// ✅ Good
import SvgX from "@/icons/x";
import SvgMoreHorizontal from "@/icons/more-horizontal";
// ❌ Bad
import { User } from "lucide-react";
import { FiSearch } from "react-icons/fi";
```
**Missing Icons**: If an icon is needed but doesn't exist in the `web/src/icons` directory, import it from Figma using the Figma MCP tool and add it to the icons directory.
If you need help with this step, reach out to `raunak@onyx.app`.
### 9. Text Rendering
**Prefer using the `refresh-components/texts/Text` component for all text rendering. Avoid "naked" text nodes.**
**Reason:** The `Text` component is fully compliant with the stylings provided in Figma. It provides easy utilities to specify the text-colour and font-size in the form of flags. Super duper easy.
```typescript
// ✅ Good
import { Text } from '@/refresh-components/texts/Text'
function UserCard({ name }: { name: string }) {
return (
<Text
{/* The `text03` flag makes the text it renders to be coloured the 3rd-scale grey */}
text03
{/* The `mainAction` flag makes the text it renders to be "main-action" font + line-height + weightage, as described in the Figma */}
mainAction
>
{name}
</Text>
)
}
// ❌ Bad
function UserCard({ name }: { name: string }) {
return (
<div>
<h2>{name}</h2>
<p>User details</p>
</div>
)
}
```
### 10. Component Usage
**Heavily avoid raw HTML input components. Always use components from the `web/src/refresh-components` or `web/lib/opal/src` directory.**
**Reason:** We've put in a lot of effort to unify the components that are rendered in the Onyx app. Using raw components breaks the entire UI of the application, and leaves it in a muddier state than before.
```typescript
// ✅ Good
import Button from '@/refresh-components/buttons/Button'
import InputTypeIn from '@/refresh-components/inputs/InputTypeIn'
import SvgPlusCircle from '@/icons/plus-circle'
function ContactForm() {
return (
<form>
<InputTypeIn placeholder="Search..." />
<Button type="submit" leftIcon={SvgPlusCircle}>Submit</Button>
</form>
)
}
// ❌ Bad
function ContactForm() {
return (
<form>
<input placeholder="Name" />
<textarea placeholder="Message" />
<button type="submit">Submit</button>
</form>
)
}
```
### 11. Colors
**Always use custom overrides for colors and borders rather than built in Tailwind CSS colors. These overrides live in `web/tailwind-themes/tailwind.config.js`.**
**Reason:** Our custom color system uses CSS variables that automatically handle dark mode and maintain design consistency across the app. Standard Tailwind colors bypass this system.
**Available color categories:**
- **Text:** `text-01` through `text-05`, `text-inverted-XX`
- **Backgrounds:** `background-neutral-XX`, `background-tint-XX` (and inverted variants)
- **Borders:** `border-01` through `border-05`, `border-inverted-XX`
- **Actions:** `action-link-XX`, `action-danger-XX`
- **Status:** `status-info-XX`, `status-success-XX`, `status-warning-XX`, `status-error-XX`
- **Theme:** `theme-primary-XX`, `theme-red-XX`, `theme-blue-XX`, etc.
```typescript
// ✅ Good - Use custom Onyx color classes
<div className="bg-background-neutral-01 border border-border-02" />
<div className="bg-background-tint-02 border border-border-01" />
<div className="bg-status-success-01" />
<div className="bg-action-link-01" />
<div className="bg-theme-primary-05" />
// ❌ Bad - Do NOT use standard Tailwind colors
<div className="bg-gray-100 border border-gray-300 text-gray-600" />
<div className="bg-white border border-slate-200" />
<div className="bg-green-100 text-green-700" />
<div className="bg-blue-100 text-blue-600" />
<div className="bg-indigo-500" />
```
### 12. Data Fetching
**Prefer using `useSWR` for data fetching. Data should generally be fetched on the client side. Components that need data should display a loader / placeholder while waiting for that data. Prefer loading data within the component that needs it rather than at the top level and passing it down.**
**Reason:** Client side fetching allows us to load the skeleton of the page without waiting for data to load, leading to a snappier UX. Loading data where needed reduces dependencies between a component and its parent component(s).
Frontend standards for the `web/` and `desktop/` projects live in `web/AGENTS.md`.
## Database & Migrations

View File

@@ -47,6 +47,8 @@ RUN apt-get update && \
gcc \
nano \
vim \
# Install procps so kubernetes exec sessions can use ps aux for debugging
procps \
libjemalloc2 \
&& \
rm -rf /var/lib/apt/lists/* && \

View File

@@ -0,0 +1,103 @@
"""add_hook_and_hook_execution_log_tables
Revision ID: 689433b0d8de
Revises: 93a2e195e25c
Create Date: 2026-03-13 11:25:06.547474
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID as PGUUID
# revision identifiers, used by Alembic.
revision = "689433b0d8de"
down_revision = "93a2e195e25c"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"hook",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column(
"hook_point",
sa.Enum("document_ingestion", "query_processing", native_enum=False),
nullable=False,
),
sa.Column("endpoint_url", sa.Text(), nullable=True),
sa.Column("api_key", sa.LargeBinary(), nullable=True),
sa.Column("is_reachable", sa.Boolean(), nullable=True),
sa.Column(
"fail_strategy",
sa.Enum("hard", "soft", native_enum=False),
nullable=False,
),
sa.Column("timeout_seconds", sa.Float(), nullable=False),
sa.Column(
"is_active", sa.Boolean(), nullable=False, server_default=sa.text("false")
),
sa.Column(
"deleted", sa.Boolean(), nullable=False, server_default=sa.text("false")
),
sa.Column("creator_id", PGUUID(as_uuid=True), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(["creator_id"], ["user.id"], ondelete="SET NULL"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_hook_one_non_deleted_per_point",
"hook",
["hook_point"],
unique=True,
postgresql_where=sa.text("deleted = false"),
)
op.create_table(
"hook_execution_log",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("hook_id", sa.Integer(), nullable=False),
sa.Column(
"is_success",
sa.Boolean(),
nullable=False,
),
sa.Column("error_message", sa.Text(), nullable=True),
sa.Column("status_code", sa.Integer(), nullable=True),
sa.Column("duration_ms", sa.Integer(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(["hook_id"], ["hook.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("ix_hook_execution_log_hook_id", "hook_execution_log", ["hook_id"])
op.create_index(
"ix_hook_execution_log_created_at", "hook_execution_log", ["created_at"]
)
def downgrade() -> None:
op.drop_index("ix_hook_execution_log_created_at", table_name="hook_execution_log")
op.drop_index("ix_hook_execution_log_hook_id", table_name="hook_execution_log")
op.drop_table("hook_execution_log")
op.drop_index("ix_hook_one_non_deleted_per_point", table_name="hook")
op.drop_table("hook")

View File

@@ -118,9 +118,7 @@ JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", "[]"))
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
# The posthog client does not accept empty API keys or hosts however it fails silently
# when the capture is called. These defaults prevent Posthog issues from breaking the Onyx app
POSTHOG_API_KEY = os.environ.get("POSTHOG_API_KEY") or "FooBar"
POSTHOG_API_KEY = os.environ.get("POSTHOG_API_KEY")
POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
POSTHOG_DEBUG_LOGS_ENABLED = (
os.environ.get("POSTHOG_DEBUG_LOGS_ENABLED", "").lower() == "true"

View File

@@ -34,6 +34,9 @@ class PostHogFeatureFlagProvider(FeatureFlagProvider):
Returns:
True if the feature is enabled for the user, False otherwise.
"""
if not posthog:
return False
try:
posthog.set(
distinct_id=user_id,

View File

@@ -29,7 +29,6 @@ from onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
from onyx.configs.app_configs import OPENROUTER_DEFAULT_API_KEY
from onyx.configs.app_configs import VERTEXAI_DEFAULT_CREDENTIALS
from onyx.configs.app_configs import VERTEXAI_DEFAULT_LOCATION
from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine.sql_engine import get_session_with_shared_schema
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.image_generation import create_default_image_gen_config_from_api_key
@@ -59,7 +58,6 @@ from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from onyx.setup import setup_onyx
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import mt_cloud_telemetry
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX
@@ -71,7 +69,9 @@ logger = setup_logger()
async def get_or_provision_tenant(
email: str, referral_source: str | None = None, request: Request | None = None
email: str,
referral_source: str | None = None,
request: Request | None = None,
) -> str:
"""
Get existing tenant ID for an email or create a new tenant if none exists.
@@ -693,12 +693,6 @@ async def assign_tenant_to_user(
try:
add_users_to_tenant([email], tenant_id)
mt_cloud_telemetry(
tenant_id=tenant_id,
distinct_id=email,
event=MilestoneRecordType.TENANT_CREATED,
)
except Exception:
logger.exception(f"Failed to assign tenant {tenant_id} to user {email}")
raise Exception("Failed to assign tenant to user")

View File

@@ -9,6 +9,7 @@ from ee.onyx.configs.app_configs import POSTHOG_API_KEY
from ee.onyx.configs.app_configs import POSTHOG_DEBUG_LOGS_ENABLED
from ee.onyx.configs.app_configs import POSTHOG_HOST
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -18,12 +19,19 @@ def posthog_on_error(error: Any, items: Any) -> None:
logger.error(f"PostHog error: {error}, items: {items}")
posthog = Posthog(
project_api_key=POSTHOG_API_KEY,
host=POSTHOG_HOST,
debug=POSTHOG_DEBUG_LOGS_ENABLED,
on_error=posthog_on_error,
)
posthog: Posthog | None = None
if POSTHOG_API_KEY:
posthog = Posthog(
project_api_key=POSTHOG_API_KEY,
host=POSTHOG_HOST,
debug=POSTHOG_DEBUG_LOGS_ENABLED,
on_error=posthog_on_error,
)
elif MULTI_TENANT:
logger.warning(
"POSTHOG_API_KEY is not set but MULTI_TENANT is enabled — "
"PostHog telemetry and feature flags will be disabled"
)
# For cross referencing between cloud and www Onyx sites
# NOTE: These clients are separate because they are separate posthog projects.
@@ -60,7 +68,7 @@ def capture_and_sync_with_alternate_posthog(
logger.error(f"Error capturing marketing posthog event: {e}")
try:
if cloud_user_id := props.get("onyx_cloud_user_id"):
if posthog and (cloud_user_id := props.get("onyx_cloud_user_id")):
cloud_props = props.copy()
cloud_props.pop("onyx_cloud_user_id", None)

View File

@@ -1,3 +1,5 @@
from typing import Any
from ee.onyx.utils.posthog_client import posthog
from onyx.utils.logger import setup_logger
@@ -5,12 +7,27 @@ logger = setup_logger()
def event_telemetry(
distinct_id: str, event: str, properties: dict | None = None
distinct_id: str, event: str, properties: dict[str, Any] | None = None
) -> None:
"""Capture and send an event to PostHog, flushing immediately."""
if not posthog:
return
logger.info(f"Capturing PostHog event: {distinct_id} {event} {properties}")
try:
posthog.capture(distinct_id, event, properties)
posthog.flush()
except Exception as e:
logger.error(f"Error capturing PostHog event: {e}")
def identify_user(distinct_id: str, properties: dict[str, Any] | None = None) -> None:
"""Create/update a PostHog person profile, flushing immediately."""
if not posthog:
return
try:
posthog.identify(distinct_id, properties)
posthog.flush()
except Exception as e:
logger.error(f"Error identifying PostHog user: {e}")

View File

@@ -19,6 +19,7 @@ from typing import Optional
from typing import Protocol
from typing import Tuple
from typing import TypeVar
from urllib.parse import urlparse
import jwt
from email_validator import EmailNotValidError
@@ -134,6 +135,7 @@ from onyx.redis.redis_pool import retrieve_ws_token_data
from onyx.server.settings.store import load_settings
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import mt_cloud_identify
from onyx.utils.telemetry import mt_cloud_telemetry
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
@@ -792,6 +794,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
except Exception:
logger.exception("Error deleting anonymous user cookie")
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
mt_cloud_identify(
distinct_id=str(user.id),
properties={"email": user.email, "tenant_id": tenant_id},
)
async def on_after_register(
self, user: User, request: Optional[Request] = None
) -> None:
@@ -810,12 +818,25 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user_count = await get_user_count()
logger.debug(f"Current tenant user count: {user_count}")
# Ensure a PostHog person profile exists for this user.
mt_cloud_identify(
distinct_id=str(user.id),
properties={"email": user.email, "tenant_id": tenant_id},
)
mt_cloud_telemetry(
tenant_id=tenant_id,
distinct_id=user.email,
distinct_id=str(user.id),
event=MilestoneRecordType.USER_SIGNED_UP,
)
if user_count == 1:
mt_cloud_telemetry(
tenant_id=tenant_id,
distinct_id=str(user.id),
event=MilestoneRecordType.TENANT_CREATED,
)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
@@ -1652,6 +1673,33 @@ async def _get_user_from_token_data(token_data: dict) -> User | None:
return user
_LOOPBACK_HOSTNAMES = frozenset({"localhost", "127.0.0.1", "::1"})
def _is_same_origin(actual: str, expected: str) -> bool:
"""Compare two origins for the WebSocket CSWSH check.
Scheme and hostname must match exactly. Port must also match, except
when the hostname is a loopback address (localhost / 127.0.0.1 / ::1),
where port is ignored. On loopback, all ports belong to the same
operator, so port differences carry no security significance — the
CSWSH threat is remote origins, not local ones.
"""
a = urlparse(actual.rstrip("/"))
e = urlparse(expected.rstrip("/"))
if a.scheme != e.scheme or a.hostname != e.hostname:
return False
if a.hostname in _LOOPBACK_HOSTNAMES:
return True
actual_port = a.port or (443 if a.scheme == "https" else 80)
expected_port = e.port or (443 if e.scheme == "https" else 80)
return actual_port == expected_port
async def current_user_from_websocket(
websocket: WebSocket,
token: str = Query(..., description="WebSocket authentication token"),
@@ -1671,19 +1719,15 @@ async def current_user_from_websocket(
This applies the same auth checks as current_user() for HTTP endpoints.
"""
# Check Origin header to prevent Cross-Site WebSocket Hijacking (CSWSH)
# Browsers always send Origin on WebSocket connections
# Check Origin header to prevent Cross-Site WebSocket Hijacking (CSWSH).
# Browsers always send Origin on WebSocket connections.
origin = websocket.headers.get("origin")
expected_origin = WEB_DOMAIN.rstrip("/")
if not origin:
logger.warning("WS auth: missing Origin header")
raise BasicAuthenticationError(detail="Access denied. Missing origin.")
actual_origin = origin.rstrip("/")
if actual_origin != expected_origin:
logger.warning(
f"WS auth: origin mismatch. Expected {expected_origin}, got {actual_origin}"
)
if not _is_same_origin(origin, WEB_DOMAIN):
logger.warning(f"WS auth: origin mismatch. Expected {WEB_DOMAIN}, got {origin}")
raise BasicAuthenticationError(detail="Access denied. Invalid origin.")
# Validate WS token in Redis (single-use, deleted after retrieval)

View File

@@ -317,6 +317,7 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.docprocessing",
"onyx.background.celery.tasks.evals",
"onyx.background.celery.tasks.hierarchyfetching",
"onyx.background.celery.tasks.hooks",
"onyx.background.celery.tasks.periodic",
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.shared",

View File

@@ -9,6 +9,7 @@ from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
from onyx.configs.app_configs import HOOK_ENABLED
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
@@ -361,6 +362,19 @@ if not MULTI_TENANT:
tasks_to_schedule.extend(beat_task_templates)
if not MULTI_TENANT and HOOK_ENABLED:
tasks_to_schedule.append(
{
"name": "hook-execution-log-cleanup",
"task": OnyxCeleryTask.HOOK_EXECUTION_LOG_CLEANUP_TASK,
"schedule": timedelta(days=1),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
},
}
)
def generate_cloud_tasks(
beat_tasks: list[dict], beat_templates: list[dict], beat_multiplier: float

View File

@@ -29,6 +29,8 @@ from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.connectors.factory import ConnectorMissingException
from onyx.connectors.factory import identify_connector_class
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.interfaces import HierarchyConnector
from onyx.connectors.models import HierarchyNode as PydanticHierarchyNode
@@ -55,6 +57,26 @@ logger = setup_logger()
HIERARCHY_FETCH_INTERVAL_SECONDS = 24 * 60 * 60
def _connector_supports_hierarchy_fetching(
cc_pair: ConnectorCredentialPair,
) -> bool:
"""Return True only for connectors whose class implements HierarchyConnector."""
try:
connector_class = identify_connector_class(
cc_pair.connector.source,
)
except ConnectorMissingException as e:
task_logger.warning(
"Skipping hierarchy fetching enqueue for source=%s input_type=%s: %s",
cc_pair.connector.source,
cc_pair.connector.input_type,
str(e),
)
return False
return issubclass(connector_class, HierarchyConnector)
def _is_hierarchy_fetching_due(cc_pair: ConnectorCredentialPair) -> bool:
"""Returns boolean indicating if hierarchy fetching is due for this connector.
@@ -186,7 +208,10 @@ def check_for_hierarchy_fetching(self: Task, *, tenant_id: str) -> int | None:
cc_pair_id=cc_pair_id,
)
if not cc_pair or not _is_hierarchy_fetching_due(cc_pair):
if not cc_pair or not _connector_supports_hierarchy_fetching(cc_pair):
continue
if not _is_hierarchy_fetching_due(cc_pair):
continue
task_id = _try_creating_hierarchy_fetching_task(

View File

@@ -0,0 +1,35 @@
from celery import shared_task
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.hook import cleanup_old_execution_logs__no_commit
from onyx.utils.logger import setup_logger
logger = setup_logger()
_HOOK_EXECUTION_LOG_RETENTION_DAYS: int = 30
@shared_task(
name=OnyxCeleryTask.HOOK_EXECUTION_LOG_CLEANUP_TASK,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
trail=False,
)
def hook_execution_log_cleanup_task(*, tenant_id: str) -> None: # noqa: ARG001
try:
with get_session_with_current_tenant() as db_session:
deleted: int = cleanup_old_execution_logs__no_commit(
db_session=db_session,
max_age_days=_HOOK_EXECUTION_LOG_RETENTION_DAYS,
)
db_session.commit()
if deleted:
logger.info(
f"Deleted {deleted} hook execution log(s) older than "
f"{_HOOK_EXECUTION_LOG_RETENTION_DAYS} days."
)
except Exception:
logger.exception("Failed to clean up hook execution logs")
raise

View File

@@ -24,6 +24,7 @@ from onyx.configs.app_configs import MANAGED_VESPA
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_DELETE_TASK_EXPIRES
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
@@ -33,6 +34,7 @@ from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import USER_FILE_DELETE_MAX_QUEUE_DEPTH
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
from onyx.configs.constants import USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH
from onyx.connectors.file.connector import LocalFileConnector
@@ -91,6 +93,17 @@ def _user_file_delete_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_DELETE_LOCK_PREFIX}:{user_file_id}"
def _user_file_delete_queued_key(user_file_id: str | UUID) -> str:
"""Key that exists while a delete_single_user_file task is sitting in the queue.
The beat generator sets this with a TTL equal to CELERY_USER_FILE_DELETE_TASK_EXPIRES
before enqueuing and the worker deletes it as its first action. This prevents
the beat from adding duplicate tasks for files that already have a live task
in flight.
"""
return f"{OnyxRedisLocks.USER_FILE_DELETE_QUEUED_PREFIX}:{user_file_id}"
def get_user_file_project_sync_queue_depth(celery_app: Celery) -> int:
redis_celery: Redis = celery_app.broker_connection().channel().client # type: ignore
return celery_get_queue_length(
@@ -546,7 +559,23 @@ def process_single_user_file(
ignore_result=True,
)
def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
"""Scan for user files with DELETING status and enqueue per-file tasks."""
"""Scan for user files with DELETING status and enqueue per-file tasks.
Three mechanisms prevent queue runaway (mirrors check_user_file_processing):
1. **Queue depth backpressure** if the broker queue already has more than
USER_FILE_DELETE_MAX_QUEUE_DEPTH items we skip this beat cycle entirely.
2. **Per-file queued guard** before enqueuing a task we set a short-lived
Redis key (TTL = CELERY_USER_FILE_DELETE_TASK_EXPIRES). If that key
already exists the file already has a live task in the queue, so we skip
it. The worker deletes the key the moment it picks up the task so the
next beat cycle can re-enqueue if the file is still DELETING.
3. **Task expiry** every enqueued task carries an `expires` value equal to
CELERY_USER_FILE_DELETE_TASK_EXPIRES. If a task is still sitting in
the queue after that deadline, Celery discards it without touching the DB.
"""
task_logger.info("check_for_user_file_delete - Starting")
redis_client = get_redis_client(tenant_id=tenant_id)
lock: RedisLock = redis_client.lock(
@@ -555,8 +584,23 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
)
if not lock.acquire(blocking=False):
return None
enqueued = 0
skipped_guard = 0
try:
# --- Protection 1: queue depth backpressure ---
# NOTE: must use the broker's Redis client (not redis_client) because
# Celery queues live on a separate Redis DB with CELERY_SEPARATOR keys.
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
queue_len = celery_get_queue_length(OnyxCeleryQueues.USER_FILE_DELETE, r_celery)
if queue_len > USER_FILE_DELETE_MAX_QUEUE_DEPTH:
task_logger.warning(
f"check_for_user_file_delete - Queue depth {queue_len} exceeds "
f"{USER_FILE_DELETE_MAX_QUEUE_DEPTH}, skipping enqueue for "
f"tenant={tenant_id}"
)
return None
with get_session_with_current_tenant() as db_session:
user_file_ids = (
db_session.execute(
@@ -568,23 +612,40 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
.all()
)
for user_file_id in user_file_ids:
self.app.send_task(
OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_DELETE,
priority=OnyxCeleryPriority.HIGH,
# --- Protection 2: per-file queued guard ---
queued_key = _user_file_delete_queued_key(user_file_id)
guard_set = redis_client.set(
queued_key,
1,
ex=CELERY_USER_FILE_DELETE_TASK_EXPIRES,
nx=True,
)
if not guard_set:
skipped_guard += 1
continue
# --- Protection 3: task expiry ---
try:
self.app.send_task(
OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
kwargs={
"user_file_id": str(user_file_id),
"tenant_id": tenant_id,
},
queue=OnyxCeleryQueues.USER_FILE_DELETE,
priority=OnyxCeleryPriority.HIGH,
expires=CELERY_USER_FILE_DELETE_TASK_EXPIRES,
)
except Exception:
redis_client.delete(queued_key)
raise
enqueued += 1
except Exception as e:
task_logger.exception(
f"check_for_user_file_delete - Error enqueuing deletes - {e.__class__.__name__}"
)
return None
finally:
if lock.owned():
lock.release()
task_logger.info(
f"check_for_user_file_delete - Enqueued {enqueued} tasks for tenant={tenant_id}"
f"check_for_user_file_delete - Enqueued {enqueued} tasks, skipped_guard={skipped_guard} for tenant={tenant_id}"
)
return None
@@ -602,6 +663,9 @@ def delete_user_file_impl(
file_lock: RedisLock | None = None
if redis_locking:
redis_client = get_redis_client(tenant_id=tenant_id)
# Clear the queued guard so the beat can re-enqueue if deletion fails
# and the file remains in DELETING status.
redis_client.delete(_user_file_delete_queued_key(user_file_id))
file_lock = redis_client.lock(
_user_file_delete_lock_key(user_file_id),
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,

View File

@@ -297,7 +297,9 @@ class PostgresCacheBackend(CacheBackend):
def _lock_id_for(self, name: str) -> int:
"""Map *name* to a 64-bit signed int for ``pg_advisory_lock``."""
h = hashlib.md5(f"{self._tenant_id}:{name}".encode()).digest()
h = hashlib.md5(
f"{self._tenant_id}:{name}".encode(), usedforsecurity=False
).digest()
return struct.unpack("q", h[:8])[0]

View File

@@ -36,9 +36,11 @@ from onyx.db.memory import add_memory
from onyx.db.memory import update_memory_at_index
from onyx.db.memory import UserMemoryContext
from onyx.db.models import Persona
from onyx.llm.constants import LlmProviderNames
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMUserIdentity
from onyx.llm.interfaces import ToolChoiceOptions
from onyx.llm.utils import is_true_openai_model
from onyx.prompts.chat_prompts import IMAGE_GEN_REMINDER
from onyx.prompts.chat_prompts import OPEN_URL_REMINDER
from onyx.server.query_and_chat.placement import Placement
@@ -72,6 +74,70 @@ from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
class EmptyLLMResponseError(RuntimeError):
"""Raised when the streamed LLM response completes without a usable answer."""
def __init__(
self,
*,
provider: str,
model: str,
tool_choice: ToolChoiceOptions,
client_error_msg: str,
error_code: str = "EMPTY_LLM_RESPONSE",
is_retryable: bool = True,
) -> None:
super().__init__(client_error_msg)
self.provider = provider
self.model = model
self.tool_choice = tool_choice
self.client_error_msg = client_error_msg
self.error_code = error_code
self.is_retryable = is_retryable
def _build_empty_llm_response_error(
llm: LLM,
llm_step_result: LlmStepResult,
tool_choice: ToolChoiceOptions,
) -> EmptyLLMResponseError:
provider = llm.config.model_provider
model = llm.config.model_name
# OpenAI quota exhaustion has reached us as a streamed "stop" with zero content.
# When the stream is completely empty and there is no reasoning/tool output, surface
# the likely account-level cause instead of a generic tool-calling error.
if (
not llm_step_result.reasoning
and provider == LlmProviderNames.OPENAI
and is_true_openai_model(provider, model)
):
return EmptyLLMResponseError(
provider=provider,
model=model,
tool_choice=tool_choice,
client_error_msg=(
"The selected OpenAI model returned an empty streamed response "
"before producing any tokens. This commonly happens when the API "
"key or project has no remaining quota or billing is not enabled. "
"Verify quota and billing for this key and try again."
),
error_code="BUDGET_EXCEEDED",
is_retryable=False,
)
return EmptyLLMResponseError(
provider=provider,
model=model,
tool_choice=tool_choice,
client_error_msg=(
"The selected model returned no final answer before the stream "
"completed. No text or tool calls were received from the upstream "
"provider."
),
)
def _looks_like_xml_tool_call_payload(text: str | None) -> bool:
"""Detect XML-style marshaled tool calls emitted as plain text."""
if not text:
@@ -613,7 +679,12 @@ def run_llm_loop(
)
citation_processor.update_citation_mapping(project_citation_mapping)
llm_step_result: LlmStepResult | None = None
llm_step_result = LlmStepResult(
reasoning=None,
answer=None,
tool_calls=None,
raw_answer=None,
)
# Pass the total budget to construct_message_history, which will handle token allocation
available_tokens = llm.config.max_input_tokens
@@ -1084,12 +1155,18 @@ def run_llm_loop(
# As long as 1 tool with citeable documents is called at any point, we ask the LLM to try to cite
should_cite_documents = True
if not llm_step_result or not llm_step_result.answer:
if not llm_step_result.answer and not llm_step_result.tool_calls:
raise _build_empty_llm_response_error(
llm=llm,
llm_step_result=llm_step_result,
tool_choice=tool_choice,
)
if not llm_step_result.answer:
raise RuntimeError(
"The LLM did not return an answer. "
"Typically this is an issue with LLMs that do not support tool calling natively, "
"or the model serving API is not configured correctly. "
"This may also happen with models that are lower quality outputting invalid tool calls."
"The LLM did not return a final answer after tool execution. "
"Typically this indicates invalid tool-call output, a model/provider mismatch, "
"or serving API misconfiguration."
)
emitter.emit(

View File

@@ -1013,6 +1013,10 @@ def run_llm_step_pkt_generator(
accumulated_reasoning = ""
accumulated_answer = ""
accumulated_raw_answer = ""
stream_chunk_count = 0
actionable_chunk_count = 0
empty_chunk_count = 0
finish_reasons: set[str] = set()
xml_tool_call_content_filter = _XmlToolCallContentFilter()
processor_state: Any = None
@@ -1145,6 +1149,7 @@ def run_llm_step_pkt_generator(
user_identity=user_identity,
timeout_override=timeout_override,
):
stream_chunk_count += 1
if packet.usage:
usage = packet.usage
span_generation.span_data.usage = {
@@ -1154,16 +1159,21 @@ def run_llm_step_pkt_generator(
"cache_creation_input_tokens": usage.cache_creation_input_tokens,
}
# Note: LLM cost tracking is now handled in multi_llm.py
finish_reason = packet.choice.finish_reason
if finish_reason:
finish_reasons.add(str(finish_reason))
delta = packet.choice.delta
# Weird behavior from some model providers, just log and ignore for now
if (
delta.content is None
not delta.content
and delta.reasoning_content is None
and delta.tool_calls is None
and not delta.tool_calls
):
empty_chunk_count += 1
logger.warning(
f"LLM packet is empty (no contents, reasoning or tool calls). Skipping: {packet}"
"LLM packet is empty (no content, reasoning, or tool calls). "
f"finish_reason={finish_reason}. Skipping: {packet}"
)
continue
@@ -1172,6 +1182,8 @@ def run_llm_step_pkt_generator(
time.monotonic() - stream_start_time
)
first_action_recorded = True
if _delta_has_action(delta):
actionable_chunk_count += 1
if custom_token_processor:
# The custom token processor can modify the deltas for specific custom logic
@@ -1307,6 +1319,15 @@ def run_llm_step_pkt_generator(
else:
logger.debug("Tool calls: []")
if actionable_chunk_count == 0:
logger.warning(
"LLM stream completed with no actionable deltas. "
f"chunks={stream_chunk_count}, empty_chunks={empty_chunk_count}, "
f"finish_reasons={sorted(finish_reasons)}, "
f"provider={llm.config.model_provider}, model={llm.config.model_name}, "
f"tool_choice={tool_choice}, tools_sent={len(tool_definitions)}"
)
return (
LlmStepResult(
reasoning=accumulated_reasoning if accumulated_reasoning else None,

View File

@@ -29,6 +29,7 @@ from onyx.chat.compression import compress_chat_history
from onyx.chat.compression import find_summary_for_branch
from onyx.chat.compression import get_compression_params
from onyx.chat.emitter import get_default_emitter
from onyx.chat.llm_loop import EmptyLLMResponseError
from onyx.chat.llm_loop import run_llm_loop
from onyx.chat.models import AnswerStream
from onyx.chat.models import ChatBasicResponse
@@ -490,13 +491,13 @@ def handle_stream_message_objects(
# Milestone tracking, most devs using the API don't need to understand this
mt_cloud_telemetry(
tenant_id=tenant_id,
distinct_id=user.email if not user.is_anonymous else tenant_id,
distinct_id=str(user.id) if not user.is_anonymous else tenant_id,
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
)
mt_cloud_telemetry(
tenant_id=tenant_id,
distinct_id=user.email if not user.is_anonymous else tenant_id,
distinct_id=str(user.id) if not user.is_anonymous else tenant_id,
event=MilestoneRecordType.USER_MESSAGE_SENT,
properties={
"origin": new_msg_req.origin.value,
@@ -925,9 +926,28 @@ def handle_stream_message_objects(
db_session.rollback()
return
except EmptyLLMResponseError as e:
stack_trace = traceback.format_exc()
logger.warning(
"LLM returned an empty response "
f"(provider={e.provider}, model={e.model}, tool_choice={e.tool_choice})"
)
yield StreamingError(
error=e.client_error_msg,
stack_trace=stack_trace,
error_code=e.error_code,
is_retryable=e.is_retryable,
details={
"model": e.model,
"provider": e.provider,
"tool_choice": e.tool_choice.value,
},
)
db_session.rollback()
except Exception as e:
logger.exception(f"Failed to process chat message due to {e}")
error_msg = str(e)
stack_trace = traceback.format_exc()
if llm:
@@ -1046,10 +1066,46 @@ def llm_loop_completion_handle(
)
def remove_answer_citations(answer: str) -> str:
pattern = r"\s*\[\[\d+\]\]\(http[s]?://[^\s]+\)"
_CITATION_LINK_START_PATTERN = re.compile(r"\s*\[\[\d+\]\]\(")
return re.sub(pattern, "", answer)
def _find_markdown_link_end(text: str, destination_start: int) -> int | None:
depth = 0
i = destination_start
while i < len(text):
curr = text[i]
if curr == "\\":
i += 2
continue
if curr == "(":
depth += 1
elif curr == ")":
if depth == 0:
return i
depth -= 1
i += 1
return None
def remove_answer_citations(answer: str) -> str:
stripped_parts: list[str] = []
cursor = 0
while match := _CITATION_LINK_START_PATTERN.search(answer, cursor):
stripped_parts.append(answer[cursor : match.start()])
link_end = _find_markdown_link_end(answer, match.end())
if link_end is None:
stripped_parts.append(answer[match.start() :])
return "".join(stripped_parts)
cursor = link_end + 1
stripped_parts.append(answer[cursor:])
return "".join(stripped_parts)
@log_function_time()
@@ -1087,8 +1143,11 @@ def gather_stream(
raise ValueError("Message ID is required")
if answer is None:
# This should never be the case as these non-streamed flows do not have a stop-generation signal
raise RuntimeError("Answer was not generated")
if error_msg is not None:
answer = ""
else:
# This should never be the case as these non-streamed flows do not have a stop-generation signal
raise RuntimeError("Answer was not generated")
return ChatBasicResponse(
answer=answer,

View File

@@ -278,14 +278,17 @@ USING_AWS_MANAGED_OPENSEARCH = (
OPENSEARCH_PROFILING_DISABLED = (
os.environ.get("OPENSEARCH_PROFILING_DISABLED", "").lower() == "true"
)
# Whether to disable match highlights for OpenSearch. Defaults to True for now
# as we investigate query performance.
OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED = (
os.environ.get("OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED", "true").lower() == "true"
)
# When enabled, OpenSearch returns detailed score breakdowns for each hit.
# Useful for debugging and tuning search relevance. Has ~10-30% performance overhead according to documentation.
# Seems for Hybrid Search in practice, the impact is actually more like 1000x slower.
OPENSEARCH_EXPLAIN_ENABLED = (
os.environ.get("OPENSEARCH_EXPLAIN_ENABLED", "").lower() == "true"
)
# Analyzer used for full-text fields (title, content). Use OpenSearch built-in analyzer
# names (e.g. "english", "standard", "german"). Affects stemming and tokenization;
# existing indices need reindexing after a change.
@@ -318,8 +321,16 @@ VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT = (
OPENSEARCH_MIGRATION_GET_VESPA_CHUNKS_PAGE_SIZE = int(
os.environ.get("OPENSEARCH_MIGRATION_GET_VESPA_CHUNKS_PAGE_SIZE") or 500
)
OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = int(
os.environ.get("OPENSEARCH_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES") or 0
# If set, will override the default number of shards and replicas for the index.
OPENSEARCH_INDEX_NUM_SHARDS: int | None = (
int(os.environ["OPENSEARCH_INDEX_NUM_SHARDS"])
if os.environ.get("OPENSEARCH_INDEX_NUM_SHARDS", None) is not None
else None
)
OPENSEARCH_INDEX_NUM_REPLICAS: int | None = (
int(os.environ["OPENSEARCH_INDEX_NUM_REPLICAS"])
if os.environ.get("OPENSEARCH_INDEX_NUM_REPLICAS", None) is not None
else None
)
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
@@ -1046,6 +1057,8 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
HOOK_ENABLED = os.environ.get("HOOK_ENABLED", "").lower() == "true"
INTEGRATION_TESTS_MODE = os.environ.get("INTEGRATION_TESTS_MODE", "").lower() == "true"
#####

View File

@@ -177,6 +177,14 @@ USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH = 500
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
# How long a queued user-file-delete task is valid before workers discard it.
# Mirrors the processing task expiry to prevent indefinite queue growth when
# files are stuck in DELETING status and the beat keeps re-enqueuing them.
CELERY_USER_FILE_DELETE_TASK_EXPIRES = 60 # 1 minute (in seconds)
# Max queue depth before the delete beat stops enqueuing more delete tasks.
USER_FILE_DELETE_MAX_QUEUE_DEPTH = 500
CELERY_SANDBOX_FILE_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
@@ -469,6 +477,9 @@ class OnyxRedisLocks:
USER_FILE_PROJECT_SYNC_QUEUED_PREFIX = "da_lock:user_file_project_sync_queued"
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"
USER_FILE_DELETE_LOCK_PREFIX = "da_lock:user_file_delete"
# Short-lived key set when a delete task is enqueued; cleared when the worker picks it up.
# Prevents the beat from re-enqueuing the same file while a delete task is already queued.
USER_FILE_DELETE_QUEUED_PREFIX = "da_lock:user_file_delete_queued"
# Release notes
RELEASE_NOTES_FETCH_LOCK = "da_lock:release_notes_fetch"
@@ -597,6 +608,9 @@ class OnyxCeleryTask:
EXPORT_QUERY_HISTORY_TASK = "export_query_history_task"
EXPORT_QUERY_HISTORY_CLEANUP_TASK = "export_query_history_cleanup_task"
# Hook execution log retention
HOOK_EXECUTION_LOG_CLEANUP_TASK = "hook_execution_log_cleanup_task"
# Sandbox cleanup
CLEANUP_IDLE_SANDBOXES = "cleanup_idle_sandboxes"
CLEANUP_OLD_SNAPSHOTS = "cleanup_old_snapshots"

View File

@@ -157,9 +157,7 @@ def _execute_single_retrieval(
logger.error(f"Error executing request: {e}")
raise e
elif _is_rate_limit_error(e):
results = _execute_with_retry(
lambda: retrieval_function(**request_kwargs).execute()
)
results = _execute_with_retry(retrieval_function(**request_kwargs))
elif e.resp.status == 404 or e.resp.status == 403:
if continue_on_404_or_403:
logger.debug(f"Error executing request: {e}")

View File

@@ -511,7 +511,7 @@ def add_credential_to_connector(
user: User,
connector_id: int,
credential_id: int,
cc_pair_name: str | None,
cc_pair_name: str,
access_type: AccessType,
groups: list[int] | None,
auto_sync_options: dict | None = None,

View File

@@ -304,3 +304,13 @@ class LLMModelFlowType(str, PyEnum):
CHAT = "chat"
VISION = "vision"
CONTEXTUAL_RAG = "contextual_rag"
class HookPoint(str, PyEnum):
DOCUMENT_INGESTION = "document_ingestion"
QUERY_PROCESSING = "query_processing"
class HookFailStrategy(str, PyEnum):
HARD = "hard" # exception propagates, pipeline aborts
SOFT = "soft" # log error, return original input, pipeline continues

233
backend/onyx/db/hook.py Normal file
View File

@@ -0,0 +1,233 @@
import datetime
from uuid import UUID
from sqlalchemy import delete
from sqlalchemy import select
from sqlalchemy.engine import CursorResult
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.db.constants import UNSET
from onyx.db.constants import UnsetType
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.db.models import Hook
from onyx.db.models import HookExecutionLog
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
# ── Hook CRUD ────────────────────────────────────────────────────────────
def get_hook_by_id(
*,
db_session: Session,
hook_id: int,
include_deleted: bool = False,
include_creator: bool = False,
) -> Hook | None:
stmt = select(Hook).where(Hook.id == hook_id)
if not include_deleted:
stmt = stmt.where(Hook.deleted.is_(False))
if include_creator:
stmt = stmt.options(selectinload(Hook.creator))
return db_session.scalar(stmt)
def get_non_deleted_hook_by_hook_point(
*,
db_session: Session,
hook_point: HookPoint,
include_creator: bool = False,
) -> Hook | None:
stmt = (
select(Hook).where(Hook.hook_point == hook_point).where(Hook.deleted.is_(False))
)
if include_creator:
stmt = stmt.options(selectinload(Hook.creator))
return db_session.scalar(stmt)
def get_hooks(
*,
db_session: Session,
include_deleted: bool = False,
include_creator: bool = False,
) -> list[Hook]:
stmt = select(Hook)
if not include_deleted:
stmt = stmt.where(Hook.deleted.is_(False))
if include_creator:
stmt = stmt.options(selectinload(Hook.creator))
stmt = stmt.order_by(Hook.hook_point, Hook.created_at.desc())
return list(db_session.scalars(stmt).all())
def create_hook__no_commit(
*,
db_session: Session,
name: str,
hook_point: HookPoint,
endpoint_url: str | None = None,
api_key: str | None = None,
fail_strategy: HookFailStrategy,
timeout_seconds: float,
is_active: bool = False,
creator_id: UUID | None = None,
) -> Hook:
"""Create a new hook for the given hook point.
At most one non-deleted hook per hook point is allowed. Raises
OnyxError(CONFLICT) if a hook already exists, including under concurrent
duplicate creates where the partial unique index fires an IntegrityError.
"""
existing = get_non_deleted_hook_by_hook_point(
db_session=db_session, hook_point=hook_point
)
if existing:
raise OnyxError(
OnyxErrorCode.CONFLICT,
f"A hook for '{hook_point.value}' already exists (id={existing.id}).",
)
hook = Hook(
name=name,
hook_point=hook_point,
endpoint_url=endpoint_url,
api_key=api_key,
fail_strategy=fail_strategy,
timeout_seconds=timeout_seconds,
is_active=is_active,
creator_id=creator_id,
)
# Use a savepoint so that a failed insert only rolls back this operation,
# not the entire outer transaction.
savepoint = db_session.begin_nested()
try:
db_session.add(hook)
savepoint.commit()
except IntegrityError as exc:
savepoint.rollback()
if "ix_hook_one_non_deleted_per_point" in str(exc.orig):
raise OnyxError(
OnyxErrorCode.CONFLICT,
f"A hook for '{hook_point.value}' already exists.",
)
raise # re-raise unrelated integrity errors (FK violations, etc.)
return hook
def update_hook__no_commit(
*,
db_session: Session,
hook_id: int,
name: str | None = None,
endpoint_url: str | None | UnsetType = UNSET,
api_key: str | None | UnsetType = UNSET,
fail_strategy: HookFailStrategy | None = None,
timeout_seconds: float | None = None,
is_active: bool | None = None,
is_reachable: bool | None = None,
include_creator: bool = False,
) -> Hook:
"""Update hook fields.
Sentinel conventions:
- endpoint_url, api_key: pass UNSET to leave unchanged; pass None to clear.
- name, fail_strategy, timeout_seconds, is_active, is_reachable: pass None to leave unchanged.
"""
hook = get_hook_by_id(
db_session=db_session, hook_id=hook_id, include_creator=include_creator
)
if hook is None:
raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Hook with id {hook_id} not found.")
if name is not None:
hook.name = name
if not isinstance(endpoint_url, UnsetType):
hook.endpoint_url = endpoint_url
if not isinstance(api_key, UnsetType):
hook.api_key = api_key # type: ignore[assignment] # EncryptedString coerces str → SensitiveValue at the ORM level
if fail_strategy is not None:
hook.fail_strategy = fail_strategy
if timeout_seconds is not None:
hook.timeout_seconds = timeout_seconds
if is_active is not None:
hook.is_active = is_active
if is_reachable is not None:
hook.is_reachable = is_reachable
db_session.flush()
return hook
def delete_hook__no_commit(
*,
db_session: Session,
hook_id: int,
) -> None:
hook = get_hook_by_id(db_session=db_session, hook_id=hook_id)
if hook is None:
raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Hook with id {hook_id} not found.")
hook.deleted = True
hook.is_active = False
db_session.flush()
# ── HookExecutionLog CRUD ────────────────────────────────────────────────
def create_hook_execution_log__no_commit(
*,
db_session: Session,
hook_id: int,
is_success: bool,
error_message: str | None = None,
status_code: int | None = None,
duration_ms: int | None = None,
) -> HookExecutionLog:
log = HookExecutionLog(
hook_id=hook_id,
is_success=is_success,
error_message=error_message,
status_code=status_code,
duration_ms=duration_ms,
)
db_session.add(log)
db_session.flush()
return log
def get_hook_execution_logs(
*,
db_session: Session,
hook_id: int,
limit: int,
) -> list[HookExecutionLog]:
stmt = (
select(HookExecutionLog)
.where(HookExecutionLog.hook_id == hook_id)
.order_by(HookExecutionLog.created_at.desc())
.limit(limit)
)
return list(db_session.scalars(stmt).all())
def cleanup_old_execution_logs__no_commit(
*,
db_session: Session,
max_age_days: int,
) -> int:
"""Delete execution logs older than max_age_days. Returns the number of rows deleted."""
cutoff = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
days=max_age_days
)
result: CursorResult = db_session.execute( # type: ignore[assignment]
delete(HookExecutionLog)
.where(HookExecutionLog.created_at < cutoff)
.execution_options(synchronize_session=False)
)
return result.rowcount

View File

@@ -64,6 +64,8 @@ from onyx.db.enums import (
BuildSessionStatus,
EmbeddingPrecision,
HierarchyNodeType,
HookFailStrategy,
HookPoint,
IndexingMode,
OpenSearchDocumentMigrationStatus,
OpenSearchTenantMigrationStatus,
@@ -5178,3 +5180,90 @@ class CacheStore(Base):
expires_at: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
class Hook(Base):
"""Pairs a HookPoint with a customer-provided API endpoint.
At most one non-deleted Hook per HookPoint is allowed, enforced by a
partial unique index on (hook_point) where deleted=false.
"""
__tablename__ = "hook"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String, nullable=False)
hook_point: Mapped[HookPoint] = mapped_column(
Enum(HookPoint, native_enum=False), nullable=False
)
endpoint_url: Mapped[str | None] = mapped_column(Text, nullable=True)
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(
EncryptedString(), nullable=True
)
is_reachable: Mapped[bool | None] = mapped_column(
Boolean, nullable=True, default=None
) # null = never validated, true = last check passed, false = last check failed
fail_strategy: Mapped[HookFailStrategy] = mapped_column(
Enum(HookFailStrategy, native_enum=False),
nullable=False,
default=HookFailStrategy.HARD,
)
timeout_seconds: Mapped[float] = mapped_column(Float, nullable=False, default=30.0)
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
deleted: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
creator_id: Mapped[UUID | None] = mapped_column(
PGUUID(as_uuid=True),
ForeignKey("user.id", ondelete="SET NULL"),
nullable=True,
)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
updated_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
creator: Mapped["User | None"] = relationship("User", foreign_keys=[creator_id])
execution_logs: Mapped[list["HookExecutionLog"]] = relationship(
"HookExecutionLog", back_populates="hook", cascade="all, delete-orphan"
)
__table_args__ = (
Index(
"ix_hook_one_non_deleted_per_point",
"hook_point",
unique=True,
postgresql_where=(deleted == False), # noqa: E712
),
)
class HookExecutionLog(Base):
"""Records hook executions for health monitoring and debugging.
Currently only failures are logged; the is_success column exists so
success logging can be added later without a schema change.
Retention: rows older than 30 days are deleted by a nightly Celery task.
"""
__tablename__ = "hook_execution_log"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
hook_id: Mapped[int] = mapped_column(
Integer,
ForeignKey("hook.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
is_success: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
status_code: Mapped[int | None] = mapped_column(Integer, nullable=True)
duration_ms: Mapped[int | None] = mapped_column(Integer, nullable=True)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False, index=True
)
hook: Mapped["Hook"] = relationship("Hook", back_populates="execution_logs")

View File

@@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
from starlette.background import BackgroundTasks
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
@@ -144,6 +145,7 @@ def upload_files_to_user_files_with_indexing(
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
)
logger.info(
f"Triggered indexing for user_file_id={user_file.id} with task_id={task.id}"

View File

@@ -1,3 +1,4 @@
import json
import logging
import time
from contextlib import AbstractContextManager
@@ -18,6 +19,7 @@ from onyx.configs.app_configs import OPENSEARCH_HOST
from onyx.configs.app_configs import OPENSEARCH_REST_API_PORT
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.schema import DocumentChunk
from onyx.document_index.opensearch.schema import DocumentChunkWithoutVectors
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
from onyx.document_index.opensearch.search import DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW
from onyx.utils.logger import setup_logger
@@ -56,8 +58,8 @@ class SearchHit(BaseModel, Generic[SchemaDocumentModel]):
# Maps schema property name to a list of highlighted snippets with match
# terms wrapped in tags (e.g. "something <hi>keyword</hi> other thing").
match_highlights: dict[str, list[str]] = {}
# Score explanation from OpenSearch when "explain": true is set in the query.
# Contains detailed breakdown of how the score was calculated.
# Score explanation from OpenSearch when "explain": true is set in the
# query. Contains detailed breakdown of how the score was calculated.
explanation: dict[str, Any] | None = None
@@ -833,9 +835,13 @@ class OpenSearchIndexClient(OpenSearchClient):
@log_function_time(print_only=True, debug_only=True)
def search(
self, body: dict[str, Any], search_pipeline_id: str | None
) -> list[SearchHit[DocumentChunk]]:
) -> list[SearchHit[DocumentChunkWithoutVectors]]:
"""Searches the index.
NOTE: Does not return vector fields. In order to take advantage of
performance benefits, the search body should exclude the schema's vector
fields.
TODO(andrei): Ideally we could check that every field in the body is
present in the index, to avoid a class of runtime bugs that could easily
be caught during development. Or change the function signature to accept
@@ -883,7 +889,7 @@ class OpenSearchIndexClient(OpenSearchClient):
raise_on_timeout=True,
)
search_hits: list[SearchHit[DocumentChunk]] = []
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = []
for hit in hits:
document_chunk_source: dict[str, Any] | None = hit.get("_source")
if not document_chunk_source:
@@ -893,8 +899,10 @@ class OpenSearchIndexClient(OpenSearchClient):
document_chunk_score = hit.get("_score", None)
match_highlights: dict[str, list[str]] = hit.get("highlight", {})
explanation: dict[str, Any] | None = hit.get("_explanation", None)
search_hit = SearchHit[DocumentChunk](
document_chunk=DocumentChunk.model_validate(document_chunk_source),
search_hit = SearchHit[DocumentChunkWithoutVectors](
document_chunk=DocumentChunkWithoutVectors.model_validate(
document_chunk_source
),
score=document_chunk_score,
match_highlights=match_highlights,
explanation=explanation,
@@ -1055,7 +1063,7 @@ class OpenSearchIndexClient(OpenSearchClient):
f"Body: {get_new_body_without_vectors(body)}\n"
f"Search pipeline ID: {search_pipeline_id}\n"
f"Phase took: {phase_took}\n"
f"Profile: {profile}\n"
f"Profile: {json.dumps(profile, indent=2)}\n"
)
if timed_out:
error_str = f"OpenSearch client error: Search timed out for index {self._index_name}."

View File

@@ -1,12 +1,23 @@
# Default value for the maximum number of tokens a chunk can hold, if none is
# specified when creating an index.
from onyx.configs.app_configs import (
OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
)
import os
from enum import Enum
DEFAULT_MAX_CHUNK_SIZE = 512
# By default OpenSearch will only return a maximum of this many results in a
# given search. This value is configurable in the index settings.
DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW = 10_000
# For documents which do not have a value for LAST_UPDATED_FIELD_NAME, we assume
# that the document was last updated this many days ago for the purpose of time
# cutoff filtering during retrieval.
ASSUMED_DOCUMENT_AGE_DAYS = 90
# Size of the dynamic list used to consider elements during kNN graph creation.
# Higher values improve search quality but increase indexing time. Values
# typically range between 100 - 512.
@@ -26,10 +37,10 @@ M = 32 # Set relatively high for better accuracy.
# we have a much higher chance of all 10 of the final desired docs showing up
# and getting scored. In worse situations, the final 10 docs don't even show up
# as the final 10 (worse than just a miss at the reranking step).
DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = (
OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES
if OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES > 0
else 750
# Defaults to 100 for now. Initially this defaulted to 750 but we were seeing
# poor search performance.
DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES = int(
os.environ.get("DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES", 100)
)
# Number of vectors to examine to decide the top k neighbors for the HNSW
@@ -39,23 +50,43 @@ DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = (
# larger than k, you can provide the size parameter to limit the final number of
# results to k." from
# https://docs.opensearch.org/latest/query-dsl/specialized/k-nn/index/#ef_search
EF_SEARCH = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES
EF_SEARCH = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES
# Since the titles are included in the contents, the embedding matches are
# heavily downweighted as they act as a boost rather than an independent scoring
# component.
SEARCH_TITLE_VECTOR_WEIGHT = 0.1
SEARCH_CONTENT_VECTOR_WEIGHT = 0.45
# Single keyword weight for both title and content (merged from former title
# keyword + content keyword).
SEARCH_KEYWORD_WEIGHT = 0.45
# NOTE: It is critical that the order of these weights matches the order of the
# sub-queries in the hybrid search.
HYBRID_SEARCH_NORMALIZATION_WEIGHTS = [
SEARCH_TITLE_VECTOR_WEIGHT,
SEARCH_CONTENT_VECTOR_WEIGHT,
SEARCH_KEYWORD_WEIGHT,
]
class HybridSearchSubqueryConfiguration(Enum):
TITLE_VECTOR_CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD = 1
# Current default.
CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD = 2
assert sum(HYBRID_SEARCH_NORMALIZATION_WEIGHTS) == 1.0
# Will raise and block application start if HYBRID_SEARCH_SUBQUERY_CONFIGURATION
# is set but not a valid value. If not set, defaults to
# CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD.
HYBRID_SEARCH_SUBQUERY_CONFIGURATION: HybridSearchSubqueryConfiguration = (
HybridSearchSubqueryConfiguration(
int(os.environ["HYBRID_SEARCH_SUBQUERY_CONFIGURATION"])
)
if os.environ.get("HYBRID_SEARCH_SUBQUERY_CONFIGURATION", None) is not None
else HybridSearchSubqueryConfiguration.CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD
)
class HybridSearchNormalizationPipeline(Enum):
# Current default.
MIN_MAX = 1
# NOTE: Using z-score normalization is better for hybrid search from a
# theoretical standpoint. Empirically on a small dataset of up to 10K docs,
# it's not very different. Likely more impactful at scale.
# https://opensearch.org/blog/introducing-the-z-score-normalization-technique-for-hybrid-search/
ZSCORE = 2
# Will raise and block application start if HYBRID_SEARCH_NORMALIZATION_PIPELINE
# is set but not a valid value. If not set, defaults to MIN_MAX.
HYBRID_SEARCH_NORMALIZATION_PIPELINE: HybridSearchNormalizationPipeline = (
HybridSearchNormalizationPipeline(
int(os.environ["HYBRID_SEARCH_NORMALIZATION_PIPELINE"])
)
if os.environ.get("HYBRID_SEARCH_NORMALIZATION_PIPELINE", None) is not None
else HybridSearchNormalizationPipeline.MIN_MAX
)

View File

@@ -47,6 +47,7 @@ from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
from onyx.document_index.opensearch.schema import CONTENT_FIELD_NAME
from onyx.document_index.opensearch.schema import DOCUMENT_SETS_FIELD_NAME
from onyx.document_index.opensearch.schema import DocumentChunk
from onyx.document_index.opensearch.schema import DocumentChunkWithoutVectors
from onyx.document_index.opensearch.schema import DocumentSchema
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
from onyx.document_index.opensearch.schema import GLOBAL_BOOST_FIELD_NAME
@@ -55,16 +56,13 @@ from onyx.document_index.opensearch.schema import PERSONAS_FIELD_NAME
from onyx.document_index.opensearch.schema import USER_PROJECTS_FIELD_NAME
from onyx.document_index.opensearch.search import DocumentQuery
from onyx.document_index.opensearch.search import (
MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
get_min_max_normalization_pipeline_name_and_config,
)
from onyx.document_index.opensearch.search import (
MIN_MAX_NORMALIZATION_PIPELINE_NAME,
get_normalization_pipeline_name_and_config,
)
from onyx.document_index.opensearch.search import (
ZSCORE_NORMALIZATION_PIPELINE_CONFIG,
)
from onyx.document_index.opensearch.search import (
ZSCORE_NORMALIZATION_PIPELINE_NAME,
get_zscore_normalization_pipeline_name_and_config,
)
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import Document
@@ -103,18 +101,24 @@ def set_cluster_state(client: OpenSearchClient) -> None:
"is not the first time running Onyx against this instance of OpenSearch, these "
"settings have likely already been set. Not taking any further action..."
)
client.create_search_pipeline(
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
pipeline_body=MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
min_max_normalization_pipeline_name, min_max_normalization_pipeline_config = (
get_min_max_normalization_pipeline_name_and_config()
)
zscore_normalization_pipeline_name, zscore_normalization_pipeline_config = (
get_zscore_normalization_pipeline_name_and_config()
)
client.create_search_pipeline(
pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
pipeline_body=ZSCORE_NORMALIZATION_PIPELINE_CONFIG,
pipeline_id=min_max_normalization_pipeline_name,
pipeline_body=min_max_normalization_pipeline_config,
)
client.create_search_pipeline(
pipeline_id=zscore_normalization_pipeline_name,
pipeline_body=zscore_normalization_pipeline_config,
)
def _convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
chunk: DocumentChunk,
chunk: DocumentChunkWithoutVectors,
score: float | None,
highlights: dict[str, list[str]],
) -> InferenceChunkUncleaned:
@@ -877,7 +881,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
)
results: list[InferenceChunk] = []
for chunk_request in chunk_requests:
search_hits: list[SearchHit[DocumentChunk]] = []
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = []
query_body = DocumentQuery.get_from_document_id_query(
document_id=chunk_request.document_id,
tenant_state=self._tenant_state,
@@ -940,17 +944,92 @@ class OpenSearchDocumentIndex(DocumentIndex):
index_filters=filters,
include_hidden=False,
)
# NOTE: Using z-score normalization here because it's better for hybrid
# search from a theoretical standpoint. Empirically on a small dataset
# of up to 10K docs, it's not very different. Likely more impactful at
# scale.
# https://opensearch.org/blog/introducing-the-z-score-normalization-technique-for-hybrid-search/
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
normalization_pipeline_name, _ = get_normalization_pipeline_name_and_config()
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
body=query_body,
search_pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
search_pipeline_id=normalization_pipeline_name,
)
# Good place for a breakpoint to inspect the search hits if you have
# "explain" enabled.
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
search_hit.document_chunk, search_hit.score, search_hit.match_highlights
)
for search_hit in search_hits
]
inference_chunks: list[InferenceChunk] = cleanup_content_for_chunks(
inference_chunks_uncleaned
)
return inference_chunks
def keyword_retrieval(
self,
query: str,
filters: IndexFilters,
num_to_retrieve: int,
) -> list[InferenceChunk]:
logger.debug(
f"[OpenSearchDocumentIndex] Keyword retrieving {num_to_retrieve} chunks for index {self._index_name}."
)
query_body = DocumentQuery.get_keyword_search_query(
query_text=query,
num_hits=num_to_retrieve,
tenant_state=self._tenant_state,
# NOTE: Index filters includes metadata tags which were filtered
# for invalid unicode at indexing time. In theory it would be
# ideal to do filtering here as well, in practice we never did
# that in the Vespa codepath and have not seen issues in
# production, so we deliberately conform to the existing logic
# in order to not unknowningly introduce a possible bug.
index_filters=filters,
include_hidden=False,
)
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
body=query_body,
search_pipeline_id=None,
)
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
search_hit.document_chunk, search_hit.score, search_hit.match_highlights
)
for search_hit in search_hits
]
inference_chunks: list[InferenceChunk] = cleanup_content_for_chunks(
inference_chunks_uncleaned
)
return inference_chunks
def semantic_retrieval(
self,
query_embedding: Embedding,
filters: IndexFilters,
num_to_retrieve: int,
) -> list[InferenceChunk]:
logger.debug(
f"[OpenSearchDocumentIndex] Semantic retrieving {num_to_retrieve} chunks for index {self._index_name}."
)
query_body = DocumentQuery.get_semantic_search_query(
query_embedding=query_embedding,
num_hits=num_to_retrieve,
tenant_state=self._tenant_state,
# NOTE: Index filters includes metadata tags which were filtered
# for invalid unicode at indexing time. In theory it would be
# ideal to do filtering here as well, in practice we never did
# that in the Vespa codepath and have not seen issues in
# production, so we deliberately conform to the existing logic
# in order to not unknowningly introduce a possible bug.
index_filters=filters,
include_hidden=False,
)
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
body=query_body,
search_pipeline_id=None,
)
# Good place for a breakpoint to inspect the search hits if you have "explain" enabled.
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
search_hit.document_chunk, search_hit.score, search_hit.match_highlights
@@ -977,7 +1056,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
index_filters=filters,
num_to_retrieve=num_to_retrieve,
)
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
body=query_body,
search_pipeline_id=None,
)

View File

@@ -11,6 +11,8 @@ from pydantic import model_serializer
from pydantic import model_validator
from pydantic import SerializerFunctionWrapHandler
from onyx.configs.app_configs import OPENSEARCH_INDEX_NUM_REPLICAS
from onyx.configs.app_configs import OPENSEARCH_INDEX_NUM_SHARDS
from onyx.configs.app_configs import OPENSEARCH_TEXT_ANALYZER
from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH
from onyx.document_index.interfaces_new import TenantState
@@ -100,9 +102,9 @@ def set_or_convert_timezone_to_utc(value: datetime) -> datetime:
return value
class DocumentChunk(BaseModel):
class DocumentChunkWithoutVectors(BaseModel):
"""
Represents a chunk of a document in the OpenSearch index.
Represents a chunk of a document in the OpenSearch index without vectors.
The names of these fields are based on the OpenSearch schema. Changes to the
schema require changes here. See get_document_schema.
@@ -124,9 +126,7 @@ class DocumentChunk(BaseModel):
# Either both should be None or both should be non-None.
title: str | None = None
title_vector: list[float] | None = None
content: str
content_vector: list[float]
source_type: str
# A list of key-value pairs separated by INDEX_SEPARATOR. See
@@ -176,19 +176,9 @@ class DocumentChunk(BaseModel):
def __str__(self) -> str:
return (
f"DocumentChunk(document_id={self.document_id}, chunk_index={self.chunk_index}, "
f"content length={len(self.content)}, content vector length={len(self.content_vector)}, "
f"tenant_id={self.tenant_id.tenant_id})"
f"content length={len(self.content)}, tenant_id={self.tenant_id.tenant_id})."
)
@model_validator(mode="after")
def check_title_and_title_vector_are_consistent(self) -> Self:
# title and title_vector should both either be None or not.
if self.title is not None and self.title_vector is None:
raise ValueError("Bug: Title vector must not be None if title is not None.")
if self.title_vector is not None and self.title is None:
raise ValueError("Bug: Title must not be None if title vector is not None.")
return self
@model_serializer(mode="wrap")
def serialize_model(
self, handler: SerializerFunctionWrapHandler
@@ -305,6 +295,35 @@ class DocumentChunk(BaseModel):
return TenantState(tenant_id=value, multitenant=MULTI_TENANT)
class DocumentChunk(DocumentChunkWithoutVectors):
"""Represents a chunk of a document in the OpenSearch index.
The names of these fields are based on the OpenSearch schema. Changes to the
schema require changes here. See get_document_schema.
"""
model_config = {"frozen": True}
title_vector: list[float] | None = None
content_vector: list[float]
def __str__(self) -> str:
return (
f"DocumentChunk(document_id={self.document_id}, chunk_index={self.chunk_index}, "
f"content length={len(self.content)}, content vector length={len(self.content_vector)}, "
f"tenant_id={self.tenant_id.tenant_id})"
)
@model_validator(mode="after")
def check_title_and_title_vector_are_consistent(self) -> Self:
# title and title_vector should both either be None or not.
if self.title is not None and self.title_vector is None:
raise ValueError("Bug: Title vector must not be None if title is not None.")
if self.title_vector is not None and self.title is None:
raise ValueError("Bug: Title must not be None if title vector is not None.")
return self
class DocumentSchema:
"""
Represents the schema and indexing strategies of the OpenSearch index.
@@ -516,78 +535,35 @@ class DocumentSchema:
return schema
@staticmethod
def get_index_settings() -> dict[str, Any]:
"""
Standard settings for reasonable local index and search performance.
"""
return {
"index": {
"number_of_shards": 1,
"number_of_replicas": 1,
# Required for vector search.
"knn": True,
"knn.algo_param.ef_search": EF_SEARCH,
}
}
@staticmethod
def get_index_settings_for_aws_managed_opensearch_st_dev() -> dict[str, Any]:
"""
Settings for AWS-managed OpenSearch.
Our AWS-managed OpenSearch cluster has 3 data nodes in 3 availability
zones.
- We use 3 shards to distribute load across all data nodes.
- We use 2 replicas to ensure each shard has a copy in each
availability zone. This is a hard requirement from AWS. The number
of data copies, including the primary (not a replica) copy, must be
divisible by the number of AZs.
"""
return {
"index": {
"number_of_shards": 3,
"number_of_replicas": 2,
# Required for vector search.
"knn": True,
"knn.algo_param.ef_search": EF_SEARCH,
}
}
@staticmethod
def get_index_settings_for_aws_managed_opensearch_mt_cloud() -> dict[str, Any]:
"""
Settings for AWS-managed OpenSearch in multi-tenant cloud.
324 shards very roughly targets a storage load of ~30Gb per shard, which
according to AWS OpenSearch documentation is within a good target range.
As documented above we need 2 replicas for a total of 3 copies of the
data because the cluster is configured with 3-AZ awareness.
"""
return {
"index": {
"number_of_shards": 324,
"number_of_replicas": 2,
# Required for vector search.
"knn": True,
"knn.algo_param.ef_search": EF_SEARCH,
}
}
@staticmethod
def get_index_settings_based_on_environment() -> dict[str, Any]:
"""
Returns the index settings based on the environment.
"""
if USING_AWS_MANAGED_OPENSEARCH:
# NOTE: The number of data copies, including the primary (not a
# replica) copy, must be divisible by the number of AZs.
if MULTI_TENANT:
return (
DocumentSchema.get_index_settings_for_aws_managed_opensearch_mt_cloud()
)
number_of_shards = 324
number_of_replicas = 2
else:
return (
DocumentSchema.get_index_settings_for_aws_managed_opensearch_st_dev()
)
number_of_shards = 3
number_of_replicas = 2
else:
return DocumentSchema.get_index_settings()
number_of_shards = 1
number_of_replicas = 1
if OPENSEARCH_INDEX_NUM_SHARDS is not None:
number_of_shards = OPENSEARCH_INDEX_NUM_SHARDS
if OPENSEARCH_INDEX_NUM_REPLICAS is not None:
number_of_replicas = OPENSEARCH_INDEX_NUM_REPLICAS
return {
"index": {
"number_of_shards": number_of_shards,
"number_of_replicas": number_of_replicas,
# Required for vector search.
"knn": True,
"knn.algo_param.ef_search": EF_SEARCH,
}
}

View File

@@ -7,16 +7,28 @@ from uuid import UUID
from onyx.configs.app_configs import DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S
from onyx.configs.app_configs import OPENSEARCH_EXPLAIN_ENABLED
from onyx.configs.app_configs import OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED
from onyx.configs.app_configs import OPENSEARCH_PROFILING_DISABLED
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import INDEX_SEPARATOR
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import Tag
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.constants import ASSUMED_DOCUMENT_AGE_DAYS
from onyx.document_index.opensearch.constants import (
DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES,
)
from onyx.document_index.opensearch.constants import HYBRID_SEARCH_NORMALIZATION_WEIGHTS
from onyx.document_index.opensearch.constants import (
DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW,
)
from onyx.document_index.opensearch.constants import (
HYBRID_SEARCH_NORMALIZATION_PIPELINE,
)
from onyx.document_index.opensearch.constants import (
HYBRID_SEARCH_SUBQUERY_CONFIGURATION,
)
from onyx.document_index.opensearch.constants import HybridSearchNormalizationPipeline
from onyx.document_index.opensearch.constants import HybridSearchSubqueryConfiguration
from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
from onyx.document_index.opensearch.schema import ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME
from onyx.document_index.opensearch.schema import CHUNK_INDEX_FIELD_NAME
@@ -43,49 +55,113 @@ from onyx.document_index.opensearch.schema import USER_PROJECTS_FIELD_NAME
# TODO(andrei): Turn all magic dictionaries to pydantic models.
MIN_MAX_NORMALIZATION_PIPELINE_NAME = "normalization_pipeline_min_max"
MIN_MAX_NORMALIZATION_PIPELINE_CONFIG: dict[str, Any] = {
"description": "Normalization for keyword and vector scores using min-max",
"phase_results_processors": [
{
# https://docs.opensearch.org/latest/search-plugins/search-pipelines/normalization-processor/
"normalization-processor": {
"normalization": {"technique": "min_max"},
"combination": {
"technique": "arithmetic_mean",
"parameters": {"weights": HYBRID_SEARCH_NORMALIZATION_WEIGHTS},
},
def _get_hybrid_search_normalization_weights() -> list[float]:
if (
HYBRID_SEARCH_SUBQUERY_CONFIGURATION
is HybridSearchSubqueryConfiguration.TITLE_VECTOR_CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD
):
# Since the titles are included in the contents, the embedding matches
# are heavily downweighted as they act as a boost rather than an
# independent scoring component.
search_title_vector_weight = 0.1
search_content_vector_weight = 0.45
# Single keyword weight for both title and content (merged from former
# title keyword + content keyword).
search_keyword_weight = 0.45
# NOTE: It is critical that the order of these weights matches the order
# of the sub-queries in the hybrid search.
hybrid_search_normalization_weights = [
search_title_vector_weight,
search_content_vector_weight,
search_keyword_weight,
]
elif (
HYBRID_SEARCH_SUBQUERY_CONFIGURATION
is HybridSearchSubqueryConfiguration.CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD
):
search_content_vector_weight = 0.5
# Single keyword weight for both title and content (merged from former
# title keyword + content keyword).
search_keyword_weight = 0.5
# NOTE: It is critical that the order of these weights matches the order
# of the sub-queries in the hybrid search.
hybrid_search_normalization_weights = [
search_content_vector_weight,
search_keyword_weight,
]
else:
raise ValueError(
f"Bug: Unhandled hybrid search subquery configuration: {HYBRID_SEARCH_SUBQUERY_CONFIGURATION}."
)
assert (
sum(hybrid_search_normalization_weights) == 1.0
), "Bug: Hybrid search normalization weights do not sum to 1.0."
return hybrid_search_normalization_weights
def get_min_max_normalization_pipeline_name_and_config() -> tuple[str, dict[str, Any]]:
min_max_normalization_pipeline_name = "normalization_pipeline_min_max"
min_max_normalization_pipeline_config: dict[str, Any] = {
"description": "Normalization for keyword and vector scores using min-max",
"phase_results_processors": [
{
# https://docs.opensearch.org/latest/search-plugins/search-pipelines/normalization-processor/
"normalization-processor": {
"normalization": {"technique": "min_max"},
"combination": {
"technique": "arithmetic_mean",
"parameters": {
"weights": _get_hybrid_search_normalization_weights()
},
},
}
}
}
],
}
],
}
return min_max_normalization_pipeline_name, min_max_normalization_pipeline_config
ZSCORE_NORMALIZATION_PIPELINE_NAME = "normalization_pipeline_zscore"
ZSCORE_NORMALIZATION_PIPELINE_CONFIG: dict[str, Any] = {
"description": "Normalization for keyword and vector scores using z-score",
"phase_results_processors": [
{
# https://docs.opensearch.org/latest/search-plugins/search-pipelines/normalization-processor/
"normalization-processor": {
"normalization": {"technique": "z_score"},
"combination": {
"technique": "arithmetic_mean",
"parameters": {"weights": HYBRID_SEARCH_NORMALIZATION_WEIGHTS},
},
def get_zscore_normalization_pipeline_name_and_config() -> tuple[str, dict[str, Any]]:
zscore_normalization_pipeline_name = "normalization_pipeline_zscore"
zscore_normalization_pipeline_config: dict[str, Any] = {
"description": "Normalization for keyword and vector scores using z-score",
"phase_results_processors": [
{
# https://docs.opensearch.org/latest/search-plugins/search-pipelines/normalization-processor/
"normalization-processor": {
"normalization": {"technique": "z_score"},
"combination": {
"technique": "arithmetic_mean",
"parameters": {
"weights": _get_hybrid_search_normalization_weights()
},
},
}
}
}
],
}
],
}
return zscore_normalization_pipeline_name, zscore_normalization_pipeline_config
# By default OpenSearch will only return a maximum of this many results in a
# given search. This value is configurable in the index settings.
DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW = 10_000
# For documents which do not have a value for LAST_UPDATED_FIELD_NAME, we assume
# that the document was last updated this many days ago for the purpose of time
# cutoff filtering during retrieval.
ASSUMED_DOCUMENT_AGE_DAYS = 90
def get_normalization_pipeline_name_and_config() -> tuple[str, dict[str, Any]]:
if (
HYBRID_SEARCH_NORMALIZATION_PIPELINE
is HybridSearchNormalizationPipeline.MIN_MAX
):
return get_min_max_normalization_pipeline_name_and_config()
elif (
HYBRID_SEARCH_NORMALIZATION_PIPELINE is HybridSearchNormalizationPipeline.ZSCORE
):
return get_zscore_normalization_pipeline_name_and_config()
else:
raise ValueError(
f"Bug: Unhandled hybrid search normalization pipeline: {HYBRID_SEARCH_NORMALIZATION_PIPELINE}."
)
class DocumentQuery:
@@ -160,9 +236,17 @@ class DocumentQuery:
# returning some number of results less than the index max allowed
# return size.
"size": DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW,
"_source": get_full_document,
# By default exclude retrieving the vector fields in order to save
# on retrieval cost as we don't need them upstream.
"_source": {
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
},
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
}
if not get_full_document:
# If we explicitly do not want the underlying document, we will only
# retrieve IDs.
final_get_ids_query["_source"] = False
if not OPENSEARCH_PROFILING_DISABLED:
final_get_ids_query["profile"] = True
@@ -257,7 +341,7 @@ class DocumentQuery:
# TODO(andrei, yuhong): We can tune this more dynamically based on
# num_hits.
max_results_per_subquery = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES
max_results_per_subquery = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES
hybrid_search_subqueries = DocumentQuery._get_hybrid_search_subqueries(
query_text, query_vector, vector_candidates=max_results_per_subquery
@@ -281,9 +365,6 @@ class DocumentQuery:
attached_document_ids=index_filters.attached_document_ids,
hierarchy_node_ids=index_filters.hierarchy_node_ids,
)
match_highlights_configuration = (
DocumentQuery._get_match_highlights_configuration()
)
# See https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
hybrid_search_query: dict[str, Any] = {
@@ -310,16 +391,183 @@ class DocumentQuery:
final_hybrid_search_body: dict[str, Any] = {
"query": hybrid_search_query,
"size": num_hits,
"highlight": match_highlights_configuration,
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
# Exclude retrieving the vector fields in order to save on
# retrieval cost as we don't need them upstream.
"_source": {
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
},
}
# Explain is for scoring breakdowns.
if not OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED:
final_hybrid_search_body["highlight"] = (
DocumentQuery._get_match_highlights_configuration()
)
# Explain is for scoring breakdowns. Setting this significantly
# increases query latency.
if OPENSEARCH_EXPLAIN_ENABLED:
final_hybrid_search_body["explain"] = True
return final_hybrid_search_body
@staticmethod
def get_keyword_search_query(
query_text: str,
num_hits: int,
tenant_state: TenantState,
index_filters: IndexFilters,
include_hidden: bool,
) -> dict[str, Any]:
"""Returns a final keyword search query.
This query can be directly supplied to the OpenSearch client.
Args:
query_text: The text to query for.
num_hits: The final number of hits to return.
tenant_state: Tenant state containing the tenant ID.
index_filters: Filters for the keyword search query.
include_hidden: Whether to include hidden documents.
Returns:
A dictionary representing the final keyword search query.
"""
if num_hits > DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW:
raise ValueError(
f"Bug: num_hits ({num_hits}) is greater than the current maximum allowed "
f"result window ({DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW})."
)
keyword_search_filters = DocumentQuery._get_search_filters(
tenant_state=tenant_state,
include_hidden=include_hidden,
# TODO(andrei): We've done no filtering for PUBLIC_DOC_PAT up to
# now. This should not cause any issues but it can introduce
# redundant filters in queries that may affect performance.
access_control_list=index_filters.access_control_list,
source_types=index_filters.source_type or [],
tags=index_filters.tags or [],
document_sets=index_filters.document_set or [],
user_file_ids=index_filters.user_file_ids or [],
project_id=index_filters.project_id,
persona_id=index_filters.persona_id,
time_cutoff=index_filters.time_cutoff,
min_chunk_index=None,
max_chunk_index=None,
attached_document_ids=index_filters.attached_document_ids,
hierarchy_node_ids=index_filters.hierarchy_node_ids,
)
keyword_search_query = (
DocumentQuery._get_title_content_combined_keyword_search_query(
query_text, search_filters=keyword_search_filters
)
)
final_keyword_search_query: dict[str, Any] = {
"query": keyword_search_query,
"size": num_hits,
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
# Exclude retrieving the vector fields in order to save on
# retrieval cost as we don't need them upstream.
"_source": {
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
},
}
if not OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED:
final_keyword_search_query["highlight"] = (
DocumentQuery._get_match_highlights_configuration()
)
if not OPENSEARCH_PROFILING_DISABLED:
final_keyword_search_query["profile"] = True
# Explain is for scoring breakdowns. Setting this significantly
# increases query latency.
if OPENSEARCH_EXPLAIN_ENABLED:
final_keyword_search_query["explain"] = True
return final_keyword_search_query
@staticmethod
def get_semantic_search_query(
query_embedding: list[float],
num_hits: int,
tenant_state: TenantState,
index_filters: IndexFilters,
include_hidden: bool,
) -> dict[str, Any]:
"""Returns a final semantic search query.
This query can be directly supplied to the OpenSearch client.
Args:
query_embedding: The vector embedding of the text to query for.
num_hits: The final number of hits to return.
tenant_state: Tenant state containing the tenant ID.
index_filters: Filters for the semantic search query.
include_hidden: Whether to include hidden documents.
Returns:
A dictionary representing the final semantic search query.
"""
if num_hits > DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW:
raise ValueError(
f"Bug: num_hits ({num_hits}) is greater than the current maximum allowed "
f"result window ({DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW})."
)
semantic_search_filters = DocumentQuery._get_search_filters(
tenant_state=tenant_state,
include_hidden=include_hidden,
# TODO(andrei): We've done no filtering for PUBLIC_DOC_PAT up to
# now. This should not cause any issues but it can introduce
# redundant filters in queries that may affect performance.
access_control_list=index_filters.access_control_list,
source_types=index_filters.source_type or [],
tags=index_filters.tags or [],
document_sets=index_filters.document_set or [],
user_file_ids=index_filters.user_file_ids or [],
project_id=index_filters.project_id,
persona_id=index_filters.persona_id,
time_cutoff=index_filters.time_cutoff,
min_chunk_index=None,
max_chunk_index=None,
attached_document_ids=index_filters.attached_document_ids,
hierarchy_node_ids=index_filters.hierarchy_node_ids,
)
semantic_search_query = (
DocumentQuery._get_content_vector_similarity_search_query(
query_embedding,
vector_candidates=num_hits,
search_filters=semantic_search_filters,
)
)
final_semantic_search_query: dict[str, Any] = {
"query": semantic_search_query,
"size": num_hits,
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
# Exclude retrieving the vector fields in order to save on
# retrieval cost as we don't need them upstream.
"_source": {
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
},
}
if not OPENSEARCH_PROFILING_DISABLED:
final_semantic_search_query["profile"] = True
# Explain is for scoring breakdowns. Setting this significantly
# increases query latency.
if OPENSEARCH_EXPLAIN_ENABLED:
final_semantic_search_query["explain"] = True
return final_semantic_search_query
@staticmethod
def get_random_search_query(
tenant_state: TenantState,
@@ -371,6 +619,11 @@ class DocumentQuery:
},
"size": num_to_retrieve,
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
# Exclude retrieving the vector fields in order to save on
# retrieval cost as we don't need them upstream.
"_source": {
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
},
}
if not OPENSEARCH_PROFILING_DISABLED:
final_random_search_query["profile"] = True
@@ -385,7 +638,7 @@ class DocumentQuery:
# search. This is higher than the number of results because the scoring
# is hybrid. For a detailed breakdown, see where the default value is
# set.
vector_candidates: int = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
vector_candidates: int = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES,
) -> list[dict[str, Any]]:
"""Returns subqueries for hybrid search.
@@ -395,20 +648,18 @@ class DocumentQuery:
The return of this function is not sufficient to be directly supplied to
the OpenSearch client. See get_hybrid_search_query.
Matches:
- Title vector
- Content vector
- Keyword (title + content, match and phrase)
Normalization is not performed here.
The weights of each of these subqueries should be configured in a search
pipeline.
The exact subqueries executed depend on the
HYBRID_SEARCH_SUBQUERY_CONFIGURATION setting.
NOTE: For OpenSearch, 5 is the maximum number of query clauses allowed
in a single hybrid query. Source:
https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
NOTE: Each query is independent during the search phase, there is no
NOTE: Each query is independent during the search phase; there is no
backfilling of scores for missing query components. What this means is
that if a document was a good vector match but did not show up for
keyword, it gets a score of 0 for the keyword component of the hybrid
@@ -437,74 +688,133 @@ class DocumentQuery:
similarity search.
"""
# Build sub-queries for hybrid search. Order must match normalization
# pipeline weights: title vector, content vector, keyword (title + content).
hybrid_search_queries: list[dict[str, Any]] = [
# 1. Title vector search
{
"knn": {
TITLE_VECTOR_FIELD_NAME: {
"vector": query_vector,
"k": vector_candidates,
}
}
},
# 2. Content vector search
{
"knn": {
CONTENT_VECTOR_FIELD_NAME: {
"vector": query_vector,
"k": vector_candidates,
}
}
},
# 3. Keyword (title + content) match and phrase search.
{
"bool": {
"should": [
{
"match": {
TITLE_FIELD_NAME: {
"query": query_text,
"operator": "or",
# The title fields are strongly discounted as they are included in the content.
# It just acts as a minor boost
"boost": 0.1,
}
}
},
{
"match_phrase": {
TITLE_FIELD_NAME: {
"query": query_text,
"slop": 1,
"boost": 0.2,
}
}
},
{
"match": {
CONTENT_FIELD_NAME: {
"query": query_text,
"operator": "or",
"boost": 1.0,
}
}
},
{
"match_phrase": {
CONTENT_FIELD_NAME: {
"query": query_text,
"slop": 1,
"boost": 1.5,
}
}
},
]
}
},
]
# pipeline weights.
if (
HYBRID_SEARCH_SUBQUERY_CONFIGURATION
is HybridSearchSubqueryConfiguration.TITLE_VECTOR_CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD
):
return [
DocumentQuery._get_title_vector_similarity_search_query(
query_vector, vector_candidates
),
DocumentQuery._get_content_vector_similarity_search_query(
query_vector, vector_candidates
),
DocumentQuery._get_title_content_combined_keyword_search_query(
query_text
),
]
elif (
HYBRID_SEARCH_SUBQUERY_CONFIGURATION
is HybridSearchSubqueryConfiguration.CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD
):
return [
DocumentQuery._get_content_vector_similarity_search_query(
query_vector, vector_candidates
),
DocumentQuery._get_title_content_combined_keyword_search_query(
query_text
),
]
else:
raise ValueError(
f"Bug: Unhandled hybrid search subquery configuration: {HYBRID_SEARCH_SUBQUERY_CONFIGURATION}"
)
return hybrid_search_queries
@staticmethod
def _get_title_vector_similarity_search_query(
query_vector: list[float],
vector_candidates: int = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES,
) -> dict[str, Any]:
return {
"knn": {
TITLE_VECTOR_FIELD_NAME: {
"vector": query_vector,
"k": vector_candidates,
}
}
}
@staticmethod
def _get_content_vector_similarity_search_query(
query_vector: list[float],
vector_candidates: int = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES,
search_filters: list[dict[str, Any]] | None = None,
) -> dict[str, Any]:
query = {
"knn": {
CONTENT_VECTOR_FIELD_NAME: {
"vector": query_vector,
"k": vector_candidates,
}
}
}
if search_filters is not None:
query["knn"][CONTENT_VECTOR_FIELD_NAME]["filter"] = {
"bool": {"filter": search_filters}
}
return query
@staticmethod
def _get_title_content_combined_keyword_search_query(
query_text: str,
search_filters: list[dict[str, Any]] | None = None,
) -> dict[str, Any]:
query = {
"bool": {
"should": [
{
"match": {
TITLE_FIELD_NAME: {
"query": query_text,
"operator": "or",
# The title fields are strongly discounted as they are included in the content.
# It just acts as a minor boost
"boost": 0.1,
}
}
},
{
"match_phrase": {
TITLE_FIELD_NAME: {
"query": query_text,
"slop": 1,
"boost": 0.2,
}
}
},
{
"match": {
CONTENT_FIELD_NAME: {
"query": query_text,
"operator": "or",
"boost": 1.0,
}
}
},
{
"match_phrase": {
CONTENT_FIELD_NAME: {
"query": query_text,
"slop": 1,
"boost": 1.5,
}
}
},
],
# Ensure at least one term from the query is present in the
# document. This defaults to 1, unless a filter or must clause
# is supplied, in which case it defaults to 0.
"minimum_should_match": 1,
}
}
if search_filters is not None:
query["bool"]["filter"] = search_filters
return query
@staticmethod
def _get_search_filters(

View File

@@ -501,20 +501,31 @@ def query_vespa(
response = http_client.post(SEARCH_ENDPOINT, json=params)
response.raise_for_status()
except httpx.HTTPError as e:
error_base = "Failed to query Vespa"
logger.error(
f"{error_base}:\n"
f"Request URL: {e.request.url}\n"
f"Request Headers: {e.request.headers}\n"
f"Request Payload: {params}\n"
f"Exception: {str(e)}"
+ (
f"\nResponse: {e.response.text}"
if isinstance(e, httpx.HTTPStatusError)
else ""
)
response_text = (
e.response.text if isinstance(e, httpx.HTTPStatusError) else None
)
raise httpx.HTTPError(error_base) from e
status_code = (
e.response.status_code if isinstance(e, httpx.HTTPStatusError) else None
)
yql_value = params.get("yql", "")
yql_length = len(str(yql_value))
# Log each detail on its own line so log collectors capture them
# as separate entries rather than truncating a single multiline msg
logger.error(
f"Failed to query Vespa | "
f"status={status_code} | "
f"yql_length={yql_length} | "
f"exception={str(e)}"
)
if response_text:
logger.error(f"Vespa error response: {response_text[:1000]}")
logger.error(f"Vespa request URL: {e.request.url}")
# Re-raise with diagnostics so callers see what actually went wrong
raise httpx.HTTPError(
f"Failed to query Vespa (status={status_code}, " f"yql_length={yql_length})"
) from e
response_json: dict[str, Any] = response.json()

View File

@@ -43,6 +43,22 @@ def build_vespa_filters(
return ""
return f"({' or '.join(eq_elems)})"
def _build_weighted_set_filter(key: str, vals: list[str] | None) -> str:
"""Build a Vespa weightedSet filter for large value lists.
Uses Vespa's native weightedSet() operator instead of OR-chained
'contains' clauses. This is critical for fields like
access_control_list where a single user may have tens of thousands
of ACL entries — OR clauses at that scale cause Vespa to reject
the query with HTTP 400."""
if not key or not vals:
return ""
filtered = [val for val in vals if val]
if not filtered:
return ""
items = ", ".join(f'"{val}":1' for val in filtered)
return f"weightedSet({key}, {{{items}}})"
def _build_int_or_filters(key: str, vals: list[int] | None) -> str:
"""For an integer field filter.
Returns a bare clause or ""."""
@@ -157,11 +173,16 @@ def build_vespa_filters(
if filters.tenant_id and MULTI_TENANT:
filter_parts.append(build_tenant_id_filter(filters.tenant_id))
# ACL filters
# ACL filters — use weightedSet for efficient matching against the
# access_control_list weightedset<string> field. OR-chaining thousands
# of 'contains' clauses causes Vespa to reject the query (HTTP 400)
# for users with large numbers of external permission groups.
if filters.access_control_list is not None:
_append(
filter_parts,
_build_or_filters(ACCESS_CONTROL_LIST, filters.access_control_list),
_build_weighted_set_filter(
ACCESS_CONTROL_LIST, filters.access_control_list
),
)
# Source type filters

View File

@@ -35,6 +35,8 @@ class OnyxErrorCode(Enum):
INSUFFICIENT_PERMISSIONS = ("INSUFFICIENT_PERMISSIONS", 403)
ADMIN_ONLY = ("ADMIN_ONLY", 403)
EE_REQUIRED = ("EE_REQUIRED", 403)
SINGLE_TENANT_ONLY = ("SINGLE_TENANT_ONLY", 403)
ENV_VAR_GATED = ("ENV_VAR_GATED", 403)
# ------------------------------------------------------------------
# Validation / Bad Request (400)

View File

@@ -88,9 +88,13 @@ def summarize_image_with_error_handling(
try:
return summarize_image_pipeline(llm, image_data, user_prompt, system_prompt)
except UnsupportedImageFormatError:
magic_hex = image_data[:8].hex() if image_data else "empty"
logger.info(
"Skipping image summarization due to unsupported MIME type for %s",
"Skipping image summarization due to unsupported MIME type "
"for %s (magic_bytes=%s, size=%d bytes)",
context_name,
magic_hex,
len(image_data),
)
return None
@@ -134,9 +138,23 @@ def _summarize_image(
return summary
except Exception as e:
error_msg = f"Summarization failed. Messages: {messages}"
error_msg = error_msg[:1024]
raise ValueError(error_msg) from e
# Extract structured details from LiteLLM exceptions when available,
# rather than dumping the full messages payload (which contains base64
# image data and produces enormous, unreadable error logs).
str_e = str(e)
if len(str_e) > 512:
str_e = str_e[:512] + "... (truncated)"
parts = [f"Summarization failed: {type(e).__name__}: {str_e}"]
status_code = getattr(e, "status_code", None)
llm_provider = getattr(e, "llm_provider", None)
model = getattr(e, "model", None)
if status_code is not None:
parts.append(f"status_code={status_code}")
if llm_provider is not None:
parts.append(f"llm_provider={llm_provider}")
if model is not None:
parts.append(f"model={model}")
raise ValueError(" | ".join(parts)) from e
def _encode_image_for_llm_prompt(image_data: bytes) -> str:

View File

View File

@@ -0,0 +1,26 @@
from onyx.configs.app_configs import HOOK_ENABLED
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from shared_configs.configs import MULTI_TENANT
def require_hook_enabled() -> None:
"""FastAPI dependency that gates all hook management endpoints.
Hooks are only available in single-tenant / self-hosted deployments with
HOOK_ENABLED=true explicitly set. Two layers of protection:
1. MULTI_TENANT check — rejects even if HOOK_ENABLED is accidentally set true
2. HOOK_ENABLED flag — explicit opt-in by the operator
Use as: Depends(require_hook_enabled)
"""
if MULTI_TENANT:
raise OnyxError(
OnyxErrorCode.SINGLE_TENANT_ONLY,
"Hooks are not available in multi-tenant deployments",
)
if not HOOK_ENABLED:
raise OnyxError(
OnyxErrorCode.ENV_VAR_GATED,
"Hooks are not enabled. Set HOOK_ENABLED=true to enable.",
)

View File

@@ -0,0 +1,127 @@
from datetime import datetime
from enum import Enum
from typing import Annotated
from typing import Any
from pydantic import BaseModel
from pydantic import Field
from pydantic import field_validator
from pydantic import model_validator
from pydantic import SecretStr
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
NonEmptySecretStr = Annotated[SecretStr, Field(min_length=1)]
# ---------------------------------------------------------------------------
# Request models
# ---------------------------------------------------------------------------
class HookCreateRequest(BaseModel):
name: str = Field(min_length=1)
hook_point: HookPoint
endpoint_url: str = Field(min_length=1)
api_key: NonEmptySecretStr | None = None
fail_strategy: HookFailStrategy | None = None # if None, uses HookPointSpec default
timeout_seconds: float | None = Field(
default=None, gt=0
) # if None, uses HookPointSpec default
@field_validator("name", "endpoint_url")
@classmethod
def no_whitespace_only(cls, v: str) -> str:
if not v.strip():
raise ValueError("cannot be whitespace-only.")
return v
class HookUpdateRequest(BaseModel):
name: str | None = None
endpoint_url: str | None = None
api_key: NonEmptySecretStr | None = None
fail_strategy: HookFailStrategy | None = (
None # if None in model_fields_set, reset to spec default
)
timeout_seconds: float | None = Field(
default=None, gt=0
) # if None in model_fields_set, reset to spec default
@model_validator(mode="after")
def require_at_least_one_field(self) -> "HookUpdateRequest":
if not self.model_fields_set:
raise ValueError("At least one field must be provided for an update.")
if "name" in self.model_fields_set and not (self.name or "").strip():
raise ValueError("name cannot be cleared.")
if (
"endpoint_url" in self.model_fields_set
and not (self.endpoint_url or "").strip()
):
raise ValueError("endpoint_url cannot be cleared.")
return self
# ---------------------------------------------------------------------------
# Response models
# ---------------------------------------------------------------------------
class HookPointMetaResponse(BaseModel):
hook_point: HookPoint
display_name: str
description: str
docs_url: str | None
input_schema: dict[str, Any]
output_schema: dict[str, Any]
default_timeout_seconds: float
default_fail_strategy: HookFailStrategy
fail_hard_description: str
class HookResponse(BaseModel):
id: int
name: str
hook_point: HookPoint
# Nullable to match the DB column — endpoint_url is required on creation but
# future hook point types may not use an external endpoint (e.g. built-in handlers).
endpoint_url: str | None
fail_strategy: HookFailStrategy
timeout_seconds: float # always resolved — None from request is replaced with spec default before DB write
is_active: bool
creator_email: str | None
created_at: datetime
updated_at: datetime
class HookValidateResponse(BaseModel):
success: bool
error_message: str | None = None
# ---------------------------------------------------------------------------
# Health models
# ---------------------------------------------------------------------------
class HookHealthStatus(str, Enum):
healthy = "healthy" # green — reachable, no failures in last 1h
degraded = "degraded" # yellow — reachable, failures in last 1h
unreachable = "unreachable" # red — is_reachable=false or null
class HookFailureRecord(BaseModel):
error_message: str | None = None
status_code: int | None = None
duration_ms: int | None = None
created_at: datetime
class HookHealthResponse(BaseModel):
status: HookHealthStatus
recent_failures: list[HookFailureRecord] = Field(
default_factory=list,
description="Last 10 failures, newest first",
max_length=10,
)

View File

View File

@@ -0,0 +1,59 @@
from abc import ABC
from abc import abstractmethod
from typing import Any
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
_REQUIRED_ATTRS = (
"hook_point",
"display_name",
"description",
"default_timeout_seconds",
"fail_hard_description",
"default_fail_strategy",
)
class HookPointSpec(ABC):
"""Static metadata and contract for a pipeline hook point.
This is NOT a regular class meant for direct instantiation by callers.
Each concrete subclass represents exactly one hook point and is instantiated
once at startup, registered in onyx.hooks.registry._REGISTRY. No caller
should ever create instances directly — use get_hook_point_spec() or
get_all_specs() from the registry instead.
Each hook point is a concrete subclass of this class. Onyx engineers
own these definitions — customers never touch this code.
Subclasses must define all attributes as class-level constants.
"""
hook_point: HookPoint
display_name: str
description: str
default_timeout_seconds: float
fail_hard_description: str
default_fail_strategy: HookFailStrategy
docs_url: str | None = None
def __init_subclass__(cls, **kwargs: object) -> None:
super().__init_subclass__(**kwargs)
# Skip intermediate abstract subclasses — they may still be partially defined.
if getattr(cls, "__abstractmethods__", None):
return
missing = [attr for attr in _REQUIRED_ATTRS if not hasattr(cls, attr)]
if missing:
raise TypeError(f"{cls.__name__} must define class attributes: {missing}")
@property
@abstractmethod
def input_schema(self) -> dict[str, Any]:
"""JSON schema describing the request payload sent to the customer's endpoint."""
@property
@abstractmethod
def output_schema(self) -> dict[str, Any]:
"""JSON schema describing the expected response from the customer's endpoint."""

View File

@@ -0,0 +1,29 @@
from typing import Any
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.hooks.points.base import HookPointSpec
class DocumentIngestionSpec(HookPointSpec):
"""Hook point that runs during document ingestion.
# TODO(@Bo-Onyx): define call site, input/output schema, and timeout budget.
"""
hook_point = HookPoint.DOCUMENT_INGESTION
display_name = "Document Ingestion"
description = "Runs during document ingestion. Allows filtering or transforming documents before indexing."
default_timeout_seconds = 30.0
fail_hard_description = "The document will not be indexed."
default_fail_strategy = HookFailStrategy.HARD
@property
def input_schema(self) -> dict[str, Any]:
# TODO(@Bo-Onyx): define input schema
return {"type": "object", "properties": {}}
@property
def output_schema(self) -> dict[str, Any]:
# TODO(@Bo-Onyx): define output schema
return {"type": "object", "properties": {}}

View File

@@ -0,0 +1,83 @@
from typing import Any
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.hooks.points.base import HookPointSpec
class QueryProcessingSpec(HookPointSpec):
"""Hook point that runs on every user query before it enters the pipeline.
Call site: inside handle_stream_message_objects() in
backend/onyx/chat/process_message.py, immediately after message_text is
assigned from the request and before create_new_chat_message() saves it.
This is the earliest possible point in the query pipeline:
- Raw query — unmodified, exactly as the user typed it
- No side effects yet — message has not been saved to DB
- User identity is available for user-specific logic
Supported use cases:
- Query rejection: block queries based on content or user context
- Query rewriting: normalize, expand, or modify the query
- PII removal: scrub sensitive data before the LLM sees it
- Access control: reject queries from certain users or groups
- Query auditing: log or track queries based on business rules
"""
hook_point = HookPoint.QUERY_PROCESSING
display_name = "Query Processing"
description = (
"Runs on every user query before it enters the pipeline. "
"Allows rewriting, filtering, or rejecting queries."
)
default_timeout_seconds = 5.0 # user is actively waiting — keep tight
fail_hard_description = (
"The query will be blocked and the user will see an error message."
)
default_fail_strategy = HookFailStrategy.HARD
@property
def input_schema(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The raw query string exactly as the user typed it.",
},
"user_email": {
"type": ["string", "null"],
"description": "Email of the user submitting the query, or null if unauthenticated.",
},
"chat_session_id": {
"type": "string",
"description": "UUID of the chat session. Always present — the session is guaranteed to exist by the time this hook fires.",
},
},
"required": ["query", "user_email", "chat_session_id"],
"additionalProperties": False,
}
@property
def output_schema(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"query": {
"type": ["string", "null"],
"description": (
"The (optionally modified) query to use. "
"Set to null to reject the query."
),
},
"rejection_message": {
"type": ["string", "null"],
"description": (
"Message shown to the user when query is null. "
"Falls back to a generic message if not provided."
),
},
},
"required": ["query"],
}

View File

@@ -0,0 +1,45 @@
from onyx.db.enums import HookPoint
from onyx.hooks.points.base import HookPointSpec
from onyx.hooks.points.document_ingestion import DocumentIngestionSpec
from onyx.hooks.points.query_processing import QueryProcessingSpec
# Internal: use `monkeypatch.setattr(registry_module, "_REGISTRY", {...})` to override in tests.
_REGISTRY: dict[HookPoint, HookPointSpec] = {
HookPoint.DOCUMENT_INGESTION: DocumentIngestionSpec(),
HookPoint.QUERY_PROCESSING: QueryProcessingSpec(),
}
def validate_registry() -> None:
"""Assert that every HookPoint enum value has a registered spec.
Call once at application startup (e.g. from the FastAPI lifespan hook).
Raises RuntimeError if any hook point is missing a spec.
"""
missing = set(HookPoint) - set(_REGISTRY)
if missing:
raise RuntimeError(
f"Hook point(s) have no registered spec: {missing}. "
"Add an entry to onyx.hooks.registry._REGISTRY."
)
def get_hook_point_spec(hook_point: HookPoint) -> HookPointSpec:
"""Returns the spec for a given hook point.
Raises ValueError if the hook point has no registered spec — this is a
programmer error; every HookPoint enum value must have a corresponding spec
in _REGISTRY.
"""
try:
return _REGISTRY[hook_point]
except KeyError:
raise ValueError(
f"No spec registered for hook point {hook_point!r}. "
"Add an entry to onyx.hooks.registry._REGISTRY."
)
def get_all_specs() -> list[HookPointSpec]:
"""Returns the specs for all registered hook points."""
return list(_REGISTRY.values())

View File

@@ -395,6 +395,12 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
llm = get_default_llm_with_vision()
if not llm:
if get_image_extraction_and_analysis_enabled():
logger.warning(
"Image analysis is enabled but no vision-capable LLM is "
"available — images will not be summarized. Configure a "
"vision model in the admin LLM settings."
)
# Even without LLM, we still convert to IndexingDocument with base Sections
return [
IndexingDocument(

View File

@@ -168,10 +168,23 @@ def get_default_llm_with_vision(
if model_supports_image_input(
default_model.name, default_model.llm_provider.provider
):
logger.info(
"Using default vision model: %s (provider=%s)",
default_model.name,
default_model.llm_provider.provider,
)
return create_vision_llm(
LLMProviderView.from_model(default_model.llm_provider),
default_model.name,
)
else:
logger.warning(
"Default vision model %s (provider=%s) does not support "
"image input — falling back to searching all providers",
default_model.name,
default_model.llm_provider.provider,
)
# Fall back to searching all providers
models = fetch_existing_models(
db_session=db_session,
@@ -179,6 +192,10 @@ def get_default_llm_with_vision(
)
if not models:
logger.warning(
"No LLM models with VISION or CHAT flow type found — "
"image summarization will be disabled"
)
return None
for model in models:
@@ -200,11 +217,25 @@ def get_default_llm_with_vision(
for model in sorted_models:
if model_supports_image_input(model.name, model.llm_provider.provider):
logger.info(
"Using fallback vision model: %s (provider=%s)",
model.name,
model.llm_provider.provider,
)
return create_vision_llm(
provider_map[model.llm_provider_id],
model.name,
)
checked_models = [
f"{m.name} (provider={m.llm_provider.provider})" for m in sorted_models
]
logger.warning(
"No vision-capable model found among %d candidates: %s"
"image summarization will be disabled",
len(sorted_models),
", ".join(checked_models),
)
return None

View File

@@ -530,6 +530,11 @@ class LitellmLLM(LLM):
):
messages = _strip_tool_content_from_messages(messages)
# Only pass tool_choice when tools are present — some providers (e.g. Fireworks)
# reject requests where tool_choice is explicitly null.
if tools and tool_choice is not None:
optional_kwargs["tool_choice"] = tool_choice
response = litellm.completion(
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
model=model,
@@ -538,7 +543,6 @@ class LitellmLLM(LLM):
custom_llm_provider=self._custom_llm_provider or None,
messages=messages,
tools=tools,
tool_choice=tool_choice,
stream=stream,
temperature=temperature,
timeout=timeout_override or self._timeout,

View File

@@ -219,13 +219,26 @@ def litellm_exception_to_error_msg(
"ratelimiterror"
):
upstream_detail = upstream_detail.split(":", 1)[1].strip()
error_msg = (
f"{provider_name} rate limit: {upstream_detail}"
if upstream_detail
else f"{provider_name} rate limit exceeded: Please slow down your requests and try again later."
)
error_code = "RATE_LIMIT"
is_retryable = True
upstream_detail_lower = upstream_detail.lower()
if (
"insufficient_quota" in upstream_detail_lower
or "exceeded your current quota" in upstream_detail_lower
):
error_msg = (
f"{provider_name} quota exceeded: {upstream_detail}"
if upstream_detail
else f"{provider_name} quota exceeded: Verify billing and quota for this API key."
)
error_code = "BUDGET_EXCEEDED"
is_retryable = False
else:
error_msg = (
f"{provider_name} rate limit: {upstream_detail}"
if upstream_detail
else f"{provider_name} rate limit exceeded: Please slow down your requests and try again later."
)
error_code = "RATE_LIMIT"
is_retryable = True
elif isinstance(core_exception, ServiceUnavailableError):
provider_name = (
llm.config.model_provider

View File

@@ -62,6 +62,7 @@ from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import SqlEngine
from onyx.error_handling.exceptions import register_onyx_exception_handlers
from onyx.file_store.file_store import get_default_file_store
from onyx.hooks.registry import validate_registry
from onyx.server.api_key.api import router as api_key_router
from onyx.server.auth_check import check_router_auth
from onyx.server.documents.cc_pair import router as cc_pair_router
@@ -308,6 +309,7 @@ def validate_no_vector_db_settings() -> None:
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
validate_no_vector_db_settings()
validate_cache_backend_settings()
validate_registry()
# Set recursion limit
if SYSTEM_RECURSION_LIMIT is not None:

View File

@@ -479,7 +479,9 @@ def is_zip_file(file: UploadFile) -> bool:
def upload_files(
files: list[UploadFile], file_origin: FileOrigin = FileOrigin.CONNECTOR
files: list[UploadFile],
file_origin: FileOrigin = FileOrigin.CONNECTOR,
unzip: bool = True,
) -> FileUploadResponse:
# Skip directories and known macOS metadata entries
@@ -502,31 +504,46 @@ def upload_files(
if seen_zip:
raise HTTPException(status_code=400, detail=SEEN_ZIP_DETAIL)
seen_zip = True
# Validate the zip by opening it (catches corrupt/non-zip files)
with zipfile.ZipFile(file.file, "r") as zf:
zip_metadata_file_id = save_zip_metadata_to_file_store(
zf, file_store
)
for file_info in zf.namelist():
if zf.getinfo(file_info).is_dir():
continue
if not should_process_file(file_info):
continue
sub_file_bytes = zf.read(file_info)
mime_type, __ = mimetypes.guess_type(file_info)
if mime_type is None:
mime_type = "application/octet-stream"
file_id = file_store.save_file(
content=BytesIO(sub_file_bytes),
display_name=os.path.basename(file_info),
file_origin=file_origin,
file_type=mime_type,
if unzip:
zip_metadata_file_id = save_zip_metadata_to_file_store(
zf, file_store
)
deduped_file_paths.append(file_id)
deduped_file_names.append(os.path.basename(file_info))
for file_info in zf.namelist():
if zf.getinfo(file_info).is_dir():
continue
if not should_process_file(file_info):
continue
sub_file_bytes = zf.read(file_info)
mime_type, __ = mimetypes.guess_type(file_info)
if mime_type is None:
mime_type = "application/octet-stream"
file_id = file_store.save_file(
content=BytesIO(sub_file_bytes),
display_name=os.path.basename(file_info),
file_origin=file_origin,
file_type=mime_type,
)
deduped_file_paths.append(file_id)
deduped_file_names.append(os.path.basename(file_info))
continue
# Store the zip as-is (unzip=False)
file.file.seek(0)
file_id = file_store.save_file(
content=file.file,
display_name=file.filename,
file_origin=file_origin,
file_type=file.content_type or "application/zip",
)
deduped_file_paths.append(file_id)
deduped_file_names.append(file.filename)
continue
# Since we can't render docx files in the UI,
@@ -613,9 +630,10 @@ def _fetch_and_check_file_connector_cc_pair_permissions(
@router.post("/admin/connector/file/upload", tags=PUBLIC_API_TAGS)
def upload_files_api(
files: list[UploadFile],
unzip: bool = True,
_: User = Depends(current_curator_or_admin_user),
) -> FileUploadResponse:
return upload_files(files, FileOrigin.OTHER)
return upload_files(files, FileOrigin.OTHER, unzip=unzip)
@router.get("/admin/connector/{connector_id}/files", tags=PUBLIC_API_TAGS)
@@ -1319,7 +1337,7 @@ def get_connector_indexing_status(
# Track admin page visit for analytics
mt_cloud_telemetry(
tenant_id=tenant_id,
distinct_id=user.email,
distinct_id=str(user.id),
event=MilestoneRecordType.VISITED_ADMIN_PAGE,
)
@@ -1533,7 +1551,7 @@ def create_connector_from_model(
mt_cloud_telemetry(
tenant_id=tenant_id,
distinct_id=user.email,
distinct_id=str(user.id),
event=MilestoneRecordType.CREATED_CONNECTOR,
)
@@ -1611,7 +1629,7 @@ def create_connector_with_mock_credential(
mt_cloud_telemetry(
tenant_id=tenant_id,
distinct_id=user.email,
distinct_id=str(user.id),
event=MilestoneRecordType.CREATED_CONNECTOR,
)
return response
@@ -1915,9 +1933,7 @@ def submit_connector_request(
if not connector_name:
raise HTTPException(status_code=400, detail="Connector name cannot be empty")
# Get user identifier for telemetry
user_email = user.email
distinct_id = user_email or tenant_id
# Track connector request via PostHog telemetry (Cloud only)
from shared_configs.configs import MULTI_TENANT
@@ -1925,11 +1941,11 @@ def submit_connector_request(
if MULTI_TENANT:
mt_cloud_telemetry(
tenant_id=tenant_id,
distinct_id=distinct_id,
distinct_id=str(user.id),
event=MilestoneRecordType.REQUESTED_CONNECTOR,
properties={
"connector_name": connector_name,
"user_email": user_email,
"user_email": user.email,
},
)

View File

@@ -408,7 +408,7 @@ class FailedConnectorIndexingStatus(BaseModel):
"""Simplified version of ConnectorIndexingStatus for failed indexing attempts"""
cc_pair_id: int
name: str | None
name: str
error_msg: str | None
is_deletable: bool
connector_id: int
@@ -422,7 +422,7 @@ class ConnectorStatus(BaseModel):
"""
cc_pair_id: int
name: str | None
name: str
connector: ConnectorSnapshot
credential: CredentialSnapshot
access_type: AccessType
@@ -453,7 +453,7 @@ class DocsCountOperator(str, Enum):
class ConnectorIndexingStatusLite(BaseModel):
cc_pair_id: int
name: str | None
name: str
source: DocumentSource
access_type: AccessType
cc_pair_status: ConnectorCredentialPairStatus
@@ -488,7 +488,7 @@ class ConnectorCredentialPairIdentifier(BaseModel):
class ConnectorCredentialPairMetadata(BaseModel):
name: str | None = None
name: str
access_type: AccessType
auto_sync_options: dict[str, Any] | None = None
groups: list[int] = Field(default_factory=list)
@@ -501,7 +501,7 @@ class CCStatusUpdateRequest(BaseModel):
class ConnectorCredentialPairDescriptor(BaseModel):
id: int
name: str | None = None
name: str
connector: ConnectorSnapshot
credential: CredentialSnapshot
access_type: AccessType
@@ -511,7 +511,7 @@ class CCPairSummary(BaseModel):
"""Simplified connector-credential pair information with just essential data"""
id: int
name: str | None
name: str
source: DocumentSource
access_type: AccessType

View File

@@ -15,7 +15,7 @@
"date-fns": "^4.1.0",
"embla-carousel-react": "^8.6.0",
"lucide-react": "^0.562.0",
"next": "16.1.5",
"next": "16.1.7",
"next-themes": "^0.4.6",
"radix-ui": "^1.4.3",
"react": "19.2.3",
@@ -1711,9 +1711,9 @@
}
},
"node_modules/@next/env": {
"version": "16.1.5",
"resolved": "https://registry.npmjs.org/@next/env/-/env-16.1.5.tgz",
"integrity": "sha512-CRSCPJiSZoi4Pn69RYBDI9R7YK2g59vLexPQFXY0eyw+ILevIenCywzg+DqmlBik9zszEnw2HLFOUlLAcJbL7g==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/env/-/env-16.1.7.tgz",
"integrity": "sha512-rJJbIdJB/RQr2F1nylZr/PJzamvNNhfr3brdKP6s/GW850jbtR70QlSfFselvIBbcPUOlQwBakexjFzqLzF6pg==",
"license": "MIT"
},
"node_modules/@next/eslint-plugin-next": {
@@ -1727,9 +1727,9 @@
}
},
"node_modules/@next/swc-darwin-arm64": {
"version": "16.1.5",
"resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-16.1.5.tgz",
"integrity": "sha512-eK7Wdm3Hjy/SCL7TevlH0C9chrpeOYWx2iR7guJDaz4zEQKWcS1IMVfMb9UKBFMg1XgzcPTYPIp1Vcpukkjg6Q==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-16.1.7.tgz",
"integrity": "sha512-b2wWIE8sABdyafc4IM8r5Y/dS6kD80JRtOGrUiKTsACFQfWWgUQ2NwoUX1yjFMXVsAwcQeNpnucF2ZrujsBBPg==",
"cpu": [
"arm64"
],
@@ -1743,9 +1743,9 @@
}
},
"node_modules/@next/swc-darwin-x64": {
"version": "16.1.5",
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-16.1.5.tgz",
"integrity": "sha512-foQscSHD1dCuxBmGkbIr6ScAUF6pRoDZP6czajyvmXPAOFNnQUJu2Os1SGELODjKp/ULa4fulnBWoHV3XdPLfA==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-16.1.7.tgz",
"integrity": "sha512-zcnVaaZulS1WL0Ss38R5Q6D2gz7MtBu8GZLPfK+73D/hp4GFMrC2sudLky1QibfV7h6RJBJs/gOFvYP0X7UVlQ==",
"cpu": [
"x64"
],
@@ -1759,9 +1759,9 @@
}
},
"node_modules/@next/swc-linux-arm64-gnu": {
"version": "16.1.5",
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-16.1.5.tgz",
"integrity": "sha512-qNIb42o3C02ccIeSeKjacF3HXotGsxh/FMk/rSRmCzOVMtoWH88odn2uZqF8RLsSUWHcAqTgYmPD3pZ03L9ZAA==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-16.1.7.tgz",
"integrity": "sha512-2ant89Lux/Q3VyC8vNVg7uBaFVP9SwoK2jJOOR0L8TQnX8CAYnh4uctAScy2Hwj2dgjVHqHLORQZJ2wH6VxhSQ==",
"cpu": [
"arm64"
],
@@ -1775,9 +1775,9 @@
}
},
"node_modules/@next/swc-linux-arm64-musl": {
"version": "16.1.5",
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-16.1.5.tgz",
"integrity": "sha512-U+kBxGUY1xMAzDTXmuVMfhaWUZQAwzRaHJ/I6ihtR5SbTVUEaDRiEU9YMjy1obBWpdOBuk1bcm+tsmifYSygfw==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-16.1.7.tgz",
"integrity": "sha512-uufcze7LYv0FQg9GnNeZ3/whYfo+1Q3HnQpm16o6Uyi0OVzLlk2ZWoY7j07KADZFY8qwDbsmFnMQP3p3+Ftprw==",
"cpu": [
"arm64"
],
@@ -1791,9 +1791,9 @@
}
},
"node_modules/@next/swc-linux-x64-gnu": {
"version": "16.1.5",
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-16.1.5.tgz",
"integrity": "sha512-gq2UtoCpN7Ke/7tKaU7i/1L7eFLfhMbXjNghSv0MVGF1dmuoaPeEVDvkDuO/9LVa44h5gqpWeJ4mRRznjDv7LA==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-16.1.7.tgz",
"integrity": "sha512-KWVf2gxYvHtvuT+c4MBOGxuse5TD7DsMFYSxVxRBnOzok/xryNeQSjXgxSv9QpIVlaGzEn/pIuI6Koosx8CGWA==",
"cpu": [
"x64"
],
@@ -1807,9 +1807,9 @@
}
},
"node_modules/@next/swc-linux-x64-musl": {
"version": "16.1.5",
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-16.1.5.tgz",
"integrity": "sha512-bQWSE729PbXT6mMklWLf8dotislPle2L70E9q6iwETYEOt092GDn0c+TTNj26AjmeceSsC4ndyGsK5nKqHYXjQ==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-16.1.7.tgz",
"integrity": "sha512-HguhaGwsGr1YAGs68uRKc4aGWxLET+NevJskOcCAwXbwj0fYX0RgZW2gsOCzr9S11CSQPIkxmoSbuVaBp4Z3dA==",
"cpu": [
"x64"
],
@@ -1823,9 +1823,9 @@
}
},
"node_modules/@next/swc-win32-arm64-msvc": {
"version": "16.1.5",
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-16.1.5.tgz",
"integrity": "sha512-LZli0anutkIllMtTAWZlDqdfvjWX/ch8AFK5WgkNTvaqwlouiD1oHM+WW8RXMiL0+vAkAJyAGEzPPjO+hnrSNQ==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-16.1.7.tgz",
"integrity": "sha512-S0n3KrDJokKTeFyM/vGGGR8+pCmXYrjNTk2ZozOL1C/JFdfUIL9O1ATaJOl5r2POe56iRChbsszrjMAdWSv7kQ==",
"cpu": [
"arm64"
],
@@ -1839,9 +1839,9 @@
}
},
"node_modules/@next/swc-win32-x64-msvc": {
"version": "16.1.5",
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-16.1.5.tgz",
"integrity": "sha512-7is37HJTNQGhjPpQbkKjKEboHYQnCgpVt/4rBrrln0D9nderNxZ8ZWs8w1fAtzUx7wEyYjQ+/13myFgFj6K2Ng==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-16.1.7.tgz",
"integrity": "sha512-mwgtg8CNZGYm06LeEd+bNnOUfwOyNem/rOiP14Lsz+AnUY92Zq/LXwtebtUiaeVkhbroRCQ0c8GlR4UT1U+0yg==",
"cpu": [
"x64"
],
@@ -4971,12 +4971,15 @@
"license": "MIT"
},
"node_modules/baseline-browser-mapping": {
"version": "2.9.17",
"resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.9.17.tgz",
"integrity": "sha512-agD0MgJFUP/4nvjqzIB29zRPUuCF7Ge6mEv9s8dHrtYD7QWXRcx75rOADE/d5ah1NI+0vkDl0yorDd5U852IQQ==",
"version": "2.10.8",
"resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.10.8.tgz",
"integrity": "sha512-PCLz/LXGBsNTErbtB6i5u4eLpHeMfi93aUv5duMmj6caNu6IphS4q6UevDnL36sZQv9lrP11dbPKGMaXPwMKfQ==",
"license": "Apache-2.0",
"bin": {
"baseline-browser-mapping": "dist/cli.js"
"baseline-browser-mapping": "dist/cli.cjs"
},
"engines": {
"node": ">=6.0.0"
}
},
"node_modules/body-parser": {
@@ -8975,14 +8978,14 @@
}
},
"node_modules/next": {
"version": "16.1.5",
"resolved": "https://registry.npmjs.org/next/-/next-16.1.5.tgz",
"integrity": "sha512-f+wE+NSbiQgh3DSAlTaw2FwY5yGdVViAtp8TotNQj4kk4Q8Bh1sC/aL9aH+Rg1YAVn18OYXsRDT7U/079jgP7w==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/next/-/next-16.1.7.tgz",
"integrity": "sha512-WM0L7WrSvKwoLegLYr6V+mz+RIofqQgVAfHhMp9a88ms0cFX8iX9ew+snpWlSBwpkURJOUdvCEt3uLl3NNzvWg==",
"license": "MIT",
"dependencies": {
"@next/env": "16.1.5",
"@next/env": "16.1.7",
"@swc/helpers": "0.5.15",
"baseline-browser-mapping": "^2.8.3",
"baseline-browser-mapping": "^2.9.19",
"caniuse-lite": "^1.0.30001579",
"postcss": "8.4.31",
"styled-jsx": "5.1.6"
@@ -8994,14 +8997,14 @@
"node": ">=20.9.0"
},
"optionalDependencies": {
"@next/swc-darwin-arm64": "16.1.5",
"@next/swc-darwin-x64": "16.1.5",
"@next/swc-linux-arm64-gnu": "16.1.5",
"@next/swc-linux-arm64-musl": "16.1.5",
"@next/swc-linux-x64-gnu": "16.1.5",
"@next/swc-linux-x64-musl": "16.1.5",
"@next/swc-win32-arm64-msvc": "16.1.5",
"@next/swc-win32-x64-msvc": "16.1.5",
"@next/swc-darwin-arm64": "16.1.7",
"@next/swc-darwin-x64": "16.1.7",
"@next/swc-linux-arm64-gnu": "16.1.7",
"@next/swc-linux-arm64-musl": "16.1.7",
"@next/swc-linux-x64-gnu": "16.1.7",
"@next/swc-linux-x64-musl": "16.1.7",
"@next/swc-win32-arm64-msvc": "16.1.7",
"@next/swc-win32-x64-msvc": "16.1.7",
"sharp": "^0.34.4"
},
"peerDependencies": {

View File

@@ -16,7 +16,7 @@
"date-fns": "^4.1.0",
"embla-carousel-react": "^8.6.0",
"lucide-react": "^0.562.0",
"next": "16.1.5",
"next": "16.1.7",
"next-themes": "^0.4.6",
"radix-ui": "^1.4.3",
"react": "19.2.3",

View File

@@ -314,7 +314,7 @@ def create_persona(
)
mt_cloud_telemetry(
tenant_id=tenant_id,
distinct_id=user.email,
distinct_id=str(user.id),
event=MilestoneRecordType.CREATED_ASSISTANT,
)

View File

@@ -81,6 +81,7 @@ from onyx.server.manage.llm.models import VisionProviderResponse
from onyx.server.manage.llm.utils import generate_bedrock_display_name
from onyx.server.manage.llm.utils import generate_ollama_display_name
from onyx.server.manage.llm.utils import infer_vision_support
from onyx.server.manage.llm.utils import is_embedding_model
from onyx.server.manage.llm.utils import is_reasoning_model
from onyx.server.manage.llm.utils import is_valid_bedrock_model
from onyx.server.manage.llm.utils import ModelMetadata
@@ -1374,6 +1375,10 @@ def get_litellm_available_models(
try:
model_details = LitellmModelDetails.model_validate(model)
# Skip embedding models
if is_embedding_model(model_details.id):
continue
results.append(
LitellmFinalModelResponse(
provider_name=model_details.owned_by,

View File

@@ -366,3 +366,18 @@ def extract_vendor_from_model_name(model_name: str, provider: str) -> str | None
return None
return None
def is_embedding_model(model_name: str) -> bool:
"""Checks for if a model is an embedding model"""
from litellm import get_model_info
try:
# get_model_info raises on unknown models
# default to False
model_info = get_model_info(model_name)
except Exception:
return False
is_embedding_mode = model_info.get("mode") == "embedding"
return is_embedding_mode

View File

@@ -118,12 +118,6 @@ async def handle_streaming_transcription(
if result is None: # End of stream
logger.info("Streaming transcription: transcript stream ended")
break
if result.error:
logger.warning(
f"Streaming transcription: provider error: {result.error}"
)
await websocket.send_json({"type": "error", "message": result.error})
continue
# Send if text changed OR if VAD detected end of speech (for auto-send trigger)
if result.text and (result.text != last_transcript or result.is_vad_end):
last_transcript = result.text

View File

@@ -561,7 +561,7 @@ def handle_send_chat_message(
tenant_id = get_current_tenant_id()
mt_cloud_telemetry(
tenant_id=tenant_id,
distinct_id=tenant_id if user.is_anonymous else user.email,
distinct_id=tenant_id if user.is_anonymous else str(user.id),
event=MilestoneRecordType.RAN_QUERY,
)

View File

@@ -1,239 +0,0 @@
from __future__ import annotations
import re
from onyx.context.search.models import SavedSearchDoc
from onyx.context.search.models import SearchDoc
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import GeneratedImage
from onyx.server.query_and_chat.streaming_models import ImageGenerationFinal
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
from onyx.server.query_and_chat.streaming_models import OpenUrlDocuments
from onyx.server.query_and_chat.streaming_models import OpenUrlStart
from onyx.server.query_and_chat.streaming_models import OpenUrlUrls
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
from onyx.server.query_and_chat.streaming_models import ReasoningStart
from onyx.server.query_and_chat.streaming_models import SearchToolDocumentsDelta
from onyx.server.query_and_chat.streaming_models import SearchToolQueriesDelta
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
_CANNOT_SHOW_STEP_RESULTS_STR = "[Cannot display step results]"
def _adjust_message_text_for_agent_search_results(
adjusted_message_text: str,
final_documents: list[SavedSearchDoc], # noqa: ARG001
) -> str:
# Remove all [Q<integer>] patterns (sub-question citations)
return re.sub(r"\[Q\d+\]", "", adjusted_message_text)
def _replace_d_citations_with_links(
message_text: str, final_documents: list[SavedSearchDoc]
) -> str:
def replace_citation(match: re.Match[str]) -> str:
d_number = match.group(1)
try:
doc_index = int(d_number) - 1
if 0 <= doc_index < len(final_documents):
doc = final_documents[doc_index]
link = doc.link if doc.link else ""
return f"[[{d_number}]]({link})"
return match.group(0)
except (ValueError, IndexError):
return match.group(0)
return re.sub(r"\[D(\d+)\]", replace_citation, message_text)
def create_message_packets(
message_text: str,
final_documents: list[SavedSearchDoc] | None,
turn_index: int,
is_legacy_agentic: bool = False,
) -> list[Packet]:
packets: list[Packet] = []
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
obj=AgentResponseStart(
final_documents=SearchDoc.from_saved_search_docs(final_documents or []),
),
)
)
adjusted_message_text = message_text
if is_legacy_agentic:
if final_documents is not None:
adjusted_message_text = _adjust_message_text_for_agent_search_results(
message_text, final_documents
)
adjusted_message_text = _replace_d_citations_with_links(
adjusted_message_text, final_documents
)
else:
adjusted_message_text = re.sub(r"\[Q\d+\]", "", message_text)
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
obj=AgentResponseDelta(
content=adjusted_message_text,
),
),
)
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
obj=SectionEnd(),
)
)
return packets
def create_citation_packets(
citation_info_list: list[CitationInfo], turn_index: int
) -> list[Packet]:
packets: list[Packet] = []
# Emit each citation as a separate CitationInfo packet
for citation_info in citation_info_list:
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
obj=citation_info,
)
)
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
return packets
def create_reasoning_packets(reasoning_text: str, turn_index: int) -> list[Packet]:
packets: list[Packet] = []
packets.append(
Packet(placement=Placement(turn_index=turn_index), obj=ReasoningStart())
)
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
obj=ReasoningDelta(
reasoning=reasoning_text,
),
),
)
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
return packets
def create_image_generation_packets(
images: list[GeneratedImage], turn_index: int
) -> list[Packet]:
packets: list[Packet] = []
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
obj=ImageGenerationToolStart(),
)
)
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
obj=ImageGenerationFinal(images=images),
),
)
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
return packets
def create_fetch_packets(
fetch_docs: list[SavedSearchDoc],
urls: list[str],
turn_index: int,
) -> list[Packet]:
packets: list[Packet] = []
# Emit start packet
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
obj=OpenUrlStart(),
)
)
# Emit URLs packet
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
obj=OpenUrlUrls(urls=urls),
)
)
# Emit documents packet
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
obj=OpenUrlDocuments(
documents=SearchDoc.from_saved_search_docs(fetch_docs)
),
)
)
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
return packets
def create_search_packets(
search_queries: list[str],
saved_search_docs: list[SavedSearchDoc],
is_internet_search: bool,
turn_index: int,
) -> list[Packet]:
packets: list[Packet] = []
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
obj=SearchToolStart(
is_internet_search=is_internet_search,
),
)
)
# Emit queries if present
if search_queries:
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
obj=SearchToolQueriesDelta(queries=search_queries),
),
)
# Emit documents if present
if saved_search_docs:
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
obj=SearchToolDocumentsDelta(
documents=SearchDoc.from_saved_search_docs(saved_search_docs)
),
),
)
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
return packets

View File

@@ -1,3 +1,4 @@
import hashlib
import mimetypes
from io import BytesIO
from typing import Any
@@ -83,6 +84,14 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
def __init__(self, tool_id: int, emitter: Emitter) -> None:
super().__init__(emitter=emitter)
self._id = tool_id
# Cache of (filename, content_hash) -> ci_file_id to avoid re-uploading
# the same file on every tool call iteration within the same agent session.
# Filename is included in the key so two files with identical bytes but
# different names each get their own upload slot.
# TTL assumption: code-interpreter file TTLs (typically hours) greatly
# exceed the lifetime of a single agent session (at most MAX_LLM_CYCLES
# iterations, typically a few minutes), so stale-ID eviction is not needed.
self._uploaded_file_cache: dict[tuple[str, str], str] = {}
@property
def id(self) -> int:
@@ -182,8 +191,13 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
for ind, chat_file in enumerate(chat_files):
file_name = chat_file.filename or f"file_{ind}"
try:
# Upload to Code Interpreter
ci_file_id = client.upload_file(chat_file.content, file_name)
content_hash = hashlib.sha256(chat_file.content).hexdigest()
cache_key = (file_name, content_hash)
ci_file_id = self._uploaded_file_cache.get(cache_key)
if ci_file_id is None:
# Upload to Code Interpreter
ci_file_id = client.upload_file(chat_file.content, file_name)
self._uploaded_file_cache[cache_key] = ci_file_id
# Stage for execution
files_to_stage.append({"path": file_name, "file_id": ci_file_id})
@@ -299,14 +313,10 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
f"Failed to delete Code Interpreter generated file {ci_file_id}: {e}"
)
# Cleanup staged input files
for file_mapping in files_to_stage:
try:
client.delete_file(file_mapping["file_id"])
except Exception as e:
logger.error(
f"Failed to delete Code Interpreter staged file {file_mapping['file_id']}: {e}"
)
# Note: staged input files are intentionally not deleted here because
# _uploaded_file_cache reuses their file_ids across iterations. They are
# orphaned when the session ends, but the code interpreter cleans up
# stale files on its own TTL.
# Emit file_ids once files are processed
if generated_file_ids:

View File

@@ -74,7 +74,7 @@ def make_structured_onyx_request_id(prefix: str, request_url: str) -> str:
def _make_onyx_request_id(prefix: str, hash_input: str) -> str:
"""helper function to return an id given a string input"""
hash_obj = hashlib.md5(hash_input.encode("utf-8"))
hash_obj = hashlib.md5(hash_input.encode("utf-8"), usedforsecurity=False)
hash_bytes = hash_obj.digest()[:6] # Truncate to 6 bytes
# 6 bytes becomes 8 bytes. we shouldn't need to strip but just in case

View File

@@ -2,6 +2,7 @@ import contextvars
import threading
import uuid
from enum import Enum
from typing import Any
import requests
@@ -152,7 +153,7 @@ def mt_cloud_telemetry(
tenant_id: str,
distinct_id: str,
event: MilestoneRecordType,
properties: dict | None = None,
properties: dict[str, Any] | None = None,
) -> None:
if not MULTI_TENANT:
return
@@ -173,3 +174,18 @@ def mt_cloud_telemetry(
attribute="event_telemetry",
fallback=noop_fallback,
)(distinct_id, event, all_properties)
def mt_cloud_identify(
distinct_id: str,
properties: dict[str, Any] | None = None,
) -> None:
"""Create/update a PostHog person profile (Cloud only)."""
if not MULTI_TENANT:
return
fetch_versioned_implementation_with_fallback(
module="onyx.utils.telemetry",
attribute="identify_user",
fallback=noop_fallback,
)(distinct_id, properties)

View File

@@ -15,9 +15,6 @@ class TranscriptResult(BaseModel):
is_vad_end: bool = False
"""True if VAD detected end of speech (silence). Use for auto-send."""
error: str | None = None
"""Provider error message to forward to the client, if any."""
class StreamingTranscriberProtocol(Protocol):
"""Protocol for streaming transcription sessions."""

View File

@@ -56,17 +56,6 @@ def _http_to_ws_url(http_url: str) -> str:
return http_url
_USER_FACING_ERROR_MESSAGES: dict[str, str] = {
"input_audio_buffer_commit_empty": (
"No audio was recorded. Please check your microphone and try again."
),
"invalid_api_key": "Voice service authentication failed. Please contact support.",
"rate_limit_exceeded": "Voice service is temporarily busy. Please try again shortly.",
}
_DEFAULT_USER_ERROR = "A voice transcription error occurred. Please try again."
class OpenAIStreamingTranscriber(StreamingTranscriberProtocol):
"""Streaming transcription using OpenAI Realtime API."""
@@ -153,17 +142,6 @@ class OpenAIStreamingTranscriber(StreamingTranscriberProtocol):
if msg_type == OpenAIRealtimeMessageType.ERROR:
error = data.get("error", {})
self._logger.error(f"OpenAI error: {error}")
error_code = error.get("code", "")
user_message = _USER_FACING_ERROR_MESSAGES.get(
error_code, _DEFAULT_USER_ERROR
)
await self._transcript_queue.put(
TranscriptResult(
text="",
is_vad_end=False,
error=user_message,
)
)
continue
# Handle VAD events

View File

@@ -65,7 +65,7 @@ attrs==25.4.0
# jsonschema
# referencing
# zeep
authlib==1.6.7
authlib==1.6.9
# via fastmcp
azure-cognitiveservices-speech==1.38.0
# via onyx
@@ -698,7 +698,7 @@ py-key-value-aio==0.4.4
# via fastmcp
pyairtable==3.0.1
# via onyx
pyasn1==0.6.2
pyasn1==0.6.3
# via
# pyasn1-modules
# rsa
@@ -737,7 +737,7 @@ pygithub==2.5.0
# via onyx
pygments==2.19.2
# via rich
pyjwt==2.11.0
pyjwt==2.12.0
# via
# fastapi-users
# mcp
@@ -752,7 +752,7 @@ pypandoc-binary==1.16.2
# via onyx
pyparsing==3.2.5
# via httplib2
pypdf==6.8.0
pypdf==6.9.1
# via
# onyx
# unstructured-client

View File

@@ -263,7 +263,7 @@ oauthlib==3.2.2
# via
# kubernetes
# requests-oauthlib
onyx-devtools==0.7.0
onyx-devtools==0.7.1
# via onyx
openai==2.14.0
# via
@@ -326,7 +326,7 @@ pure-eval==0.2.3
# via stack-data
py==1.11.0
# via retry
pyasn1==0.6.2
pyasn1==0.6.3
# via
# pyasn1-modules
# rsa
@@ -353,7 +353,7 @@ pygments==2.19.2
# via
# ipython
# ipython-pygments-lexers
pyjwt==2.11.0
pyjwt==2.12.0
# via mcp
pyparsing==3.2.5
# via matplotlib

View File

@@ -195,7 +195,7 @@ propcache==0.4.1
# yarl
py==1.11.0
# via retry
pyasn1==0.6.2
pyasn1==0.6.3
# via
# pyasn1-modules
# rsa
@@ -218,7 +218,7 @@ pydantic-core==2.33.2
# via pydantic
pydantic-settings==2.12.0
# via mcp
pyjwt==2.11.0
pyjwt==2.12.0
# via mcp
python-dateutil==2.8.2
# via

View File

@@ -285,7 +285,7 @@ psutil==7.1.3
# via accelerate
py==1.11.0
# via retry
pyasn1==0.6.2
pyasn1==0.6.3
# via
# pyasn1-modules
# rsa
@@ -308,7 +308,7 @@ pydantic-core==2.33.2
# via pydantic
pydantic-settings==2.12.0
# via mcp
pyjwt==2.11.0
pyjwt==2.12.0
# via mcp
python-dateutil==2.8.2
# via

View File

@@ -0,0 +1,471 @@
from __future__ import annotations
import argparse
import asyncio
import json
import logging
import sys
from dataclasses import asdict
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from typing import TypedDict
from typing import TypeGuard
import aiohttp
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
)
logger = logging.getLogger(__name__)
DEFAULT_API_BASE = "http://localhost:3000"
INTERNAL_SEARCH_TOOL_NAME = "internal_search"
INTERNAL_SEARCH_IN_CODE_TOOL_ID = "SearchTool"
MAX_REQUEST_ATTEMPTS = 5
RETRIABLE_STATUS_CODES = {429, 500, 502, 503, 504}
@dataclass(frozen=True)
class QuestionRecord:
question_id: str
question: str
@dataclass(frozen=True)
class AnswerRecord:
question_id: str
answer: str
document_ids: list[str]
@dataclass(frozen=True)
class FailedQuestionRecord:
question_id: str
error: str
class Citation(TypedDict, total=False):
citation_number: int
document_id: str
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=(
"Submit questions to Onyx chat with internal search forced and write "
"answers to a JSONL file."
)
)
parser.add_argument(
"--questions-file",
type=Path,
required=True,
help="Path to the input questions JSONL file.",
)
parser.add_argument(
"--output-file",
type=Path,
required=True,
help="Path to the output answers JSONL file.",
)
parser.add_argument(
"--api-key",
type=str,
required=True,
help="API key used to authenticate against Onyx.",
)
parser.add_argument(
"--api-base",
type=str,
default=DEFAULT_API_BASE,
help=(
"Frontend base URL for Onyx. If `/api` is omitted, it will be added "
f"automatically. Default: {DEFAULT_API_BASE}"
),
)
parser.add_argument(
"--parallelism",
type=int,
default=1,
help="Number of questions to process in parallel. Default: 1.",
)
parser.add_argument(
"--max-questions",
type=int,
default=None,
help="Optional cap on how many questions to process. Defaults to all.",
)
return parser.parse_args()
def normalize_api_base(api_base: str) -> str:
normalized = api_base.rstrip("/")
if normalized.endswith("/api"):
return normalized
return f"{normalized}/api"
def load_questions(questions_file: Path) -> list[QuestionRecord]:
if not questions_file.exists():
raise FileNotFoundError(f"Questions file not found: {questions_file}")
questions: list[QuestionRecord] = []
with questions_file.open("r", encoding="utf-8") as file:
for line_number, line in enumerate(file, start=1):
stripped_line = line.strip()
if not stripped_line:
continue
try:
payload = json.loads(stripped_line)
except json.JSONDecodeError as exc:
raise ValueError(
f"Invalid JSON on line {line_number} of {questions_file}"
) from exc
question_id = payload.get("question_id")
question = payload.get("question")
if not isinstance(question_id, str) or not question_id:
raise ValueError(
f"Line {line_number} is missing a non-empty `question_id`."
)
if not isinstance(question, str) or not question:
raise ValueError(
f"Line {line_number} is missing a non-empty `question`."
)
questions.append(QuestionRecord(question_id=question_id, question=question))
return questions
async def read_json_response(
response: aiohttp.ClientResponse,
) -> dict[str, Any] | list[dict[str, Any]]:
response_text = await response.text()
if response.status >= 400:
raise RuntimeError(
f"Request to {response.url} failed with {response.status}: {response_text}"
)
try:
payload = json.loads(response_text)
except json.JSONDecodeError as exc:
raise RuntimeError(
f"Request to {response.url} returned non-JSON content: {response_text}"
) from exc
if not isinstance(payload, (dict, list)):
raise RuntimeError(
f"Unexpected response payload type from {response.url}: {type(payload)}"
)
return payload
async def request_json_with_retries(
session: aiohttp.ClientSession,
method: str,
url: str,
headers: dict[str, str],
json_payload: dict[str, Any] | None = None,
) -> dict[str, Any] | list[dict[str, Any]]:
backoff_seconds = 1.0
for attempt in range(1, MAX_REQUEST_ATTEMPTS + 1):
try:
async with session.request(
method=method,
url=url,
headers=headers,
json=json_payload,
) as response:
if (
response.status in RETRIABLE_STATUS_CODES
and attempt < MAX_REQUEST_ATTEMPTS
):
response_text = await response.text()
logger.warning(
"Retryable response from %s on attempt %s/%s: %s %s",
url,
attempt,
MAX_REQUEST_ATTEMPTS,
response.status,
response_text,
)
await asyncio.sleep(backoff_seconds)
backoff_seconds *= 2
continue
return await read_json_response(response)
except (aiohttp.ClientError, asyncio.TimeoutError) as exc:
if attempt == MAX_REQUEST_ATTEMPTS:
raise RuntimeError(
f"Request to {url} failed after {MAX_REQUEST_ATTEMPTS} attempts."
) from exc
logger.warning(
"Request to %s failed on attempt %s/%s: %s",
url,
attempt,
MAX_REQUEST_ATTEMPTS,
exc,
)
await asyncio.sleep(backoff_seconds)
backoff_seconds *= 2
raise RuntimeError(f"Request to {url} failed unexpectedly.")
def extract_document_ids(citation_info: object) -> list[str]:
if not isinstance(citation_info, list):
return []
sorted_citations = sorted(
(citation for citation in citation_info if _is_valid_citation(citation)),
key=_citation_sort_key,
)
document_ids: list[str] = []
seen_document_ids: set[str] = set()
for citation in sorted_citations:
document_id = citation["document_id"]
if document_id not in seen_document_ids:
seen_document_ids.add(document_id)
document_ids.append(document_id)
return document_ids
def _is_valid_citation(citation: object) -> TypeGuard[Citation]:
return (
isinstance(citation, dict)
and isinstance(citation.get("document_id"), str)
and bool(citation["document_id"])
)
def _citation_sort_key(citation: Citation) -> int:
citation_number = citation.get("citation_number")
if isinstance(citation_number, int):
return citation_number
return sys.maxsize
async def fetch_internal_search_tool_id(
session: aiohttp.ClientSession,
api_base: str,
headers: dict[str, str],
) -> int:
payload = await request_json_with_retries(
session=session,
method="GET",
url=f"{api_base}/tool",
headers=headers,
)
if not isinstance(payload, list):
raise RuntimeError("Expected `/tool` to return a list.")
for tool in payload:
if not isinstance(tool, dict):
continue
if tool.get("in_code_tool_id") == INTERNAL_SEARCH_IN_CODE_TOOL_ID:
tool_id = tool.get("id")
if isinstance(tool_id, int):
return tool_id
for tool in payload:
if not isinstance(tool, dict):
continue
if tool.get("name") == INTERNAL_SEARCH_TOOL_NAME:
tool_id = tool.get("id")
if isinstance(tool_id, int):
return tool_id
raise RuntimeError(
"Could not find the internal search tool in `/tool`. "
"Make sure SearchTool is available for this environment."
)
async def submit_question(
session: aiohttp.ClientSession,
api_base: str,
headers: dict[str, str],
internal_search_tool_id: int,
question_record: QuestionRecord,
) -> AnswerRecord:
payload = {
"message": question_record.question,
"chat_session_info": {"persona_id": 0},
"parent_message_id": None,
"file_descriptors": [],
"allowed_tool_ids": [internal_search_tool_id],
"forced_tool_id": internal_search_tool_id,
"stream": False,
}
response_payload = await request_json_with_retries(
session=session,
method="POST",
url=f"{api_base}/chat/send-chat-message",
headers=headers,
json_payload=payload,
)
if not isinstance(response_payload, dict):
raise RuntimeError(
"Expected `/chat/send-chat-message` to return an object when `stream=false`."
)
answer = response_payload.get("answer_citationless")
if not isinstance(answer, str):
answer = response_payload.get("answer")
if not isinstance(answer, str):
raise RuntimeError(
f"Response for question {question_record.question_id} is missing `answer`."
)
return AnswerRecord(
question_id=question_record.question_id,
answer=answer,
document_ids=extract_document_ids(response_payload.get("citation_info")),
)
async def generate_answers(
questions: list[QuestionRecord],
output_file: Path,
api_base: str,
api_key: str,
parallelism: int,
) -> None:
if parallelism < 1:
raise ValueError("`--parallelism` must be at least 1.")
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
timeout = aiohttp.ClientTimeout(
total=None,
connect=30,
sock_connect=30,
sock_read=600,
)
connector = aiohttp.TCPConnector(limit=parallelism)
output_file.parent.mkdir(parents=True, exist_ok=True)
with output_file.open("a", encoding="utf-8") as file:
async with aiohttp.ClientSession(
timeout=timeout, connector=connector
) as session:
internal_search_tool_id = await fetch_internal_search_tool_id(
session=session,
api_base=api_base,
headers=headers,
)
logger.info("Using internal search tool id %s", internal_search_tool_id)
semaphore = asyncio.Semaphore(parallelism)
progress_lock = asyncio.Lock()
write_lock = asyncio.Lock()
completed = 0
successful = 0
failed_questions: list[FailedQuestionRecord] = []
total = len(questions)
async def process_question(question_record: QuestionRecord) -> None:
nonlocal completed
nonlocal successful
try:
async with semaphore:
result = await submit_question(
session=session,
api_base=api_base,
headers=headers,
internal_search_tool_id=internal_search_tool_id,
question_record=question_record,
)
except Exception as exc:
async with progress_lock:
completed += 1
failed_questions.append(
FailedQuestionRecord(
question_id=question_record.question_id,
error=str(exc),
)
)
logger.exception(
"Failed question %s (%s/%s)",
question_record.question_id,
completed,
total,
)
return
async with write_lock:
file.write(json.dumps(asdict(result), ensure_ascii=False))
file.write("\n")
file.flush()
async with progress_lock:
completed += 1
successful += 1
logger.info("Processed %s/%s questions", completed, total)
await asyncio.gather(
*(process_question(question_record) for question_record in questions)
)
if failed_questions:
logger.warning(
"Completed with %s failed questions and %s successful questions.",
len(failed_questions),
successful,
)
for failed_question in failed_questions:
logger.warning(
"Failed question %s: %s",
failed_question.question_id,
failed_question.error,
)
def main() -> None:
args = parse_args()
questions = load_questions(args.questions_file)
api_base = normalize_api_base(args.api_base)
if args.max_questions is not None:
if args.max_questions < 1:
raise ValueError("`--max-questions` must be at least 1 when provided.")
questions = questions[: args.max_questions]
logger.info("Loaded %s questions from %s", len(questions), args.questions_file)
logger.info("Writing answers to %s", args.output_file)
asyncio.run(
generate_answers(
questions=questions,
output_file=args.output_file,
api_base=api_base,
api_key=args.api_key,
parallelism=args.parallelism,
)
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,291 @@
"""
Script to upload files from a directory as individual file connectors in Onyx.
Each file gets its own connector named after the file.
Usage:
python upload_files_as_connectors.py --data-dir /path/to/files --api-key YOUR_KEY
python upload_files_as_connectors.py --data-dir /path/to/files --api-key YOUR_KEY --api-base http://onyxserver:3000
python upload_files_as_connectors.py --data-dir /path/to/files --api-key YOUR_KEY --file-glob '*.zip'
Requires:
pip install requests
"""
import argparse
import fnmatch
import os
import sys
import threading
import time
import requests
REQUEST_TIMEOUT = 900 # 15 minutes
def _elapsed_printer(label: str, stop_event: threading.Event) -> None:
"""Print a live elapsed-time counter until stop_event is set."""
start = time.monotonic()
while not stop_event.wait(timeout=1):
elapsed = int(time.monotonic() - start)
m, s = divmod(elapsed, 60)
print(f"\r {label} ... {m:02d}:{s:02d}", end="", flush=True)
elapsed = int(time.monotonic() - start)
m, s = divmod(elapsed, 60)
print(f"\r {label} ... {m:02d}:{s:02d} done")
def _timed_request(label: str, fn: object) -> requests.Response:
"""Run a request function while displaying a live elapsed timer."""
stop = threading.Event()
t = threading.Thread(target=_elapsed_printer, args=(label, stop), daemon=True)
t.start()
try:
resp = fn() # type: ignore[operator]
finally:
stop.set()
t.join()
return resp
def upload_file(
session: requests.Session, base_url: str, file_path: str
) -> dict | None:
"""Upload a single file and return the response with file_paths and file_names."""
with open(file_path, "rb") as f:
resp = _timed_request(
"Uploading",
lambda: session.post(
f"{base_url}/api/manage/admin/connector/file/upload",
files={"files": (os.path.basename(file_path), f)},
timeout=REQUEST_TIMEOUT,
),
)
if not resp.ok:
print(f" ERROR uploading: {resp.text}")
return None
return resp.json()
def create_connector(
session: requests.Session,
base_url: str,
name: str,
file_paths: list[str],
file_names: list[str],
zip_metadata_file_id: str | None,
) -> int | None:
"""Create a file connector and return its ID."""
resp = _timed_request(
"Creating connector",
lambda: session.post(
f"{base_url}/api/manage/admin/connector",
json={
"name": name,
"source": "file",
"input_type": "load_state",
"connector_specific_config": {
"file_locations": file_paths,
"file_names": file_names,
"zip_metadata_file_id": zip_metadata_file_id,
},
"refresh_freq": None,
"prune_freq": None,
"indexing_start": None,
"access_type": "public",
"groups": [],
},
timeout=REQUEST_TIMEOUT,
),
)
if not resp.ok:
print(f" ERROR creating connector: {resp.text}")
return None
return resp.json()["id"]
def create_credential(
session: requests.Session, base_url: str, name: str
) -> int | None:
"""Create a dummy credential for the file connector."""
resp = session.post(
f"{base_url}/api/manage/credential",
json={
"credential_json": {},
"admin_public": True,
"source": "file",
"curator_public": True,
"groups": [],
"name": name,
},
timeout=REQUEST_TIMEOUT,
)
if not resp.ok:
print(f" ERROR creating credential: {resp.text}")
return None
return resp.json()["id"]
def link_credential(
session: requests.Session,
base_url: str,
connector_id: int,
credential_id: int,
name: str,
) -> bool:
"""Link the connector to the credential (create CC pair)."""
resp = session.put(
f"{base_url}/api/manage/connector/{connector_id}/credential/{credential_id}",
json={
"name": name,
"access_type": "public",
"groups": [],
"auto_sync_options": None,
"processing_mode": "REGULAR",
},
timeout=REQUEST_TIMEOUT,
)
if not resp.ok:
print(f" ERROR linking credential: {resp.text}")
return False
return True
def run_connector(
session: requests.Session,
base_url: str,
connector_id: int,
credential_id: int,
) -> bool:
"""Trigger the connector to start indexing."""
resp = session.post(
f"{base_url}/api/manage/admin/connector/run-once",
json={
"connector_id": connector_id,
"credentialIds": [credential_id],
"from_beginning": False,
},
timeout=REQUEST_TIMEOUT,
)
if not resp.ok:
print(f" ERROR running connector: {resp.text}")
return False
return True
def process_file(session: requests.Session, base_url: str, file_path: str) -> bool:
"""Process a single file through the full connector creation flow."""
file_name = os.path.basename(file_path)
connector_name = file_name
print(f"Processing: {file_name}")
# Step 1: Upload
upload_resp = upload_file(session, base_url, file_path)
if not upload_resp:
return False
# Step 2: Create connector
connector_id = create_connector(
session,
base_url,
name=f"FileConnector-{connector_name}",
file_paths=upload_resp["file_paths"],
file_names=upload_resp["file_names"],
zip_metadata_file_id=upload_resp.get("zip_metadata_file_id"),
)
if connector_id is None:
return False
# Step 3: Create credential
credential_id = create_credential(session, base_url, name=connector_name)
if credential_id is None:
return False
# Step 4: Link connector to credential
if not link_credential(
session, base_url, connector_id, credential_id, connector_name
):
return False
# Step 5: Trigger indexing
if not run_connector(session, base_url, connector_id, credential_id):
return False
print(f" OK (connector_id={connector_id})")
return True
def get_authenticated_session(api_key: str) -> requests.Session:
"""Create a session authenticated with an API key."""
session = requests.Session()
session.headers.update({"Authorization": f"Bearer {api_key}"})
return session
def main() -> None:
parser = argparse.ArgumentParser(
description="Upload files as individual Onyx file connectors."
)
parser.add_argument(
"--data-dir",
required=True,
help="Directory containing files to upload.",
)
parser.add_argument(
"--api-base",
default="http://localhost:3000",
help="Base URL for the Onyx API (default: http://localhost:3000).",
)
parser.add_argument(
"--api-key",
required=True,
help="API key for authentication.",
)
parser.add_argument(
"--file-glob",
default=None,
help="Glob pattern to filter files (e.g. '*.json', '*.zip').",
)
args = parser.parse_args()
data_dir = args.data_dir
base_url = args.api_base.rstrip("/")
api_key = args.api_key
file_glob = args.file_glob
if not os.path.isdir(data_dir):
print(f"Error: {data_dir} is not a directory")
sys.exit(1)
script_path = os.path.realpath(__file__)
files = sorted(
os.path.join(data_dir, f)
for f in os.listdir(data_dir)
if os.path.isfile(os.path.join(data_dir, f))
and os.path.realpath(os.path.join(data_dir, f)) != script_path
and (file_glob is None or fnmatch.fnmatch(f, file_glob))
)
if not files:
print(f"No files found in {data_dir}")
sys.exit(1)
print(f"Found {len(files)} file(s) in {data_dir}\n")
session = get_authenticated_session(api_key)
success = 0
failed = 0
for file_path in files:
if process_file(session, base_url, file_path):
success += 1
else:
failed += 1
# Small delay to avoid overwhelming the server
time.sleep(0.5)
print(f"\nDone: {success} succeeded, {failed} failed out of {len(files)} files.")
if __name__ == "__main__":
main()

View File

@@ -45,6 +45,21 @@ npx playwright test <TEST_NAME>
Shared fixtures live in `backend/tests/conftest.py`. Test subdirectories can define
their own `conftest.py` for directory-scoped fixtures.
## Running Tests Repeatedly (`pytest-repeat`)
Use `pytest-repeat` to catch flaky tests by running them multiple times:
```bash
# Run a specific test 50 times
pytest --count=50 backend/tests/unit/path/to/test.py::test_name
# Stop on first failure with -x
pytest --count=50 -x backend/tests/unit/path/to/test.py::test_name
# Repeat an entire test file
pytest --count=10 backend/tests/unit/path/to/test_file.py
```
## Best Practices
### Use `enable_ee` fixture instead of inlining

View File

@@ -0,0 +1,274 @@
"""
External dependency unit tests for user file delete queue protections.
Verifies that the three mechanisms added to check_for_user_file_delete work
correctly:
1. Queue depth backpressure when the broker queue exceeds
USER_FILE_DELETE_MAX_QUEUE_DEPTH, no new tasks are enqueued.
2. Per-file Redis guard key if the guard key for a file already exists in
Redis, that file is skipped even though it is still in DELETING status.
3. Task expiry every send_task call carries expires=
CELERY_USER_FILE_DELETE_TASK_EXPIRES so that stale queued tasks are
discarded by workers automatically.
Also verifies that delete_user_file_impl clears the guard key the moment
it is picked up by a worker.
Uses real Redis (DB 0 via get_redis_client) and real PostgreSQL for UserFile
rows. The Celery app is provided as a MagicMock injected via a PropertyMock
on the task class so no real broker is needed.
"""
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
from unittest.mock import PropertyMock
from uuid import uuid4
from sqlalchemy.orm import Session
from onyx.background.celery.tasks.user_file_processing.tasks import (
_user_file_delete_lock_key,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
_user_file_delete_queued_key,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
check_for_user_file_delete,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
process_single_user_file_delete,
)
from onyx.configs.constants import CELERY_USER_FILE_DELETE_TASK_EXPIRES
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import USER_FILE_DELETE_MAX_QUEUE_DEPTH
from onyx.db.enums import UserFileStatus
from onyx.db.models import UserFile
from onyx.redis.redis_pool import get_redis_client
from tests.external_dependency_unit.conftest import create_test_user
from tests.external_dependency_unit.constants import TEST_TENANT_ID
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_PATCH_QUEUE_LEN = (
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_queue_length"
)
def _create_deleting_user_file(db_session: Session, user_id: object) -> UserFile:
"""Insert a UserFile in DELETING status and return it."""
uf = UserFile(
id=uuid4(),
user_id=user_id,
file_id=f"test_file_{uuid4().hex[:8]}",
name=f"test_{uuid4().hex[:8]}.txt",
file_type="text/plain",
status=UserFileStatus.DELETING,
)
db_session.add(uf)
db_session.commit()
db_session.refresh(uf)
return uf
@contextmanager
def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, None]:
"""Patch the ``app`` property on *task*'s class so that ``self.app``
inside the task function returns *mock_app*.
With ``bind=True``, ``task.run`` is a bound method whose ``__self__`` is
the actual task instance. We patch ``app`` on that instance's class
(a unique Celery-generated Task subclass) so the mock is scoped to this
task only.
"""
task_instance = task.run.__self__
with patch.object(
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
):
yield
# ---------------------------------------------------------------------------
# Test classes
# ---------------------------------------------------------------------------
class TestDeleteQueueDepthBackpressure:
"""Protection 1: skip all enqueuing when the broker queue is too deep."""
def test_no_tasks_enqueued_when_queue_over_limit(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""When the queue depth exceeds the limit the beat cycle is skipped."""
user = create_test_user(db_session, "del_bp_user")
_create_deleting_user_file(db_session, user.id)
mock_app = MagicMock()
with (
_patch_task_app(check_for_user_file_delete, mock_app),
patch(_PATCH_QUEUE_LEN, return_value=USER_FILE_DELETE_MAX_QUEUE_DEPTH + 1),
):
check_for_user_file_delete.run(tenant_id=TEST_TENANT_ID)
mock_app.send_task.assert_not_called()
class TestDeletePerFileGuardKey:
"""Protection 2: per-file Redis guard key prevents duplicate enqueue."""
def test_guarded_file_not_re_enqueued(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""A file whose guard key is already set in Redis is skipped."""
user = create_test_user(db_session, "del_guard_user")
uf = _create_deleting_user_file(db_session, user.id)
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_delete_queued_key(uf.id)
redis_client.setex(guard_key, CELERY_USER_FILE_DELETE_TASK_EXPIRES, 1)
mock_app = MagicMock()
try:
with (
_patch_task_app(check_for_user_file_delete, mock_app),
patch(_PATCH_QUEUE_LEN, return_value=0),
):
check_for_user_file_delete.run(tenant_id=TEST_TENANT_ID)
# send_task must not have been called with this specific file's ID
for call in mock_app.send_task.call_args_list:
kwargs = call.kwargs.get("kwargs", {})
assert kwargs.get("user_file_id") != str(
uf.id
), f"File {uf.id} should have been skipped because its guard key exists"
finally:
redis_client.delete(guard_key)
def test_guard_key_exists_in_redis_after_enqueue(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""After a file is enqueued its guard key is present in Redis with a TTL."""
user = create_test_user(db_session, "del_guard_set_user")
uf = _create_deleting_user_file(db_session, user.id)
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_delete_queued_key(uf.id)
redis_client.delete(guard_key) # clean slate
mock_app = MagicMock()
try:
with (
_patch_task_app(check_for_user_file_delete, mock_app),
patch(_PATCH_QUEUE_LEN, return_value=0),
):
check_for_user_file_delete.run(tenant_id=TEST_TENANT_ID)
assert redis_client.exists(
guard_key
), "Guard key should be set in Redis after enqueue"
ttl = int(redis_client.ttl(guard_key)) # type: ignore[arg-type]
assert (
0 < ttl <= CELERY_USER_FILE_DELETE_TASK_EXPIRES
), f"Guard key TTL {ttl}s is outside the expected range (0, {CELERY_USER_FILE_DELETE_TASK_EXPIRES}]"
finally:
redis_client.delete(guard_key)
class TestDeleteTaskExpiry:
"""Protection 3: every send_task call includes an expires value."""
def test_send_task_called_with_expires(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""send_task is called with the correct queue, task name, and expires."""
user = create_test_user(db_session, "del_expires_user")
uf = _create_deleting_user_file(db_session, user.id)
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_delete_queued_key(uf.id)
redis_client.delete(guard_key)
mock_app = MagicMock()
try:
with (
_patch_task_app(check_for_user_file_delete, mock_app),
patch(_PATCH_QUEUE_LEN, return_value=0),
):
check_for_user_file_delete.run(tenant_id=TEST_TENANT_ID)
# At least one task should have been submitted (for our file)
assert (
mock_app.send_task.call_count >= 1
), "Expected at least one task to be submitted"
# Every submitted task must carry expires
for call in mock_app.send_task.call_args_list:
assert call.args[0] == OnyxCeleryTask.DELETE_SINGLE_USER_FILE
assert call.kwargs.get("queue") == OnyxCeleryQueues.USER_FILE_DELETE
assert (
call.kwargs.get("expires") == CELERY_USER_FILE_DELETE_TASK_EXPIRES
), "Task must be submitted with the correct expires value to prevent stale task accumulation"
finally:
redis_client.delete(guard_key)
class TestDeleteWorkerClearsGuardKey:
"""process_single_user_file_delete removes the guard key when it picks up a task."""
def test_guard_key_deleted_on_pickup(
self,
tenant_context: None, # noqa: ARG002
) -> None:
"""The guard key is deleted before the worker does any real work.
We simulate an already-locked file so delete_user_file_impl returns
early but crucially, after the guard key deletion.
"""
user_file_id = str(uuid4())
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_delete_queued_key(user_file_id)
# Simulate the guard key set when the beat enqueued the task
redis_client.setex(guard_key, CELERY_USER_FILE_DELETE_TASK_EXPIRES, 1)
assert redis_client.exists(guard_key), "Guard key must exist before pickup"
# Hold the per-file delete lock so the worker exits early without
# touching the database or file store.
lock_key = _user_file_delete_lock_key(user_file_id)
delete_lock = redis_client.lock(lock_key, timeout=10)
acquired = delete_lock.acquire(blocking=False)
assert acquired, "Should be able to acquire the delete lock for this test"
try:
process_single_user_file_delete.run(
user_file_id=user_file_id,
tenant_id=TEST_TENANT_ID,
)
finally:
if delete_lock.owned():
delete_lock.release()
assert not redis_client.exists(
guard_key
), "Guard key should be deleted when the worker picks up the task"

View File

@@ -297,6 +297,10 @@ def index_batch_params(
class TestDocumentIndexOld:
"""Tests the old DocumentIndex interface."""
# TODO(ENG-3864)(andrei): Re-enable this test.
@pytest.mark.xfail(
reason="Flaky test: Retrieved chunks vary non-deterministically before and after changing user projects and personas. Likely a timing issue with the index being updated."
)
def test_update_single_can_clear_user_projects_and_personas(
self,
document_indices: list[DocumentIndex],

View File

@@ -22,18 +22,26 @@ from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.client import OpenSearchIndexClient
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
from onyx.document_index.opensearch.constants import HybridSearchNormalizationPipeline
from onyx.document_index.opensearch.constants import HybridSearchSubqueryConfiguration
from onyx.document_index.opensearch.opensearch_document_index import (
generate_opensearch_filtered_access_control_list,
)
from onyx.document_index.opensearch.schema import CONTENT_FIELD_NAME
from onyx.document_index.opensearch.schema import DocumentChunk
from onyx.document_index.opensearch.schema import DocumentChunkWithoutVectors
from onyx.document_index.opensearch.schema import DocumentSchema
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
from onyx.document_index.opensearch.search import DocumentQuery
from onyx.document_index.opensearch.search import (
MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
get_min_max_normalization_pipeline_name_and_config,
)
from onyx.document_index.opensearch.search import (
get_normalization_pipeline_name_and_config,
)
from onyx.document_index.opensearch.search import (
get_zscore_normalization_pipeline_name_and_config,
)
from onyx.document_index.opensearch.search import MIN_MAX_NORMALIZATION_PIPELINE_NAME
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
@@ -49,6 +57,63 @@ def _patch_global_tenant_state(monkeypatch: pytest.MonkeyPatch, state: bool) ->
monkeypatch.setattr("onyx.document_index.opensearch.schema.MULTI_TENANT", state)
def _patch_hybrid_search_subquery_configuration(
monkeypatch: pytest.MonkeyPatch, configuration: HybridSearchSubqueryConfiguration
) -> None:
"""
Patches HYBRID_SEARCH_SUBQUERY_CONFIGURATION wherever necessary for this
test file.
Args:
monkeypatch: The test instance's monkeypatch instance, used for
patching.
configuration: The intended state of
HYBRID_SEARCH_SUBQUERY_CONFIGURATION.
"""
monkeypatch.setattr(
"onyx.document_index.opensearch.constants.HYBRID_SEARCH_SUBQUERY_CONFIGURATION",
configuration,
)
monkeypatch.setattr(
"onyx.document_index.opensearch.search.HYBRID_SEARCH_SUBQUERY_CONFIGURATION",
configuration,
)
def _patch_hybrid_search_normalization_pipeline(
monkeypatch: pytest.MonkeyPatch, pipeline: HybridSearchNormalizationPipeline
) -> None:
"""
Patches HYBRID_SEARCH_NORMALIZATION_PIPELINE wherever necessary for this
test file.
"""
monkeypatch.setattr(
"onyx.document_index.opensearch.constants.HYBRID_SEARCH_NORMALIZATION_PIPELINE",
pipeline,
)
monkeypatch.setattr(
"onyx.document_index.opensearch.search.HYBRID_SEARCH_NORMALIZATION_PIPELINE",
pipeline,
)
def _patch_opensearch_match_highlights_disabled(
monkeypatch: pytest.MonkeyPatch, disabled: bool
) -> None:
"""
Patches OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED wherever necessary for this
test file.
"""
monkeypatch.setattr(
"onyx.configs.app_configs.OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED",
disabled,
)
monkeypatch.setattr(
"onyx.document_index.opensearch.search.OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED",
disabled,
)
def _create_test_document_chunk(
document_id: str,
content: str,
@@ -144,14 +209,27 @@ def test_client(
@pytest.fixture(scope="function")
def search_pipeline(test_client: OpenSearchIndexClient) -> Generator[None, None, None]:
"""Creates a search pipeline for testing with automatic cleanup."""
min_max_normalization_pipeline_name, min_max_normalization_pipeline_config = (
get_min_max_normalization_pipeline_name_and_config()
)
zscore_normalization_pipeline_name, zscore_normalization_pipeline_config = (
get_zscore_normalization_pipeline_name_and_config()
)
test_client.create_search_pipeline(
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
pipeline_body=MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
pipeline_id=min_max_normalization_pipeline_name,
pipeline_body=min_max_normalization_pipeline_config,
)
test_client.create_search_pipeline(
pipeline_id=zscore_normalization_pipeline_name,
pipeline_body=zscore_normalization_pipeline_config,
)
yield # Test runs here.
try:
test_client.delete_search_pipeline(
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
pipeline_id=min_max_normalization_pipeline_name,
)
test_client.delete_search_pipeline(
pipeline_id=zscore_normalization_pipeline_name,
)
except Exception:
pass
@@ -166,7 +244,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
# Under test.
# Should not raise.
@@ -182,7 +260,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Under test.
@@ -211,7 +289,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
@@ -225,7 +303,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
# Under test and postcondition.
# Should return False before creation.
@@ -245,7 +323,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Under test.
@@ -280,7 +358,7 @@ class TestOpenSearchClient:
},
},
}
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=initial_mappings, settings=settings)
# Under test.
@@ -323,7 +401,7 @@ class TestOpenSearchClient:
"test_field": {"type": "keyword"},
},
}
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=initial_mappings, settings=settings)
# Under test and postcondition.
@@ -358,7 +436,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
# Create once - should succeed.
test_client.create_index(mappings=mappings, settings=settings)
@@ -377,18 +455,19 @@ class TestOpenSearchClient:
self, test_client: OpenSearchIndexClient
) -> None:
"""Tests creating and deleting a search pipeline."""
# Precondition.
pipeline_name, pipeline_config = get_normalization_pipeline_name_and_config()
# Under test and postcondition.
# Should not raise.
test_client.create_search_pipeline(
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
pipeline_body=MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
pipeline_id=pipeline_name,
pipeline_body=pipeline_config,
)
# Under test and postcondition.
# Should not raise.
test_client.delete_search_pipeline(
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME
)
test_client.delete_search_pipeline(pipeline_id=pipeline_name)
def test_index_document(
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
@@ -400,7 +479,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
doc = _create_test_document_chunk(
@@ -428,7 +507,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
docs = [
@@ -459,7 +538,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
doc = _create_test_document_chunk(
@@ -487,7 +566,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
original_doc = _create_test_document_chunk(
@@ -522,7 +601,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=False
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Under test and postcondition.
@@ -541,7 +620,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
doc = _create_test_document_chunk(
@@ -577,7 +656,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Under test.
@@ -598,7 +677,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index multiple documents.
@@ -674,7 +753,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Create a document to update.
@@ -723,7 +802,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Under test and postcondition.
@@ -734,22 +813,22 @@ class TestOpenSearchClient:
properties_to_update={"hidden": True},
)
def test_hybrid_search_with_pipeline(
def test_hybrid_search_configurations_and_pipelines(
self,
test_client: OpenSearchIndexClient,
search_pipeline: None, # noqa: ARG002
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Tests hybrid search with a normalization pipeline."""
"""Tests all hybrid search configurations and pipelines."""
# Precondition.
_patch_global_tenant_state(monkeypatch, False)
_patch_opensearch_match_highlights_disabled(monkeypatch, False)
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index documents.
docs = {
"doc-1": _create_test_document_chunk(
@@ -780,40 +859,62 @@ class TestOpenSearchClient:
# Refresh index to make documents searchable.
test_client.refresh_index()
# Search query.
query_text = "Python programming"
query_vector = _generate_test_vector(0.12)
search_body = DocumentQuery.get_hybrid_search_query(
query_text=query_text,
query_vector=query_vector,
num_hits=5,
tenant_state=tenant_state,
# We're not worried about filtering here. tenant_id in this object
# is not relevant.
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
include_hidden=False,
)
for configuration in HybridSearchSubqueryConfiguration:
_patch_hybrid_search_subquery_configuration(monkeypatch, configuration)
for pipeline in HybridSearchNormalizationPipeline:
_patch_hybrid_search_normalization_pipeline(monkeypatch, pipeline)
pipeline_name, pipeline_config = (
get_normalization_pipeline_name_and_config()
)
test_client.create_search_pipeline(
pipeline_id=pipeline_name,
pipeline_body=pipeline_config,
)
# Under test.
results = test_client.search(
body=search_body, search_pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME
)
# Search query.
query_text = "Python programming"
query_vector = _generate_test_vector(0.12)
search_body = DocumentQuery.get_hybrid_search_query(
query_text=query_text,
query_vector=query_vector,
num_hits=5,
tenant_state=tenant_state,
# We're not worried about filtering here. tenant_id in this object
# is not relevant.
index_filters=IndexFilters(
access_control_list=None, tenant_id=None
),
include_hidden=False,
)
# Postcondition.
assert len(results) == len(docs)
# Assert that all the chunks above are present.
assert all(chunk.document_chunk.document_id in docs.keys() for chunk in results)
# Make sure the chunk contents are preserved.
for i, chunk in enumerate(results):
assert chunk.document_chunk == docs[chunk.document_chunk.document_id]
# Make sure score reporting seems reasonable (it should not be None
# or 0).
assert chunk.score
# Make sure there is some kind of match highlight only for the first
# result. The other results are so bad they're not expected to have
# match highlights.
if i == 0:
assert chunk.match_highlights.get(CONTENT_FIELD_NAME, [])
# Under test.
results = test_client.search(
body=search_body, search_pipeline_id=pipeline_name
)
# Postcondition.
assert len(results) == len(docs)
# Assert that all the chunks above are present.
assert all(
chunk.document_chunk.document_id in docs.keys() for chunk in results
)
# Make sure the chunk contents are preserved.
for i, chunk in enumerate(results):
expected = docs[chunk.document_chunk.document_id]
assert chunk.document_chunk == DocumentChunkWithoutVectors(
**{
k: getattr(expected, k)
for k in DocumentChunkWithoutVectors.model_fields
}
)
# Make sure score reporting seems reasonable (it should not be None
# or 0).
assert chunk.score
# Make sure there is some kind of match highlight only for the first
# result. The other results are so bad they're not expected to have
# match highlights.
if i == 0:
assert chunk.match_highlights.get(CONTENT_FIELD_NAME, [])
def test_search_empty_index(
self,
@@ -828,7 +929,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Note no documents were indexed.
@@ -845,11 +946,10 @@ class TestOpenSearchClient:
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
include_hidden=False,
)
pipeline_name, _ = get_normalization_pipeline_name_and_config()
# Under test.
results = test_client.search(
body=search_body, search_pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME
)
results = test_client.search(body=search_body, search_pipeline_id=pipeline_name)
# Postcondition.
assert len(results) == 0
@@ -865,12 +965,13 @@ class TestOpenSearchClient:
"""
# Precondition.
_patch_global_tenant_state(monkeypatch, True)
_patch_opensearch_match_highlights_disabled(monkeypatch, False)
tenant_x = TenantState(tenant_id="tenant-x", multitenant=True)
tenant_y = TenantState(tenant_id="tenant-y", multitenant=True)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_x.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index documents with different public/hidden and tenant states.
@@ -948,11 +1049,10 @@ class TestOpenSearchClient:
),
include_hidden=False,
)
pipeline_name, _ = get_normalization_pipeline_name_and_config()
# Under test.
results = test_client.search(
body=search_body, search_pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME
)
results = test_client.search(body=search_body, search_pipeline_id=pipeline_name)
# Postcondition.
# Should only get the public, non-hidden document, and the private
@@ -962,7 +1062,12 @@ class TestOpenSearchClient:
# ordered; we're just assuming which doc will be the first result here.
assert results[0].document_chunk.document_id == "public-doc"
# Make sure the chunk contents are preserved.
assert results[0].document_chunk == docs["public-doc"]
assert results[0].document_chunk == DocumentChunkWithoutVectors(
**{
k: getattr(docs["public-doc"], k)
for k in DocumentChunkWithoutVectors.model_fields
}
)
# Make sure score reporting seems reasonable (it should not be None
# or 0).
assert results[0].score
@@ -970,7 +1075,12 @@ class TestOpenSearchClient:
assert results[0].match_highlights.get(CONTENT_FIELD_NAME, [])
# Same for the second result.
assert results[1].document_chunk.document_id == "private-doc-user-a"
assert results[1].document_chunk == docs["private-doc-user-a"]
assert results[1].document_chunk == DocumentChunkWithoutVectors(
**{
k: getattr(docs["private-doc-user-a"], k)
for k in DocumentChunkWithoutVectors.model_fields
}
)
assert results[1].score
assert results[1].match_highlights.get(CONTENT_FIELD_NAME, [])
@@ -986,11 +1096,12 @@ class TestOpenSearchClient:
"""
# Precondition.
_patch_global_tenant_state(monkeypatch, True)
_patch_opensearch_match_highlights_disabled(monkeypatch, False)
tenant_x = TenantState(tenant_id="tenant-x", multitenant=True)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_x.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index documents with varying relevance to the query.
@@ -1067,11 +1178,10 @@ class TestOpenSearchClient:
index_filters=IndexFilters(access_control_list=[], tenant_id=None),
include_hidden=False,
)
pipeline_name, _ = get_normalization_pipeline_name_and_config()
# Under test.
results = test_client.search(
body=search_body, search_pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME
)
results = test_client.search(body=search_body, search_pipeline_id=pipeline_name)
# Postcondition.
# Should only get public, non-hidden documents (3 out of 5).
@@ -1118,7 +1228,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_x.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Although very unlikely in practice, let's use the same doc ID just to
@@ -1211,7 +1321,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Don't index any documents.
@@ -1238,7 +1348,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index chunks for two different documents.
@@ -1306,7 +1416,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index documents with different public/hidden and tenant states.
@@ -1383,7 +1493,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index docs with various ages.
@@ -1441,15 +1551,16 @@ class TestOpenSearchClient:
),
include_hidden=False,
)
pipeline_name, _ = get_normalization_pipeline_name_and_config()
# Under test.
last_week_results = test_client.search(
body=last_week_search_body,
search_pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
search_pipeline_id=pipeline_name,
)
last_six_months_results = test_client.search(
body=last_six_months_search_body,
search_pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
search_pipeline_id=pipeline_name,
)
# Postcondition.
@@ -1474,7 +1585,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index chunks for two different documents, one hidden one not.
@@ -1523,4 +1634,281 @@ class TestOpenSearchClient:
for result in results:
# Note each result must be from doc 1, which is not hidden.
expected_result = doc1_chunks[result.document_chunk.chunk_index]
assert result.document_chunk == expected_result
assert result.document_chunk == DocumentChunkWithoutVectors(
**{
k: getattr(expected_result, k)
for k in DocumentChunkWithoutVectors.model_fields
}
)
def test_keyword_search(
self,
test_client: OpenSearchIndexClient,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""
Tests keyword search with filters for ACL, hidden documents, and tenant
isolation.
"""
# Precondition.
_patch_global_tenant_state(monkeypatch, True)
_patch_opensearch_match_highlights_disabled(monkeypatch, False)
tenant_x = TenantState(tenant_id="tenant-x", multitenant=True)
tenant_y = TenantState(tenant_id="tenant-y", multitenant=True)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_x.multitenant
)
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index documents with different public/hidden and tenant states.
docs = {
"public-doc": _create_test_document_chunk(
document_id="public-doc",
chunk_index=0,
content="Public document content",
hidden=False,
tenant_state=tenant_x,
),
"hidden-doc": _create_test_document_chunk(
document_id="hidden-doc",
chunk_index=0,
content="Hidden document content, spooky",
hidden=True,
tenant_state=tenant_x,
),
"private-doc-user-a": _create_test_document_chunk(
document_id="private-doc-user-a",
chunk_index=0,
content="Private document content, btw my SSN is 123-45-6789",
hidden=False,
tenant_state=tenant_x,
document_access=DocumentAccess.build(
user_emails=["user-a@example.com"],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=False,
),
),
# Tests that we don't return documents that don't match keywords at
# all, even if they match filters.
"private-but-not-relevant-doc-user-a": _create_test_document_chunk(
document_id="private-but-not-relevant-doc-user-a",
chunk_index=0,
content="This text should not match the query at all",
hidden=False,
tenant_state=tenant_x,
document_access=DocumentAccess.build(
user_emails=["user-a@example.com"],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=False,
),
),
"private-doc-user-b": _create_test_document_chunk(
document_id="private-doc-user-b",
chunk_index=0,
content="Private document content, btw my SSN is 987-65-4321",
hidden=False,
tenant_state=tenant_x,
document_access=DocumentAccess.build(
user_emails=["user-b@example.com"],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=False,
),
),
"should-not-exist-from-tenant-x-pov": _create_test_document_chunk(
document_id="should-not-exist-from-tenant-x-pov",
chunk_index=0,
content="This is an entirely different tenant, x should never see this",
# Make this as permissive as possible to exercise tenant
# isolation.
hidden=False,
tenant_state=tenant_y,
),
}
for doc in docs.values():
test_client.index_document(document=doc, tenant_state=doc.tenant_id)
# Refresh index to make documents searchable.
test_client.refresh_index()
# Should not match private-but-not-relevant-doc-user-a.
query_text = "document content"
search_body = DocumentQuery.get_keyword_search_query(
query_text=query_text,
num_hits=5,
tenant_state=tenant_x,
# The user should only be able to see their private docs. tenant_id
# in this object is not relevant.
index_filters=IndexFilters(
access_control_list=[prefix_user_email("user-a@example.com")],
tenant_id=None,
),
include_hidden=False,
)
# Under test.
results = test_client.search(body=search_body, search_pipeline_id=None)
# Postcondition.
# Should only get the public, non-hidden document, and the private
# document for which the user has access.
assert len(results) == 2
# This should be the highest-ranked result, as a higher percentage of
# the content matches the query.
assert results[0].document_chunk.document_id == "public-doc"
# Make sure the chunk contents are preserved.
assert results[0].document_chunk == DocumentChunkWithoutVectors(
**{
k: getattr(docs["public-doc"], k)
for k in DocumentChunkWithoutVectors.model_fields
}
)
# Make sure score reporting seems reasonable (it should not be None
# or 0).
assert results[0].score
# Make sure there is some kind of match highlight.
assert results[0].match_highlights.get(CONTENT_FIELD_NAME, [])
# Same for the second result.
assert results[1].document_chunk.document_id == "private-doc-user-a"
assert results[1].document_chunk == DocumentChunkWithoutVectors(
**{
k: getattr(docs["private-doc-user-a"], k)
for k in DocumentChunkWithoutVectors.model_fields
}
)
assert results[1].score
assert results[1].match_highlights.get(CONTENT_FIELD_NAME, [])
assert results[1].score < results[0].score
def test_semantic_search(
self,
test_client: OpenSearchIndexClient,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""
Tests semantic search with filters for ACL, hidden documents, and tenant
isolation.
"""
# Precondition.
_patch_global_tenant_state(monkeypatch, True)
tenant_x = TenantState(tenant_id="tenant-x", multitenant=True)
tenant_y = TenantState(tenant_id="tenant-y", multitenant=True)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_x.multitenant
)
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index documents with different public/hidden and tenant states.
docs = {
"public-doc": _create_test_document_chunk(
document_id="public-doc",
chunk_index=0,
content="Public document content",
hidden=False,
tenant_state=tenant_x,
# Make this identical to the query vector to test that this
# result is returned first.
content_vector=_generate_test_vector(0.6),
),
"hidden-doc": _create_test_document_chunk(
document_id="hidden-doc",
chunk_index=0,
content="Hidden document content, spooky",
hidden=True,
tenant_state=tenant_x,
),
"private-doc-user-a": _create_test_document_chunk(
document_id="private-doc-user-a",
chunk_index=0,
content="Private document content, btw my SSN is 123-45-6789",
hidden=False,
tenant_state=tenant_x,
document_access=DocumentAccess.build(
user_emails=["user-a@example.com"],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=False,
),
# Make this different from the query vector to test that this
# result is returned second.
content_vector=_generate_test_vector(0.5),
),
"private-doc-user-b": _create_test_document_chunk(
document_id="private-doc-user-b",
chunk_index=0,
content="Private document content, btw my SSN is 987-65-4321",
hidden=False,
tenant_state=tenant_x,
document_access=DocumentAccess.build(
user_emails=["user-b@example.com"],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=False,
),
),
"should-not-exist-from-tenant-x-pov": _create_test_document_chunk(
document_id="should-not-exist-from-tenant-x-pov",
chunk_index=0,
content="This is an entirely different tenant, x should never see this",
# Make this as permissive as possible to exercise tenant
# isolation.
hidden=False,
tenant_state=tenant_y,
),
}
for doc in docs.values():
test_client.index_document(document=doc, tenant_state=doc.tenant_id)
# Refresh index to make documents searchable.
test_client.refresh_index()
query_vector = _generate_test_vector(0.6)
search_body = DocumentQuery.get_semantic_search_query(
query_embedding=query_vector,
num_hits=5,
tenant_state=tenant_x,
# The user should only be able to see their private docs. tenant_id
# in this object is not relevant.
index_filters=IndexFilters(
access_control_list=[prefix_user_email("user-a@example.com")],
tenant_id=None,
),
include_hidden=False,
)
# Under test.
results = test_client.search(body=search_body, search_pipeline_id=None)
# Postcondition.
# Should only get the public, non-hidden document, and the private
# document for which the user has access.
assert len(results) == 2
# We explicitly expect this to be the highest-ranked result.
assert results[0].document_chunk.document_id == "public-doc"
# Make sure the chunk contents are preserved.
assert results[0].document_chunk == DocumentChunkWithoutVectors(
**{
k: getattr(docs["public-doc"], k)
for k in DocumentChunkWithoutVectors.model_fields
}
)
assert results[0].score == 1.0
# Same for the second result.
assert results[1].document_chunk.document_id == "private-doc-user-a"
assert results[1].document_chunk == DocumentChunkWithoutVectors(
**{
k: getattr(docs["private-doc-user-a"], k)
for k in DocumentChunkWithoutVectors.model_fields
}
)
assert results[1].score
assert 0.0 < results[1].score < 1.0

View File

@@ -31,7 +31,6 @@ from onyx.background.celery.tasks.opensearch_migration.transformer import (
)
from onyx.configs.constants import PUBLIC_DOC_PAT
from onyx.configs.constants import SOURCE_TYPE
from onyx.context.search.models import IndexFilters
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import Document
from onyx.db.models import OpenSearchDocumentMigrationRecord
@@ -44,6 +43,7 @@ from onyx.document_index.opensearch.client import OpenSearchIndexClient
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
from onyx.document_index.opensearch.schema import DocumentChunk
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
from onyx.document_index.opensearch.search import DocumentQuery
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
@@ -70,6 +70,7 @@ from onyx.document_index.vespa_constants import SOURCE_LINKS
from onyx.document_index.vespa_constants import TITLE
from onyx.document_index.vespa_constants import TITLE_EMBEDDING
from onyx.document_index.vespa_constants import USER_PROJECT
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
from tests.external_dependency_unit.full_setup import ensure_full_deployment_setup
@@ -78,24 +79,22 @@ CHUNK_COUNT = 5
def _get_document_chunks_from_opensearch(
opensearch_client: OpenSearchIndexClient, document_id: str, current_tenant_id: str
opensearch_client: OpenSearchIndexClient,
document_id: str,
tenant_state: TenantState,
) -> list[DocumentChunk]:
opensearch_client.refresh_index()
filters = IndexFilters(access_control_list=None, tenant_id=current_tenant_id)
query_body = DocumentQuery.get_from_document_id_query(
document_id=document_id,
tenant_state=TenantState(tenant_id=current_tenant_id, multitenant=False),
index_filters=filters,
include_hidden=False,
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
min_chunk_index=None,
max_chunk_index=None,
)
search_hits = opensearch_client.search(
body=query_body,
search_pipeline_id=None,
)
return [search_hit.document_chunk for search_hit in search_hits]
results: list[DocumentChunk] = []
for i in range(CHUNK_COUNT):
document_chunk_id: str = get_opensearch_doc_chunk_id(
tenant_state=tenant_state,
document_id=document_id,
chunk_index=i,
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
)
result = opensearch_client.get_document(document_chunk_id)
results.append(result)
return results
def _delete_document_chunks_from_opensearch(
@@ -452,10 +451,13 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
for chunks in document_chunks.values():
all_chunks.extend(chunks)
vespa_document_index.index_raw_chunks(all_chunks)
tenant_state = TenantState(
tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT
)
# Under test.
result = migrate_chunks_from_vespa_to_opensearch_task(
tenant_id=get_current_tenant_id()
tenant_id=tenant_state.tenant_id
)
# Postcondition.
@@ -477,7 +479,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
# Verify chunks were indexed in OpenSearch.
for document in test_documents:
opensearch_chunks = _get_document_chunks_from_opensearch(
opensearch_client, document.id, get_current_tenant_id()
opensearch_client, document.id, tenant_state
)
assert len(opensearch_chunks) == CHUNK_COUNT
opensearch_chunks.sort(key=lambda x: x.chunk_index)
@@ -522,6 +524,9 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
for chunks in document_chunks.values():
all_chunks.extend(chunks)
vespa_document_index.index_raw_chunks(all_chunks)
tenant_state = TenantState(
tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT
)
# Run the initial batch. To simulate partial progress we will mock the
# redis lock to return True for the first invocation of .owned() and
@@ -536,7 +541,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
return_value=mock_redis_client,
):
result_1 = migrate_chunks_from_vespa_to_opensearch_task(
tenant_id=get_current_tenant_id()
tenant_id=tenant_state.tenant_id
)
assert result_1 is True
@@ -559,7 +564,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
# Under test.
# Run the remainder of the migration.
result_2 = migrate_chunks_from_vespa_to_opensearch_task(
tenant_id=get_current_tenant_id()
tenant_id=tenant_state.tenant_id
)
# Postcondition.
@@ -583,7 +588,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
# Verify chunks were indexed in OpenSearch.
for document in test_documents:
opensearch_chunks = _get_document_chunks_from_opensearch(
opensearch_client, document.id, get_current_tenant_id()
opensearch_client, document.id, tenant_state
)
assert len(opensearch_chunks) == CHUNK_COUNT
opensearch_chunks.sort(key=lambda x: x.chunk_index)
@@ -630,6 +635,9 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
for chunks in document_chunks.values():
all_chunks.extend(chunks)
vespa_document_index.index_raw_chunks(all_chunks)
tenant_state = TenantState(
tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT
)
# Run the initial batch. To simulate partial progress we will mock the
# redis lock to return True for the first invocation of .owned() and
@@ -646,7 +654,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
return_value=mock_redis_client,
):
result_1 = migrate_chunks_from_vespa_to_opensearch_task(
tenant_id=get_current_tenant_id()
tenant_id=tenant_state.tenant_id
)
assert result_1 is True
@@ -691,7 +699,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
),
):
result_2 = migrate_chunks_from_vespa_to_opensearch_task(
tenant_id=get_current_tenant_id()
tenant_id=tenant_state.tenant_id
)
# Postcondition.
@@ -728,7 +736,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
),
):
result_3 = migrate_chunks_from_vespa_to_opensearch_task(
tenant_id=get_current_tenant_id()
tenant_id=tenant_state.tenant_id
)
# Postcondition.
@@ -752,7 +760,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
# Verify chunks were indexed in OpenSearch.
for document in test_documents:
opensearch_chunks = _get_document_chunks_from_opensearch(
opensearch_client, document.id, get_current_tenant_id()
opensearch_client, document.id, tenant_state
)
assert len(opensearch_chunks) == CHUNK_COUNT
opensearch_chunks.sort(key=lambda x: x.chunk_index)
@@ -840,24 +848,25 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
chunk["content"] = (
f"Different content {chunk[CHUNK_ID]} for {test_documents[0].id}"
)
tenant_state = TenantState(
tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT
)
chunks_for_document_in_opensearch, _ = (
transform_vespa_chunks_to_opensearch_chunks(
document_in_opensearch,
TenantState(tenant_id=get_current_tenant_id(), multitenant=False),
tenant_state,
{},
)
)
opensearch_client.bulk_index_documents(
documents=chunks_for_document_in_opensearch,
tenant_state=TenantState(
tenant_id=get_current_tenant_id(), multitenant=False
),
tenant_state=tenant_state,
update_if_exists=True,
)
# Under test.
result = migrate_chunks_from_vespa_to_opensearch_task(
tenant_id=get_current_tenant_id()
tenant_id=tenant_state.tenant_id
)
# Postcondition.
@@ -878,7 +887,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
# Verify chunks were indexed in OpenSearch.
for document in test_documents:
opensearch_chunks = _get_document_chunks_from_opensearch(
opensearch_client, document.id, get_current_tenant_id()
opensearch_client, document.id, tenant_state
)
assert len(opensearch_chunks) == CHUNK_COUNT
opensearch_chunks.sort(key=lambda x: x.chunk_index)
@@ -922,11 +931,14 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
for chunks in document_chunks.values():
all_chunks.extend(chunks)
vespa_document_index.index_raw_chunks(all_chunks)
tenant_state = TenantState(
tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT
)
# Under test.
# First run.
result_1 = migrate_chunks_from_vespa_to_opensearch_task(
tenant_id=get_current_tenant_id()
tenant_id=tenant_state.tenant_id
)
# Postcondition.
@@ -947,7 +959,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
# Verify chunks were indexed in OpenSearch.
for document in test_documents:
opensearch_chunks = _get_document_chunks_from_opensearch(
opensearch_client, document.id, get_current_tenant_id()
opensearch_client, document.id, tenant_state
)
assert len(opensearch_chunks) == CHUNK_COUNT
opensearch_chunks.sort(key=lambda x: x.chunk_index)
@@ -960,7 +972,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
# Under test.
# Second run.
result_2 = migrate_chunks_from_vespa_to_opensearch_task(
tenant_id=get_current_tenant_id()
tenant_id=tenant_state.tenant_id
)
# Postcondition.
@@ -982,7 +994,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
# Verify chunks were indexed in OpenSearch.
for document in test_documents:
opensearch_chunks = _get_document_chunks_from_opensearch(
opensearch_client, document.id, get_current_tenant_id()
opensearch_client, document.id, tenant_state
)
assert len(opensearch_chunks) == CHUNK_COUNT
opensearch_chunks.sort(key=lambda x: x.chunk_index)

View File

@@ -1219,15 +1219,16 @@ def test_code_interpreter_receives_chat_files(
finally:
ci_mod.CodeInterpreterClient.__init__.__defaults__ = original_defaults
# Verify: file uploaded, code executed via streaming, staged file cleaned up
# Verify: file uploaded and code executed via streaming.
assert len(mock_ci_server.get_requests(method="POST", path="/v1/files")) == 1
assert (
len(mock_ci_server.get_requests(method="POST", path="/v1/execute/stream")) == 1
)
delete_requests = mock_ci_server.get_requests(method="DELETE")
assert len(delete_requests) == 1
assert delete_requests[0].path.startswith("/v1/files/")
# Staged input files are intentionally NOT deleted — PythonTool caches their
# file IDs across agent-loop iterations to avoid re-uploading on every call.
# The code interpreter cleans them up via its own TTL.
assert len(mock_ci_server.get_requests(method="DELETE")) == 0
execute_body = mock_ci_server.get_requests(
method="POST", path="/v1/execute/stream"

View File

@@ -0,0 +1,120 @@
import pytest
from onyx.auth.users import _is_same_origin
class TestExactMatch:
"""Origins that are textually identical should always match."""
@pytest.mark.parametrize(
"origin",
[
"http://localhost:3000",
"https://app.example.com",
"https://app.example.com:8443",
"http://127.0.0.1:8080",
],
)
def test_identical_origins(self, origin: str) -> None:
assert _is_same_origin(origin, origin)
class TestLoopbackPortRelaxation:
"""On loopback addresses, port differences should be ignored."""
@pytest.mark.parametrize(
"actual,expected",
[
("http://localhost:3001", "http://localhost:3000"),
("http://localhost:8080", "http://localhost:3000"),
("http://localhost", "http://localhost:3000"),
("http://127.0.0.1:3001", "http://127.0.0.1:3000"),
("http://[::1]:3001", "http://[::1]:3000"),
],
)
def test_loopback_different_ports_accepted(
self, actual: str, expected: str
) -> None:
assert _is_same_origin(actual, expected)
@pytest.mark.parametrize(
"actual,expected",
[
("https://localhost:3001", "http://localhost:3000"),
("http://localhost:3001", "https://localhost:3000"),
],
)
def test_loopback_different_scheme_rejected(
self, actual: str, expected: str
) -> None:
assert not _is_same_origin(actual, expected)
def test_loopback_hostname_mismatch_rejected(self) -> None:
assert not _is_same_origin("http://localhost:3001", "http://127.0.0.1:3000")
class TestNonLoopbackStrictPort:
"""Non-loopback origins must match scheme, hostname, AND port."""
def test_different_port_rejected(self) -> None:
assert not _is_same_origin(
"https://app.example.com:8443", "https://app.example.com"
)
def test_different_hostname_rejected(self) -> None:
assert not _is_same_origin("https://evil.com", "https://app.example.com")
def test_different_scheme_rejected(self) -> None:
assert not _is_same_origin("http://app.example.com", "https://app.example.com")
def test_same_port_explicit(self) -> None:
assert _is_same_origin(
"https://app.example.com:443", "https://app.example.com:443"
)
class TestDefaultPortNormalization:
"""Port should be normalized so that omitted default port == explicit default port."""
def test_http_implicit_vs_explicit_80(self) -> None:
assert _is_same_origin("http://example.com", "http://example.com:80")
def test_http_explicit_80_vs_implicit(self) -> None:
assert _is_same_origin("http://example.com:80", "http://example.com")
def test_https_implicit_vs_explicit_443(self) -> None:
assert _is_same_origin("https://example.com", "https://example.com:443")
def test_https_explicit_443_vs_implicit(self) -> None:
assert _is_same_origin("https://example.com:443", "https://example.com")
def test_http_non_default_port_vs_implicit_rejected(self) -> None:
assert not _is_same_origin("http://example.com:8080", "http://example.com")
class TestTrailingSlash:
"""Trailing slashes should not affect comparison."""
def test_trailing_slash_on_actual(self) -> None:
assert _is_same_origin("https://app.example.com/", "https://app.example.com")
def test_trailing_slash_on_expected(self) -> None:
assert _is_same_origin("https://app.example.com", "https://app.example.com/")
def test_trailing_slash_on_both(self) -> None:
assert _is_same_origin("https://app.example.com/", "https://app.example.com/")
class TestCSWSHScenarios:
"""Realistic attack scenarios that must be rejected."""
def test_remote_attacker_rejected(self) -> None:
assert not _is_same_origin("https://evil.com", "http://localhost:3000")
def test_remote_attacker_same_port_rejected(self) -> None:
assert not _is_same_origin("http://evil.com:3000", "http://localhost:3000")
def test_remote_attacker_matching_hostname_different_port(self) -> None:
assert not _is_same_origin(
"https://app.example.com:9999", "https://app.example.com"
)

View File

@@ -0,0 +1,194 @@
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.background.celery.tasks.hierarchyfetching.tasks import (
_connector_supports_hierarchy_fetching,
)
from onyx.background.celery.tasks.hierarchyfetching.tasks import (
check_for_hierarchy_fetching,
)
from onyx.connectors.factory import ConnectorMissingException
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import HierarchyConnector
from onyx.connectors.interfaces import HierarchyOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
TASKS_MODULE = "onyx.background.celery.tasks.hierarchyfetching.tasks"
class _NonHierarchyConnector(BaseConnector):
def load_credentials(self, credentials: dict) -> dict | None: # noqa: ARG002
return None
class _HierarchyCapableConnector(HierarchyConnector):
def load_credentials(self, credentials: dict) -> dict | None: # noqa: ARG002
return None
def load_hierarchy(
self,
start: SecondsSinceUnixEpoch, # noqa: ARG002
end: SecondsSinceUnixEpoch, # noqa: ARG002
) -> HierarchyOutput:
return
yield
def _build_cc_pair_mock() -> MagicMock:
cc_pair = MagicMock()
cc_pair.connector.source = "mock-source"
cc_pair.connector.input_type = "mock-input-type"
return cc_pair
def _build_redis_mock_with_lock() -> tuple[MagicMock, MagicMock]:
redis_client = MagicMock()
lock = MagicMock()
lock.acquire.return_value = True
lock.owned.return_value = True
redis_client.lock.return_value = lock
return redis_client, lock
@patch(f"{TASKS_MODULE}.identify_connector_class")
def test_connector_supports_hierarchy_fetching_false_for_non_hierarchy_connector(
mock_identify_connector_class: MagicMock,
) -> None:
mock_identify_connector_class.return_value = _NonHierarchyConnector
assert _connector_supports_hierarchy_fetching(_build_cc_pair_mock()) is False
mock_identify_connector_class.assert_called_once_with("mock-source")
@patch(f"{TASKS_MODULE}.task_logger.warning")
@patch(f"{TASKS_MODULE}.identify_connector_class")
def test_connector_supports_hierarchy_fetching_false_when_class_missing(
mock_identify_connector_class: MagicMock,
mock_warning: MagicMock,
) -> None:
mock_identify_connector_class.side_effect = ConnectorMissingException("missing")
assert _connector_supports_hierarchy_fetching(_build_cc_pair_mock()) is False
mock_warning.assert_called_once()
@patch(f"{TASKS_MODULE}.identify_connector_class")
def test_connector_supports_hierarchy_fetching_true_for_supported_connector(
mock_identify_connector_class: MagicMock,
) -> None:
mock_identify_connector_class.return_value = _HierarchyCapableConnector
assert _connector_supports_hierarchy_fetching(_build_cc_pair_mock()) is True
mock_identify_connector_class.assert_called_once_with("mock-source")
@patch(f"{TASKS_MODULE}._try_creating_hierarchy_fetching_task")
@patch(f"{TASKS_MODULE}._is_hierarchy_fetching_due")
@patch(f"{TASKS_MODULE}.get_connector_credential_pair_from_id")
@patch(f"{TASKS_MODULE}.fetch_indexable_standard_connector_credential_pair_ids")
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
@patch(f"{TASKS_MODULE}.get_redis_client")
@patch(f"{TASKS_MODULE}._connector_supports_hierarchy_fetching")
def test_check_for_hierarchy_fetching_skips_unsupported_connectors(
mock_supports_hierarchy_fetching: MagicMock,
mock_get_redis_client: MagicMock,
mock_get_session: MagicMock,
mock_fetch_cc_pair_ids: MagicMock,
mock_get_cc_pair: MagicMock,
mock_is_due: MagicMock,
mock_try_create_task: MagicMock,
) -> None:
redis_client, lock = _build_redis_mock_with_lock()
mock_get_redis_client.return_value = redis_client
mock_get_session.return_value.__enter__.return_value = MagicMock()
mock_fetch_cc_pair_ids.return_value = [123]
mock_get_cc_pair.return_value = _build_cc_pair_mock()
mock_supports_hierarchy_fetching.return_value = False
mock_is_due.return_value = True
task_app = MagicMock()
with patch.object(check_for_hierarchy_fetching, "app", task_app):
result = check_for_hierarchy_fetching.run(tenant_id="test-tenant")
assert result == 0
mock_is_due.assert_not_called()
mock_try_create_task.assert_not_called()
lock.release.assert_called_once()
@patch(f"{TASKS_MODULE}._try_creating_hierarchy_fetching_task")
@patch(f"{TASKS_MODULE}._is_hierarchy_fetching_due")
@patch(f"{TASKS_MODULE}.get_connector_credential_pair_from_id")
@patch(f"{TASKS_MODULE}.fetch_indexable_standard_connector_credential_pair_ids")
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
@patch(f"{TASKS_MODULE}.get_redis_client")
@patch(f"{TASKS_MODULE}._connector_supports_hierarchy_fetching")
def test_check_for_hierarchy_fetching_creates_task_for_supported_due_connector(
mock_supports_hierarchy_fetching: MagicMock,
mock_get_redis_client: MagicMock,
mock_get_session: MagicMock,
mock_fetch_cc_pair_ids: MagicMock,
mock_get_cc_pair: MagicMock,
mock_is_due: MagicMock,
mock_try_create_task: MagicMock,
) -> None:
redis_client, lock = _build_redis_mock_with_lock()
cc_pair = _build_cc_pair_mock()
db_session = MagicMock()
mock_get_redis_client.return_value = redis_client
mock_get_session.return_value.__enter__.return_value = db_session
mock_fetch_cc_pair_ids.return_value = [123]
mock_get_cc_pair.return_value = cc_pair
mock_supports_hierarchy_fetching.return_value = True
mock_is_due.return_value = True
mock_try_create_task.return_value = "task-id"
task_app = MagicMock()
with patch.object(check_for_hierarchy_fetching, "app", task_app):
result = check_for_hierarchy_fetching.run(tenant_id="test-tenant")
assert result == 1
mock_is_due.assert_called_once_with(cc_pair)
mock_try_create_task.assert_called_once_with(
celery_app=task_app,
cc_pair=cc_pair,
db_session=db_session,
r=redis_client,
tenant_id="test-tenant",
)
lock.release.assert_called_once()
@patch(f"{TASKS_MODULE}._try_creating_hierarchy_fetching_task")
@patch(f"{TASKS_MODULE}._is_hierarchy_fetching_due")
@patch(f"{TASKS_MODULE}.get_connector_credential_pair_from_id")
@patch(f"{TASKS_MODULE}.fetch_indexable_standard_connector_credential_pair_ids")
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
@patch(f"{TASKS_MODULE}.get_redis_client")
@patch(f"{TASKS_MODULE}._connector_supports_hierarchy_fetching")
def test_check_for_hierarchy_fetching_skips_supported_connector_when_not_due(
mock_supports_hierarchy_fetching: MagicMock,
mock_get_redis_client: MagicMock,
mock_get_session: MagicMock,
mock_fetch_cc_pair_ids: MagicMock,
mock_get_cc_pair: MagicMock,
mock_is_due: MagicMock,
mock_try_create_task: MagicMock,
) -> None:
redis_client, lock = _build_redis_mock_with_lock()
cc_pair = _build_cc_pair_mock()
mock_get_redis_client.return_value = redis_client
mock_get_session.return_value.__enter__.return_value = MagicMock()
mock_fetch_cc_pair_ids.return_value = [123]
mock_get_cc_pair.return_value = cc_pair
mock_supports_hierarchy_fetching.return_value = True
mock_is_due.return_value = False
task_app = MagicMock()
with patch.object(check_for_hierarchy_fetching, "app", task_app):
result = check_for_hierarchy_fetching.run(tenant_id="test-tenant")
assert result == 0
mock_is_due.assert_called_once_with(cc_pair)
mock_try_create_task.assert_not_called()
lock.release.assert_called_once()

View File

@@ -1,9 +1,13 @@
"""Tests for llm_loop.py, specifically the construct_message_history function."""
"""Tests for llm_loop.py, including history construction and empty-response paths."""
from unittest.mock import Mock
import pytest
from onyx.chat.llm_loop import _build_empty_llm_response_error
from onyx.chat.llm_loop import _try_fallback_tool_extraction
from onyx.chat.llm_loop import construct_message_history
from onyx.chat.llm_loop import EmptyLLMResponseError
from onyx.chat.models import ChatLoadedFile
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import ContextFileMetadata
@@ -13,6 +17,7 @@ from onyx.chat.models import LlmStepResult
from onyx.chat.models import ToolCallSimple
from onyx.configs.constants import MessageType
from onyx.file_store.models import ChatFileType
from onyx.llm.interfaces import LLMConfig
from onyx.llm.interfaces import ToolChoiceOptions
from onyx.server.query_and_chat.placement import Placement
from onyx.tools.models import ToolCallKickoff
@@ -1167,3 +1172,57 @@ class TestFallbackToolExtraction:
assert result is llm_step_result
assert attempted is False
class TestEmptyLlmResponseClassification:
def _make_llm(self, provider: str = "openai", model: str = "gpt-5.2") -> Mock:
llm = Mock()
llm.config = LLMConfig(
model_provider=provider,
model_name=model,
temperature=0.0,
max_input_tokens=4096,
)
return llm
def test_openai_empty_stream_is_classified_as_budget_exceeded(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.setattr("onyx.chat.llm_loop.is_true_openai_model", lambda *_: True)
err = _build_empty_llm_response_error(
llm=self._make_llm(),
llm_step_result=LlmStepResult(
reasoning=None,
answer=None,
tool_calls=None,
raw_answer=None,
),
tool_choice=ToolChoiceOptions.AUTO,
)
assert isinstance(err, EmptyLLMResponseError)
assert err.error_code == "BUDGET_EXCEEDED"
assert err.is_retryable is False
assert "quota" in err.client_error_msg.lower()
def test_reasoning_only_response_uses_generic_empty_response_error(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.setattr("onyx.chat.llm_loop.is_true_openai_model", lambda *_: True)
err = _build_empty_llm_response_error(
llm=self._make_llm(),
llm_step_result=LlmStepResult(
reasoning="scratchpad only",
answer=None,
tool_calls=None,
raw_answer=None,
),
tool_choice=ToolChoiceOptions.AUTO,
)
assert isinstance(err, EmptyLLMResponseError)
assert err.error_code == "EMPTY_LLM_RESPONSE"
assert err.is_retryable is True
assert "quota" not in err.client_error_msg.lower()

View File

@@ -0,0 +1,34 @@
from onyx.chat.process_message import remove_answer_citations
def test_remove_answer_citations_strips_http_markdown_citation() -> None:
answer = "The answer is Paris [[1]](https://example.com/doc)."
assert remove_answer_citations(answer) == "The answer is Paris."
def test_remove_answer_citations_strips_empty_markdown_citation() -> None:
answer = "The answer is Paris [[1]]()."
assert remove_answer_citations(answer) == "The answer is Paris."
def test_remove_answer_citations_strips_citation_with_parentheses_in_url() -> None:
answer = (
"The answer is Paris "
"[[1]](https://en.wikipedia.org/wiki/Function_(mathematics))."
)
assert remove_answer_citations(answer) == "The answer is Paris."
def test_remove_answer_citations_preserves_non_citation_markdown_links() -> None:
answer = (
"See [reference](https://example.com/Function_(mathematics)) "
"for context [[1]](https://en.wikipedia.org/wiki/Function_(mathematics))."
)
assert (
remove_answer_citations(answer)
== "See [reference](https://example.com/Function_(mathematics)) for context."
)

View File

@@ -3,7 +3,10 @@ from unittest.mock import Mock
import pytest
from onyx.chat import process_message
from onyx.chat.models import AnswerStream
from onyx.chat.models import StreamingError
from onyx.configs import app_configs
from onyx.server.query_and_chat.models import MessageResponseIDInfo
from onyx.server.query_and_chat.models import SendMessageRequest
@@ -35,3 +38,26 @@ def test_mock_llm_response_requires_integration_mode() -> None:
db_session=Mock(),
)
)
def test_gather_stream_returns_empty_answer_when_streaming_error_only() -> None:
packets: AnswerStream = iter(
[
MessageResponseIDInfo(
user_message_id=None,
reserved_assistant_message_id=42,
),
StreamingError(
error="OpenAI quota exceeded",
error_code="BUDGET_EXCEEDED",
is_retryable=False,
),
]
)
result = process_message.gather_stream(packets)
assert result.answer == ""
assert result.answer_citationless == ""
assert result.error_msg == "OpenAI quota exceeded"
assert result.message_id == 42

View File

@@ -0,0 +1,63 @@
"""
Unit test verifying that the upload API path sends tasks with expires=.
The upload_files_to_user_files_with_indexing function must include expires=
on every send_task call to prevent phantom task accumulation if the worker
is down or slow.
"""
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.models import UserFile
from onyx.db.projects import upload_files_to_user_files_with_indexing
def _make_mock_user_file() -> MagicMock:
uf = MagicMock(spec=UserFile)
uf.id = str(uuid4())
return uf
@patch("onyx.db.projects.get_current_tenant_id", return_value="test_tenant")
@patch("onyx.db.projects.create_user_files")
@patch(
"onyx.background.celery.versioned_apps.client.app",
new_callable=MagicMock,
)
def test_send_task_includes_expires(
mock_client_app: MagicMock,
mock_create: MagicMock,
mock_tenant: MagicMock, # noqa: ARG001
) -> None:
"""Every send_task call from the upload path must include expires=."""
user_files = [_make_mock_user_file(), _make_mock_user_file()]
mock_create.return_value = MagicMock(
user_files=user_files,
rejected_files=[],
id_to_temp_id={},
)
mock_user = MagicMock()
mock_db_session = MagicMock()
upload_files_to_user_files_with_indexing(
files=[],
project_id=None,
user=mock_user,
temp_id_map=None,
db_session=mock_db_session,
)
assert mock_client_app.send_task.call_count == len(user_files)
for call in mock_client_app.send_task.call_args_list:
assert call.args[0] == OnyxCeleryTask.PROCESS_SINGLE_USER_FILE
assert call.kwargs.get("queue") == OnyxCeleryQueues.USER_FILE_PROCESSING
assert (
call.kwargs.get("expires") == CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
), "send_task must include expires= to prevent phantom task accumulation"

View File

@@ -0,0 +1,45 @@
%PDF-1.3
%<25><><EFBFBD><EFBFBD>
1 0 obj
<<
/Producer (pypdf)
>>
endobj
2 0 obj
<<
/Type /Pages
/Count 1
/Kids [ 4 0 R ]
>>
endobj
3 0 obj
<<
/Type /Catalog
/Pages 2 0 R
>>
endobj
4 0 obj
<<
/Type /Page
/Resources <<
>>
/MediaBox [ 0.0 0.0 200 200 ]
/Parent 2 0 R
>>
endobj
xref
0 5
0000000000 65535 f
0000000015 00000 n
0000000054 00000 n
0000000113 00000 n
0000000162 00000 n
trailer
<<
/Size 5
/Root 3 0 R
/Info 1 0 R
>>
startxref
256
%%EOF

View File

@@ -0,0 +1,89 @@
%PDF-1.3
%<25><><EFBFBD><EFBFBD>
1 0 obj
<<
/Producer (pypdf)
>>
endobj
2 0 obj
<<
/Type /Pages
/Count 2
/Kids [ 4 0 R 6 0 R ]
>>
endobj
3 0 obj
<<
/Type /Catalog
/Pages 2 0 R
>>
endobj
4 0 obj
<<
/Type /Page
/Resources <<
/Font <<
/F1 <<
/Type /Font
/Subtype /Type1
/BaseFont /Helvetica
>>
>>
>>
/MediaBox [ 0.0 0.0 200 200 ]
/Contents 5 0 R
/Parent 2 0 R
>>
endobj
5 0 obj
<<
/Length 47
>>
stream
BT /F1 12 Tf 50 150 Td (Page one content) Tj ET
endstream
endobj
6 0 obj
<<
/Type /Page
/Resources <<
/Font <<
/F1 <<
/Type /Font
/Subtype /Type1
/BaseFont /Helvetica
>>
>>
>>
/MediaBox [ 0.0 0.0 200 200 ]
/Contents 7 0 R
/Parent 2 0 R
>>
endobj
7 0 obj
<<
/Length 47
>>
stream
BT /F1 12 Tf 50 150 Td (Page two content) Tj ET
endstream
endobj
xref
0 8
0000000000 65535 f
0000000015 00000 n
0000000054 00000 n
0000000119 00000 n
0000000168 00000 n
0000000349 00000 n
0000000446 00000 n
0000000627 00000 n
trailer
<<
/Size 8
/Root 3 0 R
/Info 1 0 R
>>
startxref
724
%%EOF

View File

@@ -0,0 +1,62 @@
%PDF-1.3
%<25><><EFBFBD><EFBFBD>
1 0 obj
<<
/Producer (pypdf)
>>
endobj
2 0 obj
<<
/Type /Pages
/Count 1
/Kids [ 4 0 R ]
>>
endobj
3 0 obj
<<
/Type /Catalog
/Pages 2 0 R
>>
endobj
4 0 obj
<<
/Type /Page
/Resources <<
/Font <<
/F1 <<
/Type /Font
/Subtype /Type1
/BaseFont /Helvetica
>>
>>
>>
/MediaBox [ 0.0 0.0 200 200 ]
/Contents 5 0 R
/Parent 2 0 R
>>
endobj
5 0 obj
<<
/Length 42
>>
stream
BT /F1 12 Tf 50 150 Td (Hello World) Tj ET
endstream
endobj
xref
0 6
0000000000 65535 f
0000000015 00000 n
0000000054 00000 n
0000000113 00000 n
0000000162 00000 n
0000000343 00000 n
trailer
<<
/Size 6
/Root 3 0 R
/Info 1 0 R
>>
startxref
435
%%EOF

Some files were not shown because too many files have changed in this diff Show More