Compare commits

..

61 Commits

Author SHA1 Message Date
Nik
c8e565fa75 fix(chat): fix B1/B2/P1 bugs in multi-model streaming + cleanup
B1 — Self-completion race: model finishes before GeneratorExit fires,
exits _run_model with drain_done=False, skips self-completion. Fix:
add completion_locks (one per model); disconnect else-branch claims
lock and calls llm_loop_completion_handle for already-succeeded models.

B2 — Stop-button saves wrong message for errored models: the stop loop
called llm_loop_completion_handle for all models including ones that
threw exceptions, persisting "stopped by user" for an errored model.
Fix: add model_errored flag; stop loop skips errored models.

P1 — Orphaned ChatMessage rows for errored models: reserved_messages
were never cleaned up when a model errored. Fix: delete via
db_session.get(ChatMessage, id) in all three exit paths (normal
completion, stop-button, disconnect).

Also: extract repeated orphan-cleanup into _delete_orphaned_message
nested helper; remove dead check_call_count variable in tests; rename
ctx→worker_context, _completion_done→completion_persisted; replace
functools.partial with captured-variable lambda; fix stale docstring
("bounded"→"unbounded"); add _CANCEL_POLL_INTERVAL_S named constant;
if/if/if→if/elif/elif; %-style logger calls throughout.

Tests: two new regression tests (B1 race, B2 stop-button errored model).
26 tests pass. mypy clean.
2026-04-01 08:17:50 -07:00
Nik
bab95d8bf0 fix(chat): remove duplicate drain_done declaration after rebase 2026-03-31 20:02:29 -07:00
Nik
eb7bc74e1b fix(chat): persist LLM response on HTTP disconnect via drain_done + worker self-completion
When the HTTP client disconnects, Starlette throws GeneratorExit into the
drain loop generator. The old code called executor.shutdown(wait=False) with
no completion handling, leaving the assistant DB message as the TERMINATED
placeholder forever (regressing test_send_message_disconnect_and_cleanup).

New design:
- drain_done (threading.Event) signals emitters to return immediately instead
  of blocking on queue.put — no retry loops, no daemon threads
- One-time queue drain in the else block releases any in-progress puts so
  workers exit within milliseconds
- Workers self-complete: after run_llm_loop returns, each worker checks
  drain_done.is_set() and, if true, opens its own DB session and calls
  llm_loop_completion_handle directly

Unit test updated to reflect the async self-completion semantics: the test
blocks the worker inside run_llm_loop until gen.close() sets drain_done,
then waits for completion_called inside the patch context (while mocks are
still active) to avoid calling the real get_session_with_current_tenant.
2026-03-31 20:02:29 -07:00
Nik
29da0aefb5 feat(chat): add multi-model parallel streaming (N=2-3 LLMs side-by-side)
Adds support for running 2-3 LLMs in parallel within a single chat turn,
with responses streamed interleaved to the frontend via the merged queue
infrastructure introduced in the preceding PR.

Backend changes
- process_message.py: restore llm_overrides param on build_chat_turn and
  _stream_chat_turn; restore is_multi branching for LLM setup, context
  window sizing, and message ID reservation; add _build_model_display_name
  and handle_multi_model_stream (public multi-model entrypoint)
- db/chat.py: add reserve_multi_model_message_ids (reserves N assistant
  message placeholders sharing the same parent), set_preferred_response
  (marks one response as the user's preferred), and extend
  translate_db_message_to_chat_message_detail with preferred_response_id
  and model_display_name fields
- chat_backend.py: route requests with llm_overrides >1 through
  handle_multi_model_stream; reject non-streaming multi-model requests with
  OnyxError; add /set-preferred-response endpoint

Tests
- test_multi_model_streaming.py: unit tests for _run_models drain loop
  (arrival-order yield, error isolation, cancellation), handle_multi_model_stream
  validation guards, and N=1 backwards-compatibility
2026-03-31 20:01:21 -07:00
Nik
6c86301c51 fix(chat): remove bounded queue and packet drops — match old behavior
Old code used queue.Queue() (unbounded, blocking put). New code introduced
queue.Queue(maxsize=100) + put(timeout=3.0) + silent drop on queue.Full —
a regression in all three callsites:

- Emitter.emit(): data packets silently dropped on queue full
- _run_model exception path: model errors silently lost
- _run_model finally (_MODEL_DONE): if dropped, drain loop hangs forever
  (models_remaining never reaches 0)

Fix: remove maxsize, remove all timeout= arguments, remove all
except queue.Full handlers. The drain_done early-return in emit() is the
correct disconnect mechanism; queue backpressure is not needed.

Also adds _completion_done: bool type annotation and fixes the queue drain
comment (no longer unblocking timed-out puts — just releasing memory).
2026-03-31 20:00:46 -07:00
Nik
631146f48f fix(chat): use model_succeeded instead of check_is_connected on self-completion
On HTTP disconnect, check_is_connected() returns False, causing
llm_loop_completion_handle to treat a completed response as
user-cancelled and append "Generation was stopped by the user."
Use lambda: model_succeeded[model_idx] (always True here) instead,
matching the cancellation path's functools.partial(bool, model_succeeded[i]).
2026-03-31 18:42:04 -07:00
Nik
f327278506 fix(chat): persist LLM response on HTTP disconnect via drain_done + worker self-completion
When the HTTP client disconnects, Starlette throws GeneratorExit into the
drain loop generator. The old else branch just called executor.shutdown(wait=False)
with no completion handling, leaving the assistant DB message as the TERMINATED
placeholder forever (regressing test_send_message_disconnect_and_cleanup).

New design:
- drain_done (threading.Event) signals emitters to return immediately instead
  of blocking on queue.put — no retry loops, no daemon threads
- One-time queue drain in the else block releases any in-progress puts so
  workers exit within milliseconds
- Workers self-complete: after run_llm_loop returns, each worker checks
  drain_done.is_set() and, if true, opens its own DB session and calls
  llm_loop_completion_handle directly
2026-03-31 18:14:50 -07:00
Nik
c7cc439862 fix(emitter): address Greptile P1/P2/P3 and Queue typing
- P1: executor.shutdown(wait=False) on early exit — don't block the
  server thread waiting for LLM workers; they will hit queue.Full
  timeouts and exit on their own (matches old run_chat_loop behavior)
- P2: wrap db_session.commit() in try/finally in build_chat_turn —
  reset processing status before propagating if commit fails, so the
  chat session isn't stuck at "processing" permanently
- P3: fix inaccurate comment "All worker threads have exited" — workers
  may still be closing their own DB sessions at that point; clarify
  that only the main-thread db_session is safe to use
- Queue[Any] → Queue[tuple[int, Packet | Exception | object]] in Emitter
2026-03-31 17:02:46 -07:00
Nik
3365a369e2 fix(review): address Greptile comments
- Add owner to bare TODO comment
- Restore placement field assertions weakened by Emitter refactor
2026-03-31 12:49:09 -07:00
Nik
470bda3fb5 refactor(chat): elegance pass on PR1 changed files
process_message.py:
- Fix `skip_clarification` field in ChatTurnSetup: inline comment inside
  the type annotation → separate `#` comment on the line above the field
- Flatten `model_tools` via list comprehension instead of manual extend loop
- `forced_tool_id` membership test: list → set comprehension (O(1) lookup)
- Trim `_run_model` inner-function docstring — private closure doesn't need
  10-line Args block
- Remove redundant inline param comments from `_stream_chat_turn` and
  `handle_stream_message_objects` where the docstring Args section already
  documents them
- Strip duplicate Args/Returns from `handle_stream_message_objects` docstring
  — it delegates entirely to `_stream_chat_turn`

emitter.py:
- Widen `merged_queue` annotation to `Queue[Any]`: Queue is invariant so
  `Queue[tuple[int, Packet]]` can't be passed a `Queue[tuple[int, Packet |
  Exception | object]]`; the emitter is a write-only producer and doesn't
  care what else lives on the queue
2026-03-31 12:16:38 -07:00
Nik
13f511e209 refactor(emitter): clean up string annotation and use model_copy
- Fix `"Queue"` forward-reference annotation → `Queue[tuple[int, Packet]]`
  (Queue is already imported, the string was unnecessary)
- Replace manual Placement field copy with `base.model_copy(update={...})`
- Remove redundant `key` variable (was just `self._model_idx`)
- Tighten docstring
2026-03-31 11:44:28 -07:00
Nik
c5e8ba1eab refactor(chat): replace bus-polling emitter with merged-queue streaming; fix 429 hang
Switch Emitter from a per-model event bus + polling thread to a single
bounded queue shared across all models.  Each emit() call puts directly onto
the queue; the drain loop in _run_models yields packets in arrival order.

Key changes
- emitter.py: remove Bus, get_default_emitter(); add Emitter(merged_queue, model_idx)
- chat_state.py: remove run_chat_loop_with_state_containers (113-line bus-poll loop)
- process_message.py: add ChatTurnSetup dataclass and build_chat_turn(); rewrite
  _stream_chat_turn + _run_models around the merged queue; single-model (N=1)
  path is fully backwards-compatible
- placement.py, override_models.py: add docstrings; LLMOverride gains display_name
- research_agent.py, custom_tool.py: update Emitter call sites
- test_emitter.py: new unit tests for queue routing, model_index tagging, placement

Frontend 429 fix
- lib.tsx: parse response body for human-readable detail on non-2xx responses
  instead of "HTTP error! status: 429"
- useChatController.ts: surface stack.error after the FIFO drain loop exits so
  the catch block replaces the thinking placeholder with an error message
2026-03-30 22:18:48 -07:00
dependabot[bot]
0b9d154a73 chore(deps): bump runs-on/cache from 50350ad4242587b6c8c2baa2e740b1bc11285ff4 to a5f51d6f3fece787d03b7b4e981c82538a0654ed (#9763)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-30 13:54:43 -07:00
dependabot[bot]
6e65e55bf5 chore(deps): bump actions/cache from 5.0.3 to 5.0.4 (#9765)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-30 13:46:53 -07:00
Raunak Bhagat
3f9e208759 feat(opal): SelectCard + CardHeaderLayout (#9760) 2026-03-30 19:54:54 +00:00
dependabot[bot]
fb8edda14a chore(deps): bump pygments from 2.19.2 to 2.20.0 (#9757)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-30 18:30:18 +00:00
Jamison Lahman
58decd8a6b chore(gha): prefer ci-protected env (#9728) 2026-03-30 17:20:54 +00:00
Danelegend
e97204d9cc feat(indexing): Batch chunks during doc processing (#9468) 2026-03-30 11:49:36 +00:00
Danelegend
44ab02c94f refactor(indexing): Refactor indexing vector db abstraction (#9653) 2026-03-30 09:57:16 +00:00
Danelegend
a98cc30f25 refactor(indexing): Change adapters to support iterables (#9469) 2026-03-30 01:43:10 +00:00
Danelegend
a709dcb8fa feat(indexing): Max chunk processing (#9400) 2026-03-30 00:10:24 +00:00
Raunak Bhagat
a3dfe6aa1b refactor(opal): unify Interactive color system (#9717) 2026-03-28 00:40:23 +00:00
Nikolas Garza
23e4d55fb1 perf(swr): convert raw-fetch hooks to SWR to eliminate duplicate requests (#9694) 2026-03-28 00:26:20 +00:00
Jamison Lahman
470cc85f83 feat(cli): onyx-cli serve over SSH (#9726) 2026-03-27 23:46:14 +00:00
Justin Tahara
64d9be5a41 fix(openpyxl): Colors must be aRGB hex values (#9727) 2026-03-27 23:14:36 +00:00
roshan
71a5b469b0 feat(widget): add citation badges to chat widget (#9714) 2026-03-27 22:39:46 +00:00
Evan Lohn
462eb0697f fix: Anthropic litellm thinking workaround (#9713) 2026-03-27 21:03:05 +00:00
dependabot[bot]
b708dc8796 chore(deps): bump langchain-core from 1.2.11 to 1.2.22 (#9720)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-27 20:50:19 +00:00
dependabot[bot]
c9e2c32f55 chore(deps): bump cryptography from 46.0.5 to 46.0.6 (#9721)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-27 20:48:59 +00:00
Jamison Lahman
d725df62e7 feat(cli): --version and validate-config warn if backend version is incompatible (#9715) 2026-03-27 13:13:16 -07:00
Jamison Lahman
d1460972b6 fix(cli): onyx-cli --version interpolation (#9712) 2026-03-27 19:22:31 +00:00
Jamison Lahman
706872f0b7 chore(deps): upgrade go deps (#9711) 2026-03-27 12:24:25 -07:00
Jamison Lahman
ed3856be2b chore(release): build all CLI wheels before publishing (#9710) 2026-03-27 19:04:02 +00:00
Jamison Lahman
6326c7f0b9 chore(gha): fix git error after helm release migration to alpine base image (#9709) 2026-03-27 11:21:34 -07:00
Jamison Lahman
40420fc4e6 chore(gha): helm release upstream nits (#9708) 2026-03-27 11:10:41 -07:00
Nikolas Garza
1a2b6a66cc fix(celery): use broker connection pool to prevent Redis connection leak (#9682) 2026-03-27 17:53:49 +00:00
Jamison Lahman
d1b1529ccf chore(gha): fix helm release after image update (#9707) 2026-03-27 10:37:43 -07:00
Bo-Onyx
fedd9c76e5 feat(hook): admin page create or edit hook (#9690) 2026-03-27 17:10:14 +00:00
Jamison Lahman
0b34b40b79 chore(gha): pin helm release docker image (#9706) 2026-03-27 10:16:41 -07:00
Yuhong Sun
fe82ddb1b9 Update README.md (#9703) 2026-03-27 10:03:56 -07:00
Jamison Lahman
32d3d70525 chore(playwright): deflake settings_pages.spec.ts (#9684) 2026-03-27 15:54:23 +00:00
Jamison Lahman
40b9e10890 chore(devtools): upgrade ods: 0.7.1->0.7.2 (#9701) 2026-03-27 08:17:42 -07:00
dependabot[bot]
e21b204b8a chore(deps): bump brace-expansion in /backend/onyx/server/features/build/sandbox/kubernetes/docker/templates/outputs/web (#9698)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-27 08:10:15 -07:00
Jamison Lahman
2f672b3a4f fix(fe): Popover content doesnt overflow on small screens (#9612) 2026-03-27 08:07:52 -07:00
Nikolas Garza
cf19d0df4f feat(helm): add Prometheus metrics ports and Services for celery workers (#9630) 2026-03-27 08:03:48 +00:00
Danelegend
86a6a4c134 refactor(indexing): Vespa & Opensearch index function use Iterable (#9384) 2026-03-27 04:36:59 +00:00
SubashMohan
146b5449d2 feat: configurable file upload size and token limits via admin settings (#9232) 2026-03-27 04:23:16 +00:00
Jamison Lahman
b66991b5c5 chore(devtools): ods trace (#9688)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-03-27 03:56:38 +00:00
dependabot[bot]
9cb76dc027 chore(deps-dev): bump picomatch from 2.3.1 to 2.3.2 in /web (#9691)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-27 02:22:22 +00:00
dependabot[bot]
f66891d19e chore(deps-dev): bump handlebars from 4.7.8 to 4.7.9 in /web (#9689)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-27 01:41:29 +00:00
Nikolas Garza
c07c952ad5 chore(greptile): add nginx routing rule for non-api backend routes (#9687) 2026-03-27 00:34:15 +00:00
Nikolas Garza
be7f40a28a fix(nginx): route /scim/* to api_server (#9686) 2026-03-26 17:21:57 -07:00
Evan Lohn
26f941b5da perf: perm sync start time (#9685) 2026-03-27 00:07:53 +00:00
Jamison Lahman
b9e84c42a8 feat(providers): allow deleting all types of providers (#9625)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-03-26 15:20:56 -07:00
Bo-Onyx
0a1df52c2f feat(hook): Hook Form Modal Polish. (#9683) 2026-03-26 22:12:33 +00:00
Nikolas Garza
306b0d452f fix(billing): retry claimLicense up to 3x after Stripe checkout return (#9669) 2026-03-26 21:06:19 +00:00
Justin Tahara
5fdb34ba8e feat(llm): add Bifrost gateway frontend modal and provider registration (#9617) 2026-03-26 20:50:25 +00:00
Jamison Lahman
2d066631e3 fix(voice): dont soft-delete providers (#9679) 2026-03-26 19:26:32 +00:00
Evan Lohn
5c84f6c61b fix(jira): large batches fail json decode (#9677) 2026-03-26 18:53:37 +00:00
Nikolas Garza
899179d4b6 fix(api-key): clarify upgrade message for trial accounts (#9678) 2026-03-26 18:32:41 +00:00
Bo-Onyx
80d6bafc74 feat(hook): Hook connect/manage modal (#9645) 2026-03-26 18:16:33 +00:00
447 changed files with 13671 additions and 3085 deletions

View File

@@ -47,7 +47,8 @@ jobs:
done
- name: Publish Helm charts to gh-pages
uses: stefanprodan/helm-gh-pages@0ad2bb377311d61ac04ad9eb6f252fb68e207260 # ratchet:stefanprodan/helm-gh-pages@v1.7.0
# NOTE: HEAD of https://github.com/stefanprodan/helm-gh-pages/pull/43
uses: stefanprodan/helm-gh-pages@ad32ad3b8720abfeaac83532fd1e9bdfca5bbe27 # zizmor: ignore[impostor-commit]
with:
token: ${{ secrets.GITHUB_TOKEN }}
charts_dir: deployment/helm/charts

View File

@@ -35,6 +35,7 @@ jobs:
needs: [provider-chat-test]
if: failure() && github.event_name == 'schedule'
runs-on: ubuntu-slim
environment: ci-protected
timeout-minutes: 5
steps:
- name: Checkout

View File

@@ -183,6 +183,7 @@ jobs:
- cherry-pick-to-latest-release
if: needs.resolve-cherry-pick-request.outputs.should_cherrypick == 'true' && needs.resolve-cherry-pick-request.result == 'success' && needs.cherry-pick-to-latest-release.result == 'success'
runs-on: ubuntu-slim
environment: ci-protected
timeout-minutes: 10
steps:
- name: Checkout
@@ -232,6 +233,7 @@ jobs:
- cherry-pick-to-latest-release
if: always() && needs.resolve-cherry-pick-request.outputs.should_cherrypick == 'true' && (needs.resolve-cherry-pick-request.result == 'failure' || needs.cherry-pick-to-latest-release.result == 'failure')
runs-on: ubuntu-slim
environment: ci-protected
timeout-minutes: 10
steps:
- name: Checkout

View File

@@ -63,7 +63,7 @@ jobs:
targets: ${{ matrix.target }}
- name: Cache Cargo registry and build
uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # zizmor: ignore[cache-poisoning]
uses: actions/cache@668228422ae6a00e4ad889ee87cd7109ec5666a7 # zizmor: ignore[cache-poisoning]
with:
path: |
~/.cargo/bin/

View File

@@ -284,7 +284,7 @@ jobs:
- name: Cache playwright cache
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
uses: runs-on/cache@a5f51d6f3fece787d03b7b4e981c82538a0654ed # ratchet:runs-on/cache@v4
with:
path: ~/.cache/ms-playwright
key: ${{ runner.os }}-playwright-npm-${{ hashFiles('web/package-lock.json') }}
@@ -626,7 +626,7 @@ jobs:
- name: Cache playwright cache
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
uses: runs-on/cache@a5f51d6f3fece787d03b7b4e981c82538a0654ed # ratchet:runs-on/cache@v4
with:
path: ~/.cache/ms-playwright
key: ${{ runner.os }}-playwright-npm-${{ hashFiles('web/package-lock.json') }}

View File

@@ -56,7 +56,7 @@ jobs:
- name: Cache mypy cache
if: ${{ vars.DISABLE_MYPY_CACHE != 'true' }}
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
uses: runs-on/cache@a5f51d6f3fece787d03b7b4e981c82538a0654ed # ratchet:runs-on/cache@v4
with:
path: .mypy_cache
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'pyproject.toml') }}

View File

@@ -31,6 +31,7 @@ jobs:
- runner=4cpu-linux-arm64
- "run-id=${{ github.run_id }}-model-check"
- "extras=ecr-cache"
environment: ci-protected
timeout-minutes: 45
env:

View File

@@ -15,6 +15,7 @@ permissions:
jobs:
Deploy-Preview:
runs-on: ubuntu-latest
environment: ci-protected
timeout-minutes: 30
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd

View File

@@ -13,15 +13,6 @@ jobs:
permissions:
id-token: write
timeout-minutes: 10
strategy:
matrix:
os-arch:
- { goos: "linux", goarch: "amd64" }
- { goos: "linux", goarch: "arm64" }
- { goos: "windows", goarch: "amd64" }
- { goos: "windows", goarch: "arm64" }
- { goos: "darwin", goarch: "amd64" }
- { goos: "darwin", goarch: "arm64" }
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
@@ -31,9 +22,11 @@ jobs:
enable-cache: false
version: "0.9.9"
- run: |
GOOS="${{ matrix.os-arch.goos }}" \
GOARCH="${{ matrix.os-arch.goarch }}" \
uv build --wheel
for goos in linux windows darwin; do
for goarch in amd64 arm64; do
GOOS="$goos" GOARCH="$goarch" uv build --wheel
done
done
working-directory: cli
- run: uv publish
working-directory: cli

View File

@@ -25,6 +25,7 @@ permissions:
jobs:
Deploy-Storybook:
runs-on: ubuntu-latest
environment: ci-protected
timeout-minutes: 30
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v4
@@ -54,6 +55,7 @@ jobs:
needs: Deploy-Storybook
if: always() && needs.Deploy-Storybook.result == 'failure'
runs-on: ubuntu-latest
environment: ci-protected
timeout-minutes: 10
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v4

View File

@@ -9,6 +9,7 @@ on:
jobs:
sync-foss:
runs-on: ubuntu-latest
environment: ci-protected
timeout-minutes: 45
permissions:
contents: read

View File

@@ -11,6 +11,7 @@ permissions:
jobs:
create-and-push-tag:
runs-on: ubuntu-slim
environment: ci-protected
timeout-minutes: 45
steps:

View File

@@ -24,6 +24,16 @@ When hardcoding a boolean variable to a constant value, remove the variable enti
Code changes must consider both multi-tenant and single-tenant deployments. In multi-tenant mode, preserve tenant isolation, ensure tenant context is propagated correctly, and avoid assumptions that only hold for a single shared schema or globally shared state. In single-tenant mode, avoid introducing unnecessary tenant-specific requirements or cloud-only control-plane dependencies.
## Nginx Routing — New Backend Routes
Whenever a new backend route is added that does NOT start with `/api`, it must also be explicitly added to ALL nginx configs:
- `deployment/helm/charts/onyx/templates/nginx-conf.yaml` (Helm/k8s)
- `deployment/data/nginx/app.conf.template` (docker-compose dev)
- `deployment/data/nginx/app.conf.template.prod` (docker-compose prod)
- `deployment/data/nginx/app.conf.template.no-letsencrypt` (docker-compose no-letsencrypt)
Routes not starting with `/api` are not caught by the existing `^/(api|openapi\.json)` location block and will fall through to `location /`, which proxies to the Next.js web server and returns an HTML 404. The new location block must be placed before the `/api` block. Examples of routes that need this treatment: `/scim`, `/mcp`.
## Full vs Lite Deployments
Code changes must consider both regular Onyx deployments and Onyx lite deployments. Lite deployments disable the vector DB, Redis, model servers, and background workers by default, use PostgreSQL-backed cache/auth/file storage, and rely on the API server to handle background work. Do not assume those services are available unless the code path is explicitly limited to full deployments.

View File

@@ -122,7 +122,7 @@ repos:
rev: 5d1e709b7be35cb2025444e19de266b056b7b7ee # frozen: v2.10.1
hooks:
- id: golangci-lint
language_version: "1.26.0"
language_version: "1.26.1"
entry: bash -c "find . -name go.mod -not -path './.venv/*' -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
- repo: https://github.com/astral-sh/ruff-pre-commit

View File

@@ -35,7 +35,7 @@ Onyx comes loaded with advanced features like Agents, Web Search, RAG, MCP, Deep
> [!TIP]
> Run Onyx with one command (or see deployment section below):
> ```
> curl -fsSL https://raw.githubusercontent.com/onyx-dot-app/onyx/main/deployment/docker_compose/install.sh > install.sh && chmod +x install.sh && ./install.sh
> curl -fsSL https://onyx.app/install_onyx.sh | bash
> ```
****

View File

@@ -0,0 +1,35 @@
"""remove voice_provider deleted column
Revision ID: 1d78c0ca7853
Revises: a3f8b2c1d4e5
Create Date: 2026-03-26 11:30:53.883127
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "1d78c0ca7853"
down_revision = "a3f8b2c1d4e5"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Hard-delete any soft-deleted rows before dropping the column
op.execute("DELETE FROM voice_provider WHERE deleted = true")
op.drop_column("voice_provider", "deleted")
def downgrade() -> None:
op.add_column(
"voice_provider",
sa.Column(
"deleted",
sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
),
)

View File

@@ -28,6 +28,7 @@ from onyx.access.models import DocExternalAccess
from onyx.access.models import ElementExternalAccess
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
@@ -187,7 +188,6 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
# (which lives on a different db number)
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
@@ -227,6 +227,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# or be currently executing
try:
r_celery = celery_get_broker_client(self.app)
validate_permission_sync_fences(
tenant_id, r, r_replica, r_celery, lock_beat
)
@@ -473,6 +474,8 @@ def connector_permission_sync_generator_task(
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
eager_load_connector=True,
eager_load_credential=True,
)
if cc_pair is None:
raise ValueError(

View File

@@ -29,6 +29,7 @@ from ee.onyx.external_permissions.sync_params import (
from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.background.error_logging import emit_background_error
@@ -162,7 +163,6 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
# (which lives on a different db number)
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
@@ -221,6 +221,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# or be currently executing
try:
r_celery = celery_get_broker_client(self.app)
validate_external_group_sync_fences(
tenant_id, self.app, r, r_replica, r_celery, lock_beat
)

View File

@@ -8,6 +8,7 @@ from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import HierarchyNode
from onyx.connectors.slack.connector import get_channels
from onyx.connectors.slack.connector import make_paginated_slack_api_call
@@ -105,9 +106,11 @@ def _get_slack_document_access(
slack_connector: SlackConnector,
channel_permissions: dict[str, ExternalAccess], # noqa: ARG001
callback: IndexingHeartbeatInterface | None,
indexing_start: SecondsSinceUnixEpoch | None = None,
) -> Generator[DocExternalAccess, None, None]:
slim_doc_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
callback=callback
callback=callback,
start=indexing_start,
)
for doc_metadata_batch in slim_doc_generator:
@@ -180,9 +183,15 @@ def slack_doc_sync(
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
slack_connector.set_credentials_provider(provider)
indexing_start_ts: SecondsSinceUnixEpoch | None = (
cc_pair.connector.indexing_start.timestamp()
if cc_pair.connector.indexing_start is not None
else None
)
yield from _get_slack_document_access(
slack_connector,
slack_connector=slack_connector,
channel_permissions=channel_permissions,
callback=callback,
indexing_start=indexing_start_ts,
)

View File

@@ -6,6 +6,7 @@ from onyx.access.models import ElementExternalAccess
from onyx.access.models import ExternalAccess
from onyx.access.models import NodeExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import HierarchyNode
from onyx.db.models import ConnectorCredentialPair
@@ -40,10 +41,19 @@ def generic_doc_sync(
logger.info(f"Starting {doc_source} doc sync for CC Pair ID: {cc_pair.id}")
indexing_start: SecondsSinceUnixEpoch | None = (
cc_pair.connector.indexing_start.timestamp()
if cc_pair.connector.indexing_start is not None
else None
)
newly_fetched_doc_ids: set[str] = set()
logger.info(f"Fetching all slim documents from {doc_source}")
for doc_batch in slim_connector.retrieve_all_slim_docs_perm_sync(callback=callback):
for doc_batch in slim_connector.retrieve_all_slim_docs_perm_sync(
start=indexing_start,
callback=callback,
):
logger.info(f"Got {len(doc_batch)} slim documents from {doc_source}")
if callback:

View File

@@ -1,5 +1,6 @@
# These are helper objects for tracking the keys we need to write in redis
import json
import threading
from typing import Any
from typing import cast
@@ -7,7 +8,59 @@ from celery import Celery
from redis import Redis
from onyx.background.celery.configs.base import CELERY_SEPARATOR
from onyx.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
_broker_client: Redis | None = None
_broker_url: str | None = None
_broker_client_lock = threading.Lock()
def celery_get_broker_client(app: Celery) -> Redis:
"""Return a shared Redis client connected to the Celery broker DB.
Uses a module-level singleton so all tasks on a worker share one
connection instead of creating a new one per call. The client
connects directly to the broker Redis DB (parsed from the broker URL).
Thread-safe via lock — safe for use in Celery thread-pool workers.
Usage:
r_celery = celery_get_broker_client(self.app)
length = celery_get_queue_length(queue, r_celery)
"""
global _broker_client, _broker_url
with _broker_client_lock:
url = app.conf.broker_url
if _broker_client is not None and _broker_url == url:
try:
_broker_client.ping()
return _broker_client
except Exception:
try:
_broker_client.close()
except Exception:
pass
_broker_client = None
elif _broker_client is not None:
try:
_broker_client.close()
except Exception:
pass
_broker_client = None
_broker_url = url
_broker_client = Redis.from_url(
url,
decode_responses=False,
health_check_interval=REDIS_HEALTH_CHECK_INTERVAL,
socket_keepalive=True,
socket_keepalive_options=REDIS_SOCKET_KEEPALIVE_OPTIONS,
retry_on_timeout=True,
)
return _broker_client
def celery_get_unacked_length(r: Redis) -> int:

View File

@@ -14,6 +14,7 @@ from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.configs.app_configs import JOB_TIMEOUT
@@ -132,7 +133,6 @@ def revoke_tasks_blocking_deletion(
def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | None:
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
@@ -149,6 +149,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | N
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES):
# clear fences that don't have associated celery tasks in progress
try:
r_celery = celery_get_broker_client(self.app)
validate_connector_deletion_fences(
tenant_id, r, r_replica, r_celery, lock_beat
)

View File

@@ -22,6 +22,7 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.memory_monitoring import emit_process_memory
@@ -449,7 +450,7 @@ def check_indexing_completion(
):
# Check if the task exists in the celery queue
# This handles the case where Redis dies after task creation but before task execution
redis_celery = task.app.broker_connection().channel().client # type: ignore
redis_celery = celery_get_broker_client(task.app)
task_exists = celery_find_task(
attempt.celery_task_id,
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,

View File

@@ -1,6 +1,5 @@
import json
import time
from collections.abc import Callable
from datetime import timedelta
from itertools import islice
from typing import Any
@@ -19,6 +18,7 @@ from sqlalchemy import text
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.memory_monitoring import emit_process_memory
@@ -698,31 +698,27 @@ def monitor_background_processes(self: Task, *, tenant_id: str) -> None:
return None
try:
# Get Redis client for Celery broker
redis_celery = self.app.broker_connection().channel().client # type: ignore
redis_std = get_redis_client()
# Define metric collection functions and their dependencies
metric_functions: list[Callable[[], list[Metric]]] = [
lambda: _collect_queue_metrics(redis_celery),
lambda: _collect_connector_metrics(db_session, redis_std),
lambda: _collect_sync_metrics(db_session, redis_std),
]
# Collect queue metrics with broker connection
r_celery = celery_get_broker_client(self.app)
queue_metrics = _collect_queue_metrics(r_celery)
# Collect and log each metric
# Collect remaining metrics (no broker connection needed)
with get_session_with_current_tenant() as db_session:
for metric_fn in metric_functions:
metrics = metric_fn()
for metric in metrics:
# double check to make sure we aren't double-emitting metrics
if metric.key is None or not _has_metric_been_emitted(
redis_std, metric.key
):
metric.log()
metric.emit(tenant_id)
all_metrics: list[Metric] = queue_metrics
all_metrics.extend(_collect_connector_metrics(db_session, redis_std))
all_metrics.extend(_collect_sync_metrics(db_session, redis_std))
if metric.key is not None:
_mark_metric_as_emitted(redis_std, metric.key)
for metric in all_metrics:
if metric.key is None or not _has_metric_been_emitted(
redis_std, metric.key
):
metric.log()
metric.emit(tenant_id)
if metric.key is not None:
_mark_metric_as_emitted(redis_std, metric.key)
task_logger.info("Successfully collected background metrics")
except SoftTimeLimitExceeded:
@@ -890,7 +886,7 @@ def monitor_celery_queues_helper(
) -> None:
"""A task to monitor all celery queue lengths."""
r_celery = task.app.broker_connection().channel().client # type: ignore
r_celery = celery_get_broker_client(task.app)
n_celery = celery_get_queue_length(OnyxCeleryQueues.PRIMARY, r_celery)
n_docfetching = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
@@ -1080,7 +1076,7 @@ def cloud_monitor_celery_pidbox(
num_deleted = 0
MAX_PIDBOX_IDLE = 24 * 3600 # 1 day in seconds
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
r_celery = celery_get_broker_client(self.app)
for key in r_celery.scan_iter("*.reply.celery.pidbox"):
key_bytes = cast(bytes, key)
key_str = key_bytes.decode("utf-8")

View File

@@ -17,6 +17,7 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
@@ -203,7 +204,6 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK,
@@ -261,6 +261,7 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# or be currently executing
try:
r_celery = celery_get_broker_client(self.app)
validate_pruning_fences(tenant_id, r, r_replica, r_celery, lock_beat)
except Exception:
task_logger.exception("Exception while validating pruning fences")

View File

@@ -16,6 +16,7 @@ from sqlalchemy.orm import Session
from onyx.access.access import build_access_for_user_files
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
@@ -105,7 +106,7 @@ def _user_file_delete_queued_key(user_file_id: str | UUID) -> str:
def get_user_file_project_sync_queue_depth(celery_app: Celery) -> int:
redis_celery: Redis = celery_app.broker_connection().channel().client # type: ignore
redis_celery = celery_get_broker_client(celery_app)
return celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC, redis_celery
)
@@ -238,7 +239,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
skipped_guard = 0
try:
# --- Protection 1: queue depth backpressure ---
r_celery = self.app.broker_connection().channel().client # type: ignore
r_celery = celery_get_broker_client(self.app)
queue_len = celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
)
@@ -591,7 +592,7 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
# --- Protection 1: queue depth backpressure ---
# NOTE: must use the broker's Redis client (not redis_client) because
# Celery queues live on a separate Redis DB with CELERY_SEPARATOR keys.
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
r_celery = celery_get_broker_client(self.app)
queue_len = celery_get_queue_length(OnyxCeleryQueues.USER_FILE_DELETE, r_celery)
if queue_len > USER_FILE_DELETE_MAX_QUEUE_DEPTH:
task_logger.warning(

View File

@@ -1,19 +1,8 @@
import threading
import time
from collections.abc import Callable
from collections.abc import Generator
from queue import Empty
from onyx.chat.citation_processor import CitationMapping
from onyx.chat.emitter import Emitter
from onyx.context.search.models import SearchDoc
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PacketException
from onyx.tools.models import ToolCallInfo
from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import wait_on_background
# Type alias for search doc deduplication key
# Simple key: just document_id (str)
@@ -159,114 +148,3 @@ class ChatStateContainer:
"""Thread-safe getter for emitted citations (returns a copy)."""
with self._lock:
return self._emitted_citations.copy()
def run_chat_loop_with_state_containers(
chat_loop_func: Callable[[Emitter, ChatStateContainer], None],
completion_callback: Callable[[ChatStateContainer], None],
is_connected: Callable[[], bool],
emitter: Emitter,
state_container: ChatStateContainer,
) -> Generator[Packet, None]:
"""
Explicit wrapper function that runs a function in a background thread
with event streaming capabilities.
The wrapped function should accept emitter as first arg and use it to emit
Packet objects. This wrapper polls every 300ms to check if stop signal is set.
Args:
func: The function to wrap (should accept emitter and state_container as first and second args)
completion_callback: Callback function to call when the function completes
emitter: Emitter instance for sending packets
state_container: ChatStateContainer instance for accumulating state
is_connected: Callable that returns False when stop signal is set
Usage:
packets = run_chat_loop_with_state_containers(
my_func,
completion_callback=completion_callback,
emitter=emitter,
state_container=state_container,
is_connected=check_func,
)
for packet in packets:
# Process packets
pass
"""
def run_with_exception_capture() -> None:
try:
chat_loop_func(emitter, state_container)
except Exception as e:
# If execution fails, emit an exception packet
emitter.emit(
Packet(
placement=Placement(turn_index=0),
obj=PacketException(type="error", exception=e),
)
)
# Run the function in a background thread
thread = run_in_background(run_with_exception_capture)
pkt: Packet | None = None
last_turn_index = 0 # Track the highest turn_index seen for stop packet
last_cancel_check = time.monotonic()
cancel_check_interval = 0.3 # Check for cancellation every 300ms
try:
while True:
# Poll queue with 300ms timeout for natural stop signal checking
# the 300ms timeout is to avoid busy-waiting and to allow the stop signal to be checked regularly
try:
pkt = emitter.bus.get(timeout=0.3)
except Empty:
if not is_connected():
# Stop signal detected
yield Packet(
placement=Placement(turn_index=last_turn_index + 1),
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
)
break
last_cancel_check = time.monotonic()
continue
if pkt is not None:
# Track the highest turn_index for the stop packet
if pkt.placement and pkt.placement.turn_index > last_turn_index:
last_turn_index = pkt.placement.turn_index
if isinstance(pkt.obj, OverallStop):
yield pkt
break
elif isinstance(pkt.obj, PacketException):
raise pkt.obj.exception
else:
yield pkt
# Check for cancellation periodically even when packets are flowing
# This ensures stop signal is checked during active streaming
current_time = time.monotonic()
if current_time - last_cancel_check >= cancel_check_interval:
if not is_connected():
# Stop signal detected during streaming
yield Packet(
placement=Placement(turn_index=last_turn_index + 1),
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
)
break
last_cancel_check = current_time
finally:
# Wait for thread to complete on normal exit to propagate exceptions and ensure cleanup.
# Skip waiting if user disconnected to exit quickly.
if is_connected():
wait_on_background(thread)
try:
completion_callback(state_container)
except Exception as e:
emitter.emit(
Packet(
placement=Placement(turn_index=last_turn_index + 1),
obj=PacketException(type="error", exception=e),
)
)

View File

@@ -1,19 +1,40 @@
import threading
from queue import Queue
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import Packet
class Emitter:
"""Use this inside tools to emit arbitrary UI progress."""
"""Routes packets from LLM/tool execution to the ``_run_models`` drain loop.
def __init__(self, bus: Queue):
self.bus = bus
Tags every packet with ``model_index`` and places it on ``merged_queue``
as a ``(model_idx, packet)`` tuple for ordered consumption downstream.
Args:
merged_queue: Shared queue owned by ``_run_models``.
model_idx: Index embedded in packet placements (``0`` for N=1 runs).
drain_done: Optional event set by ``_run_models`` when the drain loop
exits early (e.g. HTTP disconnect). When set, ``emit`` returns
immediately so worker threads can exit fast.
"""
def __init__(
self,
merged_queue: Queue[tuple[int, Packet | Exception | object]],
model_idx: int = 0,
drain_done: threading.Event | None = None,
) -> None:
self._model_idx = model_idx
self._merged_queue = merged_queue
self._drain_done = drain_done
def emit(self, packet: Packet) -> None:
self.bus.put(packet) # Thread-safe
def get_default_emitter() -> Emitter:
bus: Queue[Packet] = Queue()
emitter = Emitter(bus)
return emitter
if self._drain_done is not None and self._drain_done.is_set():
return
base = packet.placement or Placement(turn_index=0)
tagged = Packet(
placement=base.model_copy(update={"model_index": self._model_idx}),
obj=packet.obj,
)
self._merged_queue.put((self._model_idx, tagged))

File diff suppressed because it is too large Load Diff

View File

@@ -44,6 +44,31 @@ SEND_USER_METADATA_TO_LLM_PROVIDER = (
# User Facing Features Configs
#####
BLURB_SIZE = 128 # Number Encoder Tokens included in the chunk blurb
# Hard ceiling for the admin-configurable file upload size (in MB).
# Self-hosted customers can raise or lower this via the environment variable.
_raw_max_upload_size_mb = int(os.environ.get("MAX_ALLOWED_UPLOAD_SIZE_MB", "250"))
if _raw_max_upload_size_mb < 0:
logger.warning(
"MAX_ALLOWED_UPLOAD_SIZE_MB=%d is negative; falling back to 250",
_raw_max_upload_size_mb,
)
_raw_max_upload_size_mb = 250
MAX_ALLOWED_UPLOAD_SIZE_MB = _raw_max_upload_size_mb
# Default fallback for the per-user file upload size limit (in MB) when no
# admin-configured value exists. Clamped to MAX_ALLOWED_UPLOAD_SIZE_MB at
# runtime so this never silently exceeds the hard ceiling.
_raw_default_upload_size_mb = int(
os.environ.get("DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB", "100")
)
if _raw_default_upload_size_mb < 0:
logger.warning(
"DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB=%d is negative; falling back to 100",
_raw_default_upload_size_mb,
)
_raw_default_upload_size_mb = 100
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB = _raw_default_upload_size_mb
GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
os.environ.get("GENERATIVE_MODEL_ACCESS_CHECK_FREQ") or 86400
) # 1 day
@@ -61,17 +86,6 @@ CACHE_BACKEND = CacheBackendType(
os.environ.get("CACHE_BACKEND", CacheBackendType.REDIS)
)
# Maximum token count for a single uploaded file. Files exceeding this are rejected.
# Defaults to 100k tokens (or 10M when vector DB is disabled).
_DEFAULT_FILE_TOKEN_LIMIT = 10_000_000 if DISABLE_VECTOR_DB else 100_000
FILE_TOKEN_COUNT_THRESHOLD = int(
os.environ.get("FILE_TOKEN_COUNT_THRESHOLD", str(_DEFAULT_FILE_TOKEN_LIMIT))
)
# Maximum upload size for a single user file (chat/projects) in MB.
USER_FILE_MAX_UPLOAD_SIZE_MB = int(os.environ.get("USER_FILE_MAX_UPLOAD_SIZE_MB") or 50)
USER_FILE_MAX_UPLOAD_SIZE_BYTES = USER_FILE_MAX_UPLOAD_SIZE_MB * 1024 * 1024
# If set to true, will show extra/uncommon connectors in the "Other" category
SHOW_EXTRA_CONNECTORS = os.environ.get("SHOW_EXTRA_CONNECTORS", "").lower() == "true"
@@ -791,6 +805,10 @@ MINI_CHUNK_SIZE = 150
# This is the number of regular chunks per large chunk
LARGE_CHUNK_RATIO = 4
# The maximum number of chunks that can be held for 1 document processing batch
# The purpose of this is to set an upper bound on memory usage
MAX_CHUNKS_PER_DOC_BATCH = int(os.environ.get("MAX_CHUNKS_PER_DOC_BATCH") or 1000)
# Include the document level metadata in each chunk. If the metadata is too long, then it is thrown out
# We don't want the metadata to overwhelm the actual contents of the chunk
SKIP_METADATA_IN_CHUNK = os.environ.get("SKIP_METADATA_IN_CHUNK", "").lower() == "true"

View File

@@ -890,8 +890,8 @@ class ConfluenceConnector(
def _retrieve_all_slim_docs(
self,
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
include_permissions: bool = True,
) -> GenerateSlimDocumentOutput:
@@ -915,8 +915,8 @@ class ConfluenceConnector(
self.confluence_client, doc_id, restrictions, ancestors
) or space_level_access_info.get(page_space_key)
# Query pages
page_query = self.base_cql_page_query + self.cql_label_filter
# Query pages (with optional time filtering for indexing_start)
page_query = self._construct_page_cql_query(start, end)
for page in self.confluence_client.cql_paginate_all_expansions(
cql=page_query,
expand=restrictions_expand,
@@ -950,7 +950,9 @@ class ConfluenceConnector(
# Query attachments for each page
page_hierarchy_node_yielded = False
attachment_query = self._construct_attachment_query(_get_page_id(page))
attachment_query = self._construct_attachment_query(
_get_page_id(page), start, end
)
for attachment in self.confluence_client.cql_paginate_all_expansions(
cql=attachment_query,
expand=restrictions_expand,

View File

@@ -10,6 +10,7 @@ from datetime import timedelta
from datetime import timezone
from typing import Any
import requests
from jira import JIRA
from jira.exceptions import JIRAError
from jira.resources import Issue
@@ -239,29 +240,53 @@ def enhanced_search_ids(
)
def bulk_fetch_issues(
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
) -> list[Issue]:
# TODO: move away from this jira library if they continue to not support
# the endpoints we need. Using private fields is not ideal, but
# is likely fine for now since we pin the library version
def _bulk_fetch_request(
jira_client: JIRA, issue_ids: list[str], fields: str | None
) -> list[dict[str, Any]]:
"""Raw POST to the bulkfetch endpoint. Returns the list of raw issue dicts."""
bulk_fetch_path = jira_client._get_url("issue/bulkfetch")
# Prepare the payload according to Jira API v3 specification
payload: dict[str, Any] = {"issueIdsOrKeys": issue_ids}
# Only restrict fields if specified, might want to explicitly do this in the future
# to avoid reading unnecessary data
payload["fields"] = fields.split(",") if fields else ["*all"]
resp = jira_client._session.post(bulk_fetch_path, json=payload)
return resp.json()["issues"]
def bulk_fetch_issues(
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
) -> list[Issue]:
# TODO(evan): move away from this jira library if they continue to not support
# the endpoints we need. Using private fields is not ideal, but
# is likely fine for now since we pin the library version
try:
response = jira_client._session.post(bulk_fetch_path, json=payload).json()
raw_issues = _bulk_fetch_request(jira_client, issue_ids, fields)
except requests.exceptions.JSONDecodeError:
if len(issue_ids) <= 1:
logger.exception(
f"Jira bulk-fetch response for issue(s) {issue_ids} could not "
f"be decoded as JSON (response too large or truncated)."
)
raise
mid = len(issue_ids) // 2
logger.warning(
f"Jira bulk-fetch JSON decode failed for batch of {len(issue_ids)} issues. "
f"Splitting into sub-batches of {mid} and {len(issue_ids) - mid}."
)
left = bulk_fetch_issues(jira_client, issue_ids[:mid], fields)
right = bulk_fetch_issues(jira_client, issue_ids[mid:], fields)
return left + right
except Exception as e:
logger.error(f"Error fetching issues: {e}")
raise e
raise
return [
Issue(jira_client._options, jira_client._session, raw=issue)
for issue in response["issues"]
for issue in raw_issues
]

View File

@@ -1765,7 +1765,11 @@ class SharepointConnector(
checkpoint.current_drive_delta_next_link = None
checkpoint.seen_document_ids.clear()
def _fetch_slim_documents_from_sharepoint(self) -> GenerateSlimDocumentOutput:
def _fetch_slim_documents_from_sharepoint(
self,
start: datetime | None = None,
end: datetime | None = None,
) -> GenerateSlimDocumentOutput:
site_descriptors = self._filter_excluded_sites(
self.site_descriptors or self.fetch_sites()
)
@@ -1786,7 +1790,9 @@ class SharepointConnector(
# Process site documents if flag is True
if self.include_site_documents:
for driveitem, drive_name, drive_web_url in self._fetch_driveitems(
site_descriptor=site_descriptor
site_descriptor=site_descriptor,
start=start,
end=end,
):
if self._is_driveitem_excluded(driveitem):
logger.debug(f"Excluding by path denylist: {driveitem.web_url}")
@@ -1841,7 +1847,9 @@ class SharepointConnector(
# Process site pages if flag is True
if self.include_site_pages:
site_pages = self._fetch_site_pages(site_descriptor)
site_pages = self._fetch_site_pages(
site_descriptor, start=start, end=end
)
for site_page in site_pages:
logger.debug(
f"Processing site page: {site_page.get('webUrl', site_page.get('name', 'Unknown'))}"
@@ -2565,12 +2573,22 @@ class SharepointConnector(
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None, # noqa: ARG002
) -> GenerateSlimDocumentOutput:
yield from self._fetch_slim_documents_from_sharepoint()
start_dt = (
datetime.fromtimestamp(start, tz=timezone.utc)
if start is not None
else None
)
end_dt = (
datetime.fromtimestamp(end, tz=timezone.utc) if end is not None else None
)
yield from self._fetch_slim_documents_from_sharepoint(
start=start_dt,
end=end_dt,
)
if __name__ == "__main__":

View File

@@ -516,6 +516,8 @@ def _get_all_doc_ids(
] = default_msg_filter,
callback: IndexingHeartbeatInterface | None = None,
workspace_url: str | None = None,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
"""
Get all document ids in the workspace, channel by channel
@@ -546,6 +548,8 @@ def _get_all_doc_ids(
client=client,
channel=channel,
callback=callback,
oldest=str(start) if start else None, # 0.0 -> None intentionally
latest=str(end) if end is not None else None,
)
for message_batch in channel_message_batches:
@@ -847,8 +851,8 @@ class SlackConnector(
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
if self.client is None:
@@ -861,6 +865,8 @@ class SlackConnector(
msg_filter_func=self.msg_filter_func,
callback=callback,
workspace_url=self._workspace_url,
start=start,
end=end,
)
def _load_from_checkpoint(

View File

@@ -617,6 +617,92 @@ def reserve_message_id(
return empty_message
def reserve_multi_model_message_ids(
db_session: Session,
chat_session_id: UUID,
parent_message_id: int,
model_display_names: list[str],
) -> list[ChatMessage]:
"""Reserve N assistant message placeholders for multi-model parallel streaming.
All messages share the same parent (the user message). The parent's
latest_child_message_id points to the LAST reserved message so that the
default history-chain walker picks it up.
"""
reserved: list[ChatMessage] = []
for display_name in model_display_names:
msg = ChatMessage(
chat_session_id=chat_session_id,
parent_message_id=parent_message_id,
latest_child_message_id=None,
message="Response was terminated prior to completion, try regenerating.",
token_count=15, # placeholder; updated on completion by llm_loop_completion_handle
message_type=MessageType.ASSISTANT,
model_display_name=display_name,
)
db_session.add(msg)
reserved.append(msg)
# Flush to assign IDs without committing yet
db_session.flush()
# Point parent's latest_child to the last reserved message
parent = (
db_session.query(ChatMessage)
.filter(ChatMessage.id == parent_message_id)
.first()
)
if parent:
parent.latest_child_message_id = reserved[-1].id
db_session.commit()
return reserved
def set_preferred_response(
db_session: Session,
user_message_id: int,
preferred_assistant_message_id: int,
) -> None:
"""Mark one assistant response as the user's preferred choice in a multi-model turn.
Also advances ``latest_child_message_id`` so the preferred response becomes
the active branch for any subsequent messages in the conversation.
Args:
db_session: Active database session.
user_message_id: Primary key of the ``USER``-type ``ChatMessage`` whose
preferred response is being set.
preferred_assistant_message_id: Primary key of the ``ASSISTANT``-type
``ChatMessage`` to prefer. Must be a direct child of ``user_message_id``.
Raises:
ValueError: If either message is not found, if ``user_message_id`` does not
refer to a USER message, or if the assistant message is not a direct child
of the user message.
"""
user_msg = db_session.get(ChatMessage, user_message_id)
if user_msg is None:
raise ValueError(f"User message {user_message_id} not found")
if user_msg.message_type != MessageType.USER:
raise ValueError(f"Message {user_message_id} is not a user message")
assistant_msg = db_session.get(ChatMessage, preferred_assistant_message_id)
if assistant_msg is None:
raise ValueError(
f"Assistant message {preferred_assistant_message_id} not found"
)
if assistant_msg.parent_message_id != user_message_id:
raise ValueError(
f"Assistant message {preferred_assistant_message_id} is not a child "
f"of user message {user_message_id}"
)
user_msg.preferred_response_id = preferred_assistant_message_id
user_msg.latest_child_message_id = preferred_assistant_message_id
db_session.commit()
def create_new_chat_message(
chat_session_id: UUID,
parent_message: ChatMessage,
@@ -839,6 +925,8 @@ def translate_db_message_to_chat_message_detail(
error=chat_message.error,
current_feedback=current_feedback,
processing_duration_seconds=chat_message.processing_duration_seconds,
preferred_response_id=chat_message.preferred_response_id,
model_display_name=chat_message.model_display_name,
)
return chat_msg_detail

View File

@@ -3135,8 +3135,6 @@ class VoiceProvider(Base):
is_default_stt: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
is_default_tts: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)

View File

@@ -17,39 +17,30 @@ MAX_VOICE_PLAYBACK_SPEED = 2.0
def fetch_voice_providers(db_session: Session) -> list[VoiceProvider]:
"""Fetch all voice providers."""
return list(
db_session.scalars(
select(VoiceProvider)
.where(VoiceProvider.deleted.is_(False))
.order_by(VoiceProvider.name)
).all()
db_session.scalars(select(VoiceProvider).order_by(VoiceProvider.name)).all()
)
def fetch_voice_provider_by_id(
db_session: Session, provider_id: int, include_deleted: bool = False
db_session: Session, provider_id: int
) -> VoiceProvider | None:
"""Fetch a voice provider by ID."""
stmt = select(VoiceProvider).where(VoiceProvider.id == provider_id)
if not include_deleted:
stmt = stmt.where(VoiceProvider.deleted.is_(False))
return db_session.scalar(stmt)
return db_session.scalar(
select(VoiceProvider).where(VoiceProvider.id == provider_id)
)
def fetch_default_stt_provider(db_session: Session) -> VoiceProvider | None:
"""Fetch the default STT provider."""
return db_session.scalar(
select(VoiceProvider)
.where(VoiceProvider.is_default_stt.is_(True))
.where(VoiceProvider.deleted.is_(False))
select(VoiceProvider).where(VoiceProvider.is_default_stt.is_(True))
)
def fetch_default_tts_provider(db_session: Session) -> VoiceProvider | None:
"""Fetch the default TTS provider."""
return db_session.scalar(
select(VoiceProvider)
.where(VoiceProvider.is_default_tts.is_(True))
.where(VoiceProvider.deleted.is_(False))
select(VoiceProvider).where(VoiceProvider.is_default_tts.is_(True))
)
@@ -58,9 +49,7 @@ def fetch_voice_provider_by_type(
) -> VoiceProvider | None:
"""Fetch a voice provider by type."""
return db_session.scalar(
select(VoiceProvider)
.where(VoiceProvider.provider_type == provider_type)
.where(VoiceProvider.deleted.is_(False))
select(VoiceProvider).where(VoiceProvider.provider_type == provider_type)
)
@@ -119,10 +108,10 @@ def upsert_voice_provider(
def delete_voice_provider(db_session: Session, provider_id: int) -> None:
"""Soft-delete a voice provider by ID."""
"""Delete a voice provider by ID."""
provider = fetch_voice_provider_by_id(db_session, provider_id)
if provider:
provider.deleted = True
db_session.delete(provider)
db_session.flush()

View File

@@ -5,6 +5,7 @@ accidentally reaches the vector DB layer will fail loudly instead of timing
out against a nonexistent Vespa/OpenSearch instance.
"""
from collections.abc import Iterable
from typing import Any
from onyx.context.search.models import IndexFilters
@@ -66,7 +67,7 @@ class DisabledDocumentIndex(DocumentIndex):
# ------------------------------------------------------------------
def index(
self,
chunks: list[DocMetadataAwareIndexChunk], # noqa: ARG002
chunks: Iterable[DocMetadataAwareIndexChunk], # noqa: ARG002
index_batch_params: IndexBatchParams, # noqa: ARG002
) -> set[DocumentInsertionRecord]:
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)

View File

@@ -1,4 +1,5 @@
import abc
from collections.abc import Iterable
from dataclasses import dataclass
from datetime import datetime
from typing import Any
@@ -206,7 +207,7 @@ class Indexable(abc.ABC):
@abc.abstractmethod
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
chunks: Iterable[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[DocumentInsertionRecord]:
"""
@@ -226,8 +227,8 @@ class Indexable(abc.ABC):
it is done automatically outside of this code.
Parameters:
- chunks: Document chunks with all of the information needed for indexing to the document
index.
- chunks: Document chunks with all of the information needed for
indexing to the document index.
- tenant_id: The tenant id of the user whose chunks are being indexed
- large_chunks_enabled: Whether large chunks are enabled

View File

@@ -1,4 +1,5 @@
import abc
from collections.abc import Iterable
from typing import Self
from pydantic import BaseModel
@@ -209,10 +210,10 @@ class Indexable(abc.ABC):
@abc.abstractmethod
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
chunks: Iterable[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata,
) -> list[DocumentInsertionRecord]:
"""Indexes a list of document chunks into the document index.
"""Indexes an iterable of document chunks into the document index.
This is often a batch operation including chunks from multiple
documents.

View File

@@ -1,11 +1,12 @@
import json
from collections import defaultdict
from collections.abc import Iterable
from typing import Any
import httpx
from opensearchpy import NotFoundError
from onyx.access.models import DocumentAccess
from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH
from onyx.configs.app_configs import VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT
from onyx.configs.chat_configs import NUM_RETURNED_HITS
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
@@ -351,7 +352,7 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
chunks: Iterable[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[OldDocumentInsertionRecord]:
"""
@@ -647,10 +648,10 @@ class OpenSearchDocumentIndex(DocumentIndex):
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata, # noqa: ARG002
chunks: Iterable[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata,
) -> list[DocumentInsertionRecord]:
"""Indexes a list of document chunks into the document index.
"""Indexes an iterable of document chunks into the document index.
Groups chunks by document ID and for each document, deletes existing
chunks and indexes the new chunks in bulk.
@@ -673,29 +674,34 @@ class OpenSearchDocumentIndex(DocumentIndex):
document is newly indexed or had already existed and was just
updated.
"""
# Group chunks by document ID.
doc_id_to_chunks: dict[str, list[DocMetadataAwareIndexChunk]] = defaultdict(
list
total_chunks = sum(
cc.new_chunk_cnt
for cc in indexing_metadata.doc_id_to_chunk_cnt_diff.values()
)
for chunk in chunks:
doc_id_to_chunks[chunk.source_document.id].append(chunk)
logger.debug(
f"[OpenSearchDocumentIndex] Indexing {len(chunks)} chunks from {len(doc_id_to_chunks)} "
f"[OpenSearchDocumentIndex] Indexing {total_chunks} chunks from {len(indexing_metadata.doc_id_to_chunk_cnt_diff)} "
f"documents for index {self._index_name}."
)
document_indexing_results: list[DocumentInsertionRecord] = []
# Try to index per-document.
for _, chunks in doc_id_to_chunks.items():
deleted_doc_ids: set[str] = set()
# Buffer chunks per document as they arrive from the iterable.
# When the document ID changes flush the buffered chunks.
current_doc_id: str | None = None
current_chunks: list[DocMetadataAwareIndexChunk] = []
def _flush_chunks(doc_chunks: list[DocMetadataAwareIndexChunk]) -> None:
assert len(doc_chunks) > 0, "doc_chunks is empty"
# Create a batch of OpenSearch-formatted chunks for bulk insertion.
# Do this before deleting existing chunks to reduce the amount of
# time the document index has no content for a given document, and
# to reduce the chance of entering a state where we delete chunks,
# then some error happens, and never successfully index new chunks.
# Since we are doing this in batches, an error occurring midway
# can result in a state where chunks are deleted and not all the
# new chunks have been indexed.
chunk_batch: list[DocumentChunk] = [
_convert_onyx_chunk_to_opensearch_document(chunk) for chunk in chunks
_convert_onyx_chunk_to_opensearch_document(chunk)
for chunk in doc_chunks
]
onyx_document: Document = chunks[0].source_document
onyx_document: Document = doc_chunks[0].source_document
# First delete the doc's chunks from the index. This is so that
# there are no dangling chunks in the index, in the event that the
# new document's content contains fewer chunks than the previous
@@ -704,22 +710,43 @@ class OpenSearchDocumentIndex(DocumentIndex):
# if the chunk count has actually decreased. This assumes that
# overlapping chunks are perfectly overwritten. If we can't
# guarantee that then we need the code as-is.
num_chunks_deleted = self.delete(
onyx_document.id, onyx_document.chunk_count
)
# If we see that chunks were deleted we assume the doc already
# existed.
document_insertion_record = DocumentInsertionRecord(
document_id=onyx_document.id,
already_existed=num_chunks_deleted > 0,
)
if onyx_document.id not in deleted_doc_ids:
num_chunks_deleted = self.delete(
onyx_document.id, onyx_document.chunk_count
)
deleted_doc_ids.add(onyx_document.id)
# If we see that chunks were deleted we assume the doc already
# existed. We record the result before bulk_index_documents
# runs. If indexing raises, this entire result list is discarded
# by the caller's retry logic, so early recording is safe.
document_indexing_results.append(
DocumentInsertionRecord(
document_id=onyx_document.id,
already_existed=num_chunks_deleted > 0,
)
)
# Now index. This will raise if a chunk of the same ID exists, which
# we do not expect because we should have deleted all chunks.
self._client.bulk_index_documents(
documents=chunk_batch,
tenant_state=self._tenant_state,
)
document_indexing_results.append(document_insertion_record)
for chunk in chunks:
doc_id = chunk.source_document.id
if doc_id != current_doc_id:
if current_chunks:
_flush_chunks(current_chunks)
current_doc_id = doc_id
current_chunks = [chunk]
elif len(current_chunks) >= MAX_CHUNKS_PER_DOC_BATCH:
_flush_chunks(current_chunks)
current_chunks = [chunk]
else:
current_chunks.append(chunk)
if current_chunks:
_flush_chunks(current_chunks)
return document_indexing_results

View File

@@ -6,6 +6,7 @@ import re
import time
import urllib
import zipfile
from collections.abc import Iterable
from dataclasses import dataclass
from datetime import datetime
from datetime import timedelta
@@ -461,7 +462,7 @@ class VespaIndex(DocumentIndex):
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
chunks: Iterable[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[OldDocumentInsertionRecord]:
"""

View File

@@ -1,6 +1,8 @@
import concurrent.futures
import logging
import random
from collections.abc import Generator
from collections.abc import Iterable
from typing import Any
from uuid import UUID
@@ -8,6 +10,7 @@ import httpx
from pydantic import BaseModel
from retry import retry
from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH
from onyx.configs.app_configs import RECENCY_BIAS_MULTIPLIER
from onyx.configs.app_configs import RERANK_COUNT
from onyx.configs.chat_configs import DOC_TIME_DECAY
@@ -318,7 +321,7 @@ class VespaDocumentIndex(DocumentIndex):
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
chunks: Iterable[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata,
) -> list[DocumentInsertionRecord]:
doc_id_to_chunk_cnt_diff = indexing_metadata.doc_id_to_chunk_cnt_diff
@@ -338,22 +341,31 @@ class VespaDocumentIndex(DocumentIndex):
# Vespa has restrictions on valid characters, yet document IDs come from
# external w.r.t. this class. We need to sanitize them.
cleaned_chunks: list[DocMetadataAwareIndexChunk] = [
clean_chunk_id_copy(chunk) for chunk in chunks
]
assert len(cleaned_chunks) == len(
chunks
), "Bug: Cleaned chunks and input chunks have different lengths."
#
# Instead of materializing all cleaned chunks upfront, we stream them
# through a generator that cleans IDs and builds the original-ID mapping
# incrementally as chunks flow into Vespa.
def _clean_and_track(
chunks_iter: Iterable[DocMetadataAwareIndexChunk],
id_map: dict[str, str],
seen_ids: set[str],
) -> Generator[DocMetadataAwareIndexChunk, None, None]:
"""Cleans chunk IDs and builds the original-ID mapping
incrementally as chunks flow through, avoiding a separate
materialization pass."""
for chunk in chunks_iter:
original_id = chunk.source_document.id
cleaned = clean_chunk_id_copy(chunk)
cleaned_id = cleaned.source_document.id
# Needed so the final DocumentInsertionRecord returned can have
# the original document ID. cleaned_chunks might not contain IDs
# exactly as callers supplied them.
id_map[cleaned_id] = original_id
seen_ids.add(cleaned_id)
yield cleaned
# Needed so the final DocumentInsertionRecord returned can have the
# original document ID. cleaned_chunks might not contain IDs exactly as
# callers supplied them.
new_document_id_to_original_document_id: dict[str, str] = dict()
for i, cleaned_chunk in enumerate(cleaned_chunks):
old_chunk = chunks[i]
new_document_id_to_original_document_id[
cleaned_chunk.source_document.id
] = old_chunk.source_document.id
new_document_id_to_original_document_id: dict[str, str] = {}
all_cleaned_doc_ids: set[str] = set()
existing_docs: set[str] = set()
@@ -409,8 +421,16 @@ class VespaDocumentIndex(DocumentIndex):
executor=executor,
)
# Insert new Vespa documents.
for chunk_batch in batch_generator(cleaned_chunks, BATCH_SIZE):
# Insert new Vespa documents, streaming through the cleaning
# pipeline so chunks are never fully materialized.
cleaned_chunks = _clean_and_track(
chunks,
new_document_id_to_original_document_id,
all_cleaned_doc_ids,
)
for chunk_batch in batch_generator(
cleaned_chunks, min(BATCH_SIZE, MAX_CHUNKS_PER_DOC_BATCH)
):
batch_index_vespa_chunks(
chunks=chunk_batch,
index_name=self._index_name,
@@ -419,10 +439,6 @@ class VespaDocumentIndex(DocumentIndex):
executor=executor,
)
all_cleaned_doc_ids: set[str] = {
chunk.source_document.id for chunk in cleaned_chunks
}
return [
DocumentInsertionRecord(
document_id=new_document_id_to_original_document_id[cleaned_doc_id],

View File

@@ -44,6 +44,7 @@ KNOWN_OPENPYXL_BUGS = [
"Value must be either numerical or a string containing a wildcard",
"File contains no valid workbook part",
"Unable to read workbook: could not read stylesheet from None",
"Colors must be aRGB hex values",
]

View File

@@ -19,7 +19,8 @@ from onyx.db.document import update_docs_updated_at__no_commit
from onyx.db.document_set import fetch_document_sets_for_documents
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
from onyx.indexing.models import BuildMetadataAwareChunksResult
from onyx.indexing.models import ChunkEnrichmentContext
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexChunk
from onyx.indexing.models import UpdatableChunkData
@@ -85,14 +86,21 @@ class DocumentIndexingBatchAdapter:
) as transaction:
yield transaction
def build_metadata_aware_chunks(
def prepare_enrichment(
self,
chunks_with_embeddings: list[IndexChunk],
chunk_content_scores: list[float],
tenant_id: str,
context: DocumentBatchPrepareContext,
) -> BuildMetadataAwareChunksResult:
"""Enrich chunks with access, document sets, boosts, token counts, and hierarchy."""
tenant_id: str,
chunks: list[DocAwareChunk],
) -> "DocumentChunkEnricher":
"""Do all DB lookups once and return a per-chunk enricher."""
updatable_ids = [doc.id for doc in context.updatable_docs]
doc_id_to_new_chunk_cnt: dict[str, int] = {
doc_id: 0 for doc_id in updatable_ids
}
for chunk in chunks:
if chunk.source_document.id in doc_id_to_new_chunk_cnt:
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
no_access = DocumentAccess.build(
user_emails=[],
@@ -102,67 +110,30 @@ class DocumentIndexingBatchAdapter:
is_public=False,
)
updatable_ids = [doc.id for doc in context.updatable_docs]
doc_id_to_access_info = get_access_for_documents(
document_ids=updatable_ids, db_session=self.db_session
)
doc_id_to_document_set = {
document_id: document_sets
for document_id, document_sets in fetch_document_sets_for_documents(
return DocumentChunkEnricher(
doc_id_to_access_info=get_access_for_documents(
document_ids=updatable_ids, db_session=self.db_session
)
}
doc_id_to_previous_chunk_cnt: dict[str, int] = {
document_id: chunk_count
for document_id, chunk_count in fetch_chunk_counts_for_documents(
document_ids=updatable_ids,
db_session=self.db_session,
)
}
doc_id_to_new_chunk_cnt: dict[str, int] = {
doc_id: 0 for doc_id in updatable_ids
}
for chunk in chunks_with_embeddings:
if chunk.source_document.id in doc_id_to_new_chunk_cnt:
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
# Get ancestor hierarchy node IDs for each document
doc_id_to_ancestor_ids = self._get_ancestor_ids_for_documents(
context.updatable_docs, tenant_id
)
access_aware_chunks = [
DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=doc_id_to_access_info.get(chunk.source_document.id, no_access),
document_sets=set(
doc_id_to_document_set.get(chunk.source_document.id, [])
),
user_project=[],
personas=[],
boost=(
context.id_to_boost_map[chunk.source_document.id]
if chunk.source_document.id in context.id_to_boost_map
else DEFAULT_BOOST
),
tenant_id=tenant_id,
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],
ancestor_hierarchy_node_ids=doc_id_to_ancestor_ids[
chunk.source_document.id
],
)
for chunk_num, chunk in enumerate(chunks_with_embeddings)
]
return BuildMetadataAwareChunksResult(
chunks=access_aware_chunks,
doc_id_to_previous_chunk_cnt=doc_id_to_previous_chunk_cnt,
doc_id_to_new_chunk_cnt=doc_id_to_new_chunk_cnt,
user_file_id_to_raw_text={},
user_file_id_to_token_count={},
),
doc_id_to_document_set={
document_id: document_sets
for document_id, document_sets in fetch_document_sets_for_documents(
document_ids=updatable_ids, db_session=self.db_session
)
},
doc_id_to_ancestor_ids=self._get_ancestor_ids_for_documents(
context.updatable_docs, tenant_id
),
id_to_boost_map=context.id_to_boost_map,
doc_id_to_previous_chunk_cnt={
document_id: chunk_count
for document_id, chunk_count in fetch_chunk_counts_for_documents(
document_ids=updatable_ids,
db_session=self.db_session,
)
},
doc_id_to_new_chunk_cnt=dict(doc_id_to_new_chunk_cnt),
no_access=no_access,
tenant_id=tenant_id,
)
def _get_ancestor_ids_for_documents(
@@ -203,7 +174,7 @@ class DocumentIndexingBatchAdapter:
context: DocumentBatchPrepareContext,
updatable_chunk_data: list[UpdatableChunkData],
filtered_documents: list[Document],
result: BuildMetadataAwareChunksResult,
enrichment: ChunkEnrichmentContext,
) -> None:
"""Finalize DB updates, store plaintext, and mark docs as indexed."""
updatable_ids = [doc.id for doc in context.updatable_docs]
@@ -227,7 +198,7 @@ class DocumentIndexingBatchAdapter:
update_docs_chunk_count__no_commit(
document_ids=updatable_ids,
doc_id_to_chunk_count=result.doc_id_to_new_chunk_cnt,
doc_id_to_chunk_count=enrichment.doc_id_to_new_chunk_cnt,
db_session=self.db_session,
)
@@ -249,3 +220,52 @@ class DocumentIndexingBatchAdapter:
)
self.db_session.commit()
class DocumentChunkEnricher:
"""Pre-computed metadata for per-chunk enrichment of connector documents."""
def __init__(
self,
doc_id_to_access_info: dict[str, DocumentAccess],
doc_id_to_document_set: dict[str, list[str]],
doc_id_to_ancestor_ids: dict[str, list[int]],
id_to_boost_map: dict[str, int],
doc_id_to_previous_chunk_cnt: dict[str, int],
doc_id_to_new_chunk_cnt: dict[str, int],
no_access: DocumentAccess,
tenant_id: str,
) -> None:
self._doc_id_to_access_info = doc_id_to_access_info
self._doc_id_to_document_set = doc_id_to_document_set
self._doc_id_to_ancestor_ids = doc_id_to_ancestor_ids
self._id_to_boost_map = id_to_boost_map
self._no_access = no_access
self._tenant_id = tenant_id
self.doc_id_to_previous_chunk_cnt = doc_id_to_previous_chunk_cnt
self.doc_id_to_new_chunk_cnt = doc_id_to_new_chunk_cnt
def enrich_chunk(
self, chunk: IndexChunk, score: float
) -> DocMetadataAwareIndexChunk:
return DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=self._doc_id_to_access_info.get(
chunk.source_document.id, self._no_access
),
document_sets=set(
self._doc_id_to_document_set.get(chunk.source_document.id, [])
),
user_project=[],
personas=[],
boost=(
self._id_to_boost_map[chunk.source_document.id]
if chunk.source_document.id in self._id_to_boost_map
else DEFAULT_BOOST
),
tenant_id=self._tenant_id,
aggregated_chunk_boost_factor=score,
ancestor_hierarchy_node_ids=self._doc_id_to_ancestor_ids[
chunk.source_document.id
],
)

View File

@@ -1,6 +1,9 @@
from __future__ import annotations
import contextlib
import datetime
import time
from collections import defaultdict
from collections.abc import Generator
from uuid import UUID
@@ -24,11 +27,13 @@ from onyx.db.user_file import fetch_persona_ids_for_user_files
from onyx.db.user_file import fetch_user_project_ids_for_user_files
from onyx.file_store.utils import store_user_file_plaintext
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
from onyx.indexing.models import BuildMetadataAwareChunksResult
from onyx.indexing.models import ChunkEnrichmentContext
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexChunk
from onyx.indexing.models import UpdatableChunkData
from onyx.llm.factory import get_default_llm
from onyx.natural_language_processing.utils import count_tokens
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.utils.logger import setup_logger
@@ -101,13 +106,20 @@ class UserFileIndexingAdapter:
f"Failed to acquire locks after {_NUM_LOCK_ATTEMPTS} attempts for user files: {[doc.id for doc in documents]}"
)
def build_metadata_aware_chunks(
def prepare_enrichment(
self,
chunks_with_embeddings: list[IndexChunk],
chunk_content_scores: list[float],
tenant_id: str,
context: DocumentBatchPrepareContext,
) -> BuildMetadataAwareChunksResult:
tenant_id: str,
chunks: list[DocAwareChunk],
) -> UserFileChunkEnricher:
"""Do all DB lookups and pre-compute file metadata from chunks."""
updatable_ids = [doc.id for doc in context.updatable_docs]
doc_id_to_new_chunk_cnt: dict[str, int] = defaultdict(int)
content_by_file: dict[str, list[str]] = defaultdict(list)
for chunk in chunks:
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
content_by_file[chunk.source_document.id].append(chunk.content)
no_access = DocumentAccess.build(
user_emails=[],
@@ -117,7 +129,6 @@ class UserFileIndexingAdapter:
is_public=False,
)
updatable_ids = [doc.id for doc in context.updatable_docs]
user_file_id_to_project_ids = fetch_user_project_ids_for_user_files(
user_file_ids=updatable_ids,
db_session=self.db_session,
@@ -138,17 +149,6 @@ class UserFileIndexingAdapter:
)
}
user_file_id_to_new_chunk_cnt: dict[str, int] = {
user_file_id: len(
[
chunk
for chunk in chunks_with_embeddings
if chunk.source_document.id == user_file_id
]
)
for user_file_id in updatable_ids
}
# Initialize tokenizer used for token count calculation
try:
llm = get_default_llm()
@@ -163,46 +163,30 @@ class UserFileIndexingAdapter:
user_file_id_to_raw_text: dict[str, str] = {}
user_file_id_to_token_count: dict[str, int | None] = {}
for user_file_id in updatable_ids:
user_file_chunks = [
chunk
for chunk in chunks_with_embeddings
if chunk.source_document.id == user_file_id
]
if user_file_chunks:
combined_content = " ".join(
[chunk.content for chunk in user_file_chunks]
)
contents = content_by_file.get(user_file_id)
if contents:
combined_content = " ".join(contents)
user_file_id_to_raw_text[str(user_file_id)] = combined_content
token_count = (
len(llm_tokenizer.encode(combined_content)) if llm_tokenizer else 0
token_count: int = (
count_tokens(combined_content, llm_tokenizer)
if llm_tokenizer
else 0
)
user_file_id_to_token_count[str(user_file_id)] = token_count
else:
user_file_id_to_raw_text[str(user_file_id)] = ""
user_file_id_to_token_count[str(user_file_id)] = None
access_aware_chunks = [
DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=user_file_id_to_access.get(chunk.source_document.id, no_access),
document_sets=set(),
user_project=user_file_id_to_project_ids.get(
chunk.source_document.id, []
),
personas=user_file_id_to_persona_ids.get(chunk.source_document.id, []),
boost=DEFAULT_BOOST,
tenant_id=tenant_id,
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],
)
for chunk_num, chunk in enumerate(chunks_with_embeddings)
]
return BuildMetadataAwareChunksResult(
chunks=access_aware_chunks,
return UserFileChunkEnricher(
user_file_id_to_access=user_file_id_to_access,
user_file_id_to_project_ids=user_file_id_to_project_ids,
user_file_id_to_persona_ids=user_file_id_to_persona_ids,
doc_id_to_previous_chunk_cnt=user_file_id_to_previous_chunk_cnt,
doc_id_to_new_chunk_cnt=user_file_id_to_new_chunk_cnt,
doc_id_to_new_chunk_cnt=dict(doc_id_to_new_chunk_cnt),
user_file_id_to_raw_text=user_file_id_to_raw_text,
user_file_id_to_token_count=user_file_id_to_token_count,
no_access=no_access,
tenant_id=tenant_id,
)
def _notify_assistant_owners_if_files_ready(
@@ -246,8 +230,9 @@ class UserFileIndexingAdapter:
context: DocumentBatchPrepareContext,
updatable_chunk_data: list[UpdatableChunkData], # noqa: ARG002
filtered_documents: list[Document], # noqa: ARG002
result: BuildMetadataAwareChunksResult,
enrichment: ChunkEnrichmentContext,
) -> None:
assert isinstance(enrichment, UserFileChunkEnricher)
user_file_ids = [doc.id for doc in context.updatable_docs]
user_files = (
@@ -263,8 +248,10 @@ class UserFileIndexingAdapter:
user_file.last_project_sync_at = datetime.datetime.now(
datetime.timezone.utc
)
user_file.chunk_count = result.doc_id_to_new_chunk_cnt[str(user_file.id)]
user_file.token_count = result.user_file_id_to_token_count[
user_file.chunk_count = enrichment.doc_id_to_new_chunk_cnt.get(
str(user_file.id), 0
)
user_file.token_count = enrichment.user_file_id_to_token_count[
str(user_file.id)
]
@@ -276,8 +263,54 @@ class UserFileIndexingAdapter:
# Store the plaintext in the file store for faster retrieval
# NOTE: this creates its own session to avoid committing the overall
# transaction.
for user_file_id, raw_text in result.user_file_id_to_raw_text.items():
for user_file_id, raw_text in enrichment.user_file_id_to_raw_text.items():
store_user_file_plaintext(
user_file_id=UUID(user_file_id),
plaintext_content=raw_text,
)
class UserFileChunkEnricher:
"""Pre-computed metadata for per-chunk enrichment of user-uploaded files."""
def __init__(
self,
user_file_id_to_access: dict[str, DocumentAccess],
user_file_id_to_project_ids: dict[str, list[int]],
user_file_id_to_persona_ids: dict[str, list[int]],
doc_id_to_previous_chunk_cnt: dict[str, int],
doc_id_to_new_chunk_cnt: dict[str, int],
user_file_id_to_raw_text: dict[str, str],
user_file_id_to_token_count: dict[str, int | None],
no_access: DocumentAccess,
tenant_id: str,
) -> None:
self._user_file_id_to_access = user_file_id_to_access
self._user_file_id_to_project_ids = user_file_id_to_project_ids
self._user_file_id_to_persona_ids = user_file_id_to_persona_ids
self._no_access = no_access
self._tenant_id = tenant_id
self.doc_id_to_previous_chunk_cnt = doc_id_to_previous_chunk_cnt
self.doc_id_to_new_chunk_cnt = doc_id_to_new_chunk_cnt
self.user_file_id_to_raw_text = user_file_id_to_raw_text
self.user_file_id_to_token_count = user_file_id_to_token_count
def enrich_chunk(
self, chunk: IndexChunk, score: float
) -> DocMetadataAwareIndexChunk:
return DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=self._user_file_id_to_access.get(
chunk.source_document.id, self._no_access
),
document_sets=set(),
user_project=self._user_file_id_to_project_ids.get(
chunk.source_document.id, []
),
personas=self._user_file_id_to_persona_ids.get(
chunk.source_document.id, []
),
boost=DEFAULT_BOOST,
tenant_id=self._tenant_id,
aggregated_chunk_boost_factor=score,
)

View File

@@ -0,0 +1,89 @@
import pickle
import shutil
import tempfile
from collections.abc import Iterator
from pathlib import Path
from onyx.indexing.models import IndexChunk
class ChunkBatchStore:
"""Manages serialization of embedded chunks to a temporary directory.
Owns the temp directory lifetime and provides save/load/stream/scrub
operations.
Use as a context manager to ensure cleanup::
with ChunkBatchStore() as store:
store.save(chunks, batch_idx=0)
for chunk in store.stream():
...
"""
_EXT = ".pkl"
def __init__(self) -> None:
self._tmpdir: Path | None = None
# -- context manager -----------------------------------------------------
def __enter__(self) -> "ChunkBatchStore":
self._tmpdir = Path(tempfile.mkdtemp(prefix="onyx_embeddings_"))
return self
def __exit__(self, *_exc: object) -> None:
if self._tmpdir is not None:
shutil.rmtree(self._tmpdir, ignore_errors=True)
self._tmpdir = None
@property
def _dir(self) -> Path:
assert self._tmpdir is not None, "ChunkBatchStore used outside context manager"
return self._tmpdir
# -- storage primitives --------------------------------------------------
def save(self, chunks: list[IndexChunk], batch_idx: int) -> None:
"""Serialize a batch of embedded chunks to disk."""
with open(self._dir / f"batch_{batch_idx}{self._EXT}", "wb") as f:
pickle.dump(chunks, f)
def _load(self, batch_file: Path) -> list[IndexChunk]:
"""Deserialize a batch of embedded chunks from a file."""
with open(batch_file, "rb") as f:
return pickle.load(f)
def _batch_files(self) -> list[Path]:
"""Return batch files sorted by numeric index."""
return sorted(
self._dir.glob(f"batch_*{self._EXT}"),
key=lambda p: int(p.stem.removeprefix("batch_")),
)
# -- higher-level operations ---------------------------------------------
def stream(self) -> Iterator[IndexChunk]:
"""Yield all chunks across all batch files.
Each call returns a fresh generator, so the data can be iterated
multiple times (e.g. once per document index).
"""
for batch_file in self._batch_files():
yield from self._load(batch_file)
def scrub_failed_docs(self, failed_doc_ids: set[str]) -> None:
"""Remove chunks belonging to *failed_doc_ids* from all batch files.
When a document fails embedding in batch N, earlier batches may
already contain successfully embedded chunks for that document.
This ensures the output is all-or-nothing per document.
"""
for batch_file in self._batch_files():
batch_chunks = self._load(batch_file)
cleaned = [
c for c in batch_chunks if c.source_document.id not in failed_doc_ids
]
if len(cleaned) != len(batch_chunks):
with open(batch_file, "wb") as f:
pickle.dump(cleaned, f)

View File

@@ -1,5 +1,8 @@
from collections import defaultdict
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Protocol
from pydantic import BaseModel
@@ -9,6 +12,7 @@ from sqlalchemy.orm import Session
from onyx.configs.app_configs import DEFAULT_CONTEXTUAL_RAG_LLM_NAME
from onyx.configs.app_configs import DEFAULT_CONTEXTUAL_RAG_LLM_PROVIDER
from onyx.configs.app_configs import ENABLE_CONTEXTUAL_RAG
from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
from onyx.configs.app_configs import MAX_TOKENS_FOR_FULL_INCLUSION
from onyx.configs.app_configs import USE_CHUNK_SUMMARY
@@ -43,10 +47,12 @@ from onyx.document_index.interfaces import DocumentMetadata
from onyx.document_index.interfaces import IndexBatchParams
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
from onyx.file_store.file_store import get_default_file_store
from onyx.indexing.chunk_batch_store import ChunkBatchStore
from onyx.indexing.chunker import Chunker
from onyx.indexing.embedder import embed_chunks_with_failure_handling
from onyx.indexing.embedder import IndexingEmbedder
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexingBatchAdapter
from onyx.indexing.models import UpdatableChunkData
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
@@ -63,6 +69,7 @@ from onyx.natural_language_processing.utils import tokenizer_trim_middle
from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT1
from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT2
from onyx.prompts.contextual_retrieval import DOCUMENT_SUMMARY_PROMPT
from onyx.utils.batching import batch_generator
from onyx.utils.logger import setup_logger
from onyx.utils.postgres_sanitization import sanitize_documents_for_postgres
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -91,6 +98,20 @@ class IndexingPipelineResult(BaseModel):
failures: list[ConnectorFailure]
@classmethod
def empty(cls, total_docs: int) -> "IndexingPipelineResult":
return cls(
new_docs=0,
total_docs=total_docs,
total_chunks=0,
failures=[],
)
class ChunkEmbeddingResult(BaseModel):
successful_chunk_ids: list[tuple[int, str]] # (chunk_id, document_id)
connector_failures: list[ConnectorFailure]
class IndexingPipelineProtocol(Protocol):
def __call__(
@@ -139,6 +160,110 @@ def _upsert_documents_in_db(
)
def _get_failed_doc_ids(failures: list[ConnectorFailure]) -> set[str]:
"""Extract document IDs from a list of connector failures."""
return {f.failed_document.document_id for f in failures if f.failed_document}
def _embed_chunks_to_store(
chunks: list[DocAwareChunk],
embedder: IndexingEmbedder,
tenant_id: str,
request_id: str | None,
store: ChunkBatchStore,
) -> ChunkEmbeddingResult:
"""Embed chunks in batches, spilling each batch to *store*.
If a document fails embedding in any batch, its chunks are excluded from
all batches (including earlier ones already written) so that the output
is all-or-nothing per document.
"""
successful_chunk_ids: list[tuple[int, str]] = []
all_embedding_failures: list[ConnectorFailure] = []
# Track failed doc IDs across all batches so that a failure in batch N
# causes chunks for that doc to be skipped in batch N+1 and stripped
# from earlier batches.
all_failed_doc_ids: set[str] = set()
for batch_idx, chunk_batch in enumerate(
batch_generator(chunks, MAX_CHUNKS_PER_DOC_BATCH)
):
# Skip chunks belonging to documents that failed in earlier batches.
chunk_batch = [
c for c in chunk_batch if c.source_document.id not in all_failed_doc_ids
]
if not chunk_batch:
continue
logger.debug(f"Embedding batch {batch_idx}: {len(chunk_batch)} chunks")
chunks_with_embeddings, embedding_failures = embed_chunks_with_failure_handling(
chunks=chunk_batch,
embedder=embedder,
tenant_id=tenant_id,
request_id=request_id,
)
all_embedding_failures.extend(embedding_failures)
all_failed_doc_ids.update(_get_failed_doc_ids(embedding_failures))
# Only keep successfully embedded chunks for non-failed docs.
chunks_with_embeddings = [
c
for c in chunks_with_embeddings
if c.source_document.id not in all_failed_doc_ids
]
successful_chunk_ids.extend(
(c.chunk_id, c.source_document.id) for c in chunks_with_embeddings
)
store.save(chunks_with_embeddings, batch_idx)
del chunks_with_embeddings
# Scrub earlier batches for docs that failed in later batches.
if all_failed_doc_ids:
store.scrub_failed_docs(all_failed_doc_ids)
successful_chunk_ids = [
(chunk_id, doc_id)
for chunk_id, doc_id in successful_chunk_ids
if doc_id not in all_failed_doc_ids
]
return ChunkEmbeddingResult(
successful_chunk_ids=successful_chunk_ids,
connector_failures=all_embedding_failures,
)
@contextmanager
def embed_and_stream(
chunks: list[DocAwareChunk],
embedder: IndexingEmbedder,
tenant_id: str,
request_id: str | None,
) -> Generator[tuple[ChunkEmbeddingResult, ChunkBatchStore], None, None]:
"""Embed chunks to disk and yield a ``(result, store)`` pair.
The store owns the temp directory — files are cleaned up when the context
manager exits.
Usage::
with embed_and_stream(chunks, embedder, tenant_id, req_id) as (result, store):
for chunk in store.stream():
...
"""
with ChunkBatchStore() as store:
result = _embed_chunks_to_store(
chunks=chunks,
embedder=embedder,
tenant_id=tenant_id,
request_id=request_id,
store=store,
)
yield result, store
def get_doc_ids_to_update(
documents: list[Document], db_docs: list[DBDocument]
) -> list[Document]:
@@ -637,6 +762,29 @@ def add_contextual_summaries(
return chunks
def _verify_indexing_completeness(
insertion_records: list[DocumentInsertionRecord],
write_failures: list[ConnectorFailure],
embedding_failed_doc_ids: set[str],
updatable_ids: list[str],
document_index_name: str,
) -> None:
"""Verify that every updatable document was either indexed or reported as failed."""
all_returned_doc_ids = (
{r.document_id for r in insertion_records}
| {f.failed_document.document_id for f in write_failures if f.failed_document}
| embedding_failed_doc_ids
)
if all_returned_doc_ids != set(updatable_ids):
raise RuntimeError(
f"Some documents were not successfully indexed. "
f"Updatable IDs: {updatable_ids}, "
f"Returned IDs: {all_returned_doc_ids}. "
f"This should never happen. "
f"This occured for document index {document_index_name}"
)
@log_function_time(debug_only=True)
def index_doc_batch(
*,
@@ -672,12 +820,7 @@ def index_doc_batch(
filtered_documents = filter_fnc(document_batch)
context = adapter.prepare(filtered_documents, ignore_time_skip)
if not context:
return IndexingPipelineResult(
new_docs=0,
total_docs=len(filtered_documents),
total_chunks=0,
failures=[],
)
return IndexingPipelineResult.empty(len(filtered_documents))
# Convert documents to IndexingDocument objects with processed section
# logger.debug("Processing image sections")
@@ -716,117 +859,99 @@ def index_doc_batch(
)
logger.debug("Starting embedding")
chunks_with_embeddings, embedding_failures = (
embed_chunks_with_failure_handling(
chunks=chunks,
embedder=embedder,
tenant_id=tenant_id,
request_id=request_id,
)
if chunks
else ([], [])
)
with embed_and_stream(chunks, embedder, tenant_id, request_id) as (
embedding_result,
chunk_store,
):
updatable_ids = [doc.id for doc in context.updatable_docs]
updatable_chunk_data = [
UpdatableChunkData(
chunk_id=chunk_id,
document_id=document_id,
boost_score=1.0,
)
for chunk_id, document_id in embedding_result.successful_chunk_ids
]
chunk_content_scores = [1.0] * len(chunks_with_embeddings)
updatable_ids = [doc.id for doc in context.updatable_docs]
updatable_chunk_data = [
UpdatableChunkData(
chunk_id=chunk.chunk_id,
document_id=chunk.source_document.id,
boost_score=score,
)
for chunk, score in zip(chunks_with_embeddings, chunk_content_scores)
]
# Acquires a lock on the documents so that no other process can modify them
# NOTE: don't need to acquire till here, since this is when the actual race condition
# with Vespa can occur.
with adapter.lock_context(context.updatable_docs):
# we're concerned about race conditions where multiple simultaneous indexings might result
# in one set of metadata overwriting another one in vespa.
# we still write data here for the immediate and most likely correct sync, but
# to resolve this, an update of the last modified field at the end of this loop
# always triggers a final metadata sync via the celery queue
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=chunks_with_embeddings,
chunk_content_scores=chunk_content_scores,
tenant_id=tenant_id,
context=context,
embedding_failed_doc_ids = _get_failed_doc_ids(
embedding_result.connector_failures
)
short_descriptor_list = [chunk.to_short_descriptor() for chunk in result.chunks]
short_descriptor_log = str(short_descriptor_list)[:1024]
logger.debug(f"Indexing the following chunks: {short_descriptor_log}")
# Filter to only successfully embedded chunks so
# doc_id_to_new_chunk_cnt reflects what's actually written to Vespa.
embedded_chunks = [
c for c in chunks if c.source_document.id not in embedding_failed_doc_ids
]
primary_doc_idx_insertion_records: list[DocumentInsertionRecord] | None = None
primary_doc_idx_vector_db_write_failures: list[ConnectorFailure] | None = None
for document_index in document_indices:
# A document will not be spread across different batches, so all the
# documents with chunks in this set, are fully represented by the chunks
# in this set
(
insertion_records,
vector_db_write_failures,
) = write_chunks_to_vector_db_with_backoff(
document_index=document_index,
chunks=result.chunks,
index_batch_params=IndexBatchParams(
doc_id_to_previous_chunk_cnt=result.doc_id_to_previous_chunk_cnt,
doc_id_to_new_chunk_cnt=result.doc_id_to_new_chunk_cnt,
tenant_id=tenant_id,
large_chunks_enabled=chunker.enable_large_chunks,
),
# Acquires a lock on the documents so that no other process can modify
# them. Not needed until here, since this is when the actual race
# condition with vector db can occur.
with adapter.lock_context(context.updatable_docs):
enricher = adapter.prepare_enrichment(
context=context,
tenant_id=tenant_id,
chunks=embedded_chunks,
)
all_returned_doc_ids: set[str] = (
{record.document_id for record in insertion_records}
.union(
{
record.failed_document.document_id
for record in vector_db_write_failures
if record.failed_document
}
)
.union(
{
record.failed_document.document_id
for record in embedding_failures
if record.failed_document
}
)
index_batch_params = IndexBatchParams(
doc_id_to_previous_chunk_cnt=enricher.doc_id_to_previous_chunk_cnt,
doc_id_to_new_chunk_cnt=enricher.doc_id_to_new_chunk_cnt,
tenant_id=tenant_id,
large_chunks_enabled=chunker.enable_large_chunks,
)
if all_returned_doc_ids != set(updatable_ids):
raise RuntimeError(
f"Some documents were not successfully indexed. "
f"Updatable IDs: {updatable_ids}, "
f"Returned IDs: {all_returned_doc_ids}. "
"This should never happen."
f"This occured for document index {document_index.__class__.__name__}"
)
# We treat the first document index we got as the primary one used
# for reporting the state of indexing.
if primary_doc_idx_insertion_records is None:
primary_doc_idx_insertion_records = insertion_records
if primary_doc_idx_vector_db_write_failures is None:
primary_doc_idx_vector_db_write_failures = vector_db_write_failures
adapter.post_index(
context=context,
updatable_chunk_data=updatable_chunk_data,
filtered_documents=filtered_documents,
result=result,
)
primary_doc_idx_insertion_records: list[DocumentInsertionRecord] | None = (
None
)
primary_doc_idx_vector_db_write_failures: list[ConnectorFailure] | None = (
None
)
for document_index in document_indices:
def _enriched_stream() -> Iterator[DocMetadataAwareIndexChunk]:
for chunk in chunk_store.stream():
yield enricher.enrich_chunk(chunk, 1.0)
insertion_records, write_failures = (
write_chunks_to_vector_db_with_backoff(
document_index=document_index,
make_chunks=_enriched_stream,
index_batch_params=index_batch_params,
)
)
_verify_indexing_completeness(
insertion_records=insertion_records,
write_failures=write_failures,
embedding_failed_doc_ids=embedding_failed_doc_ids,
updatable_ids=updatable_ids,
document_index_name=document_index.__class__.__name__,
)
# We treat the first document index we got as the primary one used
# for reporting the state of indexing.
if primary_doc_idx_insertion_records is None:
primary_doc_idx_insertion_records = insertion_records
if primary_doc_idx_vector_db_write_failures is None:
primary_doc_idx_vector_db_write_failures = write_failures
adapter.post_index(
context=context,
updatable_chunk_data=updatable_chunk_data,
filtered_documents=filtered_documents,
enrichment=enricher,
)
assert primary_doc_idx_insertion_records is not None
assert primary_doc_idx_vector_db_write_failures is not None
return IndexingPipelineResult(
new_docs=len(
[r for r in primary_doc_idx_insertion_records if not r.already_existed]
new_docs=sum(
1 for r in primary_doc_idx_insertion_records if not r.already_existed
),
total_docs=len(filtered_documents),
total_chunks=len(chunks_with_embeddings),
failures=primary_doc_idx_vector_db_write_failures + embedding_failures,
total_chunks=len(embedding_result.successful_chunk_ids),
failures=primary_doc_idx_vector_db_write_failures
+ embedding_result.connector_failures,
)

View File

@@ -235,12 +235,16 @@ class UpdatableChunkData(BaseModel):
boost_score: float
class BuildMetadataAwareChunksResult(BaseModel):
chunks: list[DocMetadataAwareIndexChunk]
class ChunkEnrichmentContext(Protocol):
"""Returned by prepare_enrichment. Holds pre-computed metadata lookups
and provides per-chunk enrichment."""
doc_id_to_previous_chunk_cnt: dict[str, int]
doc_id_to_new_chunk_cnt: dict[str, int]
user_file_id_to_raw_text: dict[str, str]
user_file_id_to_token_count: dict[str, int | None]
def enrich_chunk(
self, chunk: IndexChunk, score: float
) -> DocMetadataAwareIndexChunk: ...
class IndexingBatchAdapter(Protocol):
@@ -254,18 +258,24 @@ class IndexingBatchAdapter(Protocol):
) -> Generator[TransactionalContext, None, None]:
"""Provide a transaction/row-lock context for critical updates."""
def build_metadata_aware_chunks(
def prepare_enrichment(
self,
chunks_with_embeddings: list[IndexChunk],
chunk_content_scores: list[float],
tenant_id: str,
context: "DocumentBatchPrepareContext",
) -> BuildMetadataAwareChunksResult: ...
tenant_id: str,
chunks: list[DocAwareChunk],
) -> ChunkEnrichmentContext:
"""Prepare per-chunk enrichment data (access, document sets, boost, etc.).
Precondition: ``chunks`` have already been through the embedding step
(i.e. they are ``IndexChunk`` instances with populated embeddings,
passed here as the base ``DocAwareChunk`` type).
"""
...
def post_index(
self,
context: "DocumentBatchPrepareContext",
updatable_chunk_data: list[UpdatableChunkData],
filtered_documents: list[Document],
result: BuildMetadataAwareChunksResult,
enrichment: ChunkEnrichmentContext,
) -> None: ...

View File

@@ -1,6 +1,9 @@
import time
from collections import defaultdict
from collections.abc import Callable
from collections.abc import Iterable
from http import HTTPStatus
from itertools import chain
from itertools import groupby
import httpx
@@ -28,22 +31,22 @@ def _log_insufficient_storage_error(e: Exception) -> None:
def write_chunks_to_vector_db_with_backoff(
document_index: DocumentIndex,
chunks: list[DocMetadataAwareIndexChunk],
make_chunks: Callable[[], Iterable[DocMetadataAwareIndexChunk]],
index_batch_params: IndexBatchParams,
) -> tuple[list[DocumentInsertionRecord], list[ConnectorFailure]]:
"""Tries to insert all chunks in one large batch. If that batch fails for any reason,
goes document by document to isolate the failure(s).
IMPORTANT: must pass in whole documents at a time not individual chunks, since the
vector DB interface assumes that all chunks for a single document are present.
vector DB interface assumes that all chunks for a single document are present. The
chunks must also be in contiguous batches
"""
# first try to write the chunks to the vector db
try:
return (
list(
document_index.index(
chunks=chunks,
chunks=make_chunks(),
index_batch_params=index_batch_params,
)
),
@@ -60,14 +63,23 @@ def write_chunks_to_vector_db_with_backoff(
# wait a couple seconds just to give the vector db a chance to recover
time.sleep(2)
# try writing each doc one by one
chunks_for_docs: dict[str, list[DocMetadataAwareIndexChunk]] = defaultdict(list)
for chunk in chunks:
chunks_for_docs[chunk.source_document.id].append(chunk)
insertion_records: list[DocumentInsertionRecord] = []
failures: list[ConnectorFailure] = []
for doc_id, chunks_for_doc in chunks_for_docs.items():
def key(chunk: DocMetadataAwareIndexChunk) -> str:
return chunk.source_document.id
seen_doc_ids: set[str] = set()
for doc_id, chunks_for_doc in groupby(make_chunks(), key=key):
if doc_id in seen_doc_ids:
raise RuntimeError(
f"Doc chunks are not arriving in order. Current doc_id={doc_id}, seen_doc_ids={list(seen_doc_ids)}"
)
seen_doc_ids.add(doc_id)
first_chunk = next(chunks_for_doc)
chunks_for_doc = chain([first_chunk], chunks_for_doc)
try:
insertion_records.extend(
document_index.index(
@@ -87,9 +99,7 @@ def write_chunks_to_vector_db_with_backoff(
ConnectorFailure(
failed_document=DocumentFailure(
document_id=doc_id,
document_link=(
chunks_for_doc[0].get_link() if chunks_for_doc else None
),
document_link=first_chunk.get_link(),
),
failure_message=str(e),
exception=e,

View File

@@ -185,6 +185,21 @@ def _messages_contain_tool_content(messages: list[dict[str, Any]]) -> bool:
return False
def _prompt_contains_tool_call_history(prompt: LanguageModelInput) -> bool:
"""Check if the prompt contains any assistant messages with tool_calls.
When Anthropic's extended thinking is enabled, the API requires every
assistant message to start with a thinking block before any tool_use
blocks. Since we don't preserve thinking_blocks (they carry
cryptographic signatures that can't be reconstructed), we must skip
the thinking param whenever history contains prior tool-calling turns.
"""
from onyx.llm.models import AssistantMessage
msgs = prompt if isinstance(prompt, list) else [prompt]
return any(isinstance(msg, AssistantMessage) and msg.tool_calls for msg in msgs)
def _is_vertex_model_rejecting_output_config(model_name: str) -> bool:
normalized_model_name = model_name.lower()
return any(
@@ -466,7 +481,20 @@ class LitellmLLM(LLM):
reasoning_effort
)
if budget_tokens is not None:
# Anthropic requires every assistant message with tool_use
# blocks to start with a thinking block that carries a
# cryptographic signature. We don't preserve those blocks
# across turns, so skip thinking when the history already
# contains tool-calling assistant messages. LiteLLM's
# modify_params workaround doesn't cover all providers
# (notably Bedrock).
can_enable_thinking = (
budget_tokens is not None
and not _prompt_contains_tool_call_history(prompt)
)
if can_enable_thinking:
assert budget_tokens is not None # mypy
if max_tokens is not None:
# Anthropic has a weird rule where max token has to be at least as much as budget tokens if set
# and the minimum budget tokens is 1024

View File

@@ -8,6 +8,24 @@ from pydantic import BaseModel
class LLMOverride(BaseModel):
"""Per-request LLM settings that override persona defaults.
All fields are optional — only the fields that differ from the persona's
configured LLM need to be supplied. Used both over the wire (API requests)
and for multi-model comparison, where one override is supplied per model.
Attributes:
model_provider: LLM provider slug (e.g. ``"openai"``, ``"anthropic"``).
When ``None``, the persona's default provider is used.
model_version: Specific model version string (e.g. ``"gpt-4o"``).
When ``None``, the persona's default model is used.
temperature: Sampling temperature in ``[0, 2]``. When ``None``, the
persona's default temperature is used.
display_name: Human-readable label shown in the UI for this model,
e.g. ``"GPT-4 Turbo"``. Optional; falls back to ``model_version``
when not set.
"""
model_provider: str | None = None
model_version: str | None = None
temperature: float | None = None

View File

@@ -175,6 +175,32 @@ def get_tokenizer(
return _check_tokenizer_cache(provider_type, model_name)
# Max characters per encode() call.
_ENCODE_CHUNK_SIZE = 500_000
def count_tokens(
text: str,
tokenizer: BaseTokenizer,
token_limit: int | None = None,
) -> int:
"""Count tokens, chunking the input to avoid tiktoken stack overflow.
If token_limit is provided and the text is large enough to require
multiple chunks (> 500k chars), stops early once the count exceeds it.
When early-exiting, the returned value exceeds token_limit but may be
less than the true full token count.
"""
if len(text) <= _ENCODE_CHUNK_SIZE:
return len(tokenizer.encode(text))
total = 0
for start in range(0, len(text), _ENCODE_CHUNK_SIZE):
total += len(tokenizer.encode(text[start : start + _ENCODE_CHUNK_SIZE]))
if token_limit is not None and total > token_limit:
return total # Already over — skip remaining chunks
return total
def tokenizer_trim_content(
content: str, desired_length: int, tokenizer: BaseTokenizer
) -> str:

View File

@@ -3844,9 +3844,9 @@
}
},
"node_modules/@ts-morph/common/node_modules/brace-expansion": {
"version": "5.0.3",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-5.0.3.tgz",
"integrity": "sha512-fy6KJm2RawA5RcHkLa1z/ScpBeA762UF9KmZQxwIbDtRJrgLzM10depAiEQ+CXYcoiqW1/m96OAAoke2nE9EeA==",
"version": "5.0.5",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-5.0.5.tgz",
"integrity": "sha512-VZznLgtwhn+Mact9tfiwx64fA9erHH/MCXEUfB/0bX/6Fz6ny5EGTXYltMocqg4xFAQZtnO3DHWWXi8RiuN7cQ==",
"license": "MIT",
"dependencies": {
"balanced-match": "^4.0.2"
@@ -4224,9 +4224,9 @@
}
},
"node_modules/@typescript-eslint/typescript-estree/node_modules/brace-expansion": {
"version": "2.0.2",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz",
"integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==",
"version": "2.0.3",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.3.tgz",
"integrity": "sha512-MCV/fYJEbqx68aE58kv2cA/kiky1G8vux3OR6/jbS+jIMe/6fJWa0DTzJU7dqijOWYwHi1t29FlfYI9uytqlpA==",
"dev": true,
"license": "MIT",
"dependencies": {
@@ -5007,9 +5007,9 @@
}
},
"node_modules/brace-expansion": {
"version": "1.1.12",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz",
"integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==",
"version": "1.1.13",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.13.tgz",
"integrity": "sha512-9ZLprWS6EENmhEOpjCYW2c8VkmOvckIJZfkr7rBW6dObmfgJ/L1GpSYW5Hpo9lDz4D1+n0Ckz8rU7FwHDQiG/w==",
"dev": true,
"license": "MIT",
"dependencies": {

View File

@@ -44,11 +44,12 @@ def _check_ssrf_safety(endpoint_url: str) -> None:
"""Raise OnyxError if endpoint_url could be used for SSRF.
Delegates to validate_outbound_http_url with https_only=True.
Uses BAD_GATEWAY so the frontend maps the error to the Endpoint URL field.
"""
try:
validate_outbound_http_url(endpoint_url, https_only=True)
except (SSRFException, ValueError) as e:
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
raise OnyxError(OnyxErrorCode.BAD_GATEWAY, str(e))
# ---------------------------------------------------------------------------
@@ -122,9 +123,8 @@ def _validate_endpoint(
(not reachable — indicates the api_key is invalid).
Timeout handling:
- ConnectTimeout: TCP handshake never completed → cannot_connect.
- ReadTimeout / WriteTimeout: TCP was established, server responded slowly → timeout
(operator should consider increasing timeout_seconds).
- Any httpx.TimeoutException (ConnectTimeout, ReadTimeout, WriteTimeout, PoolTimeout) →
timeout (operator should consider increasing timeout_seconds).
- All other exceptions → cannot_connect.
"""
_check_ssrf_safety(endpoint_url)
@@ -141,19 +141,11 @@ def _validate_endpoint(
)
return HookValidateResponse(status=HookValidateStatus.passed)
except httpx.TimeoutException as exc:
# ConnectTimeout: TCP handshake never completed → cannot_connect.
# ReadTimeout / WriteTimeout: TCP was established, server just responded slowly → timeout.
if isinstance(exc, httpx.ConnectTimeout):
logger.warning(
"Hook endpoint validation: connect timeout for %s",
endpoint_url,
exc_info=exc,
)
return HookValidateResponse(
status=HookValidateStatus.cannot_connect, error_message=str(exc)
)
# Any timeout (connect, read, or write) means the configured timeout_seconds
# is too low for this endpoint. Report as timeout so the UI directs the user
# to increase the timeout setting.
logger.warning(
"Hook endpoint validation: read/write timeout for %s",
"Hook endpoint validation: timeout for %s",
endpoint_url,
exc_info=exc,
)

View File

@@ -9,20 +9,15 @@ from pydantic import ConfigDict
from pydantic import Field
from sqlalchemy.orm import Session
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_BYTES
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.db.llm import fetch_default_llm_model
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.file_types import OnyxFileExtensions
from onyx.file_processing.password_validation import is_file_password_protected
from onyx.natural_language_processing.utils import count_tokens
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.server.settings.store import load_settings
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import SKIP_USERFILE_THRESHOLD
from shared_configs.configs import SKIP_USERFILE_THRESHOLD_TENANT_LIST
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -161,8 +156,8 @@ def categorize_uploaded_files(
document formats (.pdf, .docx, …) and falls back to a text-detection
heuristic for unknown extensions (.py, .js, .rs, …).
- Uses default tokenizer to compute token length.
- If token length > threshold, reject file (unless threshold skip is enabled).
- If text cannot be extracted, reject file.
- If token length exceeds the admin-configured threshold, reject file.
- If extension unsupported or text cannot be extracted, reject file.
- Otherwise marked as acceptable.
"""
@@ -173,36 +168,33 @@ def categorize_uploaded_files(
provider_type = default_model.llm_provider.provider if default_model else None
tokenizer = get_tokenizer(model_name=model_name, provider_type=provider_type)
# Check if threshold checks should be skipped
skip_threshold = False
# Check global skip flag (works for both single-tenant and multi-tenant)
if SKIP_USERFILE_THRESHOLD:
skip_threshold = True
logger.info("Skipping userfile threshold check (global setting)")
# Check tenant-specific skip list (only applicable in multi-tenant)
elif MULTI_TENANT and SKIP_USERFILE_THRESHOLD_TENANT_LIST:
try:
current_tenant_id = get_current_tenant_id()
skip_threshold = current_tenant_id in SKIP_USERFILE_THRESHOLD_TENANT_LIST
if skip_threshold:
logger.info(
f"Skipping userfile threshold check for tenant: {current_tenant_id}"
)
except RuntimeError as e:
logger.warning(f"Failed to get current tenant ID: {str(e)}")
# Derive limits from admin-configurable settings.
# For upload size: load_settings() resolves 0/None to a positive default.
# For token threshold: 0 means "no limit" (converted to None below).
settings = load_settings()
max_upload_size_mb = (
settings.user_file_max_upload_size_mb
) # always positive after load_settings()
max_upload_size_bytes = (
max_upload_size_mb * 1024 * 1024 if max_upload_size_mb else None
)
token_threshold_k = settings.file_token_count_threshold_k
token_threshold = (
token_threshold_k * 1000 if token_threshold_k else None
) # 0 → None = no limit
for upload in files:
try:
filename = get_safe_filename(upload)
# Size limit is a hard safety cap and is enforced even when token
# threshold checks are skipped via SKIP_USERFILE_THRESHOLD settings.
if is_upload_too_large(upload, USER_FILE_MAX_UPLOAD_SIZE_BYTES):
# Size limit is a hard safety cap.
if max_upload_size_bytes is not None and is_upload_too_large(
upload, max_upload_size_bytes
):
results.rejected.append(
RejectedFile(
filename=filename,
reason=f"Exceeds {USER_FILE_MAX_UPLOAD_SIZE_MB} MB file size limit",
reason=f"Exceeds {max_upload_size_mb} MB file size limit",
)
)
continue
@@ -224,11 +216,11 @@ def categorize_uploaded_files(
)
continue
if not skip_threshold and token_count > FILE_TOKEN_COUNT_THRESHOLD:
if token_threshold is not None and token_count > token_threshold:
results.rejected.append(
RejectedFile(
filename=filename,
reason=f"Exceeds {FILE_TOKEN_COUNT_THRESHOLD} token limit",
reason=f"Exceeds {token_threshold_k}K token limit",
)
)
else:
@@ -269,12 +261,14 @@ def categorize_uploaded_files(
)
continue
token_count = len(tokenizer.encode(text_content))
if not skip_threshold and token_count > FILE_TOKEN_COUNT_THRESHOLD:
token_count = count_tokens(
text_content, tokenizer, token_limit=token_threshold
)
if token_threshold is not None and token_count > token_threshold:
results.rejected.append(
RejectedFile(
filename=filename,
reason=f"Exceeds {FILE_TOKEN_COUNT_THRESHOLD} token limit",
reason=f"Exceeds {token_threshold_k}K token limit",
)
)
else:

View File

@@ -1524,6 +1524,7 @@ def get_bifrost_available_models(
display_name=model_name,
max_input_tokens=model.get("context_length"),
supports_image_input=infer_vision_support(model_id),
supports_reasoning=is_reasoning_model(model_id, model_name),
)
)
except Exception as e:

View File

@@ -463,3 +463,4 @@ class BifrostFinalModelResponse(BaseModel):
display_name: str # Human-readable name from Bifrost API
max_input_tokens: int | None
supports_image_input: bool
supports_reasoning: bool

View File

@@ -12,7 +12,6 @@ stale, which is fine for monitoring dashboards.
import json
import threading
import time
from collections.abc import Callable
from datetime import datetime
from datetime import timezone
from typing import Any
@@ -104,25 +103,23 @@ class _CachedCollector(Collector):
class QueueDepthCollector(_CachedCollector):
"""Reads Celery queue lengths from the broker Redis on each scrape.
Uses a Redis client factory (callable) rather than a stored client
reference so the connection is always fresh from Celery's pool.
"""
"""Reads Celery queue lengths from the broker Redis on each scrape."""
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
super().__init__(cache_ttl)
self._get_redis: Callable[[], Redis] | None = None
self._celery_app: Any | None = None
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
"""Set a callable that returns a broker Redis client on demand."""
self._get_redis = factory
def set_celery_app(self, app: Any) -> None:
"""Set the Celery app for broker Redis access."""
self._celery_app = app
def _collect_fresh(self) -> list[GaugeMetricFamily]:
if self._get_redis is None:
if self._celery_app is None:
return []
redis_client = self._get_redis()
from onyx.background.celery.celery_redis import celery_get_broker_client
redis_client = celery_get_broker_client(self._celery_app)
depth = GaugeMetricFamily(
"onyx_queue_depth",
@@ -404,17 +401,19 @@ class RedisHealthCollector(_CachedCollector):
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
super().__init__(cache_ttl)
self._get_redis: Callable[[], Redis] | None = None
self._celery_app: Any | None = None
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
"""Set a callable that returns a broker Redis client on demand."""
self._get_redis = factory
def set_celery_app(self, app: Any) -> None:
"""Set the Celery app for broker Redis access."""
self._celery_app = app
def _collect_fresh(self) -> list[GaugeMetricFamily]:
if self._get_redis is None:
if self._celery_app is None:
return []
redis_client = self._get_redis()
from onyx.background.celery.celery_redis import celery_get_broker_client
redis_client = celery_get_broker_client(self._celery_app)
memory_used = GaugeMetricFamily(
"onyx_redis_memory_used_bytes",

View File

@@ -3,12 +3,8 @@
Called once by the monitoring celery worker after Redis and DB are ready.
"""
from collections.abc import Callable
from typing import Any
from celery import Celery
from prometheus_client.registry import REGISTRY
from redis import Redis
from onyx.server.metrics.indexing_pipeline import ConnectorHealthCollector
from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
@@ -21,7 +17,7 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
# Module-level singletons — these are lightweight objects (no connections or DB
# state) until configure() / set_redis_factory() is called. Keeping them at
# state) until configure() / set_celery_app() is called. Keeping them at
# module level ensures they survive the lifetime of the worker process and are
# only registered with the Prometheus registry once.
_queue_collector = QueueDepthCollector()
@@ -32,72 +28,15 @@ _worker_health_collector = WorkerHealthCollector()
_heartbeat_monitor: WorkerHeartbeatMonitor | None = None
def _make_broker_redis_factory(celery_app: Celery) -> Callable[[], Redis]:
"""Create a factory that returns a cached broker Redis client.
Reuses a single connection across scrapes to avoid leaking connections.
Reconnects automatically if the cached connection becomes stale.
"""
_cached_client: list[Redis | None] = [None]
# Keep a reference to the Kombu Connection so we can close it on
# reconnect (the raw Redis client outlives the Kombu wrapper).
_cached_kombu_conn: list[Any] = [None]
def _close_client(client: Redis) -> None:
"""Best-effort close of a Redis client."""
try:
client.close()
except Exception:
logger.debug("Failed to close stale Redis client", exc_info=True)
def _close_kombu_conn() -> None:
"""Best-effort close of the cached Kombu Connection."""
conn = _cached_kombu_conn[0]
if conn is not None:
try:
conn.close()
except Exception:
logger.debug("Failed to close Kombu connection", exc_info=True)
_cached_kombu_conn[0] = None
def _get_broker_redis() -> Redis:
client = _cached_client[0]
if client is not None:
try:
client.ping()
return client
except Exception:
logger.debug("Cached Redis client stale, reconnecting")
_close_client(client)
_cached_client[0] = None
_close_kombu_conn()
# Get a fresh Redis client from the broker connection.
# We hold this client long-term (cached above) rather than using a
# context manager, because we need it to persist across scrapes.
# The caching logic above ensures we only ever hold one connection,
# and we close it explicitly on reconnect.
conn = celery_app.broker_connection()
# kombu's Channel exposes .client at runtime (the underlying Redis
# client) but the type stubs don't declare it.
new_client: Redis = conn.channel().client # type: ignore[attr-defined]
_cached_client[0] = new_client
_cached_kombu_conn[0] = conn
return new_client
return _get_broker_redis
def setup_indexing_pipeline_metrics(celery_app: Celery) -> None:
"""Register all indexing pipeline collectors with the default registry.
Args:
celery_app: The Celery application instance. Used to obtain a fresh
celery_app: The Celery application instance. Used to obtain a
broker Redis client on each scrape for queue depth metrics.
"""
redis_factory = _make_broker_redis_factory(celery_app)
_queue_collector.set_redis_factory(redis_factory)
_redis_health_collector.set_redis_factory(redis_factory)
_queue_collector.set_celery_app(celery_app)
_redis_health_collector.set_celery_app(celery_app)
# Start the heartbeat monitor daemon thread — uses a single persistent
# connection to receive worker-heartbeat events.

View File

@@ -28,6 +28,7 @@ from onyx.chat.chat_utils import extract_headers
from onyx.chat.models import ChatFullResponse
from onyx.chat.models import CreateChatSessionID
from onyx.chat.process_message import gather_stream_full
from onyx.chat.process_message import handle_multi_model_stream
from onyx.chat.process_message import handle_stream_message_objects
from onyx.chat.prompt_utils import get_default_base_system_prompt
from onyx.chat.stop_signal_checker import set_fence
@@ -46,6 +47,7 @@ from onyx.db.chat import get_chat_messages_by_session
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_chat_sessions_by_user
from onyx.db.chat import set_as_latest_chat_message
from onyx.db.chat import set_preferred_response
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import update_chat_session
from onyx.db.chat_search import search_chat_sessions
@@ -60,6 +62,8 @@ from onyx.db.persona import get_persona_by_id
from onyx.db.usage import increment_usage
from onyx.db.usage import UsageType
from onyx.db.user_file import get_file_id_by_user_file_id
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.file_store.file_store import get_default_file_store
from onyx.llm.constants import LlmProviderNames
from onyx.llm.factory import get_default_llm
@@ -81,6 +85,7 @@ from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
from onyx.server.query_and_chat.models import MessageOrigin
from onyx.server.query_and_chat.models import RenameChatSessionResponse
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.models import SetPreferredResponseRequest
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
from onyx.server.query_and_chat.session_loading import (
@@ -570,6 +575,46 @@ def handle_send_chat_message(
if get_hashed_api_key_from_request(request) or get_hashed_pat_from_request(request):
chat_message_req.origin = MessageOrigin.API
# Multi-model streaming path: 2-3 LLMs in parallel (streaming only)
is_multi_model = (
chat_message_req.llm_overrides is not None
and len(chat_message_req.llm_overrides) > 1
)
if is_multi_model and chat_message_req.stream:
# Narrowed here; is_multi_model already checked llm_overrides is not None
llm_overrides = chat_message_req.llm_overrides or []
def multi_model_stream_generator() -> Generator[str, None, None]:
try:
with get_session_with_current_tenant() as db_session:
for obj in handle_multi_model_stream(
new_msg_req=chat_message_req,
user=user,
db_session=db_session,
llm_overrides=llm_overrides,
litellm_additional_headers=extract_headers(
request.headers, LITELLM_PASS_THROUGH_HEADERS
),
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
request.headers
),
mcp_headers=chat_message_req.mcp_headers,
):
yield get_json_line(obj.model_dump())
except Exception as e:
logger.exception("Error in multi-model streaming")
yield json.dumps({"error": str(e)})
return StreamingResponse(
multi_model_stream_generator(), media_type="text/event-stream"
)
if is_multi_model and not chat_message_req.stream:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
"Multi-model mode (llm_overrides with >1 entry) requires stream=True.",
)
# Non-streaming path: consume all packets and return complete response
if not chat_message_req.stream:
with get_session_with_current_tenant() as db_session:
@@ -660,6 +705,30 @@ def set_message_as_latest(
)
@router.put("/set-preferred-response")
def set_preferred_response_endpoint(
request_body: SetPreferredResponseRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
"""Set the preferred assistant response for a multi-model turn."""
try:
# Ownership check: get_chat_message raises ValueError if the message
# doesn't belong to this user, preventing cross-user mutation.
get_chat_message(
chat_message_id=request_body.user_message_id,
user_id=user.id if user else None,
db_session=db_session,
)
set_preferred_response(
db_session=db_session,
user_message_id=request_body.user_message_id,
preferred_assistant_message_id=request_body.preferred_response_id,
)
except ValueError as e:
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
@router.post("/create-chat-message-feedback")
def create_chat_feedback(
feedback: ChatFeedbackRequest,

View File

@@ -2,11 +2,25 @@ from pydantic import BaseModel
class Placement(BaseModel):
# Which iterative block in the UI is this part of, these are ordered and smaller ones happened first
"""Coordinates that identify where a streaming packet belongs in the UI.
The frontend uses these fields to route each packet to the correct turn,
tool tab, agent sub-turn, and (in multi-model mode) response column.
Attributes:
turn_index: Monotonically increasing index of the iterative reasoning block
(e.g. tool call round) within this chat message. Lower values happened first.
tab_index: Disambiguates parallel tool calls within the same turn so each
tool's output can be displayed in its own tab.
sub_turn_index: Nesting level for tools that invoke other tools. ``None`` for
top-level packets; an integer for tool-within-tool output.
model_index: Which model this packet belongs to. ``0`` for single-model
responses; ``0``, ``1``, or ``2`` for multi-model comparison. ``None``
for pre-LLM setup packets (e.g. message ID info) that are yielded
before any Emitter runs.
"""
turn_index: int
# For parallel tool calls to preserve order of execution
tab_index: int = 0
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
sub_turn_index: int | None = None
# For multi-model streaming: identifies which model (0, 1, 2) this packet belongs to.
model_index: int | None = None

View File

@@ -9,7 +9,9 @@ from onyx import __version__ as onyx_version
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.auth.users import is_user_admin
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
from onyx.configs.constants import KV_REINDEX_KEY
from onyx.configs.constants import NotificationType
from onyx.db.engine.sql_engine import get_session
@@ -17,10 +19,16 @@ from onyx.db.models import User
from onyx.db.notification import dismiss_all_notifications
from onyx.db.notification import get_notifications
from onyx.db.notification import update_notification_last_shown
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.utils import HOOKS_AVAILABLE
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.features.build.utils import is_onyx_craft_enabled
from onyx.server.settings.models import (
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB,
)
from onyx.server.settings.models import DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
from onyx.server.settings.models import Notification
from onyx.server.settings.models import Settings
from onyx.server.settings.models import UserSettings
@@ -41,6 +49,15 @@ basic_router = APIRouter(prefix="/settings")
def admin_put_settings(
settings: Settings, _: User = Depends(current_admin_user)
) -> None:
if (
settings.user_file_max_upload_size_mb is not None
and settings.user_file_max_upload_size_mb > 0
and settings.user_file_max_upload_size_mb > MAX_ALLOWED_UPLOAD_SIZE_MB
):
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
f"File upload size limit cannot exceed {MAX_ALLOWED_UPLOAD_SIZE_MB} MB",
)
store_settings(settings)
@@ -83,6 +100,16 @@ def fetch_settings(
vector_db_enabled=not DISABLE_VECTOR_DB,
hooks_enabled=HOOKS_AVAILABLE,
version=onyx_version,
max_allowed_upload_size_mb=MAX_ALLOWED_UPLOAD_SIZE_MB,
default_user_file_max_upload_size_mb=min(
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB,
MAX_ALLOWED_UPLOAD_SIZE_MB,
),
default_file_token_count_threshold_k=(
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
if DISABLE_VECTOR_DB
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
),
)

View File

@@ -2,12 +2,19 @@ from datetime import datetime
from enum import Enum
from pydantic import BaseModel
from pydantic import Field
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
from onyx.configs.constants import NotificationType
from onyx.configs.constants import QueryHistoryType
from onyx.db.models import Notification as NotificationDBModel
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB = 200
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB = 10000
class PageType(str, Enum):
CHAT = "chat"
@@ -78,7 +85,12 @@ class Settings(BaseModel):
# User Knowledge settings
user_knowledge_enabled: bool | None = True
user_file_max_upload_size_mb: int | None = None
user_file_max_upload_size_mb: int | None = Field(
default=DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB, ge=0
)
file_token_count_threshold_k: int | None = Field(
default=None, ge=0 # thousands of tokens; None = context-aware default
)
# Connector settings
show_extra_connectors: bool | None = True
@@ -108,3 +120,14 @@ class UserSettings(Settings):
hooks_enabled: bool = False
# Application version, read from the ONYX_VERSION env var at startup.
version: str | None = None
# Hard ceiling for user_file_max_upload_size_mb, derived from env var.
max_allowed_upload_size_mb: int = MAX_ALLOWED_UPLOAD_SIZE_MB
# Factory defaults so the frontend can show a "restore default" button.
default_user_file_max_upload_size_mb: int = DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
default_file_token_count_threshold_k: int = Field(
default_factory=lambda: (
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
if DISABLE_VECTOR_DB
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
)
)

View File

@@ -1,13 +1,19 @@
from onyx.cache.factory import get_cache_backend
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.configs.app_configs import DISABLE_USER_KNOWLEDGE
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
from onyx.configs.app_configs import SHOW_EXTRA_CONNECTORS
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.configs.constants import KV_SETTINGS_KEY
from onyx.configs.constants import OnyxRedisLocks
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.settings.models import (
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB,
)
from onyx.server.settings.models import DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
from onyx.server.settings.models import Settings
from onyx.utils.logger import setup_logger
@@ -51,9 +57,36 @@ def load_settings() -> Settings:
if DISABLE_USER_KNOWLEDGE:
settings.user_knowledge_enabled = False
settings.user_file_max_upload_size_mb = USER_FILE_MAX_UPLOAD_SIZE_MB
settings.show_extra_connectors = SHOW_EXTRA_CONNECTORS
settings.opensearch_indexing_enabled = ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
# Resolve context-aware defaults for token threshold.
# None = admin hasn't set a value yet → use context-aware default.
# 0 = admin explicitly chose "no limit" → preserve as-is.
if settings.file_token_count_threshold_k is None:
settings.file_token_count_threshold_k = (
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
if DISABLE_VECTOR_DB
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
)
# Upload size: 0 and None are treated as "unset" (not "no limit") →
# fall back to min(configured default, hard ceiling).
if not settings.user_file_max_upload_size_mb:
settings.user_file_max_upload_size_mb = min(
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB,
MAX_ALLOWED_UPLOAD_SIZE_MB,
)
# Clamp to env ceiling so stale KV values are capped even if the
# operator lowered MAX_ALLOWED_UPLOAD_SIZE_MB after a higher value
# was already saved (api.py only guards new writes).
if (
settings.user_file_max_upload_size_mb > 0
and settings.user_file_max_upload_size_mb > MAX_ALLOWED_UPLOAD_SIZE_MB
):
settings.user_file_max_upload_size_mb = MAX_ALLOWED_UPLOAD_SIZE_MB
return settings

View File

@@ -1,3 +1,4 @@
import queue
import time
from collections.abc import Callable
from typing import Any
@@ -708,7 +709,6 @@ def run_research_agent_calls(
if __name__ == "__main__":
from queue import Queue
from uuid import uuid4
from onyx.chat.chat_state import ChatStateContainer
@@ -744,8 +744,8 @@ if __name__ == "__main__":
if user is None:
raise ValueError("No users found in database. Please create a user first.")
bus: Queue[Packet] = Queue()
emitter = Emitter(bus)
emitter_queue: queue.Queue = queue.Queue()
emitter = Emitter(merged_queue=emitter_queue)
state_container = ChatStateContainer()
tool_dict = construct_tools(
@@ -792,4 +792,4 @@ if __name__ == "__main__":
print(result.intermediate_report)
print("=" * 80)
print(f"Citations: {result.citation_mapping}")
print(f"Total packets emitted: {bus.qsize()}")
print(f"Total packets emitted: {emitter_queue.qsize()}")

View File

@@ -1,5 +1,6 @@
import csv
import json
import queue
import uuid
from io import BytesIO
from io import StringIO
@@ -11,7 +12,6 @@ import requests
from requests import JSONDecodeError
from onyx.chat.emitter import Emitter
from onyx.chat.emitter import get_default_emitter
from onyx.configs.constants import FileOrigin
from onyx.file_store.file_store import get_default_file_store
from onyx.server.query_and_chat.placement import Placement
@@ -296,9 +296,9 @@ def build_custom_tools_from_openapi_schema_and_headers(
url = openapi_to_url(openapi_schema)
method_specs = openapi_to_method_specs(openapi_schema)
# Use default emitter if none provided
# Use a discard emitter if none provided (packets go nowhere)
if emitter is None:
emitter = get_default_emitter()
emitter = Emitter(merged_queue=queue.Queue())
return [
CustomTool(
@@ -367,7 +367,7 @@ if __name__ == "__main__":
tools = build_custom_tools_from_openapi_schema_and_headers(
tool_id=0, # dummy tool id
openapi_schema=openapi_schema,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
dynamic_schema_info=None,
)

View File

@@ -187,7 +187,7 @@ coloredlogs==15.0.1
# via onnxruntime
courlan==1.3.2
# via trafilatura
cryptography==46.0.5
cryptography==46.0.6
# via
# authlib
# google-auth
@@ -449,7 +449,7 @@ kombu==5.5.4
# via celery
kubernetes==31.0.0
# via onyx
langchain-core==1.2.11
langchain-core==1.2.22
# via onyx
langdetect==1.0.9
# via unstructured
@@ -735,7 +735,7 @@ pyee==13.0.0
# via playwright
pygithub==2.5.0
# via onyx
pygments==2.19.2
pygments==2.20.0
# via rich
pyjwt==2.12.0
# via

View File

@@ -97,7 +97,7 @@ comm==0.2.3
# via ipykernel
contourpy==1.3.3
# via matplotlib
cryptography==46.0.5
cryptography==46.0.6
# via
# google-auth
# pyjwt
@@ -263,7 +263,7 @@ oauthlib==3.2.2
# via
# kubernetes
# requests-oauthlib
onyx-devtools==0.7.1
onyx-devtools==0.7.2
# via onyx
openai==2.14.0
# via
@@ -349,7 +349,7 @@ pydantic-core==2.33.2
# via pydantic
pydantic-settings==2.12.0
# via mcp
pygments==2.19.2
pygments==2.20.0
# via
# ipython
# ipython-pygments-lexers

View File

@@ -76,7 +76,7 @@ colorama==0.4.6 ; sys_platform == 'win32'
# via
# click
# tqdm
cryptography==46.0.5
cryptography==46.0.6
# via
# google-auth
# pyjwt

View File

@@ -92,7 +92,7 @@ colorama==0.4.6 ; sys_platform == 'win32'
# via
# click
# tqdm
cryptography==46.0.5
cryptography==46.0.6
# via
# google-auth
# pyjwt

View File

@@ -191,25 +191,6 @@ IGNORED_SYNCING_TENANT_LIST = (
else None
)
# Global flag to skip userfile threshold for all users/tenants
SKIP_USERFILE_THRESHOLD = (
os.environ.get("SKIP_USERFILE_THRESHOLD", "").lower() == "true"
)
# Comma-separated list of specific tenant IDs to skip threshold (multi-tenant only)
SKIP_USERFILE_THRESHOLD_TENANT_IDS = os.environ.get(
"SKIP_USERFILE_THRESHOLD_TENANT_IDS"
)
SKIP_USERFILE_THRESHOLD_TENANT_LIST = (
[
tenant.strip()
for tenant in SKIP_USERFILE_THRESHOLD_TENANT_IDS.split(",")
if tenant.strip()
]
if SKIP_USERFILE_THRESHOLD_TENANT_IDS
else None
)
ENVIRONMENT = os.environ.get("ENVIRONMENT") or "not_explicitly_set"

View File

@@ -1,4 +1,6 @@
import time
from datetime import datetime
from datetime import timezone
import pytest
@@ -17,6 +19,10 @@ PRIVATE_CHANNEL_USERS = [
"test_user_2@onyx-test.com",
]
# Predates any test workspace messages, so the result set should match
# the "no start time" case while exercising the oldest= parameter.
OLDEST_TS_2016 = datetime(2016, 1, 1, tzinfo=timezone.utc).timestamp()
pytestmark = pytest.mark.usefixtures("enable_ee")
@@ -105,15 +111,17 @@ def test_load_from_checkpoint_access__private_channel(
],
indirect=True,
)
@pytest.mark.parametrize("start_ts", [None, OLDEST_TS_2016])
def test_slim_documents_access__public_channel(
slack_connector: SlackConnector,
start_ts: float | None,
) -> None:
"""Test that retrieve_all_slim_docs_perm_sync returns correct access information for slim documents."""
if not slack_connector.client:
raise RuntimeError("Web client must be defined")
slim_docs_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
start=0.0,
start=start_ts,
end=time.time(),
)
@@ -149,7 +157,7 @@ def test_slim_documents_access__private_channel(
raise RuntimeError("Web client must be defined")
slim_docs_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
start=0.0,
start=None,
end=time.time(),
)

View File

@@ -27,11 +27,13 @@ def create_placement(
turn_index: int,
tab_index: int = 0,
sub_turn_index: int | None = None,
model_index: int | None = 0,
) -> Placement:
return Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
model_index=model_index,
)

View File

@@ -129,6 +129,10 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
return_value=mock_app,
),
patch(_PATCH_QUEUE_DEPTH, return_value=0),
patch(
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
return_value=MagicMock(),
),
):
yield

View File

@@ -88,10 +88,22 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
the actual task instance. We patch ``app`` on that instance's class
(a unique Celery-generated Task subclass) so the mock is scoped to this
task only.
Also patches ``celery_get_broker_client`` so the mock app doesn't need
a real broker URL.
"""
task_instance = task.run.__self__
with patch.object(
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
with (
patch.object(
type(task_instance),
"app",
new_callable=PropertyMock,
return_value=mock_app,
),
patch(
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
return_value=MagicMock(),
),
):
yield

View File

@@ -1,7 +1,7 @@
"""
External dependency unit tests for UserFileIndexingAdapter metadata writing.
Validates that build_metadata_aware_chunks produces DocMetadataAwareIndexChunk
Validates that prepare_enrichment produces DocMetadataAwareIndexChunk
objects with both `user_project` and `personas` fields populated correctly
based on actual DB associations.
@@ -127,7 +127,7 @@ def _make_index_chunk(user_file: UserFile) -> IndexChunk:
class TestAdapterWritesBothMetadataFields:
"""build_metadata_aware_chunks must populate user_project AND personas."""
"""prepare_enrichment must populate user_project AND personas."""
@patch(
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
@@ -153,15 +153,13 @@ class TestAdapterWritesBothMetadataFields:
doc = chunk.source_document
context = DocumentBatchPrepareContext(updatable_docs=[doc], id_to_boost_map={})
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
enricher = adapter.prepare_enrichment(
context=context,
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
assert len(result.chunks) == 1
aware_chunk = result.chunks[0]
assert persona.id in aware_chunk.personas
assert aware_chunk.user_project == []
@@ -190,15 +188,13 @@ class TestAdapterWritesBothMetadataFields:
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
enricher = adapter.prepare_enrichment(
context=context,
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
assert len(result.chunks) == 1
aware_chunk = result.chunks[0]
assert project.id in aware_chunk.user_project
assert aware_chunk.personas == []
@@ -229,14 +225,13 @@ class TestAdapterWritesBothMetadataFields:
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
enricher = adapter.prepare_enrichment(
context=context,
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
aware_chunk = result.chunks[0]
assert persona.id in aware_chunk.personas
assert project.id in aware_chunk.user_project
@@ -261,14 +256,13 @@ class TestAdapterWritesBothMetadataFields:
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
enricher = adapter.prepare_enrichment(
context=context,
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
aware_chunk = result.chunks[0]
assert aware_chunk.personas == []
assert aware_chunk.user_project == []
@@ -300,12 +294,11 @@ class TestAdapterWritesBothMetadataFields:
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
enricher = adapter.prepare_enrichment(
context=context,
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
aware_chunk = result.chunks[0]
assert set(aware_chunk.personas) == {persona_a.id, persona_b.id}

View File

@@ -90,8 +90,17 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
task only.
"""
task_instance = task.run.__self__
with patch.object(
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
with (
patch.object(
type(task_instance),
"app",
new_callable=PropertyMock,
return_value=mock_app,
),
patch(
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
return_value=MagicMock(),
),
):
yield

View File

@@ -103,6 +103,11 @@ _EXPECTED_CONFLUENCE_GROUPS = [
user_emails={"oauth@onyx.app"},
gives_anyone_access=False,
),
ExternalUserGroupSet(
id="no yuhong allowed",
user_emails={"hagen@danswer.ai", "pablo@onyx.app", "chris@onyx.app"},
gives_anyone_access=False,
),
]

View File

@@ -6,6 +6,7 @@ These tests assume Vespa and OpenSearch are running.
import time
import uuid
from collections.abc import Generator
from collections.abc import Iterator
import httpx
import pytest
@@ -21,6 +22,7 @@ from onyx.document_index.opensearch.opensearch_document_index import (
)
from onyx.document_index.vespa.index import VespaIndex
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
from onyx.indexing.models import DocMetadataAwareIndexChunk
from tests.external_dependency_unit.constants import TEST_TENANT_ID
from tests.external_dependency_unit.document_index.conftest import EMBEDDING_DIM
from tests.external_dependency_unit.document_index.conftest import make_chunk
@@ -201,3 +203,25 @@ class TestDocumentIndexNew:
assert len(result_map) == 2
assert result_map[existing_doc] is True
assert result_map[new_doc] is False
def test_index_accepts_generator(
self,
document_indices: list[DocumentIndexNew],
tenant_context: None, # noqa: ARG002
) -> None:
"""index() accepts a generator (any iterable), not just a list."""
for document_index in document_indices:
doc_id = f"test_gen_{uuid.uuid4().hex[:8]}"
metadata = make_indexing_metadata([doc_id], old_counts=[0], new_counts=[3])
def chunk_gen() -> Iterator[DocMetadataAwareIndexChunk]:
for i in range(3):
yield make_chunk(doc_id, chunk_id=i)
results = document_index.index(
chunks=chunk_gen(), indexing_metadata=metadata
)
assert len(results) == 1
assert results[0].document_id == doc_id
assert results[0].already_existed is False

View File

@@ -5,6 +5,7 @@ These tests assume Vespa and OpenSearch are running.
import time
from collections.abc import Generator
from collections.abc import Iterator
import pytest
@@ -166,3 +167,29 @@ class TestDocumentIndexOld:
batch_retrieval=True,
)
assert len(inference_chunks) == 0
def test_index_accepts_generator(
self,
document_indices: list[DocumentIndex],
tenant_context: None, # noqa: ARG002
) -> None:
"""index() accepts a generator (any iterable), not just a list."""
for document_index in document_indices:
def chunk_gen() -> Iterator[DocMetadataAwareIndexChunk]:
for i in range(3):
yield make_chunk("test_doc_gen", chunk_id=i)
index_batch_params = IndexBatchParams(
doc_id_to_previous_chunk_cnt={"test_doc_gen": 0},
doc_id_to_new_chunk_cnt={"test_doc_gen": 3},
tenant_id=get_current_tenant_id(),
large_chunks_enabled=False,
)
results = document_index.index(chunk_gen(), index_batch_params)
assert len(results) == 1
record = results.pop()
assert record.document_id == "test_doc_gen"
assert record.already_existed is False

View File

@@ -13,6 +13,7 @@ This test:
All external HTTP calls are mocked, but Postgres and Redis are running.
"""
import queue
from typing import Any
from unittest.mock import patch
from uuid import uuid4
@@ -20,7 +21,7 @@ from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from onyx.chat.emitter import get_default_emitter
from onyx.chat.emitter import Emitter
from onyx.db.enums import MCPAuthenticationPerformer
from onyx.db.enums import MCPAuthenticationType
from onyx.db.enums import MCPTransport
@@ -137,7 +138,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=search_tool_config,
@@ -200,7 +201,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=SearchToolConfig(),
@@ -275,7 +276,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=SearchToolConfig(),
@@ -350,7 +351,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=SearchToolConfig(),
@@ -458,7 +459,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=SearchToolConfig(),
@@ -541,7 +542,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=SearchToolConfig(),

View File

@@ -8,6 +8,7 @@ Tests the priority logic for OAuth tokens when constructing custom tools:
All external HTTP calls are mocked, but Postgres and Redis are running.
"""
import queue
from typing import Any
from unittest.mock import Mock
from unittest.mock import patch
@@ -16,7 +17,7 @@ from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from onyx.chat.emitter import get_default_emitter
from onyx.chat.emitter import Emitter
from onyx.db.models import OAuthAccount
from onyx.db.models import OAuthConfig
from onyx.db.models import Persona
@@ -174,7 +175,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=search_tool_config,
@@ -232,7 +233,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
)
@@ -284,7 +285,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
)
@@ -345,7 +346,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
)
@@ -416,7 +417,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
)
@@ -483,7 +484,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
)
@@ -536,7 +537,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
)

View File

@@ -0,0 +1,87 @@
"""Tests for celery_get_broker_client singleton."""
from collections.abc import Iterator
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.background.celery import celery_redis
@pytest.fixture(autouse=True)
def reset_singleton() -> Iterator[None]:
"""Reset the module-level singleton between tests."""
celery_redis._broker_client = None
celery_redis._broker_url = None
yield
celery_redis._broker_client = None
celery_redis._broker_url = None
def _make_mock_app(broker_url: str = "redis://localhost:6379/15") -> MagicMock:
app = MagicMock()
app.conf.broker_url = broker_url
return app
class TestCeleryGetBrokerClient:
@patch("onyx.background.celery.celery_redis.Redis")
def test_creates_client_on_first_call(self, mock_redis_cls: MagicMock) -> None:
mock_client = MagicMock()
mock_redis_cls.from_url.return_value = mock_client
app = _make_mock_app()
result = celery_redis.celery_get_broker_client(app)
assert result is mock_client
call_args = mock_redis_cls.from_url.call_args
assert call_args[0][0] == "redis://localhost:6379/15"
assert call_args[1]["decode_responses"] is False
assert call_args[1]["socket_keepalive"] is True
assert call_args[1]["retry_on_timeout"] is True
@patch("onyx.background.celery.celery_redis.Redis")
def test_reuses_cached_client(self, mock_redis_cls: MagicMock) -> None:
mock_client = MagicMock()
mock_client.ping.return_value = True
mock_redis_cls.from_url.return_value = mock_client
app = _make_mock_app()
client1 = celery_redis.celery_get_broker_client(app)
client2 = celery_redis.celery_get_broker_client(app)
assert client1 is client2
# from_url called only once
assert mock_redis_cls.from_url.call_count == 1
@patch("onyx.background.celery.celery_redis.Redis")
def test_reconnects_on_ping_failure(self, mock_redis_cls: MagicMock) -> None:
stale_client = MagicMock()
stale_client.ping.side_effect = ConnectionError("disconnected")
fresh_client = MagicMock()
fresh_client.ping.return_value = True
mock_redis_cls.from_url.side_effect = [stale_client, fresh_client]
app = _make_mock_app()
# First call creates stale_client
client1 = celery_redis.celery_get_broker_client(app)
assert client1 is stale_client
# Second call: ping fails, creates fresh_client
client2 = celery_redis.celery_get_broker_client(app)
assert client2 is fresh_client
assert mock_redis_cls.from_url.call_count == 2
@patch("onyx.background.celery.celery_redis.Redis")
def test_uses_broker_url_from_app_config(self, mock_redis_cls: MagicMock) -> None:
mock_redis_cls.from_url.return_value = MagicMock()
app = _make_mock_app("redis://custom-host:6380/3")
celery_redis.celery_get_broker_client(app)
call_args = mock_redis_cls.from_url.call_args
assert call_args[0][0] == "redis://custom-host:6380/3"

View File

@@ -0,0 +1,173 @@
"""Unit tests for the Emitter class.
All tests use the streaming mode (merged_queue required). Emitter has a single
code path — no standalone bus.
"""
import queue
from onyx.chat.emitter import Emitter
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import ReasoningStart
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _placement(
turn_index: int = 0,
tab_index: int = 0,
sub_turn_index: int | None = None,
) -> Placement:
return Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
)
def _packet(
turn_index: int = 0,
tab_index: int = 0,
sub_turn_index: int | None = None,
) -> Packet:
"""Build a minimal valid packet with an OverallStop payload."""
return Packet(
placement=_placement(turn_index, tab_index, sub_turn_index),
obj=OverallStop(stop_reason="test"),
)
def _make_emitter(model_idx: int = 0) -> tuple["Emitter", "queue.Queue"]:
"""Return (emitter, queue) wired together."""
mq: queue.Queue = queue.Queue()
return Emitter(merged_queue=mq, model_idx=model_idx), mq
# ---------------------------------------------------------------------------
# Queue routing
# ---------------------------------------------------------------------------
class TestEmitterQueueRouting:
def test_emit_lands_on_merged_queue(self) -> None:
emitter, mq = _make_emitter()
emitter.emit(_packet())
assert not mq.empty()
def test_queue_item_is_tuple_of_key_and_packet(self) -> None:
emitter, mq = _make_emitter(model_idx=1)
emitter.emit(_packet())
item = mq.get_nowait()
assert isinstance(item, tuple)
assert len(item) == 2
def test_multiple_packets_delivered_fifo(self) -> None:
emitter, mq = _make_emitter()
p1 = _packet(turn_index=0)
p2 = _packet(turn_index=1)
emitter.emit(p1)
emitter.emit(p2)
_, t1 = mq.get_nowait()
_, t2 = mq.get_nowait()
assert t1.placement.turn_index == 0
assert t2.placement.turn_index == 1
# ---------------------------------------------------------------------------
# model_index tagging
# ---------------------------------------------------------------------------
class TestEmitterModelIndexTagging:
def test_n1_default_model_idx_tags_model_index_zero(self) -> None:
"""N=1: default model_idx=0, so packet gets model_index=0."""
emitter, mq = _make_emitter(model_idx=0)
emitter.emit(_packet())
_key, tagged = mq.get_nowait()
assert tagged.placement.model_index == 0
def test_model_idx_one_tags_packet(self) -> None:
emitter, mq = _make_emitter(model_idx=1)
emitter.emit(_packet())
_key, tagged = mq.get_nowait()
assert tagged.placement.model_index == 1
def test_model_idx_two_tags_packet(self) -> None:
"""Boundary: third model in a 3-model run."""
emitter, mq = _make_emitter(model_idx=2)
emitter.emit(_packet())
_key, tagged = mq.get_nowait()
assert tagged.placement.model_index == 2
# ---------------------------------------------------------------------------
# Queue key
# ---------------------------------------------------------------------------
class TestEmitterQueueKey:
def test_key_equals_model_idx(self) -> None:
"""Drain loop uses the key to route packets; it must match model_idx."""
emitter, mq = _make_emitter(model_idx=2)
emitter.emit(_packet())
key, _ = mq.get_nowait()
assert key == 2
def test_n1_key_is_zero(self) -> None:
emitter, mq = _make_emitter(model_idx=0)
emitter.emit(_packet())
key, _ = mq.get_nowait()
assert key == 0
# ---------------------------------------------------------------------------
# Placement field preservation
# ---------------------------------------------------------------------------
class TestEmitterPlacementPreservation:
def test_turn_index_is_preserved(self) -> None:
emitter, mq = _make_emitter()
emitter.emit(_packet(turn_index=5))
_, tagged = mq.get_nowait()
assert tagged.placement.turn_index == 5
def test_tab_index_is_preserved(self) -> None:
emitter, mq = _make_emitter()
emitter.emit(_packet(tab_index=3))
_, tagged = mq.get_nowait()
assert tagged.placement.tab_index == 3
def test_sub_turn_index_is_preserved(self) -> None:
emitter, mq = _make_emitter()
emitter.emit(_packet(sub_turn_index=2))
_, tagged = mq.get_nowait()
assert tagged.placement.sub_turn_index == 2
def test_sub_turn_index_none_is_preserved(self) -> None:
emitter, mq = _make_emitter()
emitter.emit(_packet(sub_turn_index=None))
_, tagged = mq.get_nowait()
assert tagged.placement.sub_turn_index is None
def test_packet_obj_is_not_modified(self) -> None:
"""The payload object must survive tagging untouched."""
emitter, mq = _make_emitter()
original_obj = OverallStop(stop_reason="sentinel")
pkt = Packet(placement=_placement(), obj=original_obj)
emitter.emit(pkt)
_, tagged = mq.get_nowait()
assert tagged.obj is original_obj
def test_different_obj_types_are_handled(self) -> None:
"""Any valid PacketObj type passes through correctly."""
emitter, mq = _make_emitter()
pkt = Packet(placement=_placement(), obj=ReasoningStart())
emitter.emit(pkt)
_, tagged = mq.get_nowait()
assert isinstance(tagged.obj, ReasoningStart)

View File

@@ -0,0 +1,768 @@
"""Unit tests for multi-model streaming validation and DB helpers.
These are pure unit tests — no real database or LLM calls required.
The validation logic in handle_multi_model_stream fires before any external
calls, so we can trigger it with lightweight mocks.
"""
import time
from collections.abc import Generator
from typing import Any
from typing import cast
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
import pytest
from onyx.chat.models import StreamingError
from onyx.configs.constants import MessageType
from onyx.db.chat import set_preferred_response
from onyx.llm.override_models import LLMOverride
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import ReasoningStart
from onyx.utils.variable_functionality import global_version
@pytest.fixture(autouse=True)
def _restore_ee_version() -> Generator[None, None, None]:
"""Reset EE global state after each test.
Importing onyx.chat.process_message triggers set_is_ee_based_on_env_variable()
(via the celery client import chain). Without this fixture, the EE flag stays
True for the rest of the session and breaks unrelated tests that mock Confluence
or other connectors and assume EE is disabled.
"""
original = global_version._is_ee
yield
global_version._is_ee = original
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_request(**kwargs: Any) -> SendMessageRequest:
defaults: dict[str, Any] = {
"message": "hello",
"chat_session_id": uuid4(),
}
defaults.update(kwargs)
return SendMessageRequest(**defaults)
def _make_override(provider: str = "openai", version: str = "gpt-4") -> LLMOverride:
return LLMOverride(model_provider=provider, model_version=version)
def _first_from_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> Any:
"""Return the first item yielded by handle_multi_model_stream."""
from onyx.chat.process_message import handle_multi_model_stream
user = MagicMock()
user.is_anonymous = False
user.email = "test@example.com"
db = MagicMock()
gen = handle_multi_model_stream(req, user, db, overrides)
return next(gen)
# ---------------------------------------------------------------------------
# handle_multi_model_stream — validation
# ---------------------------------------------------------------------------
class TestRunMultiModelStreamValidation:
def test_single_override_yields_error(self) -> None:
"""Exactly 1 override is not multi-model — yields StreamingError."""
req = _make_request()
result = _first_from_stream(req, [_make_override()])
assert isinstance(result, StreamingError)
assert "2-3" in result.error
def test_four_overrides_yields_error(self) -> None:
"""4 overrides exceeds maximum — yields StreamingError."""
req = _make_request()
result = _first_from_stream(
req,
[
_make_override("openai", "gpt-4"),
_make_override("anthropic", "claude-3"),
_make_override("google", "gemini-pro"),
_make_override("cohere", "command-r"),
],
)
assert isinstance(result, StreamingError)
assert "2-3" in result.error
def test_zero_overrides_yields_error(self) -> None:
"""Empty override list yields StreamingError."""
req = _make_request()
result = _first_from_stream(req, [])
assert isinstance(result, StreamingError)
assert "2-3" in result.error
def test_deep_research_yields_error(self) -> None:
"""deep_research=True is incompatible with multi-model — yields StreamingError."""
req = _make_request(deep_research=True)
result = _first_from_stream(
req, [_make_override(), _make_override("anthropic", "claude-3")]
)
assert isinstance(result, StreamingError)
assert "not supported" in result.error
def test_exactly_two_overrides_is_minimum(self) -> None:
"""Boundary: 1 override yields error, 2 overrides passes validation."""
req = _make_request()
# 1 override must yield a StreamingError
result = _first_from_stream(req, [_make_override()])
assert isinstance(
result, StreamingError
), "1 override should yield StreamingError"
# 2 overrides must NOT yield a validation StreamingError (may raise later due to
# missing session, that's OK — validation itself passed)
try:
result2 = _first_from_stream(
req, [_make_override(), _make_override("anthropic", "claude-3")]
)
if isinstance(result2, StreamingError) and "2-3" in result2.error:
pytest.fail(
f"2 overrides should pass validation, got StreamingError: {result2.error}"
)
except Exception:
pass # Any non-validation error means validation passed
# ---------------------------------------------------------------------------
# set_preferred_response — validation (mocked db)
# ---------------------------------------------------------------------------
class TestSetPreferredResponseValidation:
def test_user_message_not_found(self) -> None:
db = MagicMock()
db.get.return_value = None
with pytest.raises(ValueError, match="not found"):
set_preferred_response(
db, user_message_id=999, preferred_assistant_message_id=1
)
def test_wrong_message_type(self) -> None:
"""Cannot set preferred response on a non-USER message."""
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.ASSISTANT # wrong type
db.get.return_value = user_msg
with pytest.raises(ValueError, match="not a user message"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_assistant_message_not_found(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
# First call returns user_msg, second call (for assistant) returns None
db.get.side_effect = [user_msg, None]
with pytest.raises(ValueError, match="not found"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_assistant_not_child_of_user(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
assistant_msg = MagicMock()
assistant_msg.parent_message_id = 999 # different parent
db.get.side_effect = [user_msg, assistant_msg]
with pytest.raises(ValueError, match="not a child"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_valid_call_sets_preferred_response_id(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
assistant_msg = MagicMock()
assistant_msg.parent_message_id = 1 # correct parent
db.get.side_effect = [user_msg, assistant_msg]
set_preferred_response(db, user_message_id=1, preferred_assistant_message_id=2)
assert user_msg.preferred_response_id == 2
assert user_msg.latest_child_message_id == 2
# ---------------------------------------------------------------------------
# LLMOverride — display_name field
# ---------------------------------------------------------------------------
class TestLLMOverrideDisplayName:
def test_display_name_defaults_none(self) -> None:
override = LLMOverride(model_provider="openai", model_version="gpt-4")
assert override.display_name is None
def test_display_name_set(self) -> None:
override = LLMOverride(
model_provider="openai",
model_version="gpt-4",
display_name="GPT-4 Turbo",
)
assert override.display_name == "GPT-4 Turbo"
def test_display_name_serializes(self) -> None:
override = LLMOverride(
model_provider="anthropic",
model_version="claude-opus-4-6",
display_name="Claude Opus",
)
d = override.model_dump()
assert d["display_name"] == "Claude Opus"
# ---------------------------------------------------------------------------
# _run_models — drain loop behaviour
# ---------------------------------------------------------------------------
def _make_setup(n_models: int = 1) -> MagicMock:
"""Minimal ChatTurnSetup mock whose fields pass Pydantic validation in _run_model."""
setup = MagicMock()
setup.llms = [MagicMock() for _ in range(n_models)]
setup.model_display_names = [f"model-{i}" for i in range(n_models)]
setup.check_is_connected = MagicMock(return_value=True)
setup.reserved_messages = [MagicMock() for _ in range(n_models)]
setup.reserved_token_count = 100
# Fields consumed by SearchToolConfig / CustomToolConfig / FileReaderToolConfig
# constructors inside _run_model — must be typed correctly for Pydantic.
setup.new_msg_req.deep_research = False
setup.new_msg_req.internal_search_filters = None
setup.new_msg_req.allowed_tool_ids = None
setup.new_msg_req.include_citations = True
setup.search_params.project_id_filter = None
setup.search_params.persona_id_filter = None
setup.bypass_acl = False
setup.slack_context = None
setup.available_files.user_file_ids = []
setup.available_files.chat_file_ids = []
setup.forced_tool_id = None
setup.simple_chat_history = []
setup.chat_session.id = uuid4()
setup.user_message.id = None
setup.custom_tool_additional_headers = None
setup.mcp_headers = None
return setup
def _run_models_collect(setup: MagicMock) -> list:
"""Drive _run_models to completion and return all yielded items."""
from onyx.chat.process_message import _run_models
return list(_run_models(setup, MagicMock(), MagicMock()))
class TestRunModels:
"""Tests for the _run_models worker-thread drain loop.
All external dependencies (LLM, DB, tools) are patched out. Worker threads
still run but return immediately since run_llm_loop is mocked.
"""
def test_n1_overall_stop_from_llm_loop_passes_through(self) -> None:
"""OverallStop emitted by run_llm_loop is passed through the drain loop unchanged."""
def emit_stop(**kwargs: Any) -> None:
kwargs["emitter"].emit(
Packet(
placement=Placement(turn_index=0),
obj=OverallStop(stop_reason="complete"),
)
)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_stop),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=1))
stops = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
]
assert len(stops) == 1
stop_obj = stops[0].obj
assert isinstance(stop_obj, OverallStop)
assert stop_obj.stop_reason == "complete"
def test_n1_emitted_packet_has_model_index_zero(self) -> None:
"""Single-model path: model_index is 0 (Emitter defaults model_idx=0)."""
def emit_one(**kwargs: Any) -> None:
kwargs["emitter"].emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=1))
reasoning = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
]
assert len(reasoning) == 1
assert reasoning[0].placement.model_index == 0
def test_n2_each_model_packet_tagged_with_its_index(self) -> None:
"""Multi-model path: packets from model 0 get index=0, model 1 gets index=1."""
def emit_one(**kwargs: Any) -> None:
# _model_idx is set by _run_model based on position in setup.llms
emitter = kwargs["emitter"]
emitter.emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=2))
reasoning = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
]
assert len(reasoning) == 2
indices = {p.placement.model_index for p in reasoning}
assert indices == {0, 1}
def test_model_error_yields_streaming_error(self) -> None:
"""An exception inside a worker thread is surfaced as a StreamingError."""
def always_fail(**_kwargs: Any) -> None:
raise RuntimeError("intentional test failure")
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=1))
errors = [p for p in packets if isinstance(p, StreamingError)]
assert len(errors) == 1
assert errors[0].error_code == "MODEL_ERROR"
assert "intentional test failure" in errors[0].error
def test_one_model_error_does_not_stop_other_models(self) -> None:
"""A failing model yields StreamingError; the surviving model's packets still arrive."""
setup = _make_setup(n_models=2)
def fail_model_0_succeed_model_1(**kwargs: Any) -> None:
if kwargs["llm"] is setup.llms[0]:
raise RuntimeError("model 0 failed")
kwargs["emitter"].emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
with (
patch(
"onyx.chat.process_message.run_llm_loop",
side_effect=fail_model_0_succeed_model_1,
),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(setup)
errors = [p for p in packets if isinstance(p, StreamingError)]
assert len(errors) == 1
reasoning = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
]
assert len(reasoning) == 1
assert reasoning[0].placement.model_index == 1
def test_cancellation_yields_user_cancelled_stop(self) -> None:
"""If check_is_connected returns False, drain loop emits user_cancelled."""
def slow_llm(**_kwargs: Any) -> None:
time.sleep(0.3) # Outlasts the 50 ms queue-poll interval
setup = _make_setup(n_models=1)
setup.check_is_connected = MagicMock(return_value=False)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(setup)
stops = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
]
assert any(
isinstance(s.obj, OverallStop) and s.obj.stop_reason == "user_cancelled"
for s in stops
)
def test_stop_button_calls_completion_for_all_models(self) -> None:
"""llm_loop_completion_handle must be called for all models when the stop button fires.
Regression test for the disconnect-cleanup bug: the old
run_chat_loop_with_state_containers always called completion_callback in
its finally block (even on disconnect) so the DB message was updated from
the TERMINATED placeholder to a partial answer. The new _run_models must
replicate this — otherwise the integration test
test_send_message_disconnect_and_cleanup fails because the message stays
as "Response was terminated prior to completion, try regenerating."
"""
def slow_llm(**_kwargs: Any) -> None:
time.sleep(0.3)
setup = _make_setup(n_models=2)
setup.check_is_connected = MagicMock(return_value=False)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
_run_models_collect(setup)
# Must be called once per model, not zero times
assert mock_handle.call_count == 2
def test_completion_handle_called_for_each_successful_model(self) -> None:
"""llm_loop_completion_handle must be called once per model that succeeded."""
setup = _make_setup(n_models=2)
with (
patch("onyx.chat.process_message.run_llm_loop"),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
_run_models_collect(setup)
assert mock_handle.call_count == 2
def test_completion_handle_not_called_for_failed_model(self) -> None:
"""llm_loop_completion_handle must be skipped for a model that raised."""
def always_fail(**_kwargs: Any) -> None:
raise RuntimeError("fail")
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
_run_models_collect(_make_setup(n_models=1))
mock_handle.assert_not_called()
def test_http_disconnect_completion_via_generator_exit(self) -> None:
"""GeneratorExit from HTTP disconnect triggers worker self-completion.
When the HTTP client closes the connection, Starlette throws GeneratorExit
into the stream generator. The finally block sets drain_done (signalling
emitters to stop blocking) and calls executor.shutdown(wait=False) so the
server thread is never blocked. Worker threads detect drain_done.is_set()
after run_llm_loop completes and self-persist the result via
llm_loop_completion_handle using their own DB session.
This is the primary regression for test_send_message_disconnect_and_cleanup:
the integration test disconnects mid-stream and expects the DB message to be
updated from the TERMINATED placeholder to the real response.
"""
import threading
# Signals the worker to unblock from run_llm_loop after gen.close() returns.
# This guarantees drain_done is set BEFORE the worker returns from run_llm_loop,
# so the self-completion path (drain_done.is_set() check) is always taken.
disconnect_received = threading.Event()
# Set by the llm_loop_completion_handle mock when called.
completion_called = threading.Event()
def emit_then_complete(**kwargs: Any) -> None:
"""Emit one packet (to give the drain loop a yield point), then block
until the main thread signals that gen.close() has been called. This
ensures drain_done is set before we return so model_succeeded is checked
against a set drain_done — no race condition.
"""
emitter = kwargs["emitter"]
emitter.emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
disconnect_received.wait(timeout=5)
setup = _make_setup(n_models=1)
# is_connected() always True — HTTP disconnect does NOT set the Redis stop fence.
setup.check_is_connected = MagicMock(return_value=True)
with (
patch(
"onyx.chat.process_message.run_llm_loop",
side_effect=emit_then_complete,
),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle",
side_effect=lambda *_, **__: completion_called.set(),
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
from onyx.chat.process_message import _run_models
# cast to Generator so .close() is available; _run_models returns
# AnswerStream (= Iterator) but the actual object is always a generator.
gen = cast(Generator, _run_models(setup, MagicMock(), MagicMock()))
# Advance to the first yielded packet — generator suspends at `yield item`.
first = next(gen)
assert isinstance(first, Packet)
# Simulate Starlette closing the stream on HTTP client disconnect.
# GeneratorExit is thrown at the `yield item` suspension point.
gen.close()
# Unblock the worker now that drain_done has been set by gen.close().
disconnect_received.set()
# Worker self-completes asynchronously (executor.shutdown(wait=False)).
# Wait here, inside the patch context, so that get_session_with_current_tenant
# and llm_loop_completion_handle mocks are still active when the worker calls them.
assert completion_called.wait(
timeout=5
), "worker must self-complete via drain_done within 5 seconds"
assert (
mock_handle.call_count == 1
), "completion handle must be called once for the successful model"
def test_b1_race_disconnect_handler_completes_already_finished_model(self) -> None:
"""B1 regression: model finishes BEFORE GeneratorExit fires.
The worker exits _run_model with drain_done.is_set()=False and skips
self-completion. When gen.close() fires afterward, the finally else-branch
must detect model_succeeded=True and call llm_loop_completion_handle itself.
Contrast with test_http_disconnect_completion_via_generator_exit, which
tests the opposite ordering (worker finishes AFTER disconnect).
"""
import threading
import time
completion_called = threading.Event()
def emit_and_return_immediately(**kwargs: Any) -> None:
# Emit one packet so the drain loop has something to yield, then return
# immediately — no blocking. The worker will be done in microseconds.
kwargs["emitter"].emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
setup = _make_setup(n_models=1)
setup.check_is_connected = MagicMock(return_value=True)
with (
patch(
"onyx.chat.process_message.run_llm_loop",
side_effect=emit_and_return_immediately,
),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle",
side_effect=lambda *_, **__: completion_called.set(),
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
from onyx.chat.process_message import _run_models
gen = cast(Generator, _run_models(setup, MagicMock(), MagicMock()))
first = next(gen)
assert isinstance(first, Packet)
# Give the worker thread time to finish completely (emit + return +
# finally + self-completion check). It does almost no work, so 100 ms
# is far more than enough while still keeping the test fast.
time.sleep(0.1)
# Now close — worker is already done, so else-branch handles completion.
gen.close()
assert completion_called.wait(
timeout=5
), "disconnect handler must call completion for a model that already finished"
assert mock_handle.call_count == 1, "completion must be called exactly once"
def test_stop_button_does_not_call_completion_for_errored_model(self) -> None:
"""B2 regression: stop-button must NOT call completion for an errored model.
When model 0 raises an exception, its reserved ChatMessage must not be
saved with 'stopped by user' — that message is wrong for a model that
errored. llm_loop_completion_handle must only be called for non-errored
models when the stop button fires.
"""
def fail_model_0(**kwargs: Any) -> None:
if kwargs["llm"] is setup.llms[0]:
raise RuntimeError("model 0 errored")
# Model 1: run forever (stop button fires before it finishes)
time.sleep(10)
setup = _make_setup(n_models=2)
# Return False immediately so the stop-button path fires while model 1
# is still sleeping (model 0 has already errored by then).
setup.check_is_connected = lambda: False
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=fail_model_0),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
_run_models_collect(setup)
# Completion must NOT be called for model 0 (it errored).
# It MAY be called for model 1 (still in-flight when stop fired).
for call in mock_handle.call_args_list:
assert (
call.kwargs.get("llm") is not setup.llms[0]
), "llm_loop_completion_handle must not be called for the errored model"
def test_external_state_container_used_for_model_zero(self) -> None:
"""When provided, external_state_container is used as state_containers[0]."""
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.process_message import _run_models
external = ChatStateContainer()
setup = _make_setup(n_models=1)
with (
patch("onyx.chat.process_message.run_llm_loop") as mock_llm,
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
list(
_run_models(
setup, MagicMock(), MagicMock(), external_state_container=external
)
)
# The state_container kwarg passed to run_llm_loop must be the external one
call_kwargs = mock_llm.call_args.kwargs
assert call_kwargs["state_container"] is external

View File

@@ -0,0 +1,147 @@
from typing import Any
from unittest.mock import MagicMock
import pytest
import requests
from jira import JIRA
from jira.resources import Issue
from onyx.connectors.jira.connector import bulk_fetch_issues
def _make_raw_issue(issue_id: str) -> dict[str, Any]:
return {
"id": issue_id,
"key": f"TEST-{issue_id}",
"fields": {"summary": f"Issue {issue_id}"},
}
def _mock_jira_client() -> MagicMock:
mock = MagicMock(spec=JIRA)
mock._options = {"server": "https://jira.example.com"}
mock._session = MagicMock()
mock._get_url = MagicMock(
return_value="https://jira.example.com/rest/api/3/issue/bulkfetch"
)
return mock
def test_bulk_fetch_success() -> None:
"""Happy path: all issues fetched in one request."""
client = _mock_jira_client()
raw = [_make_raw_issue("1"), _make_raw_issue("2"), _make_raw_issue("3")]
resp = MagicMock()
resp.json.return_value = {"issues": raw}
client._session.post.return_value = resp
result = bulk_fetch_issues(client, ["1", "2", "3"])
assert len(result) == 3
assert all(isinstance(r, Issue) for r in result)
client._session.post.assert_called_once()
def test_bulk_fetch_splits_on_json_error() -> None:
"""When the full batch fails with JSONDecodeError, sub-batches succeed."""
client = _mock_jira_client()
call_count = 0
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
nonlocal call_count
call_count += 1
ids = json["issueIdsOrKeys"]
if len(ids) > 2:
resp = MagicMock()
resp.json.side_effect = requests.exceptions.JSONDecodeError(
"Expecting ',' delimiter", "doc", 2294125
)
return resp
resp = MagicMock()
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
return resp
client._session.post.side_effect = _post_side_effect
result = bulk_fetch_issues(client, ["1", "2", "3", "4"])
assert len(result) == 4
returned_ids = {r.raw["id"] for r in result}
assert returned_ids == {"1", "2", "3", "4"}
assert call_count > 1
def test_bulk_fetch_raises_on_single_unfetchable_issue() -> None:
"""A single issue that always fails JSON decode raises after splitting."""
client = _mock_jira_client()
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
ids = json["issueIdsOrKeys"]
if "bad" in ids:
resp = MagicMock()
resp.json.side_effect = requests.exceptions.JSONDecodeError(
"Expecting ',' delimiter", "doc", 100
)
return resp
resp = MagicMock()
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
return resp
client._session.post.side_effect = _post_side_effect
with pytest.raises(requests.exceptions.JSONDecodeError):
bulk_fetch_issues(client, ["1", "bad", "2"])
def test_bulk_fetch_non_json_error_propagates() -> None:
"""Non-JSONDecodeError exceptions still propagate."""
client = _mock_jira_client()
resp = MagicMock()
resp.json.side_effect = ValueError("something else broke")
client._session.post.return_value = resp
try:
bulk_fetch_issues(client, ["1"])
assert False, "Expected ValueError to propagate"
except ValueError:
pass
def test_bulk_fetch_with_fields() -> None:
"""Fields parameter is forwarded correctly."""
client = _mock_jira_client()
raw = [_make_raw_issue("1")]
resp = MagicMock()
resp.json.return_value = {"issues": raw}
client._session.post.return_value = resp
bulk_fetch_issues(client, ["1"], fields="summary,description")
call_payload = client._session.post.call_args[1]["json"]
assert call_payload["fields"] == ["summary", "description"]
def test_bulk_fetch_recursive_splitting_raises_on_bad_issue() -> None:
"""With a 6-issue batch where one is bad, recursion isolates it and raises."""
client = _mock_jira_client()
bad_id = "BAD"
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
ids = json["issueIdsOrKeys"]
if bad_id in ids:
resp = MagicMock()
resp.json.side_effect = requests.exceptions.JSONDecodeError(
"truncated", "doc", 999
)
return resp
resp = MagicMock()
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
return resp
client._session.post.side_effect = _post_side_effect
with pytest.raises(requests.exceptions.JSONDecodeError):
bulk_fetch_issues(client, ["1", "2", bad_id, "3", "4", "5"])

View File

@@ -1,3 +1,5 @@
from datetime import datetime
from datetime import timezone
from unittest.mock import MagicMock
from unittest.mock import patch
@@ -31,6 +33,7 @@ def mock_jira_cc_pair(
"jira_base_url": jira_base_url,
"project_key": project_key,
}
mock_cc_pair.connector.indexing_start = None
return mock_cc_pair
@@ -65,3 +68,75 @@ def test_jira_permission_sync(
fetch_all_existing_docs_ids_fn=mock_fetch_all_existing_docs_ids_fn,
):
print(doc)
def test_jira_doc_sync_passes_indexing_start(
jira_connector: JiraConnector,
mock_jira_cc_pair: MagicMock,
mock_fetch_all_existing_docs_fn: MagicMock,
mock_fetch_all_existing_docs_ids_fn: MagicMock,
) -> None:
"""Verify that generic_doc_sync derives indexing_start from cc_pair
and forwards it to retrieve_all_slim_docs_perm_sync."""
indexing_start_dt = datetime(2025, 6, 1, tzinfo=timezone.utc)
mock_jira_cc_pair.connector.indexing_start = indexing_start_dt
with patch("onyx.connectors.jira.connector.build_jira_client") as mock_build_client:
mock_build_client.return_value = jira_connector._jira_client
assert jira_connector._jira_client is not None
jira_connector._jira_client._options = MagicMock()
jira_connector._jira_client._options.return_value = {
"rest_api_version": JIRA_SERVER_API_VERSION
}
with patch.object(
type(jira_connector),
"retrieve_all_slim_docs_perm_sync",
return_value=iter([]),
) as mock_retrieve:
list(
jira_doc_sync(
cc_pair=mock_jira_cc_pair,
fetch_all_existing_docs_fn=mock_fetch_all_existing_docs_fn,
fetch_all_existing_docs_ids_fn=mock_fetch_all_existing_docs_ids_fn,
)
)
mock_retrieve.assert_called_once()
call_kwargs = mock_retrieve.call_args
assert call_kwargs.kwargs["start"] == indexing_start_dt.timestamp()
def test_jira_doc_sync_passes_none_when_no_indexing_start(
jira_connector: JiraConnector,
mock_jira_cc_pair: MagicMock,
mock_fetch_all_existing_docs_fn: MagicMock,
mock_fetch_all_existing_docs_ids_fn: MagicMock,
) -> None:
"""Verify that indexing_start is None when the connector has no indexing_start set."""
mock_jira_cc_pair.connector.indexing_start = None
with patch("onyx.connectors.jira.connector.build_jira_client") as mock_build_client:
mock_build_client.return_value = jira_connector._jira_client
assert jira_connector._jira_client is not None
jira_connector._jira_client._options = MagicMock()
jira_connector._jira_client._options.return_value = {
"rest_api_version": JIRA_SERVER_API_VERSION
}
with patch.object(
type(jira_connector),
"retrieve_all_slim_docs_perm_sync",
return_value=iter([]),
) as mock_retrieve:
list(
jira_doc_sync(
cc_pair=mock_jira_cc_pair,
fetch_all_existing_docs_fn=mock_fetch_all_existing_docs_fn,
fetch_all_existing_docs_ids_fn=mock_fetch_all_existing_docs_ids_fn,
)
)
mock_retrieve.assert_called_once()
call_kwargs = mock_retrieve.call_args
assert call_kwargs.kwargs["start"] is None

View File

@@ -272,13 +272,13 @@ class TestUpsertVoiceProvider:
class TestDeleteVoiceProvider:
"""Tests for delete_voice_provider."""
def test_soft_deletes_provider_when_found(self, mock_db_session: MagicMock) -> None:
def test_hard_deletes_provider_when_found(self, mock_db_session: MagicMock) -> None:
provider = _make_voice_provider(id=1)
mock_db_session.scalar.return_value = provider
delete_voice_provider(mock_db_session, 1)
assert provider.deleted is True
mock_db_session.delete.assert_called_once_with(provider)
mock_db_session.flush.assert_called_once()
def test_does_nothing_when_provider_not_found(

View File

@@ -0,0 +1,223 @@
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.access.models import DocumentAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.document_index.interfaces_new import IndexingMetadata
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.opensearch_document_index import (
OpenSearchDocumentIndex,
)
from onyx.indexing.models import ChunkEmbedding
from onyx.indexing.models import DocMetadataAwareIndexChunk
def _make_chunk(
doc_id: str,
chunk_id: int,
) -> DocMetadataAwareIndexChunk:
"""Creates a minimal DocMetadataAwareIndexChunk for testing."""
doc = Document(
id=doc_id,
sections=[TextSection(text="test", link="http://test.com")],
source=DocumentSource.FILE,
semantic_identifier="test_doc",
metadata={},
)
access = DocumentAccess.build(
user_emails=[],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=True,
)
return DocMetadataAwareIndexChunk(
chunk_id=chunk_id,
blurb="test",
content="test content",
source_links={0: "http://test.com"},
image_file_id=None,
section_continuation=False,
source_document=doc,
title_prefix="",
metadata_suffix_semantic="",
metadata_suffix_keyword="",
mini_chunk_texts=None,
large_chunk_id=None,
doc_summary="",
chunk_context="",
contextual_rag_reserved_tokens=0,
embeddings=ChunkEmbedding(full_embedding=[0.1] * 10, mini_chunk_embeddings=[]),
title_embedding=[0.1] * 10,
tenant_id="test_tenant",
access=access,
document_sets=set(),
user_project=[],
personas=[],
boost=0,
aggregated_chunk_boost_factor=1.0,
ancestor_hierarchy_node_ids=[],
)
def _make_index() -> tuple[OpenSearchDocumentIndex, MagicMock]:
"""Creates an OpenSearchDocumentIndex with a mocked client.
Returns the index and the mock for bulk_index_documents."""
mock_client = MagicMock()
mock_bulk = MagicMock()
mock_client.bulk_index_documents = mock_bulk
tenant_state = TenantState(tenant_id="test_tenant", multitenant=False)
index = OpenSearchDocumentIndex.__new__(OpenSearchDocumentIndex)
index._index_name = "test_index"
index._client = mock_client
index._tenant_state = tenant_state
return index, mock_bulk
def _make_metadata(doc_id: str, chunk_count: int) -> IndexingMetadata:
return IndexingMetadata(
doc_id_to_chunk_cnt_diff={
doc_id: IndexingMetadata.ChunkCounts(
old_chunk_cnt=0,
new_chunk_cnt=chunk_count,
),
},
)
@patch(
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
100,
)
def test_single_doc_under_batch_limit_flushes_once() -> None:
"""A document with fewer chunks than MAX_CHUNKS_PER_DOC_BATCH should flush once."""
index, mock_bulk = _make_index()
doc_id = "doc_1"
num_chunks = 50
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
metadata = _make_metadata(doc_id, num_chunks)
with patch.object(index, "delete", return_value=0):
index.index(chunks, metadata)
assert mock_bulk.call_count == 1
batch_arg = mock_bulk.call_args_list[0]
assert len(batch_arg.kwargs["documents"]) == num_chunks
@patch(
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
100,
)
def test_single_doc_over_batch_limit_flushes_multiple_times() -> None:
"""A document with more chunks than MAX_CHUNKS_PER_DOC_BATCH should flush multiple times."""
index, mock_bulk = _make_index()
doc_id = "doc_1"
num_chunks = 250
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
metadata = _make_metadata(doc_id, num_chunks)
with patch.object(index, "delete", return_value=0):
index.index(chunks, metadata)
# 250 chunks / 100 per batch = 3 flushes (100 + 100 + 50)
assert mock_bulk.call_count == 3
batch_sizes = [len(call.kwargs["documents"]) for call in mock_bulk.call_args_list]
assert batch_sizes == [100, 100, 50]
@patch(
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
100,
)
def test_single_doc_exactly_at_batch_limit() -> None:
"""A document with exactly MAX_CHUNKS_PER_DOC_BATCH chunks should flush once
(the flush happens on the next chunk, not at the boundary)."""
index, mock_bulk = _make_index()
doc_id = "doc_1"
num_chunks = 100
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
metadata = _make_metadata(doc_id, num_chunks)
with patch.object(index, "delete", return_value=0):
index.index(chunks, metadata)
# 100 chunks hit the >= check on chunk 101 which doesn't exist,
# so final flush handles all 100
# Actually: the elif fires when len(current_chunks) >= 100, which happens
# when current_chunks has 100 items and the 101st chunk arrives.
# With exactly 100 chunks, the 100th chunk makes len == 99, then appended -> 100.
# No 101st chunk arrives, so the final flush handles all 100.
assert mock_bulk.call_count == 1
@patch(
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
100,
)
def test_single_doc_one_over_batch_limit() -> None:
"""101 chunks for one doc: first 100 flushed when the 101st arrives, then
the 101st is flushed at the end."""
index, mock_bulk = _make_index()
doc_id = "doc_1"
num_chunks = 101
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
metadata = _make_metadata(doc_id, num_chunks)
with patch.object(index, "delete", return_value=0):
index.index(chunks, metadata)
assert mock_bulk.call_count == 2
batch_sizes = [len(call.kwargs["documents"]) for call in mock_bulk.call_args_list]
assert batch_sizes == [100, 1]
@patch(
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
100,
)
def test_multiple_docs_each_under_limit_flush_per_doc() -> None:
"""Multiple documents each under the batch limit should flush once per document."""
index, mock_bulk = _make_index()
chunks = []
for doc_idx in range(3):
doc_id = f"doc_{doc_idx}"
for chunk_idx in range(50):
chunks.append(_make_chunk(doc_id, chunk_idx))
metadata = IndexingMetadata(
doc_id_to_chunk_cnt_diff={
f"doc_{i}": IndexingMetadata.ChunkCounts(old_chunk_cnt=0, new_chunk_cnt=50)
for i in range(3)
},
)
with patch.object(index, "delete", return_value=0):
index.index(chunks, metadata)
# 3 documents = 3 flushes (one per doc boundary + final)
assert mock_bulk.call_count == 3
@patch(
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
100,
)
def test_delete_called_once_per_document() -> None:
"""Even with multiple flushes for a single document, delete should only be
called once per document."""
index, _mock_bulk = _make_index()
doc_id = "doc_1"
num_chunks = 250
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
metadata = _make_metadata(doc_id, num_chunks)
with patch.object(index, "delete", return_value=0) as mock_delete:
index.index(chunks, metadata)
mock_delete.assert_called_once_with(doc_id, None)

View File

@@ -0,0 +1,152 @@
"""Unit tests for VespaDocumentIndex.index().
These tests mock all external I/O (HTTP calls, thread pools) and verify
the streaming logic, ID cleaning/mapping, and DocumentInsertionRecord
construction.
"""
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.access.models import DocumentAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
from onyx.document_index.interfaces_new import IndexingMetadata
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
from onyx.indexing.models import ChunkEmbedding
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexChunk
def _make_chunk(
doc_id: str,
chunk_id: int = 0,
content: str = "test content",
) -> DocMetadataAwareIndexChunk:
doc = Document(
id=doc_id,
semantic_identifier="test_doc",
sections=[TextSection(text=content, link=None)],
source=DocumentSource.NOT_APPLICABLE,
metadata={},
)
index_chunk = IndexChunk(
chunk_id=chunk_id,
blurb=content[:50],
content=content,
source_links=None,
image_file_id=None,
section_continuation=False,
source_document=doc,
title_prefix="",
metadata_suffix_semantic="",
metadata_suffix_keyword="",
contextual_rag_reserved_tokens=0,
doc_summary="",
chunk_context="",
mini_chunk_texts=None,
large_chunk_id=None,
embeddings=ChunkEmbedding(
full_embedding=[0.1] * 10,
mini_chunk_embeddings=[],
),
title_embedding=None,
)
access = DocumentAccess.build(
user_emails=[],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=True,
)
return DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=index_chunk,
access=access,
document_sets=set(),
user_project=[],
personas=[],
boost=0,
aggregated_chunk_boost_factor=1.0,
tenant_id="test_tenant",
)
def _make_indexing_metadata(
doc_ids: list[str],
old_counts: list[int],
new_counts: list[int],
) -> IndexingMetadata:
return IndexingMetadata(
doc_id_to_chunk_cnt_diff={
doc_id: IndexingMetadata.ChunkCounts(
old_chunk_cnt=old,
new_chunk_cnt=new,
)
for doc_id, old, new in zip(doc_ids, old_counts, new_counts)
}
)
def _stub_enrich(
doc_id: str,
old_chunk_cnt: int,
) -> EnrichedDocumentIndexingInfo:
"""Build an EnrichedDocumentIndexingInfo that says 'no chunks to delete'
when old_chunk_cnt == 0, or 'has existing chunks' otherwise."""
return EnrichedDocumentIndexingInfo(
doc_id=doc_id,
chunk_start_index=0,
old_version=False,
chunk_end_index=old_chunk_cnt,
)
@patch("onyx.document_index.vespa.vespa_document_index.batch_index_vespa_chunks")
@patch("onyx.document_index.vespa.vespa_document_index.delete_vespa_chunks")
@patch(
"onyx.document_index.vespa.vespa_document_index.get_document_chunk_ids",
return_value=[],
)
@patch("onyx.document_index.vespa.vespa_document_index._enrich_basic_chunk_info")
@patch(
"onyx.document_index.vespa.vespa_document_index.BATCH_SIZE",
3,
)
def test_index_respects_batch_size(
mock_enrich: MagicMock,
mock_get_chunk_ids: MagicMock, # noqa: ARG001
mock_delete: MagicMock, # noqa: ARG001
mock_batch_index: MagicMock,
) -> None:
"""When chunks exceed BATCH_SIZE, batch_index_vespa_chunks is called
multiple times with correctly sized batches."""
mock_enrich.return_value = _stub_enrich("doc1", old_chunk_cnt=0)
index = VespaDocumentIndex(
index_name="test_index",
tenant_state=TenantState(tenant_id="test_tenant", multitenant=False),
large_chunks_enabled=False,
httpx_client=MagicMock(),
)
chunks = [_make_chunk("doc1", chunk_id=i) for i in range(7)]
metadata = _make_indexing_metadata(["doc1"], old_counts=[0], new_counts=[7])
results = index.index(chunks=chunks, indexing_metadata=metadata)
assert len(results) == 1
# With BATCH_SIZE=3 and 7 chunks: batches of 3, 3, 1
assert mock_batch_index.call_count == 3
batch_sizes = [len(c.kwargs["chunks"]) for c in mock_batch_index.call_args_list]
assert batch_sizes == [3, 3, 1]
# Verify all chunks are accounted for and in order
all_indexed = [
chunk for c in mock_batch_index.call_args_list for chunk in c.kwargs["chunks"]
]
assert len(all_indexed) == 7
assert [c.chunk_id for c in all_indexed] == list(range(7))

View File

@@ -0,0 +1,391 @@
"""Unit tests for _embed_chunks_to_store.
Tests cover:
- Single batch, no failures
- Multiple batches, no failures
- Failure in a single batch
- Cross-batch document failure scrubbing
- Later batches skip already-failed docs
- Empty input
- All chunks fail
"""
from collections.abc import Callable
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import DocumentSource
from onyx.connectors.models import TextSection
from onyx.indexing.chunk_batch_store import ChunkBatchStore
from onyx.indexing.indexing_pipeline import _embed_chunks_to_store
from onyx.indexing.models import ChunkEmbedding
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import IndexChunk
def _make_doc(doc_id: str) -> Document:
return Document(
id=doc_id,
semantic_identifier="test",
source=DocumentSource.FILE,
sections=[TextSection(text="test", link=None)],
metadata={},
)
def _make_chunk(doc_id: str, chunk_id: int) -> DocAwareChunk:
return DocAwareChunk(
chunk_id=chunk_id,
blurb="test",
content="test content",
source_links=None,
image_file_id=None,
section_continuation=False,
source_document=_make_doc(doc_id),
title_prefix="",
metadata_suffix_semantic="",
metadata_suffix_keyword="",
mini_chunk_texts=None,
large_chunk_id=None,
doc_summary="",
chunk_context="",
contextual_rag_reserved_tokens=0,
)
def _make_index_chunk(doc_id: str, chunk_id: int) -> IndexChunk:
"""Create an IndexChunk (a DocAwareChunk with embeddings)."""
return IndexChunk(
chunk_id=chunk_id,
blurb="test",
content="test content",
source_links=None,
image_file_id=None,
section_continuation=False,
source_document=_make_doc(doc_id),
title_prefix="",
metadata_suffix_semantic="",
metadata_suffix_keyword="",
mini_chunk_texts=None,
large_chunk_id=None,
doc_summary="",
chunk_context="",
contextual_rag_reserved_tokens=0,
embeddings=ChunkEmbedding(
full_embedding=[0.1] * 10,
mini_chunk_embeddings=[],
),
title_embedding=None,
)
def _make_failure(doc_id: str) -> ConnectorFailure:
return ConnectorFailure(
failed_document=DocumentFailure(document_id=doc_id, document_link=None),
failure_message="embedding failed",
exception=RuntimeError("embedding failed"),
)
def _mock_embed_success(
chunks: list[DocAwareChunk], **_kwargs: object
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
"""Simulate successful embedding of all chunks."""
return (
[_make_index_chunk(c.source_document.id, c.chunk_id) for c in chunks],
[],
)
def _mock_embed_fail_doc(
fail_doc_id: str,
) -> Callable[..., tuple[list[IndexChunk], list[ConnectorFailure]]]:
"""Return an embed mock that fails all chunks for a specific doc."""
def _embed(
chunks: list[DocAwareChunk], **_kwargs: object
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
successes = [
_make_index_chunk(c.source_document.id, c.chunk_id)
for c in chunks
if c.source_document.id != fail_doc_id
]
failures = (
[_make_failure(fail_doc_id)]
if any(c.source_document.id == fail_doc_id for c in chunks)
else []
)
return successes, failures
return _embed
class TestEmbedChunksInBatches:
@patch(
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
)
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 100)
def test_single_batch_no_failures(self, mock_embed: MagicMock) -> None:
"""All chunks fit in one batch and embed successfully."""
mock_embed.side_effect = _mock_embed_success
with ChunkBatchStore() as store:
chunks = [_make_chunk("doc1", i) for i in range(3)]
result = _embed_chunks_to_store(
chunks=chunks,
embedder=MagicMock(),
tenant_id="test",
request_id=None,
store=store,
)
assert len(result.successful_chunk_ids) == 3
assert len(result.connector_failures) == 0
# Verify stored contents
assert len(store._batch_files()) == 1
stored = list(store.stream())
assert len(stored) == 3
@patch(
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
)
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 3)
def test_multiple_batches_no_failures(self, mock_embed: MagicMock) -> None:
"""Chunks are split across multiple batches, all succeed."""
mock_embed.side_effect = _mock_embed_success
with ChunkBatchStore() as store:
chunks = [_make_chunk("doc1", i) for i in range(7)]
result = _embed_chunks_to_store(
chunks=chunks,
embedder=MagicMock(),
tenant_id="test",
request_id=None,
store=store,
)
assert len(result.successful_chunk_ids) == 7
assert len(result.connector_failures) == 0
assert len(store._batch_files()) == 3 # 3 + 3 + 1
@patch(
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
)
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 100)
def test_single_batch_with_failure(self, mock_embed: MagicMock) -> None:
"""One doc fails embedding, its chunks are excluded from results."""
mock_embed.side_effect = _mock_embed_fail_doc("doc2")
with ChunkBatchStore() as store:
chunks = [
_make_chunk("doc1", 0),
_make_chunk("doc2", 1),
_make_chunk("doc1", 2),
]
result = _embed_chunks_to_store(
chunks=chunks,
embedder=MagicMock(),
tenant_id="test",
request_id=None,
store=store,
)
assert len(result.connector_failures) == 1
successful_doc_ids = {doc_id for _, doc_id in result.successful_chunk_ids}
assert "doc2" not in successful_doc_ids
assert "doc1" in successful_doc_ids
@patch(
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
)
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 3)
def test_cross_batch_failure_scrubs_earlier_batch(
self, mock_embed: MagicMock
) -> None:
"""Doc A spans batches 0 and 1. It succeeds in batch 0 but fails in
batch 1. Its chunks should be scrubbed from batch 0's batch file."""
call_count = 0
def _embed(
chunks: list[DocAwareChunk], **_kwargs: object
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
nonlocal call_count
call_count += 1
if call_count == 1:
return _mock_embed_success(chunks)
else:
return _mock_embed_fail_doc("docA")(chunks)
mock_embed.side_effect = _embed
with ChunkBatchStore() as store:
chunks = [
_make_chunk("docA", 0),
_make_chunk("docA", 1),
_make_chunk("docA", 2),
_make_chunk("docA", 3),
_make_chunk("docB", 0),
_make_chunk("docB", 1),
]
result = _embed_chunks_to_store(
chunks=chunks,
embedder=MagicMock(),
tenant_id="test",
request_id=None,
store=store,
)
# docA should be fully excluded from results
successful_doc_ids = {doc_id for _, doc_id in result.successful_chunk_ids}
assert "docA" not in successful_doc_ids
assert "docB" in successful_doc_ids
assert len(result.connector_failures) == 1
# Verify batch 0 was scrubbed of docA chunks
all_stored = list(store.stream())
stored_doc_ids = {c.source_document.id for c in all_stored}
assert "docA" not in stored_doc_ids
assert "docB" in stored_doc_ids
@patch(
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
)
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 3)
def test_later_batch_skips_already_failed_doc(self, mock_embed: MagicMock) -> None:
"""If docA fails in batch 0, its chunks in batch 1 are skipped
entirely (never sent to the embedder)."""
embedded_doc_ids: list[str] = []
def _embed(
chunks: list[DocAwareChunk], **_kwargs: object
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
for c in chunks:
embedded_doc_ids.append(c.source_document.id)
return _mock_embed_fail_doc("docA")(chunks)
mock_embed.side_effect = _embed
with ChunkBatchStore() as store:
chunks = [
_make_chunk("docA", 0),
_make_chunk("docA", 1),
_make_chunk("docA", 2),
_make_chunk("docA", 3),
_make_chunk("docB", 0),
_make_chunk("docB", 1),
]
_embed_chunks_to_store(
chunks=chunks,
embedder=MagicMock(),
tenant_id="test",
request_id=None,
store=store,
)
# docA should only appear in batch 0, not batch 1
batch_1_doc_ids = embedded_doc_ids[3:]
assert "docA" not in batch_1_doc_ids
@patch(
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
)
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 3)
def test_failed_doc_skipped_in_later_batch_while_other_doc_succeeds(
self, mock_embed: MagicMock
) -> None:
"""doc1 spans batches 0 and 1, doc2 only in batch 1. Batch 0 fails
doc1. In batch 1, doc1 chunks should be skipped but doc2 chunks
should still be embedded successfully."""
embedded_chunks: list[list[str]] = []
def _embed(
chunks: list[DocAwareChunk], **_kwargs: object
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
embedded_chunks.append([c.source_document.id for c in chunks])
return _mock_embed_fail_doc("doc1")(chunks)
mock_embed.side_effect = _embed
with ChunkBatchStore() as store:
chunks = [
_make_chunk("doc1", 0),
_make_chunk("doc1", 1),
_make_chunk("doc1", 2),
_make_chunk("doc1", 3),
_make_chunk("doc2", 0),
_make_chunk("doc2", 1),
]
result = _embed_chunks_to_store(
chunks=chunks,
embedder=MagicMock(),
tenant_id="test",
request_id=None,
store=store,
)
# doc1 should be fully excluded, doc2 fully included
successful_doc_ids = {doc_id for _, doc_id in result.successful_chunk_ids}
assert "doc1" not in successful_doc_ids
assert "doc2" in successful_doc_ids
assert len(result.successful_chunk_ids) == 2 # doc2's 2 chunks
# Batch 1 should only contain doc2 (doc1 was filtered before embedding)
assert len(embedded_chunks) == 2
assert "doc1" not in embedded_chunks[1]
assert embedded_chunks[1] == ["doc2", "doc2"]
# Verify on-disk state has no doc1 chunks
all_stored = list(store.stream())
assert all(c.source_document.id == "doc2" for c in all_stored)
@patch(
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
)
def test_empty_input(self, mock_embed: MagicMock) -> None:
"""Empty chunk list produces empty results."""
mock_embed.side_effect = _mock_embed_success
with ChunkBatchStore() as store:
result = _embed_chunks_to_store(
chunks=[],
embedder=MagicMock(),
tenant_id="test",
request_id=None,
store=store,
)
assert len(result.successful_chunk_ids) == 0
assert len(result.connector_failures) == 0
mock_embed.assert_not_called()
@patch(
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
)
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 100)
def test_all_chunks_fail(self, mock_embed: MagicMock) -> None:
"""When all documents fail, results have no successful chunks."""
def _fail_all(
chunks: list[DocAwareChunk], **_kwargs: object
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
doc_ids = {c.source_document.id for c in chunks}
return [], [_make_failure(doc_id) for doc_id in doc_ids]
mock_embed.side_effect = _fail_all
with ChunkBatchStore() as store:
chunks = [_make_chunk("doc1", 0), _make_chunk("doc2", 1)]
result = _embed_chunks_to_store(
chunks=chunks,
embedder=MagicMock(),
tenant_id="test",
request_id=None,
store=store,
)
assert len(result.successful_chunk_ids) == 0
assert len(result.connector_failures) == 2

View File

@@ -116,7 +116,7 @@ def _run_adapter_build(
project_ids_map: dict[str, list[int]],
persona_ids_map: dict[str, list[int]],
) -> list[DocMetadataAwareIndexChunk]:
"""Helper that runs UserFileIndexingAdapter.build_metadata_aware_chunks
"""Helper that runs UserFileIndexingAdapter.prepare_enrichment + enrich_chunk
with all external dependencies mocked."""
from onyx.indexing.adapters.user_file_indexing_adapter import (
UserFileIndexingAdapter,
@@ -155,18 +155,16 @@ def _run_adapter_build(
side_effect=Exception("no LLM in tests"),
),
):
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id="test_tenant",
enricher = adapter.prepare_enrichment(
context=context,
tenant_id="test_tenant",
chunks=[chunk],
)
return result.chunks
return [enricher.enrich_chunk(chunk, 1.0)]
def test_build_metadata_aware_chunks_includes_persona_ids() -> None:
"""UserFileIndexingAdapter.build_metadata_aware_chunks writes persona IDs
def test_prepare_enrichment_includes_persona_ids() -> None:
"""UserFileIndexingAdapter.prepare_enrichment writes persona IDs
fetched from the DB into each chunk's metadata."""
file_id = str(uuid4())
persona_ids = [5, 12]
@@ -183,7 +181,7 @@ def test_build_metadata_aware_chunks_includes_persona_ids() -> None:
assert chunks[0].user_project == project_ids
def test_build_metadata_aware_chunks_missing_file_defaults_to_empty() -> None:
def test_prepare_enrichment_missing_file_defaults_to_empty() -> None:
"""When a file has no persona or project associations in the DB, the
adapter should default to empty lists (not KeyError or None)."""
file_id = str(uuid4())

View File

@@ -11,6 +11,7 @@ from litellm.types.utils import ChatCompletionDeltaToolCall
from litellm.types.utils import Delta
from litellm.types.utils import Function as LiteLLMFunction
import onyx.llm.models
from onyx.configs.app_configs import MOCK_LLM_RESPONSE
from onyx.llm.constants import LlmProviderNames
from onyx.llm.interfaces import LLMUserIdentity
@@ -1479,6 +1480,147 @@ def test_bifrost_normalizes_api_base_in_model_kwargs() -> None:
assert llm._model_kwargs["api_base"] == "https://bifrost.example.com/v1"
def test_prompt_contains_tool_call_history_true() -> None:
from onyx.llm.multi_llm import _prompt_contains_tool_call_history
messages: LanguageModelInput = [
UserMessage(content="What's the weather?"),
AssistantMessage(
content=None,
tool_calls=[
ToolCall(
id="tc_1",
function=FunctionCall(name="get_weather", arguments="{}"),
)
],
),
]
assert _prompt_contains_tool_call_history(messages) is True
def test_prompt_contains_tool_call_history_false_no_tools() -> None:
from onyx.llm.multi_llm import _prompt_contains_tool_call_history
messages: LanguageModelInput = [
UserMessage(content="Hello"),
AssistantMessage(content="Hi there!"),
]
assert _prompt_contains_tool_call_history(messages) is False
def test_prompt_contains_tool_call_history_false_user_only() -> None:
from onyx.llm.multi_llm import _prompt_contains_tool_call_history
messages: LanguageModelInput = [UserMessage(content="Hello")]
assert _prompt_contains_tool_call_history(messages) is False
def test_bedrock_claude_drops_thinking_when_thinking_blocks_missing() -> None:
"""When thinking is enabled but assistant messages with tool_calls lack
thinking_blocks, the thinking param must be dropped to avoid the Bedrock
BadRequestError about missing thinking blocks."""
llm = LitellmLLM(
api_key=None,
timeout=30,
model_provider=LlmProviderNames.BEDROCK,
model_name="anthropic.claude-sonnet-4-20250514-v1:0",
max_input_tokens=200000,
)
messages: LanguageModelInput = [
UserMessage(content="What's the weather?"),
AssistantMessage(
content=None,
tool_calls=[
ToolCall(
id="tc_1",
function=FunctionCall(
name="get_weather",
arguments='{"city": "Paris"}',
),
)
],
),
onyx.llm.models.ToolMessage(
content="22°C sunny",
tool_call_id="tc_1",
),
]
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
},
},
}
]
with (
patch("litellm.completion") as mock_completion,
patch("onyx.llm.multi_llm.model_is_reasoning_model", return_value=True),
):
mock_completion.return_value = []
list(llm.stream(messages, tools=tools, reasoning_effort=ReasoningEffort.HIGH))
kwargs = mock_completion.call_args.kwargs
assert "thinking" not in kwargs, (
"thinking param should be dropped when thinking_blocks are missing "
"from assistant messages with tool_calls"
)
def test_bedrock_claude_keeps_thinking_when_no_tool_history() -> None:
"""When thinking is enabled and there are no historical assistant messages
with tool_calls, the thinking param should be preserved."""
llm = LitellmLLM(
api_key=None,
timeout=30,
model_provider=LlmProviderNames.BEDROCK,
model_name="anthropic.claude-sonnet-4-20250514-v1:0",
max_input_tokens=200000,
)
messages: LanguageModelInput = [
UserMessage(content="What's the weather?"),
]
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
},
},
}
]
with (
patch("litellm.completion") as mock_completion,
patch("onyx.llm.multi_llm.model_is_reasoning_model", return_value=True),
):
mock_completion.return_value = []
list(llm.stream(messages, tools=tools, reasoning_effort=ReasoningEffort.HIGH))
kwargs = mock_completion.call_args.kwargs
assert "thinking" in kwargs, (
"thinking param should be preserved when no assistant messages "
"with tool_calls exist in history"
)
assert kwargs["thinking"]["type"] == "enabled"
def test_bifrost_claude_includes_allowed_openai_params() -> None:
llm = LitellmLLM(
api_key="test_key",

View File

@@ -3,7 +3,7 @@
Covers:
- _check_ssrf_safety: scheme enforcement and private-IP blocklist
- _validate_endpoint: httpx exception → HookValidateStatus mapping
ConnectTimeout → cannot_connect (TCP handshake never completed)
ConnectTimeout → timeout (any timeout directs user to increase timeout_seconds)
ConnectError → cannot_connect (DNS / TLS failure)
ReadTimeout et al. → timeout (TCP connected, server slow)
Any other exc → cannot_connect
@@ -61,7 +61,7 @@ class TestCheckSsrfSafety:
def test_non_https_scheme_rejected(self, url: str) -> None:
with pytest.raises(OnyxError) as exc_info:
self._call(url)
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
assert "https" in (exc_info.value.detail or "").lower()
# --- private IP blocklist ---
@@ -87,7 +87,7 @@ class TestCheckSsrfSafety:
):
mock_dns.return_value = [(None, None, None, None, (ip, 0))]
self._call("https://internal.example.com/hook")
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
assert ip in (exc_info.value.detail or "")
def test_public_ip_is_allowed(self) -> None:
@@ -106,7 +106,7 @@ class TestCheckSsrfSafety:
pytest.raises(OnyxError) as exc_info,
):
self._call("https://no-such-host.example.com/hook")
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
# ---------------------------------------------------------------------------
@@ -158,13 +158,11 @@ class TestValidateEndpoint:
assert self._call().status == HookValidateStatus.passed
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_connect_timeout_returns_cannot_connect(
self, mock_client_cls: MagicMock
) -> None:
def test_connect_timeout_returns_timeout(self, mock_client_cls: MagicMock) -> None:
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
httpx.ConnectTimeout("timed out")
)
assert self._call().status == HookValidateStatus.cannot_connect
assert self._call().status == HookValidateStatus.timeout
@patch("onyx.server.features.hooks.api.httpx.Client")
@pytest.mark.parametrize(

View File

@@ -4,13 +4,23 @@ from unittest.mock import MagicMock
import pytest
from fastapi import UploadFile
from onyx.natural_language_processing import utils as nlp_utils
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.natural_language_processing.utils import count_tokens
from onyx.server.features.projects import projects_file_utils as utils
from onyx.server.settings.models import Settings
class _Tokenizer:
class _Tokenizer(BaseTokenizer):
def encode(self, text: str) -> list[int]:
return [1] * len(text)
def tokenize(self, text: str) -> list[str]:
return list(text)
def decode(self, _tokens: list[int]) -> str:
return ""
class _NonSeekableFile(BytesIO):
def tell(self) -> int:
@@ -29,10 +39,26 @@ def _make_upload_no_size(filename: str, content: bytes) -> UploadFile:
return UploadFile(filename=filename, file=BytesIO(content), size=None)
def _patch_common_dependencies(monkeypatch: pytest.MonkeyPatch) -> None:
def _make_settings(upload_size_mb: int = 1, token_threshold_k: int = 100) -> Settings:
return Settings(
user_file_max_upload_size_mb=upload_size_mb,
file_token_count_threshold_k=token_threshold_k,
)
def _patch_common_dependencies(
monkeypatch: pytest.MonkeyPatch,
upload_size_mb: int = 1,
token_threshold_k: int = 100,
) -> None:
monkeypatch.setattr(utils, "fetch_default_llm_model", lambda _db: None)
monkeypatch.setattr(utils, "get_tokenizer", lambda **_kwargs: _Tokenizer())
monkeypatch.setattr(utils, "is_file_password_protected", lambda **_kwargs: False)
monkeypatch.setattr(
utils,
"load_settings",
lambda: _make_settings(upload_size_mb, token_threshold_k),
)
def test_get_upload_size_bytes_falls_back_to_stream_size() -> None:
@@ -76,9 +102,8 @@ def test_is_upload_too_large_logs_warning_when_size_unknown(
def test_categorize_uploaded_files_accepts_size_under_limit(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
# upload_size_mb=1 → max_bytes = 1*1024*1024; file size 99 is well under
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
upload = _make_upload("small.png", size=99)
@@ -91,9 +116,7 @@ def test_categorize_uploaded_files_accepts_size_under_limit(
def test_categorize_uploaded_files_uses_seek_fallback_when_upload_size_missing(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
upload = _make_upload_no_size("small.png", content=b"x" * 99)
@@ -106,12 +129,11 @@ def test_categorize_uploaded_files_uses_seek_fallback_when_upload_size_missing(
def test_categorize_uploaded_files_accepts_size_at_limit(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
upload = _make_upload("edge.png", size=100)
# 1 MB = 1048576 bytes; file at exactly that boundary should be accepted
upload = _make_upload("edge.png", size=1048576)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 1
@@ -121,12 +143,10 @@ def test_categorize_uploaded_files_accepts_size_at_limit(
def test_categorize_uploaded_files_rejects_size_over_limit_with_reason(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
upload = _make_upload("large.png", size=101)
upload = _make_upload("large.png", size=1048577) # 1 byte over 1 MB
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 0
@@ -137,13 +157,11 @@ def test_categorize_uploaded_files_rejects_size_over_limit_with_reason(
def test_categorize_uploaded_files_mixed_batch_keeps_valid_and_rejects_oversized(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
small = _make_upload("small.png", size=50)
large = _make_upload("large.png", size=101)
large = _make_upload("large.png", size=1048577)
result = utils.categorize_uploaded_files([small, large], MagicMock())
@@ -153,15 +171,12 @@ def test_categorize_uploaded_files_mixed_batch_keeps_valid_and_rejects_oversized
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
def test_categorize_uploaded_files_enforces_size_limit_even_when_threshold_is_skipped(
def test_categorize_uploaded_files_enforces_size_limit_always(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "SKIP_USERFILE_THRESHOLD", True)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
upload = _make_upload("oversized.pdf", size=101)
upload = _make_upload("oversized.pdf", size=1048577)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 0
@@ -172,14 +187,12 @@ def test_categorize_uploaded_files_enforces_size_limit_even_when_threshold_is_sk
def test_categorize_uploaded_files_checks_size_before_text_extraction(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
extract_mock = MagicMock(return_value="this should not run")
monkeypatch.setattr(utils, "extract_file_text", extract_mock)
oversized_doc = _make_upload("oversized.pdf", size=101)
oversized_doc = _make_upload("oversized.pdf", size=1048577)
result = utils.categorize_uploaded_files([oversized_doc], MagicMock())
extract_mock.assert_not_called()
@@ -188,40 +201,219 @@ def test_categorize_uploaded_files_checks_size_before_text_extraction(
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
def test_categorize_uploaded_files_accepts_python_file(
def test_categorize_enforces_size_limit_when_upload_size_mb_is_positive(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 10_000)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
"""A positive upload_size_mb is always enforced."""
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
py_source = b'def hello():\n print("world")\n'
monkeypatch.setattr(
utils, "extract_file_text", lambda **_kwargs: py_source.decode()
)
upload = _make_upload("script.py", size=len(py_source), content=py_source)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 1
assert result.acceptable[0].filename == "script.py"
assert len(result.rejected) == 0
def test_categorize_uploaded_files_rejects_binary_file(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 10_000)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: "")
binary_content = bytes(range(256)) * 4
upload = _make_upload("data.bin", size=len(binary_content), content=binary_content)
upload = _make_upload("huge.png", size=1048577, content=b"x")
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 0
assert len(result.rejected) == 1
assert result.rejected[0].filename == "data.bin"
assert "Unsupported file type" in result.rejected[0].reason
def test_categorize_enforces_token_limit_when_threshold_k_is_positive(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A positive token_threshold_k is always enforced."""
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=5)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 6000)
upload = _make_upload("big_image.png", size=100)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 0
assert len(result.rejected) == 1
def test_categorize_no_token_limit_when_threshold_k_is_zero(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""token_threshold_k=0 means no token limit; high-token files are accepted."""
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=0)
monkeypatch.setattr(
utils, "estimate_image_tokens_for_upload", lambda _upload: 999_999
)
upload = _make_upload("huge_image.png", size=100)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.rejected) == 0
assert len(result.acceptable) == 1
def test_categorize_both_limits_enforced(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Both positive limits are enforced; file exceeding token limit is rejected."""
_patch_common_dependencies(monkeypatch, upload_size_mb=10, token_threshold_k=5)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 6000)
upload = _make_upload("over_tokens.png", size=100)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 0
assert len(result.rejected) == 1
assert result.rejected[0].reason == "Exceeds 5K token limit"
def test_categorize_rejection_reason_contains_dynamic_values(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Rejection reasons reflect the admin-configured limits, not hardcoded values."""
_patch_common_dependencies(monkeypatch, upload_size_mb=42, token_threshold_k=7)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 8000)
# File within size limit but over token limit
upload = _make_upload("tokens.png", size=100)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert result.rejected[0].reason == "Exceeds 7K token limit"
# File over size limit
_patch_common_dependencies(monkeypatch, upload_size_mb=42, token_threshold_k=7)
oversized = _make_upload("big.png", size=42 * 1024 * 1024 + 1)
result2 = utils.categorize_uploaded_files([oversized], MagicMock())
assert result2.rejected[0].reason == "Exceeds 42 MB file size limit"
# --- count_tokens tests ---
def test_count_tokens_small_text() -> None:
"""Small text should be encoded in a single call and return correct count."""
tokenizer = _Tokenizer()
text = "hello world"
assert count_tokens(text, tokenizer) == len(tokenizer.encode(text))
def test_count_tokens_chunked_matches_single_call() -> None:
"""Chunked encoding should produce the same result as single-call for small text."""
tokenizer = _Tokenizer()
text = "a" * 1000
assert count_tokens(text, tokenizer) == len(tokenizer.encode(text))
def test_count_tokens_large_text_is_chunked(monkeypatch: pytest.MonkeyPatch) -> None:
"""Text exceeding _ENCODE_CHUNK_SIZE should be split into multiple encode calls."""
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
tokenizer = _Tokenizer()
text = "a" * 250
# _Tokenizer returns 1 token per char, so total should be 250
assert count_tokens(text, tokenizer) == 250
def test_count_tokens_with_token_limit_exits_early(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When token_limit is set and exceeded, count_tokens should stop early."""
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
encode_call_count = 0
original_tokenizer = _Tokenizer()
class _CountingTokenizer(BaseTokenizer):
def encode(self, text: str) -> list[int]:
nonlocal encode_call_count
encode_call_count += 1
return original_tokenizer.encode(text)
def tokenize(self, text: str) -> list[str]:
return list(text)
def decode(self, _tokens: list[int]) -> str:
return ""
tokenizer = _CountingTokenizer()
# 500 chars → 5 chunks of 100; limit=150 → should stop after 2 chunks
text = "a" * 500
result = count_tokens(text, tokenizer, token_limit=150)
assert result == 200 # 2 chunks × 100 tokens each
assert encode_call_count == 2, "Should have stopped after 2 chunks"
def test_count_tokens_with_token_limit_not_exceeded(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When token_limit is set but not exceeded, all chunks are encoded."""
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
tokenizer = _Tokenizer()
text = "a" * 250
result = count_tokens(text, tokenizer, token_limit=1000)
assert result == 250
def test_count_tokens_no_limit_encodes_all_chunks(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Without token_limit, all chunks are encoded regardless of count."""
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
tokenizer = _Tokenizer()
text = "a" * 500
result = count_tokens(text, tokenizer)
assert result == 500
# --- early exit via token_limit in categorize tests ---
def test_categorize_early_exits_tokenization_for_large_text(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Large text files should be rejected via early-exit tokenization
without encoding all chunks."""
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=1)
# token_threshold = 1000; _ENCODE_CHUNK_SIZE = 100 → text of 500 chars = 5 chunks
# Should stop after 2nd chunk (200 tokens > 1000? No... need 1 token per char)
# With _Tokenizer: 1 token per char. threshold=1000, chunk=100 → need 11 chunks
# Let's use a bigger text
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
large_text = "x" * 5000 # 5000 tokens, threshold 1000
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: large_text)
encode_call_count = 0
original_tokenizer = _Tokenizer()
class _CountingTokenizer(BaseTokenizer):
def encode(self, text: str) -> list[int]:
nonlocal encode_call_count
encode_call_count += 1
return original_tokenizer.encode(text)
def tokenize(self, text: str) -> list[str]:
return list(text)
def decode(self, _tokens: list[int]) -> str:
return ""
monkeypatch.setattr(utils, "get_tokenizer", lambda **_kwargs: _CountingTokenizer())
upload = _make_upload("big.txt", size=5000, content=large_text.encode())
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.rejected) == 1
assert "token limit" in result.rejected[0].reason
# 5000 chars / 100 chunk_size = 50 chunks total; should stop well before all 50
assert (
encode_call_count < 50
), f"Expected early exit but encoded {encode_call_count} chunks out of 50"
def test_categorize_text_under_token_limit_accepted(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Text files under the token threshold should be accepted with exact count."""
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=1)
small_text = "x" * 500 # 500 tokens < 1000 threshold
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: small_text)
upload = _make_upload("ok.txt", size=500, content=small_text.encode())
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 1
assert result.acceptable_file_to_token_count["ok.txt"] == 500

View File

@@ -1,12 +1,23 @@
import pytest
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.settings import store as settings_store
from onyx.server.settings.models import (
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB,
)
from onyx.server.settings.models import DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
from onyx.server.settings.models import Settings
class _FakeKvStore:
def __init__(self, data: dict | None = None) -> None:
self._data = data
def load(self, _key: str) -> dict:
raise KvKeyNotFoundError()
if self._data is None:
raise KvKeyNotFoundError()
return self._data
class _FakeCache:
@@ -20,13 +31,140 @@ class _FakeCache:
self._vals[key] = value.encode("utf-8")
def test_load_settings_includes_user_file_max_upload_size_mb(
def test_load_settings_uses_model_defaults_when_no_stored_value(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When no settings are stored (vector DB enabled), load_settings() should
resolve the default token threshold to 200."""
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore())
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "USER_FILE_MAX_UPLOAD_SIZE_MB", 77)
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", False)
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == 77
assert settings.user_file_max_upload_size_mb == DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
assert (
settings.file_token_count_threshold_k
== DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
)
def test_load_settings_uses_high_token_default_when_vector_db_disabled(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When vector DB is disabled and no settings are stored, the token
threshold should default to 10000 (10M tokens)."""
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore())
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", True)
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
assert (
settings.file_token_count_threshold_k
== DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
)
def test_load_settings_preserves_explicit_value_when_vector_db_disabled(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When vector DB is disabled but admin explicitly set a token threshold,
that value should be preserved (not overridden by the 10000 default)."""
stored = Settings(file_token_count_threshold_k=500).model_dump()
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", True)
settings = settings_store.load_settings()
assert settings.file_token_count_threshold_k == 500
def test_load_settings_preserves_zero_token_threshold(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A value of 0 means 'no limit' and should be preserved."""
stored = Settings(file_token_count_threshold_k=0).model_dump()
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", True)
settings = settings_store.load_settings()
assert settings.file_token_count_threshold_k == 0
def test_load_settings_resolves_zero_upload_size_to_default(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A value of 0 should be treated as unset and resolved to the default."""
stored = Settings(user_file_max_upload_size_mb=0).model_dump()
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
def test_load_settings_clamps_upload_size_to_env_max(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When the stored upload size exceeds MAX_ALLOWED_UPLOAD_SIZE_MB, it should
be clamped to the env-configured maximum."""
stored = Settings(user_file_max_upload_size_mb=500).model_dump()
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 250)
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == 250
def test_load_settings_preserves_upload_size_within_max(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When the stored upload size is within MAX_ALLOWED_UPLOAD_SIZE_MB, it should
be preserved unchanged."""
stored = Settings(user_file_max_upload_size_mb=150).model_dump()
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 250)
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == 150
def test_load_settings_zero_upload_size_resolves_to_default(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A value of 0 should be treated as unset and resolved to the default,
clamped to MAX_ALLOWED_UPLOAD_SIZE_MB."""
stored = Settings(user_file_max_upload_size_mb=0).model_dump()
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 100)
monkeypatch.setattr(settings_store, "DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB", 100)
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == 100
def test_load_settings_default_clamped_to_max(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB exceeds MAX_ALLOWED_UPLOAD_SIZE_MB,
the effective default should be min(DEFAULT, MAX)."""
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore())
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB", 100)
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 50)
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == 50

View File

@@ -1,5 +1,6 @@
"""Tests for indexing pipeline Prometheus collectors."""
from collections.abc import Iterator
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@@ -13,6 +14,16 @@ from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
@pytest.fixture(autouse=True)
def _mock_broker_client() -> Iterator[None]:
"""Patch celery_get_broker_client for all collector tests."""
with patch(
"onyx.background.celery.celery_redis.celery_get_broker_client",
return_value=MagicMock(),
):
yield
class TestQueueDepthCollector:
def test_returns_empty_when_factory_not_set(self) -> None:
collector = QueueDepthCollector()
@@ -24,8 +35,7 @@ class TestQueueDepthCollector:
def test_collects_queue_depths(self) -> None:
collector = QueueDepthCollector(cache_ttl=0)
mock_redis = MagicMock()
collector.set_redis_factory(lambda: mock_redis)
collector.set_celery_app(MagicMock())
with (
patch(
@@ -60,8 +70,8 @@ class TestQueueDepthCollector:
def test_handles_redis_error_gracefully(self) -> None:
collector = QueueDepthCollector(cache_ttl=0)
mock_redis = MagicMock()
collector.set_redis_factory(lambda: mock_redis)
MagicMock()
collector.set_celery_app(MagicMock())
with patch(
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
@@ -74,8 +84,8 @@ class TestQueueDepthCollector:
def test_caching_returns_stale_within_ttl(self) -> None:
collector = QueueDepthCollector(cache_ttl=60)
mock_redis = MagicMock()
collector.set_redis_factory(lambda: mock_redis)
MagicMock()
collector.set_celery_app(MagicMock())
with (
patch(
@@ -98,31 +108,10 @@ class TestQueueDepthCollector:
assert first is second # Same object, from cache
def test_factory_called_each_scrape(self) -> None:
"""Verify the Redis factory is called on each fresh collect, not cached."""
collector = QueueDepthCollector(cache_ttl=0)
factory = MagicMock(return_value=MagicMock())
collector.set_redis_factory(factory)
with (
patch(
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
return_value=0,
),
patch(
"onyx.server.metrics.indexing_pipeline.celery_get_unacked_task_ids",
return_value=set(),
),
):
collector.collect()
collector.collect()
assert factory.call_count == 2
def test_error_returns_stale_cache(self) -> None:
collector = QueueDepthCollector(cache_ttl=0)
mock_redis = MagicMock()
collector.set_redis_factory(lambda: mock_redis)
MagicMock()
collector.set_celery_app(MagicMock())
# First call succeeds
with (

View File

@@ -1,96 +1,22 @@
"""Tests for indexing pipeline setup (Redis factory caching)."""
"""Tests for indexing pipeline setup."""
from unittest.mock import MagicMock
from onyx.server.metrics.indexing_pipeline_setup import _make_broker_redis_factory
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
from onyx.server.metrics.indexing_pipeline import RedisHealthCollector
def _make_mock_app(client: MagicMock) -> MagicMock:
"""Create a mock Celery app whose broker_connection().channel().client
returns the given client."""
mock_app = MagicMock()
mock_conn = MagicMock()
mock_conn.channel.return_value.client = client
class TestCollectorCeleryAppSetup:
def test_queue_depth_collector_uses_celery_app(self) -> None:
"""QueueDepthCollector.set_celery_app stores the app for broker access."""
collector = QueueDepthCollector()
mock_app = MagicMock()
collector.set_celery_app(mock_app)
assert collector._celery_app is mock_app
mock_app.broker_connection.return_value = mock_conn
return mock_app
class TestMakeBrokerRedisFactory:
def test_caches_redis_client_across_calls(self) -> None:
"""Factory should reuse the same client on subsequent calls."""
mock_client = MagicMock()
mock_client.ping.return_value = True
mock_app = _make_mock_app(mock_client)
factory = _make_broker_redis_factory(mock_app)
client1 = factory()
client2 = factory()
assert client1 is client2
# broker_connection should only be called once
assert mock_app.broker_connection.call_count == 1
def test_reconnects_when_ping_fails(self) -> None:
"""Factory should create a new client if ping fails (stale connection)."""
mock_client_stale = MagicMock()
mock_client_stale.ping.side_effect = ConnectionError("disconnected")
mock_client_fresh = MagicMock()
mock_client_fresh.ping.return_value = True
mock_app = _make_mock_app(mock_client_stale)
factory = _make_broker_redis_factory(mock_app)
# First call — creates and caches
client1 = factory()
assert client1 is mock_client_stale
assert mock_app.broker_connection.call_count == 1
# Switch to fresh client for next connection
mock_conn_fresh = MagicMock()
mock_conn_fresh.channel.return_value.client = mock_client_fresh
mock_app.broker_connection.return_value = mock_conn_fresh
# Second call — ping fails on stale, reconnects
client2 = factory()
assert client2 is mock_client_fresh
assert mock_app.broker_connection.call_count == 2
def test_reconnect_closes_stale_client(self) -> None:
"""When ping fails, the old client should be closed before reconnecting."""
mock_client_stale = MagicMock()
mock_client_stale.ping.side_effect = ConnectionError("disconnected")
mock_client_fresh = MagicMock()
mock_client_fresh.ping.return_value = True
mock_app = _make_mock_app(mock_client_stale)
factory = _make_broker_redis_factory(mock_app)
# First call — creates and caches
factory()
# Switch to fresh client
mock_conn_fresh = MagicMock()
mock_conn_fresh.channel.return_value.client = mock_client_fresh
mock_app.broker_connection.return_value = mock_conn_fresh
# Second call — ping fails, should close stale client
factory()
mock_client_stale.close.assert_called_once()
def test_first_call_creates_connection(self) -> None:
"""First call should always create a new connection."""
mock_client = MagicMock()
mock_app = _make_mock_app(mock_client)
factory = _make_broker_redis_factory(mock_app)
client = factory()
assert client is mock_client
mock_app.broker_connection.assert_called_once()
def test_redis_health_collector_uses_celery_app(self) -> None:
"""RedisHealthCollector.set_celery_app stores the app for broker access."""
collector = RedisHealthCollector()
mock_app = MagicMock()
collector.set_celery_app(mock_app)
assert collector._celery_app is mock_app

View File

@@ -1,6 +1,6 @@
"""Tests for memory tool streaming packet emissions."""
from queue import Queue
import queue
from unittest.mock import MagicMock
from unittest.mock import patch
@@ -18,9 +18,13 @@ from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
@pytest.fixture
def emitter() -> Emitter:
bus: Queue = Queue()
return Emitter(bus)
def emitter_queue() -> queue.Queue:
return queue.Queue()
@pytest.fixture
def emitter(emitter_queue: queue.Queue) -> Emitter:
return Emitter(merged_queue=emitter_queue)
@pytest.fixture
@@ -53,24 +57,27 @@ class TestMemoryToolEmitStart:
def test_emit_start_emits_memory_tool_start_packet(
self,
memory_tool: MemoryTool,
emitter: Emitter,
emitter_queue: queue.Queue,
placement: Placement,
) -> None:
memory_tool.emit_start(placement)
packet = emitter.bus.get_nowait()
_key, packet = emitter_queue.get_nowait()
assert isinstance(packet.obj, MemoryToolStart)
assert packet.placement == placement
assert packet.placement is not None
assert packet.placement.turn_index == placement.turn_index
assert packet.placement.tab_index == placement.tab_index
assert packet.placement.model_index == 0 # emitter stamps model_index=0
def test_emit_start_with_different_placement(
self,
memory_tool: MemoryTool,
emitter: Emitter,
emitter_queue: queue.Queue,
) -> None:
placement = Placement(turn_index=2, tab_index=1)
memory_tool.emit_start(placement)
packet = emitter.bus.get_nowait()
_key, packet = emitter_queue.get_nowait()
assert packet.placement.turn_index == 2
assert packet.placement.tab_index == 1
@@ -81,7 +88,7 @@ class TestMemoryToolRun:
self,
mock_process: MagicMock,
memory_tool: MemoryTool,
emitter: Emitter,
emitter_queue: queue.Queue,
placement: Placement,
override_kwargs: MemoryToolOverrideKwargs,
) -> None:
@@ -93,21 +100,19 @@ class TestMemoryToolRun:
memory="User prefers Python",
)
# The delta packet should be in the queue
packet = emitter.bus.get_nowait()
_key, packet = emitter_queue.get_nowait()
assert isinstance(packet.obj, MemoryToolDelta)
assert packet.obj.memory_text == "User prefers Python"
assert packet.obj.operation == "add"
assert packet.obj.memory_id is None
assert packet.obj.index is None
assert packet.placement == placement
@patch("onyx.tools.tool_implementations.memory.memory_tool.process_memory_update")
def test_run_emits_delta_for_update_operation(
self,
mock_process: MagicMock,
memory_tool: MemoryTool,
emitter: Emitter,
emitter_queue: queue.Queue,
placement: Placement,
override_kwargs: MemoryToolOverrideKwargs,
) -> None:
@@ -119,7 +124,7 @@ class TestMemoryToolRun:
memory="User prefers light mode",
)
packet = emitter.bus.get_nowait()
_key, packet = emitter_queue.get_nowait()
assert isinstance(packet.obj, MemoryToolDelta)
assert packet.obj.memory_text == "User prefers light mode"
assert packet.obj.operation == "update"

Some files were not shown because too many files have changed in this diff Show More