Compare commits

...

58 Commits

Author SHA1 Message Date
Jean Caillé
569fef22b4 style: align test conventions with existing codebase
- Prefix unused mock params with _ and type as Any
- Extract shared db session mock setup into helper
- Remove unused pytest import
- Use module path constant to reduce repetition
2026-03-31 18:21:42 +02:00
Jean Caille
517d1035e0 Merge branch 'main' into fix/ttl-skip-missing-file-records 2026-03-31 18:13:12 +02:00
Jean Caillé
ad108dd573 test: add unit tests for TTL deletion resilience
- test_chat_deletion: verifies delete_messages_and_files_from_chat_session
  continues when a file record is missing, and handles messages with no files
- test_ttl_task: verifies perform_ttl_management_task continues deleting
  remaining sessions after one fails, and reports correct success status

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-31 18:11:45 +02:00
Ben Wu
adf5691b5f feat(canvas 2/4): Canvas Connector data fetching (#9386) 2026-03-31 03:07:05 +00:00
Nikolas Garza
c1a8a5bd83 fix(tenants): run migrations on pool tenants before assigning to new users (#9788) 2026-03-31 01:24:01 +00:00
Justin Tahara
8fd486da99 feat(rds): Add Freeable Memory alert (#9787) 2026-03-31 00:59:30 +00:00
Raunak Bhagat
4bda4d3637 refactor: migrate away from cards/Select (#9771) 2026-03-31 00:27:01 +00:00
Justin Tahara
13c25eadad feat(rds): Adding CPU Alerts (#9784) 2026-03-31 00:22:15 +00:00
Justin Tahara
1f244e6388 feat(eks): Adding Cloudwatch logging (#9783) 2026-03-30 23:52:44 +00:00
Nikolas Garza
18b0416d30 feat(sentry): enable frontend source map uploads in cloud CI (#9775) 2026-03-30 23:42:57 +00:00
Nikolas Garza
4bc0bc1efb feat(helm): add Grafana dashboard provisioning (#9725) 2026-03-30 23:42:32 +00:00
Justin Tahara
1555217061 feat(rds): Adding RDS Snapshosts (#9779) 2026-03-30 23:17:08 +00:00
Nikolas Garza
d177a833f0 feat(sentry): add release tracking to backend and frontend (#9773) 2026-03-30 22:35:38 +00:00
Jamison Lahman
086997d3c5 chore(types): fix IconButton size props (#9772) 2026-03-30 21:40:25 +00:00
dependabot[bot]
dccec78397 chore(deps): bump helm/chart-testing-action from b5eebdd9998021f29756c53432f48dab66394810 to 2e2940618cb426dce2999631d543b53cdcfc8527 (#9764)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-30 14:41:01 -07:00
Jamison Lahman
0123133621 chore(fe): polish Query History table (#9767) 2026-03-30 21:30:13 +00:00
dependabot[bot]
0b9d154a73 chore(deps): bump runs-on/cache from 50350ad4242587b6c8c2baa2e740b1bc11285ff4 to a5f51d6f3fece787d03b7b4e981c82538a0654ed (#9763)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-30 13:54:43 -07:00
dependabot[bot]
6e65e55bf5 chore(deps): bump actions/cache from 5.0.3 to 5.0.4 (#9765)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-30 13:46:53 -07:00
Raunak Bhagat
3f9e208759 feat(opal): SelectCard + CardHeaderLayout (#9760) 2026-03-30 19:54:54 +00:00
dependabot[bot]
fb8edda14a chore(deps): bump pygments from 2.19.2 to 2.20.0 (#9757)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-30 18:30:18 +00:00
Jean Caillé
2a5810a44a fix: TTL management task no longer crashes on missing file records
The TTL cleanup task would crash with a RuntimeError when trying to
delete a file record that no longer exists, blocking cleanup of all
remaining sessions. Two fixes:

1. Wrap file deletion in delete_messages_and_files_from_chat_session
   with try/except — a missing file is already the desired state, so
   log a warning and continue.

2. Add per-session error handling in the TTL task loop. The existing
   comment said "one session per delete so that we don't blow up" but
   the outer try/except still aborted on the first failure. Now each
   session deletion is individually wrapped so failures don't block
   cleanup of subsequent sessions.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-30 19:25:36 +02:00
Jamison Lahman
58decd8a6b chore(gha): prefer ci-protected env (#9728) 2026-03-30 17:20:54 +00:00
Danelegend
e97204d9cc feat(indexing): Batch chunks during doc processing (#9468) 2026-03-30 11:49:36 +00:00
Danelegend
44ab02c94f refactor(indexing): Refactor indexing vector db abstraction (#9653) 2026-03-30 09:57:16 +00:00
Danelegend
a98cc30f25 refactor(indexing): Change adapters to support iterables (#9469) 2026-03-30 01:43:10 +00:00
Danelegend
a709dcb8fa feat(indexing): Max chunk processing (#9400) 2026-03-30 00:10:24 +00:00
Raunak Bhagat
a3dfe6aa1b refactor(opal): unify Interactive color system (#9717) 2026-03-28 00:40:23 +00:00
Nikolas Garza
23e4d55fb1 perf(swr): convert raw-fetch hooks to SWR to eliminate duplicate requests (#9694) 2026-03-28 00:26:20 +00:00
Jamison Lahman
470cc85f83 feat(cli): onyx-cli serve over SSH (#9726) 2026-03-27 23:46:14 +00:00
Justin Tahara
64d9be5a41 fix(openpyxl): Colors must be aRGB hex values (#9727) 2026-03-27 23:14:36 +00:00
roshan
71a5b469b0 feat(widget): add citation badges to chat widget (#9714) 2026-03-27 22:39:46 +00:00
Evan Lohn
462eb0697f fix: Anthropic litellm thinking workaround (#9713) 2026-03-27 21:03:05 +00:00
dependabot[bot]
b708dc8796 chore(deps): bump langchain-core from 1.2.11 to 1.2.22 (#9720)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-27 20:50:19 +00:00
dependabot[bot]
c9e2c32f55 chore(deps): bump cryptography from 46.0.5 to 46.0.6 (#9721)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-27 20:48:59 +00:00
Jamison Lahman
d725df62e7 feat(cli): --version and validate-config warn if backend version is incompatible (#9715) 2026-03-27 13:13:16 -07:00
Jamison Lahman
d1460972b6 fix(cli): onyx-cli --version interpolation (#9712) 2026-03-27 19:22:31 +00:00
Jamison Lahman
706872f0b7 chore(deps): upgrade go deps (#9711) 2026-03-27 12:24:25 -07:00
Jamison Lahman
ed3856be2b chore(release): build all CLI wheels before publishing (#9710) 2026-03-27 19:04:02 +00:00
Jamison Lahman
6326c7f0b9 chore(gha): fix git error after helm release migration to alpine base image (#9709) 2026-03-27 11:21:34 -07:00
Jamison Lahman
40420fc4e6 chore(gha): helm release upstream nits (#9708) 2026-03-27 11:10:41 -07:00
Nikolas Garza
1a2b6a66cc fix(celery): use broker connection pool to prevent Redis connection leak (#9682) 2026-03-27 17:53:49 +00:00
Jamison Lahman
d1b1529ccf chore(gha): fix helm release after image update (#9707) 2026-03-27 10:37:43 -07:00
Bo-Onyx
fedd9c76e5 feat(hook): admin page create or edit hook (#9690) 2026-03-27 17:10:14 +00:00
Jamison Lahman
0b34b40b79 chore(gha): pin helm release docker image (#9706) 2026-03-27 10:16:41 -07:00
Yuhong Sun
fe82ddb1b9 Update README.md (#9703) 2026-03-27 10:03:56 -07:00
Jamison Lahman
32d3d70525 chore(playwright): deflake settings_pages.spec.ts (#9684) 2026-03-27 15:54:23 +00:00
Jamison Lahman
40b9e10890 chore(devtools): upgrade ods: 0.7.1->0.7.2 (#9701) 2026-03-27 08:17:42 -07:00
dependabot[bot]
e21b204b8a chore(deps): bump brace-expansion in /backend/onyx/server/features/build/sandbox/kubernetes/docker/templates/outputs/web (#9698)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-27 08:10:15 -07:00
Jamison Lahman
2f672b3a4f fix(fe): Popover content doesnt overflow on small screens (#9612) 2026-03-27 08:07:52 -07:00
Nikolas Garza
cf19d0df4f feat(helm): add Prometheus metrics ports and Services for celery workers (#9630) 2026-03-27 08:03:48 +00:00
Danelegend
86a6a4c134 refactor(indexing): Vespa & Opensearch index function use Iterable (#9384) 2026-03-27 04:36:59 +00:00
SubashMohan
146b5449d2 feat: configurable file upload size and token limits via admin settings (#9232) 2026-03-27 04:23:16 +00:00
Jamison Lahman
b66991b5c5 chore(devtools): ods trace (#9688)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-03-27 03:56:38 +00:00
dependabot[bot]
9cb76dc027 chore(deps-dev): bump picomatch from 2.3.1 to 2.3.2 in /web (#9691)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-27 02:22:22 +00:00
dependabot[bot]
f66891d19e chore(deps-dev): bump handlebars from 4.7.8 to 4.7.9 in /web (#9689)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-27 01:41:29 +00:00
Nikolas Garza
c07c952ad5 chore(greptile): add nginx routing rule for non-api backend routes (#9687) 2026-03-27 00:34:15 +00:00
Nikolas Garza
be7f40a28a fix(nginx): route /scim/* to api_server (#9686) 2026-03-26 17:21:57 -07:00
Evan Lohn
26f941b5da perf: perm sync start time (#9685) 2026-03-27 00:07:53 +00:00
218 changed files with 11385 additions and 2070 deletions

View File

@@ -704,6 +704,9 @@ jobs:
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
NODE_OPTIONS=--max-old-space-size=8192
SENTRY_RELEASE=${{ github.sha }}
secrets: |
sentry_auth_token=${{ secrets.SENTRY_AUTH_TOKEN }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
@@ -786,6 +789,9 @@ jobs:
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
NODE_OPTIONS=--max-old-space-size=8192
SENTRY_RELEASE=${{ github.sha }}
secrets: |
sentry_auth_token=${{ secrets.SENTRY_AUTH_TOKEN }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest

View File

@@ -47,7 +47,8 @@ jobs:
done
- name: Publish Helm charts to gh-pages
uses: stefanprodan/helm-gh-pages@0ad2bb377311d61ac04ad9eb6f252fb68e207260 # ratchet:stefanprodan/helm-gh-pages@v1.7.0
# NOTE: HEAD of https://github.com/stefanprodan/helm-gh-pages/pull/43
uses: stefanprodan/helm-gh-pages@ad32ad3b8720abfeaac83532fd1e9bdfca5bbe27 # zizmor: ignore[impostor-commit]
with:
token: ${{ secrets.GITHUB_TOKEN }}
charts_dir: deployment/helm/charts

View File

@@ -35,6 +35,7 @@ jobs:
needs: [provider-chat-test]
if: failure() && github.event_name == 'schedule'
runs-on: ubuntu-slim
environment: ci-protected
timeout-minutes: 5
steps:
- name: Checkout

View File

@@ -183,6 +183,7 @@ jobs:
- cherry-pick-to-latest-release
if: needs.resolve-cherry-pick-request.outputs.should_cherrypick == 'true' && needs.resolve-cherry-pick-request.result == 'success' && needs.cherry-pick-to-latest-release.result == 'success'
runs-on: ubuntu-slim
environment: ci-protected
timeout-minutes: 10
steps:
- name: Checkout
@@ -232,6 +233,7 @@ jobs:
- cherry-pick-to-latest-release
if: always() && needs.resolve-cherry-pick-request.outputs.should_cherrypick == 'true' && (needs.resolve-cherry-pick-request.result == 'failure' || needs.cherry-pick-to-latest-release.result == 'failure')
runs-on: ubuntu-slim
environment: ci-protected
timeout-minutes: 10
steps:
- name: Checkout

View File

@@ -63,7 +63,7 @@ jobs:
targets: ${{ matrix.target }}
- name: Cache Cargo registry and build
uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # zizmor: ignore[cache-poisoning]
uses: actions/cache@668228422ae6a00e4ad889ee87cd7109ec5666a7 # zizmor: ignore[cache-poisoning]
with:
path: |
~/.cargo/bin/

View File

@@ -41,7 +41,7 @@ jobs:
version: v3.19.0
- name: Set up chart-testing
uses: helm/chart-testing-action@b5eebdd9998021f29756c53432f48dab66394810
uses: helm/chart-testing-action@2e2940618cb426dce2999631d543b53cdcfc8527
with:
uv_version: "0.9.9"

View File

@@ -284,7 +284,7 @@ jobs:
- name: Cache playwright cache
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
uses: runs-on/cache@a5f51d6f3fece787d03b7b4e981c82538a0654ed # ratchet:runs-on/cache@v4
with:
path: ~/.cache/ms-playwright
key: ${{ runner.os }}-playwright-npm-${{ hashFiles('web/package-lock.json') }}
@@ -626,7 +626,7 @@ jobs:
- name: Cache playwright cache
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
uses: runs-on/cache@a5f51d6f3fece787d03b7b4e981c82538a0654ed # ratchet:runs-on/cache@v4
with:
path: ~/.cache/ms-playwright
key: ${{ runner.os }}-playwright-npm-${{ hashFiles('web/package-lock.json') }}

View File

@@ -56,7 +56,7 @@ jobs:
- name: Cache mypy cache
if: ${{ vars.DISABLE_MYPY_CACHE != 'true' }}
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
uses: runs-on/cache@a5f51d6f3fece787d03b7b4e981c82538a0654ed # ratchet:runs-on/cache@v4
with:
path: .mypy_cache
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'pyproject.toml') }}

View File

@@ -31,6 +31,7 @@ jobs:
- runner=4cpu-linux-arm64
- "run-id=${{ github.run_id }}-model-check"
- "extras=ecr-cache"
environment: ci-protected
timeout-minutes: 45
env:

View File

@@ -15,6 +15,7 @@ permissions:
jobs:
Deploy-Preview:
runs-on: ubuntu-latest
environment: ci-protected
timeout-minutes: 30
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd

View File

@@ -13,15 +13,6 @@ jobs:
permissions:
id-token: write
timeout-minutes: 10
strategy:
matrix:
os-arch:
- { goos: "linux", goarch: "amd64" }
- { goos: "linux", goarch: "arm64" }
- { goos: "windows", goarch: "amd64" }
- { goos: "windows", goarch: "arm64" }
- { goos: "darwin", goarch: "amd64" }
- { goos: "darwin", goarch: "arm64" }
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
@@ -31,9 +22,11 @@ jobs:
enable-cache: false
version: "0.9.9"
- run: |
GOOS="${{ matrix.os-arch.goos }}" \
GOARCH="${{ matrix.os-arch.goarch }}" \
uv build --wheel
for goos in linux windows darwin; do
for goarch in amd64 arm64; do
GOOS="$goos" GOARCH="$goarch" uv build --wheel
done
done
working-directory: cli
- run: uv publish
working-directory: cli

View File

@@ -25,6 +25,7 @@ permissions:
jobs:
Deploy-Storybook:
runs-on: ubuntu-latest
environment: ci-protected
timeout-minutes: 30
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v4
@@ -54,6 +55,7 @@ jobs:
needs: Deploy-Storybook
if: always() && needs.Deploy-Storybook.result == 'failure'
runs-on: ubuntu-latest
environment: ci-protected
timeout-minutes: 10
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v4

View File

@@ -9,6 +9,7 @@ on:
jobs:
sync-foss:
runs-on: ubuntu-latest
environment: ci-protected
timeout-minutes: 45
permissions:
contents: read

View File

@@ -11,6 +11,7 @@ permissions:
jobs:
create-and-push-tag:
runs-on: ubuntu-slim
environment: ci-protected
timeout-minutes: 45
steps:

View File

@@ -24,6 +24,16 @@ When hardcoding a boolean variable to a constant value, remove the variable enti
Code changes must consider both multi-tenant and single-tenant deployments. In multi-tenant mode, preserve tenant isolation, ensure tenant context is propagated correctly, and avoid assumptions that only hold for a single shared schema or globally shared state. In single-tenant mode, avoid introducing unnecessary tenant-specific requirements or cloud-only control-plane dependencies.
## Nginx Routing — New Backend Routes
Whenever a new backend route is added that does NOT start with `/api`, it must also be explicitly added to ALL nginx configs:
- `deployment/helm/charts/onyx/templates/nginx-conf.yaml` (Helm/k8s)
- `deployment/data/nginx/app.conf.template` (docker-compose dev)
- `deployment/data/nginx/app.conf.template.prod` (docker-compose prod)
- `deployment/data/nginx/app.conf.template.no-letsencrypt` (docker-compose no-letsencrypt)
Routes not starting with `/api` are not caught by the existing `^/(api|openapi\.json)` location block and will fall through to `location /`, which proxies to the Next.js web server and returns an HTML 404. The new location block must be placed before the `/api` block. Examples of routes that need this treatment: `/scim`, `/mcp`.
## Full vs Lite Deployments
Code changes must consider both regular Onyx deployments and Onyx lite deployments. Lite deployments disable the vector DB, Redis, model servers, and background workers by default, use PostgreSQL-backed cache/auth/file storage, and rely on the API server to handle background work. Do not assume those services are available unless the code path is explicitly limited to full deployments.

View File

@@ -122,7 +122,7 @@ repos:
rev: 5d1e709b7be35cb2025444e19de266b056b7b7ee # frozen: v2.10.1
hooks:
- id: golangci-lint
language_version: "1.26.0"
language_version: "1.26.1"
entry: bash -c "find . -name go.mod -not -path './.venv/*' -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
- repo: https://github.com/astral-sh/ruff-pre-commit

View File

@@ -35,7 +35,7 @@ Onyx comes loaded with advanced features like Agents, Web Search, RAG, MCP, Deep
> [!TIP]
> Run Onyx with one command (or see deployment section below):
> ```
> curl -fsSL https://raw.githubusercontent.com/onyx-dot-app/onyx/main/deployment/docker_compose/install.sh > install.sh && chmod +x install.sh && ./install.sh
> curl -fsSL https://onyx.app/install_onyx.sh | bash
> ```
****

View File

@@ -28,6 +28,7 @@ from onyx.access.models import DocExternalAccess
from onyx.access.models import ElementExternalAccess
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
@@ -187,7 +188,6 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
# (which lives on a different db number)
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
@@ -227,6 +227,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# or be currently executing
try:
r_celery = celery_get_broker_client(self.app)
validate_permission_sync_fences(
tenant_id, r, r_replica, r_celery, lock_beat
)
@@ -473,6 +474,8 @@ def connector_permission_sync_generator_task(
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
eager_load_connector=True,
eager_load_credential=True,
)
if cc_pair is None:
raise ValueError(

View File

@@ -29,6 +29,7 @@ from ee.onyx.external_permissions.sync_params import (
from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.background.error_logging import emit_background_error
@@ -162,7 +163,6 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
# (which lives on a different db number)
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
@@ -221,6 +221,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# or be currently executing
try:
r_celery = celery_get_broker_client(self.app)
validate_external_group_sync_fences(
tenant_id, self.app, r, r_replica, r_celery, lock_beat
)

View File

@@ -13,6 +13,7 @@ from redis.lock import Lock as RedisLock
from ee.onyx.server.tenants.provisioning import setup_tenant
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
from ee.onyx.server.tenants.schema_management import get_current_alembic_version
from ee.onyx.server.tenants.schema_management import run_alembic_migrations
from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import TARGET_AVAILABLE_TENANTS
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
@@ -29,9 +30,10 @@ from shared_configs.configs import TENANT_ID_PREFIX
# Each tenant takes ~80s (alembic migrations), so 5 tenants ≈ 7 minutes.
_MAX_TENANTS_PER_RUN = 5
# Time limits sized for worst-case batch: _MAX_TENANTS_PER_RUN × ~90s + buffer.
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 10 # 10 minutes
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 15 # 15 minutes
# Time limits sized for worst-case: provisioning up to _MAX_TENANTS_PER_RUN new tenants
# (~90s each) plus migrating up to TARGET_AVAILABLE_TENANTS pool tenants (~90s each).
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 20 # 20 minutes
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 25 # 25 minutes
@shared_task(
@@ -91,8 +93,7 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
batch_size = min(tenants_to_provision, _MAX_TENANTS_PER_RUN)
if batch_size < tenants_to_provision:
task_logger.info(
f"Capping batch to {batch_size} "
f"(need {tenants_to_provision}, will catch up next cycle)"
f"Capping batch to {batch_size} (need {tenants_to_provision}, will catch up next cycle)"
)
provisioned = 0
@@ -103,12 +104,14 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
provisioned += 1
except Exception:
task_logger.exception(
f"Failed to provision tenant {i + 1}/{batch_size}, "
"continuing with remaining tenants"
f"Failed to provision tenant {i + 1}/{batch_size}, continuing with remaining tenants"
)
task_logger.info(f"Provisioning complete: {provisioned}/{batch_size} succeeded")
# Migrate any pool tenants that were provisioned before a new migration was deployed
_migrate_stale_pool_tenants()
except Exception:
task_logger.exception("Error in check_available_tenants task")
@@ -121,6 +124,46 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
)
def _migrate_stale_pool_tenants() -> None:
"""
Run alembic upgrade head on all pool tenants. Since alembic upgrade head is
idempotent, tenants already at head are a fast no-op. This ensures pool
tenants are always current so that signup doesn't hit schema mismatches
(e.g. missing columns added after the tenant was pre-provisioned).
"""
with get_session_with_shared_schema() as db_session:
pool_tenants = db_session.query(AvailableTenant).all()
tenant_ids = [t.tenant_id for t in pool_tenants]
if not tenant_ids:
return
task_logger.info(
f"Checking {len(tenant_ids)} pool tenant(s) for pending migrations"
)
for tenant_id in tenant_ids:
try:
run_alembic_migrations(tenant_id)
new_version = get_current_alembic_version(tenant_id)
with get_session_with_shared_schema() as db_session:
tenant = (
db_session.query(AvailableTenant)
.filter_by(tenant_id=tenant_id)
.first()
)
if tenant and tenant.alembic_version != new_version:
task_logger.info(
f"Migrated pool tenant {tenant_id}: {tenant.alembic_version} -> {new_version}"
)
tenant.alembic_version = new_version
db_session.commit()
except Exception:
task_logger.exception(
f"Failed to migrate pool tenant {tenant_id}, skipping"
)
def pre_provision_tenant() -> bool:
"""
Pre-provision a new tenant and store it in the NewAvailableTenant table.

View File

@@ -54,27 +54,35 @@ def perform_ttl_management_task(
retention_limit_days, db_session
)
failures = 0
for user_id, session_id in old_chat_sessions:
# one session per delete so that we don't blow up if a deletion fails.
with get_session_with_current_tenant() as db_session:
delete_chat_session(
user_id,
session_id,
db_session,
include_deleted=True,
hard_delete=True,
try:
with get_session_with_current_tenant() as db_session:
delete_chat_session(
user_id,
session_id,
db_session,
include_deleted=True,
hard_delete=True,
)
except Exception:
failures += 1
logger.exception(
"Failed to delete chat session "
f"user_id={user_id} session_id={session_id}, "
"continuing with remaining sessions"
)
with get_session_with_current_tenant() as db_session:
mark_task_as_finished_with_id(
db_session=db_session,
task_id=task_id,
success=True,
success=failures == 0,
)
except Exception:
logger.exception(
f"delete_chat_session exceptioned. user_id={user_id} session_id={session_id}"
f"TTL management task failed. user_id={user_id} session_id={session_id}"
)
with get_session_with_current_tenant() as db_session:
mark_task_as_finished_with_id(

View File

@@ -8,6 +8,7 @@ from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import HierarchyNode
from onyx.connectors.slack.connector import get_channels
from onyx.connectors.slack.connector import make_paginated_slack_api_call
@@ -105,9 +106,11 @@ def _get_slack_document_access(
slack_connector: SlackConnector,
channel_permissions: dict[str, ExternalAccess], # noqa: ARG001
callback: IndexingHeartbeatInterface | None,
indexing_start: SecondsSinceUnixEpoch | None = None,
) -> Generator[DocExternalAccess, None, None]:
slim_doc_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
callback=callback
callback=callback,
start=indexing_start,
)
for doc_metadata_batch in slim_doc_generator:
@@ -180,9 +183,15 @@ def slack_doc_sync(
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
slack_connector.set_credentials_provider(provider)
indexing_start_ts: SecondsSinceUnixEpoch | None = (
cc_pair.connector.indexing_start.timestamp()
if cc_pair.connector.indexing_start is not None
else None
)
yield from _get_slack_document_access(
slack_connector,
slack_connector=slack_connector,
channel_permissions=channel_permissions,
callback=callback,
indexing_start=indexing_start_ts,
)

View File

@@ -6,6 +6,7 @@ from onyx.access.models import ElementExternalAccess
from onyx.access.models import ExternalAccess
from onyx.access.models import NodeExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import HierarchyNode
from onyx.db.models import ConnectorCredentialPair
@@ -40,10 +41,19 @@ def generic_doc_sync(
logger.info(f"Starting {doc_source} doc sync for CC Pair ID: {cc_pair.id}")
indexing_start: SecondsSinceUnixEpoch | None = (
cc_pair.connector.indexing_start.timestamp()
if cc_pair.connector.indexing_start is not None
else None
)
newly_fetched_doc_ids: set[str] = set()
logger.info(f"Fetching all slim documents from {doc_source}")
for doc_batch in slim_connector.retrieve_all_slim_docs_perm_sync(callback=callback):
for doc_batch in slim_connector.retrieve_all_slim_docs_perm_sync(
start=indexing_start,
callback=callback,
):
logger.info(f"Got {len(doc_batch)} slim documents from {doc_source}")
if callback:

View File

@@ -99,6 +99,26 @@ async def get_or_provision_tenant(
tenant_id = await get_available_tenant()
if tenant_id:
# Run migrations to ensure the pre-provisioned tenant schema is current.
# Pool tenants may have been created before a new migration was deployed.
# Capture as a non-optional local so mypy can type the lambda correctly.
_tenant_id: str = tenant_id
loop = asyncio.get_running_loop()
try:
await loop.run_in_executor(
None, lambda: run_alembic_migrations(_tenant_id)
)
except Exception:
# The tenant was already dequeued from the pool — roll it back so
# it doesn't end up orphaned (schema exists, but not assigned to anyone).
logger.exception(
f"Migration failed for pre-provisioned tenant {_tenant_id}; rolling back"
)
try:
await rollback_tenant_provisioning(_tenant_id)
except Exception:
logger.exception(f"Failed to rollback orphaned tenant {_tenant_id}")
raise
# If we have a pre-provisioned tenant, assign it to the user
await assign_tenant_to_user(tenant_id, email, referral_source)
logger.info(f"Assigned pre-provisioned tenant {tenant_id} to user {email}")

View File

@@ -100,6 +100,7 @@ def get_model_app() -> FastAPI:
dsn=SENTRY_DSN,
integrations=[StarletteIntegration(), FastApiIntegration()],
traces_sample_rate=0.1,
release=__version__,
)
logger.info("Sentry initialized")
else:

View File

@@ -20,6 +20,7 @@ from sentry_sdk.integrations.celery import CeleryIntegration
from sqlalchemy import text
from sqlalchemy.orm import Session
from onyx import __version__
from onyx.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
from onyx.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
from onyx.background.celery.celery_utils import celery_is_worker_primary
@@ -65,6 +66,7 @@ if SENTRY_DSN:
dsn=SENTRY_DSN,
integrations=[CeleryIntegration()],
traces_sample_rate=0.1,
release=__version__,
)
logger.info("Sentry initialized")
else:
@@ -515,7 +517,8 @@ def reset_tenant_id(
def wait_for_vespa_or_shutdown(
sender: Any, **kwargs: Any # noqa: ARG001
sender: Any, # noqa: ARG001
**kwargs: Any, # noqa: ARG001
) -> None: # noqa: ARG001
"""Waits for Vespa to become ready subject to a timeout.
Raises WorkerShutdown if the timeout is reached."""

View File

@@ -1,5 +1,6 @@
# These are helper objects for tracking the keys we need to write in redis
import json
import threading
from typing import Any
from typing import cast
@@ -7,7 +8,59 @@ from celery import Celery
from redis import Redis
from onyx.background.celery.configs.base import CELERY_SEPARATOR
from onyx.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
_broker_client: Redis | None = None
_broker_url: str | None = None
_broker_client_lock = threading.Lock()
def celery_get_broker_client(app: Celery) -> Redis:
"""Return a shared Redis client connected to the Celery broker DB.
Uses a module-level singleton so all tasks on a worker share one
connection instead of creating a new one per call. The client
connects directly to the broker Redis DB (parsed from the broker URL).
Thread-safe via lock — safe for use in Celery thread-pool workers.
Usage:
r_celery = celery_get_broker_client(self.app)
length = celery_get_queue_length(queue, r_celery)
"""
global _broker_client, _broker_url
with _broker_client_lock:
url = app.conf.broker_url
if _broker_client is not None and _broker_url == url:
try:
_broker_client.ping()
return _broker_client
except Exception:
try:
_broker_client.close()
except Exception:
pass
_broker_client = None
elif _broker_client is not None:
try:
_broker_client.close()
except Exception:
pass
_broker_client = None
_broker_url = url
_broker_client = Redis.from_url(
url,
decode_responses=False,
health_check_interval=REDIS_HEALTH_CHECK_INTERVAL,
socket_keepalive=True,
socket_keepalive_options=REDIS_SOCKET_KEEPALIVE_OPTIONS,
retry_on_timeout=True,
)
return _broker_client
def celery_get_unacked_length(r: Redis) -> int:

View File

@@ -14,6 +14,7 @@ from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.configs.app_configs import JOB_TIMEOUT
@@ -132,7 +133,6 @@ def revoke_tasks_blocking_deletion(
def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | None:
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
@@ -149,6 +149,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | N
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES):
# clear fences that don't have associated celery tasks in progress
try:
r_celery = celery_get_broker_client(self.app)
validate_connector_deletion_fences(
tenant_id, r, r_replica, r_celery, lock_beat
)

View File

@@ -9,6 +9,7 @@ from celery import Celery
from celery import shared_task
from celery import Task
from onyx import __version__
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.memory_monitoring import emit_process_memory
from onyx.background.celery.tasks.docprocessing.heartbeat import start_heartbeat
@@ -137,6 +138,7 @@ def _docfetching_task(
sentry_sdk.init(
dsn=SENTRY_DSN,
traces_sample_rate=0.1,
release=__version__,
)
logger.info("Sentry initialized")
else:

View File

@@ -22,6 +22,7 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.memory_monitoring import emit_process_memory
@@ -449,7 +450,7 @@ def check_indexing_completion(
):
# Check if the task exists in the celery queue
# This handles the case where Redis dies after task creation but before task execution
redis_celery = task.app.broker_connection().channel().client # type: ignore
redis_celery = celery_get_broker_client(task.app)
task_exists = celery_find_task(
attempt.celery_task_id,
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,

View File

@@ -1,6 +1,5 @@
import json
import time
from collections.abc import Callable
from datetime import timedelta
from itertools import islice
from typing import Any
@@ -19,6 +18,7 @@ from sqlalchemy import text
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.memory_monitoring import emit_process_memory
@@ -698,31 +698,27 @@ def monitor_background_processes(self: Task, *, tenant_id: str) -> None:
return None
try:
# Get Redis client for Celery broker
redis_celery = self.app.broker_connection().channel().client # type: ignore
redis_std = get_redis_client()
# Define metric collection functions and their dependencies
metric_functions: list[Callable[[], list[Metric]]] = [
lambda: _collect_queue_metrics(redis_celery),
lambda: _collect_connector_metrics(db_session, redis_std),
lambda: _collect_sync_metrics(db_session, redis_std),
]
# Collect queue metrics with broker connection
r_celery = celery_get_broker_client(self.app)
queue_metrics = _collect_queue_metrics(r_celery)
# Collect and log each metric
# Collect remaining metrics (no broker connection needed)
with get_session_with_current_tenant() as db_session:
for metric_fn in metric_functions:
metrics = metric_fn()
for metric in metrics:
# double check to make sure we aren't double-emitting metrics
if metric.key is None or not _has_metric_been_emitted(
redis_std, metric.key
):
metric.log()
metric.emit(tenant_id)
all_metrics: list[Metric] = queue_metrics
all_metrics.extend(_collect_connector_metrics(db_session, redis_std))
all_metrics.extend(_collect_sync_metrics(db_session, redis_std))
if metric.key is not None:
_mark_metric_as_emitted(redis_std, metric.key)
for metric in all_metrics:
if metric.key is None or not _has_metric_been_emitted(
redis_std, metric.key
):
metric.log()
metric.emit(tenant_id)
if metric.key is not None:
_mark_metric_as_emitted(redis_std, metric.key)
task_logger.info("Successfully collected background metrics")
except SoftTimeLimitExceeded:
@@ -890,7 +886,7 @@ def monitor_celery_queues_helper(
) -> None:
"""A task to monitor all celery queue lengths."""
r_celery = task.app.broker_connection().channel().client # type: ignore
r_celery = celery_get_broker_client(task.app)
n_celery = celery_get_queue_length(OnyxCeleryQueues.PRIMARY, r_celery)
n_docfetching = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
@@ -1080,7 +1076,7 @@ def cloud_monitor_celery_pidbox(
num_deleted = 0
MAX_PIDBOX_IDLE = 24 * 3600 # 1 day in seconds
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
r_celery = celery_get_broker_client(self.app)
for key in r_celery.scan_iter("*.reply.celery.pidbox"):
key_bytes = cast(bytes, key)
key_str = key_bytes.decode("utf-8")

View File

@@ -17,6 +17,7 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
@@ -203,7 +204,6 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK,
@@ -261,6 +261,7 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# or be currently executing
try:
r_celery = celery_get_broker_client(self.app)
validate_pruning_fences(tenant_id, r, r_replica, r_celery, lock_beat)
except Exception:
task_logger.exception("Exception while validating pruning fences")

View File

@@ -16,6 +16,7 @@ from sqlalchemy.orm import Session
from onyx.access.access import build_access_for_user_files
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
@@ -105,7 +106,7 @@ def _user_file_delete_queued_key(user_file_id: str | UUID) -> str:
def get_user_file_project_sync_queue_depth(celery_app: Celery) -> int:
redis_celery: Redis = celery_app.broker_connection().channel().client # type: ignore
redis_celery = celery_get_broker_client(celery_app)
return celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC, redis_celery
)
@@ -238,7 +239,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
skipped_guard = 0
try:
# --- Protection 1: queue depth backpressure ---
r_celery = self.app.broker_connection().channel().client # type: ignore
r_celery = celery_get_broker_client(self.app)
queue_len = celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
)
@@ -591,7 +592,7 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
# --- Protection 1: queue depth backpressure ---
# NOTE: must use the broker's Redis client (not redis_client) because
# Celery queues live on a separate Redis DB with CELERY_SEPARATOR keys.
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
r_celery = celery_get_broker_client(self.app)
queue_len = celery_get_queue_length(OnyxCeleryQueues.USER_FILE_DELETE, r_celery)
if queue_len > USER_FILE_DELETE_MAX_QUEUE_DEPTH:
task_logger.warning(

View File

@@ -44,6 +44,31 @@ SEND_USER_METADATA_TO_LLM_PROVIDER = (
# User Facing Features Configs
#####
BLURB_SIZE = 128 # Number Encoder Tokens included in the chunk blurb
# Hard ceiling for the admin-configurable file upload size (in MB).
# Self-hosted customers can raise or lower this via the environment variable.
_raw_max_upload_size_mb = int(os.environ.get("MAX_ALLOWED_UPLOAD_SIZE_MB", "250"))
if _raw_max_upload_size_mb < 0:
logger.warning(
"MAX_ALLOWED_UPLOAD_SIZE_MB=%d is negative; falling back to 250",
_raw_max_upload_size_mb,
)
_raw_max_upload_size_mb = 250
MAX_ALLOWED_UPLOAD_SIZE_MB = _raw_max_upload_size_mb
# Default fallback for the per-user file upload size limit (in MB) when no
# admin-configured value exists. Clamped to MAX_ALLOWED_UPLOAD_SIZE_MB at
# runtime so this never silently exceeds the hard ceiling.
_raw_default_upload_size_mb = int(
os.environ.get("DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB", "100")
)
if _raw_default_upload_size_mb < 0:
logger.warning(
"DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB=%d is negative; falling back to 100",
_raw_default_upload_size_mb,
)
_raw_default_upload_size_mb = 100
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB = _raw_default_upload_size_mb
GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
os.environ.get("GENERATIVE_MODEL_ACCESS_CHECK_FREQ") or 86400
) # 1 day
@@ -61,17 +86,6 @@ CACHE_BACKEND = CacheBackendType(
os.environ.get("CACHE_BACKEND", CacheBackendType.REDIS)
)
# Maximum token count for a single uploaded file. Files exceeding this are rejected.
# Defaults to 100k tokens (or 10M when vector DB is disabled).
_DEFAULT_FILE_TOKEN_LIMIT = 10_000_000 if DISABLE_VECTOR_DB else 100_000
FILE_TOKEN_COUNT_THRESHOLD = int(
os.environ.get("FILE_TOKEN_COUNT_THRESHOLD", str(_DEFAULT_FILE_TOKEN_LIMIT))
)
# Maximum upload size for a single user file (chat/projects) in MB.
USER_FILE_MAX_UPLOAD_SIZE_MB = int(os.environ.get("USER_FILE_MAX_UPLOAD_SIZE_MB") or 50)
USER_FILE_MAX_UPLOAD_SIZE_BYTES = USER_FILE_MAX_UPLOAD_SIZE_MB * 1024 * 1024
# If set to true, will show extra/uncommon connectors in the "Other" category
SHOW_EXTRA_CONNECTORS = os.environ.get("SHOW_EXTRA_CONNECTORS", "").lower() == "true"
@@ -791,6 +805,10 @@ MINI_CHUNK_SIZE = 150
# This is the number of regular chunks per large chunk
LARGE_CHUNK_RATIO = 4
# The maximum number of chunks that can be held for 1 document processing batch
# The purpose of this is to set an upper bound on memory usage
MAX_CHUNKS_PER_DOC_BATCH = int(os.environ.get("MAX_CHUNKS_PER_DOC_BATCH") or 1000)
# Include the document level metadata in each chunk. If the metadata is too long, then it is thrown out
# We don't want the metadata to overwhelm the actual contents of the chunk
SKIP_METADATA_IN_CHUNK = os.environ.get("SKIP_METADATA_IN_CHUNK", "").lower() == "true"

View File

@@ -212,6 +212,7 @@ class DocumentSource(str, Enum):
PRODUCTBOARD = "productboard"
FILE = "file"
CODA = "coda"
CANVAS = "canvas"
NOTION = "notion"
ZULIP = "zulip"
LINEAR = "linear"
@@ -672,6 +673,7 @@ DocumentSourceDescription: dict[DocumentSource, str] = {
DocumentSource.SLAB: "slab data",
DocumentSource.PRODUCTBOARD: "productboard data (boards, etc.)",
DocumentSource.FILE: "files",
DocumentSource.CANVAS: "canvas lms - courses, pages, assignments, and announcements",
DocumentSource.CODA: "coda - team workspace with docs, tables, and pages",
DocumentSource.NOTION: "notion data - a workspace that combines note-taking, \
project management, and collaboration tools into a single, customizable platform",

View File

@@ -0,0 +1,32 @@
"""
Permissioning / AccessControl logic for Canvas courses.
CE stub — returns None (no permissions). The EE implementation is loaded
at runtime via ``fetch_versioned_implementation``.
"""
from collections.abc import Callable
from typing import cast
from onyx.access.models import ExternalAccess
from onyx.connectors.canvas.client import CanvasApiClient
from onyx.utils.variable_functionality import fetch_versioned_implementation
from onyx.utils.variable_functionality import global_version
def get_course_permissions(
canvas_client: CanvasApiClient,
course_id: int,
) -> ExternalAccess | None:
if not global_version.is_ee_version():
return None
ee_get_course_permissions = cast(
Callable[[CanvasApiClient, int], ExternalAccess | None],
fetch_versioned_implementation(
"onyx.external_permissions.canvas.access",
"get_course_permissions",
),
)
return ee_get_course_permissions(canvas_client, course_id)

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import logging
import re
from collections.abc import Iterator
from typing import Any
from urllib.parse import urlparse
@@ -190,3 +191,22 @@ class CanvasApiClient:
if clean_endpoint:
final_url += "/" + clean_endpoint
return final_url
def paginate(
self,
endpoint: str,
params: dict[str, Any] | None = None,
) -> Iterator[list[Any]]:
"""Yield each page of results, following Link-header pagination.
Makes the first request with endpoint + params, then follows
next_url from Link headers for subsequent pages.
"""
response, next_url = self.get(endpoint, params=params)
while True:
if not response:
break
yield response
if not next_url:
break
response, next_url = self.get(full_url=next_url)

View File

@@ -1,17 +1,82 @@
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from typing import Literal
from typing import NoReturn
from typing import TypeAlias
from pydantic import BaseModel
from retry import retry
from typing_extensions import override
from onyx.access.models import ExternalAccess
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.canvas.access import get_course_permissions
from onyx.connectors.canvas.client import CanvasApiClient
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.error_handling.exceptions import OnyxError
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _handle_canvas_api_error(e: OnyxError) -> NoReturn:
"""Map Canvas API errors to connector framework exceptions."""
if e.status_code == 401:
raise CredentialExpiredError(
"Canvas API token is invalid or expired (HTTP 401)."
)
elif e.status_code == 403:
raise InsufficientPermissionsError(
"Canvas API token does not have sufficient permissions (HTTP 403)."
)
elif e.status_code == 429:
raise ConnectorValidationError(
"Canvas rate-limit exceeded (HTTP 429). Please try again later."
)
elif e.status_code >= 500:
raise UnexpectedValidationError(
f"Unexpected Canvas HTTP error (status={e.status_code}): {e}"
)
else:
raise ConnectorValidationError(
f"Canvas API error (status={e.status_code}): {e}"
)
class CanvasCourse(BaseModel):
id: int
name: str
course_code: str
created_at: str
workflow_state: str
name: str | None = None
course_code: str | None = None
created_at: str | None = None
workflow_state: str | None = None
@classmethod
def from_api(cls, payload: dict[str, Any]) -> "CanvasCourse":
return cls(
id=payload["id"],
name=payload.get("name"),
course_code=payload.get("course_code"),
created_at=payload.get("created_at"),
workflow_state=payload.get("workflow_state"),
)
class CanvasPage(BaseModel):
@@ -19,10 +84,22 @@ class CanvasPage(BaseModel):
url: str
title: str
body: str | None = None
created_at: str
updated_at: str
created_at: str | None = None
updated_at: str | None = None
course_id: int
@classmethod
def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasPage":
return cls(
page_id=payload["page_id"],
url=payload["url"],
title=payload["title"],
body=payload.get("body"),
created_at=payload.get("created_at"),
updated_at=payload.get("updated_at"),
course_id=course_id,
)
class CanvasAssignment(BaseModel):
id: int
@@ -30,10 +107,23 @@ class CanvasAssignment(BaseModel):
description: str | None = None
html_url: str
course_id: int
created_at: str
updated_at: str
created_at: str | None = None
updated_at: str | None = None
due_at: str | None = None
@classmethod
def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasAssignment":
return cls(
id=payload["id"],
name=payload["name"],
description=payload.get("description"),
html_url=payload["html_url"],
course_id=course_id,
created_at=payload.get("created_at"),
updated_at=payload.get("updated_at"),
due_at=payload.get("due_at"),
)
class CanvasAnnouncement(BaseModel):
id: int
@@ -43,6 +133,17 @@ class CanvasAnnouncement(BaseModel):
posted_at: str | None = None
course_id: int
@classmethod
def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasAnnouncement":
return cls(
id=payload["id"],
title=payload["title"],
message=payload.get("message"),
html_url=payload["html_url"],
posted_at=payload.get("posted_at"),
course_id=course_id,
)
CanvasStage: TypeAlias = Literal["pages", "assignments", "announcements"]
@@ -72,3 +173,286 @@ class CanvasConnectorCheckpoint(ConnectorCheckpoint):
self.current_course_index += 1
self.stage = "pages"
self.next_url = None
class CanvasConnector(
CheckpointedConnectorWithPermSync[CanvasConnectorCheckpoint],
SlimConnectorWithPermSync,
):
def __init__(
self,
canvas_base_url: str,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.canvas_base_url = canvas_base_url.rstrip("/").removesuffix("/api/v1")
self.batch_size = batch_size
self._canvas_client: CanvasApiClient | None = None
self._course_permissions_cache: dict[int, ExternalAccess | None] = {}
@property
def canvas_client(self) -> CanvasApiClient:
if self._canvas_client is None:
raise ConnectorMissingCredentialError("Canvas")
return self._canvas_client
def _get_course_permissions(self, course_id: int) -> ExternalAccess | None:
"""Get course permissions with caching."""
if course_id not in self._course_permissions_cache:
self._course_permissions_cache[course_id] = get_course_permissions(
canvas_client=self.canvas_client,
course_id=course_id,
)
return self._course_permissions_cache[course_id]
@retry(tries=3, delay=1, backoff=2)
def _list_courses(self) -> list[CanvasCourse]:
"""Fetch all courses accessible to the authenticated user."""
logger.debug("Fetching Canvas courses")
courses: list[CanvasCourse] = []
for page in self.canvas_client.paginate(
"courses", params={"per_page": "100", "state[]": "available"}
):
courses.extend(CanvasCourse.from_api(c) for c in page)
return courses
@retry(tries=3, delay=1, backoff=2)
def _list_pages(self, course_id: int) -> list[CanvasPage]:
"""Fetch all pages for a given course."""
logger.debug(f"Fetching pages for course {course_id}")
pages: list[CanvasPage] = []
for page in self.canvas_client.paginate(
f"courses/{course_id}/pages",
params={"per_page": "100", "include[]": "body", "published": "true"},
):
pages.extend(CanvasPage.from_api(p, course_id=course_id) for p in page)
return pages
@retry(tries=3, delay=1, backoff=2)
def _list_assignments(self, course_id: int) -> list[CanvasAssignment]:
"""Fetch all assignments for a given course."""
logger.debug(f"Fetching assignments for course {course_id}")
assignments: list[CanvasAssignment] = []
for page in self.canvas_client.paginate(
f"courses/{course_id}/assignments",
params={"per_page": "100", "published": "true"},
):
assignments.extend(
CanvasAssignment.from_api(a, course_id=course_id) for a in page
)
return assignments
@retry(tries=3, delay=1, backoff=2)
def _list_announcements(self, course_id: int) -> list[CanvasAnnouncement]:
"""Fetch all announcements for a given course."""
logger.debug(f"Fetching announcements for course {course_id}")
announcements: list[CanvasAnnouncement] = []
for page in self.canvas_client.paginate(
"announcements",
params={
"per_page": "100",
"context_codes[]": f"course_{course_id}",
"active_only": "true",
},
):
announcements.extend(
CanvasAnnouncement.from_api(a, course_id=course_id) for a in page
)
return announcements
def _build_document(
self,
doc_id: str,
link: str,
text: str,
semantic_identifier: str,
doc_updated_at: datetime | None,
course_id: int,
doc_type: str,
) -> Document:
"""Build a Document with standard Canvas fields."""
return Document(
id=doc_id,
sections=cast(
list[TextSection | ImageSection],
[TextSection(link=link, text=text)],
),
source=DocumentSource.CANVAS,
semantic_identifier=semantic_identifier,
doc_updated_at=doc_updated_at,
metadata={"course_id": str(course_id), "type": doc_type},
)
def _convert_page_to_document(self, page: CanvasPage) -> Document:
"""Convert a Canvas page to a Document."""
link = f"{self.canvas_base_url}/courses/{page.course_id}/pages/{page.url}"
text_parts = [page.title]
body_text = parse_html_page_basic(page.body) if page.body else ""
if body_text:
text_parts.append(body_text)
doc_updated_at = (
datetime.fromisoformat(page.updated_at.replace("Z", "+00:00")).astimezone(
timezone.utc
)
if page.updated_at
else None
)
document = self._build_document(
doc_id=f"canvas-page-{page.course_id}-{page.page_id}",
link=link,
text="\n\n".join(text_parts),
semantic_identifier=page.title or f"Page {page.page_id}",
doc_updated_at=doc_updated_at,
course_id=page.course_id,
doc_type="page",
)
return document
def _convert_assignment_to_document(self, assignment: CanvasAssignment) -> Document:
"""Convert a Canvas assignment to a Document."""
text_parts = [assignment.name]
desc_text = (
parse_html_page_basic(assignment.description)
if assignment.description
else ""
)
if desc_text:
text_parts.append(desc_text)
if assignment.due_at:
due_dt = datetime.fromisoformat(
assignment.due_at.replace("Z", "+00:00")
).astimezone(timezone.utc)
text_parts.append(f"Due: {due_dt.strftime('%B %d, %Y %H:%M UTC')}")
doc_updated_at = (
datetime.fromisoformat(
assignment.updated_at.replace("Z", "+00:00")
).astimezone(timezone.utc)
if assignment.updated_at
else None
)
document = self._build_document(
doc_id=f"canvas-assignment-{assignment.course_id}-{assignment.id}",
link=assignment.html_url,
text="\n\n".join(text_parts),
semantic_identifier=assignment.name or f"Assignment {assignment.id}",
doc_updated_at=doc_updated_at,
course_id=assignment.course_id,
doc_type="assignment",
)
return document
def _convert_announcement_to_document(
self, announcement: CanvasAnnouncement
) -> Document:
"""Convert a Canvas announcement to a Document."""
text_parts = [announcement.title]
msg_text = (
parse_html_page_basic(announcement.message) if announcement.message else ""
)
if msg_text:
text_parts.append(msg_text)
doc_updated_at = (
datetime.fromisoformat(
announcement.posted_at.replace("Z", "+00:00")
).astimezone(timezone.utc)
if announcement.posted_at
else None
)
document = self._build_document(
doc_id=f"canvas-announcement-{announcement.course_id}-{announcement.id}",
link=announcement.html_url,
text="\n\n".join(text_parts),
semantic_identifier=announcement.title or f"Announcement {announcement.id}",
doc_updated_at=doc_updated_at,
course_id=announcement.course_id,
doc_type="announcement",
)
return document
@override
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
"""Load and validate Canvas credentials."""
access_token = credentials.get("canvas_access_token")
if not access_token:
raise ConnectorMissingCredentialError("Canvas")
try:
client = CanvasApiClient(
bearer_token=access_token,
canvas_base_url=self.canvas_base_url,
)
client.get("courses", params={"per_page": "1"})
except ValueError as e:
raise ConnectorValidationError(f"Invalid Canvas base URL: {e}")
except OnyxError as e:
_handle_canvas_api_error(e)
self._canvas_client = client
return None
@override
def validate_connector_settings(self) -> None:
"""Validate Canvas connector settings by testing API access."""
try:
self.canvas_client.get("courses", params={"per_page": "1"})
logger.info("Canvas connector settings validated successfully")
except OnyxError as e:
_handle_canvas_api_error(e)
except ConnectorMissingCredentialError:
raise
except Exception as exc:
raise UnexpectedValidationError(
f"Unexpected error during Canvas settings validation: {exc}"
)
@override
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: CanvasConnectorCheckpoint,
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: CanvasConnectorCheckpoint,
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def build_dummy_checkpoint(self) -> CanvasConnectorCheckpoint:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def validate_checkpoint_json(
self, checkpoint_json: str
) -> CanvasConnectorCheckpoint:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
# TODO(benwu408): implemented in PR4 (perm sync)
raise NotImplementedError

View File

@@ -890,8 +890,8 @@ class ConfluenceConnector(
def _retrieve_all_slim_docs(
self,
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
include_permissions: bool = True,
) -> GenerateSlimDocumentOutput:
@@ -915,8 +915,8 @@ class ConfluenceConnector(
self.confluence_client, doc_id, restrictions, ancestors
) or space_level_access_info.get(page_space_key)
# Query pages
page_query = self.base_cql_page_query + self.cql_label_filter
# Query pages (with optional time filtering for indexing_start)
page_query = self._construct_page_cql_query(start, end)
for page in self.confluence_client.cql_paginate_all_expansions(
cql=page_query,
expand=restrictions_expand,
@@ -950,7 +950,9 @@ class ConfluenceConnector(
# Query attachments for each page
page_hierarchy_node_yielded = False
attachment_query = self._construct_attachment_query(_get_page_id(page))
attachment_query = self._construct_attachment_query(
_get_page_id(page), start, end
)
for attachment in self.confluence_client.cql_paginate_all_expansions(
cql=attachment_query,
expand=restrictions_expand,

View File

@@ -72,6 +72,10 @@ CONNECTOR_CLASS_MAP = {
module_path="onyx.connectors.coda.connector",
class_name="CodaConnector",
),
DocumentSource.CANVAS: ConnectorMapping(
module_path="onyx.connectors.canvas.connector",
class_name="CanvasConnector",
),
DocumentSource.NOTION: ConnectorMapping(
module_path="onyx.connectors.notion.connector",
class_name="NotionConnector",

View File

@@ -1765,7 +1765,11 @@ class SharepointConnector(
checkpoint.current_drive_delta_next_link = None
checkpoint.seen_document_ids.clear()
def _fetch_slim_documents_from_sharepoint(self) -> GenerateSlimDocumentOutput:
def _fetch_slim_documents_from_sharepoint(
self,
start: datetime | None = None,
end: datetime | None = None,
) -> GenerateSlimDocumentOutput:
site_descriptors = self._filter_excluded_sites(
self.site_descriptors or self.fetch_sites()
)
@@ -1786,7 +1790,9 @@ class SharepointConnector(
# Process site documents if flag is True
if self.include_site_documents:
for driveitem, drive_name, drive_web_url in self._fetch_driveitems(
site_descriptor=site_descriptor
site_descriptor=site_descriptor,
start=start,
end=end,
):
if self._is_driveitem_excluded(driveitem):
logger.debug(f"Excluding by path denylist: {driveitem.web_url}")
@@ -1841,7 +1847,9 @@ class SharepointConnector(
# Process site pages if flag is True
if self.include_site_pages:
site_pages = self._fetch_site_pages(site_descriptor)
site_pages = self._fetch_site_pages(
site_descriptor, start=start, end=end
)
for site_page in site_pages:
logger.debug(
f"Processing site page: {site_page.get('webUrl', site_page.get('name', 'Unknown'))}"
@@ -2565,12 +2573,22 @@ class SharepointConnector(
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None, # noqa: ARG002
) -> GenerateSlimDocumentOutput:
yield from self._fetch_slim_documents_from_sharepoint()
start_dt = (
datetime.fromtimestamp(start, tz=timezone.utc)
if start is not None
else None
)
end_dt = (
datetime.fromtimestamp(end, tz=timezone.utc) if end is not None else None
)
yield from self._fetch_slim_documents_from_sharepoint(
start=start_dt,
end=end_dt,
)
if __name__ == "__main__":

View File

@@ -516,6 +516,8 @@ def _get_all_doc_ids(
] = default_msg_filter,
callback: IndexingHeartbeatInterface | None = None,
workspace_url: str | None = None,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
"""
Get all document ids in the workspace, channel by channel
@@ -546,6 +548,8 @@ def _get_all_doc_ids(
client=client,
channel=channel,
callback=callback,
oldest=str(start) if start else None, # 0.0 -> None intentionally
latest=str(end) if end is not None else None,
)
for message_batch in channel_message_batches:
@@ -847,8 +851,8 @@ class SlackConnector(
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
if self.client is None:
@@ -861,6 +865,8 @@ class SlackConnector(
msg_filter_func=self.msg_filter_func,
callback=callback,
workspace_url=self._workspace_url,
start=start,
end=end,
)
def _load_from_checkpoint(

View File

@@ -185,7 +185,15 @@ def delete_messages_and_files_from_chat_session(
for _, files in messages_with_files:
file_store = get_default_file_store()
for file_info in files or []:
file_store.delete_file(file_id=file_info.get("id"))
try:
file_store.delete_file(file_id=file_info.get("id"))
except Exception:
logger.warning(
"Failed to delete file %s from file store during "
"chat session cleanup, skipping.",
file_info.get("id"),
)
continue
# Delete ChatMessage records - CASCADE constraints will automatically handle:
# - ChatMessage__StandardAnswer relationship records

View File

@@ -5,6 +5,7 @@ accidentally reaches the vector DB layer will fail loudly instead of timing
out against a nonexistent Vespa/OpenSearch instance.
"""
from collections.abc import Iterable
from typing import Any
from onyx.context.search.models import IndexFilters
@@ -66,7 +67,7 @@ class DisabledDocumentIndex(DocumentIndex):
# ------------------------------------------------------------------
def index(
self,
chunks: list[DocMetadataAwareIndexChunk], # noqa: ARG002
chunks: Iterable[DocMetadataAwareIndexChunk], # noqa: ARG002
index_batch_params: IndexBatchParams, # noqa: ARG002
) -> set[DocumentInsertionRecord]:
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)

View File

@@ -1,4 +1,5 @@
import abc
from collections.abc import Iterable
from dataclasses import dataclass
from datetime import datetime
from typing import Any
@@ -206,7 +207,7 @@ class Indexable(abc.ABC):
@abc.abstractmethod
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
chunks: Iterable[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[DocumentInsertionRecord]:
"""
@@ -226,8 +227,8 @@ class Indexable(abc.ABC):
it is done automatically outside of this code.
Parameters:
- chunks: Document chunks with all of the information needed for indexing to the document
index.
- chunks: Document chunks with all of the information needed for
indexing to the document index.
- tenant_id: The tenant id of the user whose chunks are being indexed
- large_chunks_enabled: Whether large chunks are enabled

View File

@@ -1,4 +1,5 @@
import abc
from collections.abc import Iterable
from typing import Self
from pydantic import BaseModel
@@ -209,10 +210,10 @@ class Indexable(abc.ABC):
@abc.abstractmethod
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
chunks: Iterable[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata,
) -> list[DocumentInsertionRecord]:
"""Indexes a list of document chunks into the document index.
"""Indexes an iterable of document chunks into the document index.
This is often a batch operation including chunks from multiple
documents.

View File

@@ -1,11 +1,12 @@
import json
from collections import defaultdict
from collections.abc import Iterable
from typing import Any
import httpx
from opensearchpy import NotFoundError
from onyx.access.models import DocumentAccess
from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH
from onyx.configs.app_configs import VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT
from onyx.configs.chat_configs import NUM_RETURNED_HITS
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
@@ -351,7 +352,7 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
chunks: Iterable[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[OldDocumentInsertionRecord]:
"""
@@ -647,10 +648,10 @@ class OpenSearchDocumentIndex(DocumentIndex):
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata, # noqa: ARG002
chunks: Iterable[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata,
) -> list[DocumentInsertionRecord]:
"""Indexes a list of document chunks into the document index.
"""Indexes an iterable of document chunks into the document index.
Groups chunks by document ID and for each document, deletes existing
chunks and indexes the new chunks in bulk.
@@ -673,29 +674,34 @@ class OpenSearchDocumentIndex(DocumentIndex):
document is newly indexed or had already existed and was just
updated.
"""
# Group chunks by document ID.
doc_id_to_chunks: dict[str, list[DocMetadataAwareIndexChunk]] = defaultdict(
list
total_chunks = sum(
cc.new_chunk_cnt
for cc in indexing_metadata.doc_id_to_chunk_cnt_diff.values()
)
for chunk in chunks:
doc_id_to_chunks[chunk.source_document.id].append(chunk)
logger.debug(
f"[OpenSearchDocumentIndex] Indexing {len(chunks)} chunks from {len(doc_id_to_chunks)} "
f"[OpenSearchDocumentIndex] Indexing {total_chunks} chunks from {len(indexing_metadata.doc_id_to_chunk_cnt_diff)} "
f"documents for index {self._index_name}."
)
document_indexing_results: list[DocumentInsertionRecord] = []
# Try to index per-document.
for _, chunks in doc_id_to_chunks.items():
deleted_doc_ids: set[str] = set()
# Buffer chunks per document as they arrive from the iterable.
# When the document ID changes flush the buffered chunks.
current_doc_id: str | None = None
current_chunks: list[DocMetadataAwareIndexChunk] = []
def _flush_chunks(doc_chunks: list[DocMetadataAwareIndexChunk]) -> None:
assert len(doc_chunks) > 0, "doc_chunks is empty"
# Create a batch of OpenSearch-formatted chunks for bulk insertion.
# Do this before deleting existing chunks to reduce the amount of
# time the document index has no content for a given document, and
# to reduce the chance of entering a state where we delete chunks,
# then some error happens, and never successfully index new chunks.
# Since we are doing this in batches, an error occurring midway
# can result in a state where chunks are deleted and not all the
# new chunks have been indexed.
chunk_batch: list[DocumentChunk] = [
_convert_onyx_chunk_to_opensearch_document(chunk) for chunk in chunks
_convert_onyx_chunk_to_opensearch_document(chunk)
for chunk in doc_chunks
]
onyx_document: Document = chunks[0].source_document
onyx_document: Document = doc_chunks[0].source_document
# First delete the doc's chunks from the index. This is so that
# there are no dangling chunks in the index, in the event that the
# new document's content contains fewer chunks than the previous
@@ -704,22 +710,43 @@ class OpenSearchDocumentIndex(DocumentIndex):
# if the chunk count has actually decreased. This assumes that
# overlapping chunks are perfectly overwritten. If we can't
# guarantee that then we need the code as-is.
num_chunks_deleted = self.delete(
onyx_document.id, onyx_document.chunk_count
)
# If we see that chunks were deleted we assume the doc already
# existed.
document_insertion_record = DocumentInsertionRecord(
document_id=onyx_document.id,
already_existed=num_chunks_deleted > 0,
)
if onyx_document.id not in deleted_doc_ids:
num_chunks_deleted = self.delete(
onyx_document.id, onyx_document.chunk_count
)
deleted_doc_ids.add(onyx_document.id)
# If we see that chunks were deleted we assume the doc already
# existed. We record the result before bulk_index_documents
# runs. If indexing raises, this entire result list is discarded
# by the caller's retry logic, so early recording is safe.
document_indexing_results.append(
DocumentInsertionRecord(
document_id=onyx_document.id,
already_existed=num_chunks_deleted > 0,
)
)
# Now index. This will raise if a chunk of the same ID exists, which
# we do not expect because we should have deleted all chunks.
self._client.bulk_index_documents(
documents=chunk_batch,
tenant_state=self._tenant_state,
)
document_indexing_results.append(document_insertion_record)
for chunk in chunks:
doc_id = chunk.source_document.id
if doc_id != current_doc_id:
if current_chunks:
_flush_chunks(current_chunks)
current_doc_id = doc_id
current_chunks = [chunk]
elif len(current_chunks) >= MAX_CHUNKS_PER_DOC_BATCH:
_flush_chunks(current_chunks)
current_chunks = [chunk]
else:
current_chunks.append(chunk)
if current_chunks:
_flush_chunks(current_chunks)
return document_indexing_results

View File

@@ -6,6 +6,7 @@ import re
import time
import urllib
import zipfile
from collections.abc import Iterable
from dataclasses import dataclass
from datetime import datetime
from datetime import timedelta
@@ -461,7 +462,7 @@ class VespaIndex(DocumentIndex):
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
chunks: Iterable[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[OldDocumentInsertionRecord]:
"""

View File

@@ -1,6 +1,8 @@
import concurrent.futures
import logging
import random
from collections.abc import Generator
from collections.abc import Iterable
from typing import Any
from uuid import UUID
@@ -8,6 +10,7 @@ import httpx
from pydantic import BaseModel
from retry import retry
from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH
from onyx.configs.app_configs import RECENCY_BIAS_MULTIPLIER
from onyx.configs.app_configs import RERANK_COUNT
from onyx.configs.chat_configs import DOC_TIME_DECAY
@@ -318,7 +321,7 @@ class VespaDocumentIndex(DocumentIndex):
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
chunks: Iterable[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata,
) -> list[DocumentInsertionRecord]:
doc_id_to_chunk_cnt_diff = indexing_metadata.doc_id_to_chunk_cnt_diff
@@ -338,22 +341,31 @@ class VespaDocumentIndex(DocumentIndex):
# Vespa has restrictions on valid characters, yet document IDs come from
# external w.r.t. this class. We need to sanitize them.
cleaned_chunks: list[DocMetadataAwareIndexChunk] = [
clean_chunk_id_copy(chunk) for chunk in chunks
]
assert len(cleaned_chunks) == len(
chunks
), "Bug: Cleaned chunks and input chunks have different lengths."
#
# Instead of materializing all cleaned chunks upfront, we stream them
# through a generator that cleans IDs and builds the original-ID mapping
# incrementally as chunks flow into Vespa.
def _clean_and_track(
chunks_iter: Iterable[DocMetadataAwareIndexChunk],
id_map: dict[str, str],
seen_ids: set[str],
) -> Generator[DocMetadataAwareIndexChunk, None, None]:
"""Cleans chunk IDs and builds the original-ID mapping
incrementally as chunks flow through, avoiding a separate
materialization pass."""
for chunk in chunks_iter:
original_id = chunk.source_document.id
cleaned = clean_chunk_id_copy(chunk)
cleaned_id = cleaned.source_document.id
# Needed so the final DocumentInsertionRecord returned can have
# the original document ID. cleaned_chunks might not contain IDs
# exactly as callers supplied them.
id_map[cleaned_id] = original_id
seen_ids.add(cleaned_id)
yield cleaned
# Needed so the final DocumentInsertionRecord returned can have the
# original document ID. cleaned_chunks might not contain IDs exactly as
# callers supplied them.
new_document_id_to_original_document_id: dict[str, str] = dict()
for i, cleaned_chunk in enumerate(cleaned_chunks):
old_chunk = chunks[i]
new_document_id_to_original_document_id[
cleaned_chunk.source_document.id
] = old_chunk.source_document.id
new_document_id_to_original_document_id: dict[str, str] = {}
all_cleaned_doc_ids: set[str] = set()
existing_docs: set[str] = set()
@@ -409,8 +421,16 @@ class VespaDocumentIndex(DocumentIndex):
executor=executor,
)
# Insert new Vespa documents.
for chunk_batch in batch_generator(cleaned_chunks, BATCH_SIZE):
# Insert new Vespa documents, streaming through the cleaning
# pipeline so chunks are never fully materialized.
cleaned_chunks = _clean_and_track(
chunks,
new_document_id_to_original_document_id,
all_cleaned_doc_ids,
)
for chunk_batch in batch_generator(
cleaned_chunks, min(BATCH_SIZE, MAX_CHUNKS_PER_DOC_BATCH)
):
batch_index_vespa_chunks(
chunks=chunk_batch,
index_name=self._index_name,
@@ -419,10 +439,6 @@ class VespaDocumentIndex(DocumentIndex):
executor=executor,
)
all_cleaned_doc_ids: set[str] = {
chunk.source_document.id for chunk in cleaned_chunks
}
return [
DocumentInsertionRecord(
document_id=new_document_id_to_original_document_id[cleaned_doc_id],

View File

@@ -44,6 +44,7 @@ KNOWN_OPENPYXL_BUGS = [
"Value must be either numerical or a string containing a wildcard",
"File contains no valid workbook part",
"Unable to read workbook: could not read stylesheet from None",
"Colors must be aRGB hex values",
]

View File

@@ -19,7 +19,8 @@ from onyx.db.document import update_docs_updated_at__no_commit
from onyx.db.document_set import fetch_document_sets_for_documents
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
from onyx.indexing.models import BuildMetadataAwareChunksResult
from onyx.indexing.models import ChunkEnrichmentContext
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexChunk
from onyx.indexing.models import UpdatableChunkData
@@ -85,14 +86,21 @@ class DocumentIndexingBatchAdapter:
) as transaction:
yield transaction
def build_metadata_aware_chunks(
def prepare_enrichment(
self,
chunks_with_embeddings: list[IndexChunk],
chunk_content_scores: list[float],
tenant_id: str,
context: DocumentBatchPrepareContext,
) -> BuildMetadataAwareChunksResult:
"""Enrich chunks with access, document sets, boosts, token counts, and hierarchy."""
tenant_id: str,
chunks: list[DocAwareChunk],
) -> "DocumentChunkEnricher":
"""Do all DB lookups once and return a per-chunk enricher."""
updatable_ids = [doc.id for doc in context.updatable_docs]
doc_id_to_new_chunk_cnt: dict[str, int] = {
doc_id: 0 for doc_id in updatable_ids
}
for chunk in chunks:
if chunk.source_document.id in doc_id_to_new_chunk_cnt:
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
no_access = DocumentAccess.build(
user_emails=[],
@@ -102,67 +110,30 @@ class DocumentIndexingBatchAdapter:
is_public=False,
)
updatable_ids = [doc.id for doc in context.updatable_docs]
doc_id_to_access_info = get_access_for_documents(
document_ids=updatable_ids, db_session=self.db_session
)
doc_id_to_document_set = {
document_id: document_sets
for document_id, document_sets in fetch_document_sets_for_documents(
return DocumentChunkEnricher(
doc_id_to_access_info=get_access_for_documents(
document_ids=updatable_ids, db_session=self.db_session
)
}
doc_id_to_previous_chunk_cnt: dict[str, int] = {
document_id: chunk_count
for document_id, chunk_count in fetch_chunk_counts_for_documents(
document_ids=updatable_ids,
db_session=self.db_session,
)
}
doc_id_to_new_chunk_cnt: dict[str, int] = {
doc_id: 0 for doc_id in updatable_ids
}
for chunk in chunks_with_embeddings:
if chunk.source_document.id in doc_id_to_new_chunk_cnt:
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
# Get ancestor hierarchy node IDs for each document
doc_id_to_ancestor_ids = self._get_ancestor_ids_for_documents(
context.updatable_docs, tenant_id
)
access_aware_chunks = [
DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=doc_id_to_access_info.get(chunk.source_document.id, no_access),
document_sets=set(
doc_id_to_document_set.get(chunk.source_document.id, [])
),
user_project=[],
personas=[],
boost=(
context.id_to_boost_map[chunk.source_document.id]
if chunk.source_document.id in context.id_to_boost_map
else DEFAULT_BOOST
),
tenant_id=tenant_id,
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],
ancestor_hierarchy_node_ids=doc_id_to_ancestor_ids[
chunk.source_document.id
],
)
for chunk_num, chunk in enumerate(chunks_with_embeddings)
]
return BuildMetadataAwareChunksResult(
chunks=access_aware_chunks,
doc_id_to_previous_chunk_cnt=doc_id_to_previous_chunk_cnt,
doc_id_to_new_chunk_cnt=doc_id_to_new_chunk_cnt,
user_file_id_to_raw_text={},
user_file_id_to_token_count={},
),
doc_id_to_document_set={
document_id: document_sets
for document_id, document_sets in fetch_document_sets_for_documents(
document_ids=updatable_ids, db_session=self.db_session
)
},
doc_id_to_ancestor_ids=self._get_ancestor_ids_for_documents(
context.updatable_docs, tenant_id
),
id_to_boost_map=context.id_to_boost_map,
doc_id_to_previous_chunk_cnt={
document_id: chunk_count
for document_id, chunk_count in fetch_chunk_counts_for_documents(
document_ids=updatable_ids,
db_session=self.db_session,
)
},
doc_id_to_new_chunk_cnt=dict(doc_id_to_new_chunk_cnt),
no_access=no_access,
tenant_id=tenant_id,
)
def _get_ancestor_ids_for_documents(
@@ -203,7 +174,7 @@ class DocumentIndexingBatchAdapter:
context: DocumentBatchPrepareContext,
updatable_chunk_data: list[UpdatableChunkData],
filtered_documents: list[Document],
result: BuildMetadataAwareChunksResult,
enrichment: ChunkEnrichmentContext,
) -> None:
"""Finalize DB updates, store plaintext, and mark docs as indexed."""
updatable_ids = [doc.id for doc in context.updatable_docs]
@@ -227,7 +198,7 @@ class DocumentIndexingBatchAdapter:
update_docs_chunk_count__no_commit(
document_ids=updatable_ids,
doc_id_to_chunk_count=result.doc_id_to_new_chunk_cnt,
doc_id_to_chunk_count=enrichment.doc_id_to_new_chunk_cnt,
db_session=self.db_session,
)
@@ -249,3 +220,52 @@ class DocumentIndexingBatchAdapter:
)
self.db_session.commit()
class DocumentChunkEnricher:
"""Pre-computed metadata for per-chunk enrichment of connector documents."""
def __init__(
self,
doc_id_to_access_info: dict[str, DocumentAccess],
doc_id_to_document_set: dict[str, list[str]],
doc_id_to_ancestor_ids: dict[str, list[int]],
id_to_boost_map: dict[str, int],
doc_id_to_previous_chunk_cnt: dict[str, int],
doc_id_to_new_chunk_cnt: dict[str, int],
no_access: DocumentAccess,
tenant_id: str,
) -> None:
self._doc_id_to_access_info = doc_id_to_access_info
self._doc_id_to_document_set = doc_id_to_document_set
self._doc_id_to_ancestor_ids = doc_id_to_ancestor_ids
self._id_to_boost_map = id_to_boost_map
self._no_access = no_access
self._tenant_id = tenant_id
self.doc_id_to_previous_chunk_cnt = doc_id_to_previous_chunk_cnt
self.doc_id_to_new_chunk_cnt = doc_id_to_new_chunk_cnt
def enrich_chunk(
self, chunk: IndexChunk, score: float
) -> DocMetadataAwareIndexChunk:
return DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=self._doc_id_to_access_info.get(
chunk.source_document.id, self._no_access
),
document_sets=set(
self._doc_id_to_document_set.get(chunk.source_document.id, [])
),
user_project=[],
personas=[],
boost=(
self._id_to_boost_map[chunk.source_document.id]
if chunk.source_document.id in self._id_to_boost_map
else DEFAULT_BOOST
),
tenant_id=self._tenant_id,
aggregated_chunk_boost_factor=score,
ancestor_hierarchy_node_ids=self._doc_id_to_ancestor_ids[
chunk.source_document.id
],
)

View File

@@ -1,6 +1,9 @@
from __future__ import annotations
import contextlib
import datetime
import time
from collections import defaultdict
from collections.abc import Generator
from uuid import UUID
@@ -24,11 +27,13 @@ from onyx.db.user_file import fetch_persona_ids_for_user_files
from onyx.db.user_file import fetch_user_project_ids_for_user_files
from onyx.file_store.utils import store_user_file_plaintext
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
from onyx.indexing.models import BuildMetadataAwareChunksResult
from onyx.indexing.models import ChunkEnrichmentContext
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexChunk
from onyx.indexing.models import UpdatableChunkData
from onyx.llm.factory import get_default_llm
from onyx.natural_language_processing.utils import count_tokens
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.utils.logger import setup_logger
@@ -101,13 +106,20 @@ class UserFileIndexingAdapter:
f"Failed to acquire locks after {_NUM_LOCK_ATTEMPTS} attempts for user files: {[doc.id for doc in documents]}"
)
def build_metadata_aware_chunks(
def prepare_enrichment(
self,
chunks_with_embeddings: list[IndexChunk],
chunk_content_scores: list[float],
tenant_id: str,
context: DocumentBatchPrepareContext,
) -> BuildMetadataAwareChunksResult:
tenant_id: str,
chunks: list[DocAwareChunk],
) -> UserFileChunkEnricher:
"""Do all DB lookups and pre-compute file metadata from chunks."""
updatable_ids = [doc.id for doc in context.updatable_docs]
doc_id_to_new_chunk_cnt: dict[str, int] = defaultdict(int)
content_by_file: dict[str, list[str]] = defaultdict(list)
for chunk in chunks:
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
content_by_file[chunk.source_document.id].append(chunk.content)
no_access = DocumentAccess.build(
user_emails=[],
@@ -117,7 +129,6 @@ class UserFileIndexingAdapter:
is_public=False,
)
updatable_ids = [doc.id for doc in context.updatable_docs]
user_file_id_to_project_ids = fetch_user_project_ids_for_user_files(
user_file_ids=updatable_ids,
db_session=self.db_session,
@@ -138,17 +149,6 @@ class UserFileIndexingAdapter:
)
}
user_file_id_to_new_chunk_cnt: dict[str, int] = {
user_file_id: len(
[
chunk
for chunk in chunks_with_embeddings
if chunk.source_document.id == user_file_id
]
)
for user_file_id in updatable_ids
}
# Initialize tokenizer used for token count calculation
try:
llm = get_default_llm()
@@ -163,46 +163,30 @@ class UserFileIndexingAdapter:
user_file_id_to_raw_text: dict[str, str] = {}
user_file_id_to_token_count: dict[str, int | None] = {}
for user_file_id in updatable_ids:
user_file_chunks = [
chunk
for chunk in chunks_with_embeddings
if chunk.source_document.id == user_file_id
]
if user_file_chunks:
combined_content = " ".join(
[chunk.content for chunk in user_file_chunks]
)
contents = content_by_file.get(user_file_id)
if contents:
combined_content = " ".join(contents)
user_file_id_to_raw_text[str(user_file_id)] = combined_content
token_count = (
len(llm_tokenizer.encode(combined_content)) if llm_tokenizer else 0
token_count: int = (
count_tokens(combined_content, llm_tokenizer)
if llm_tokenizer
else 0
)
user_file_id_to_token_count[str(user_file_id)] = token_count
else:
user_file_id_to_raw_text[str(user_file_id)] = ""
user_file_id_to_token_count[str(user_file_id)] = None
access_aware_chunks = [
DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=user_file_id_to_access.get(chunk.source_document.id, no_access),
document_sets=set(),
user_project=user_file_id_to_project_ids.get(
chunk.source_document.id, []
),
personas=user_file_id_to_persona_ids.get(chunk.source_document.id, []),
boost=DEFAULT_BOOST,
tenant_id=tenant_id,
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],
)
for chunk_num, chunk in enumerate(chunks_with_embeddings)
]
return BuildMetadataAwareChunksResult(
chunks=access_aware_chunks,
return UserFileChunkEnricher(
user_file_id_to_access=user_file_id_to_access,
user_file_id_to_project_ids=user_file_id_to_project_ids,
user_file_id_to_persona_ids=user_file_id_to_persona_ids,
doc_id_to_previous_chunk_cnt=user_file_id_to_previous_chunk_cnt,
doc_id_to_new_chunk_cnt=user_file_id_to_new_chunk_cnt,
doc_id_to_new_chunk_cnt=dict(doc_id_to_new_chunk_cnt),
user_file_id_to_raw_text=user_file_id_to_raw_text,
user_file_id_to_token_count=user_file_id_to_token_count,
no_access=no_access,
tenant_id=tenant_id,
)
def _notify_assistant_owners_if_files_ready(
@@ -246,8 +230,9 @@ class UserFileIndexingAdapter:
context: DocumentBatchPrepareContext,
updatable_chunk_data: list[UpdatableChunkData], # noqa: ARG002
filtered_documents: list[Document], # noqa: ARG002
result: BuildMetadataAwareChunksResult,
enrichment: ChunkEnrichmentContext,
) -> None:
assert isinstance(enrichment, UserFileChunkEnricher)
user_file_ids = [doc.id for doc in context.updatable_docs]
user_files = (
@@ -263,8 +248,10 @@ class UserFileIndexingAdapter:
user_file.last_project_sync_at = datetime.datetime.now(
datetime.timezone.utc
)
user_file.chunk_count = result.doc_id_to_new_chunk_cnt[str(user_file.id)]
user_file.token_count = result.user_file_id_to_token_count[
user_file.chunk_count = enrichment.doc_id_to_new_chunk_cnt.get(
str(user_file.id), 0
)
user_file.token_count = enrichment.user_file_id_to_token_count[
str(user_file.id)
]
@@ -276,8 +263,54 @@ class UserFileIndexingAdapter:
# Store the plaintext in the file store for faster retrieval
# NOTE: this creates its own session to avoid committing the overall
# transaction.
for user_file_id, raw_text in result.user_file_id_to_raw_text.items():
for user_file_id, raw_text in enrichment.user_file_id_to_raw_text.items():
store_user_file_plaintext(
user_file_id=UUID(user_file_id),
plaintext_content=raw_text,
)
class UserFileChunkEnricher:
"""Pre-computed metadata for per-chunk enrichment of user-uploaded files."""
def __init__(
self,
user_file_id_to_access: dict[str, DocumentAccess],
user_file_id_to_project_ids: dict[str, list[int]],
user_file_id_to_persona_ids: dict[str, list[int]],
doc_id_to_previous_chunk_cnt: dict[str, int],
doc_id_to_new_chunk_cnt: dict[str, int],
user_file_id_to_raw_text: dict[str, str],
user_file_id_to_token_count: dict[str, int | None],
no_access: DocumentAccess,
tenant_id: str,
) -> None:
self._user_file_id_to_access = user_file_id_to_access
self._user_file_id_to_project_ids = user_file_id_to_project_ids
self._user_file_id_to_persona_ids = user_file_id_to_persona_ids
self._no_access = no_access
self._tenant_id = tenant_id
self.doc_id_to_previous_chunk_cnt = doc_id_to_previous_chunk_cnt
self.doc_id_to_new_chunk_cnt = doc_id_to_new_chunk_cnt
self.user_file_id_to_raw_text = user_file_id_to_raw_text
self.user_file_id_to_token_count = user_file_id_to_token_count
def enrich_chunk(
self, chunk: IndexChunk, score: float
) -> DocMetadataAwareIndexChunk:
return DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=self._user_file_id_to_access.get(
chunk.source_document.id, self._no_access
),
document_sets=set(),
user_project=self._user_file_id_to_project_ids.get(
chunk.source_document.id, []
),
personas=self._user_file_id_to_persona_ids.get(
chunk.source_document.id, []
),
boost=DEFAULT_BOOST,
tenant_id=self._tenant_id,
aggregated_chunk_boost_factor=score,
)

View File

@@ -0,0 +1,89 @@
import pickle
import shutil
import tempfile
from collections.abc import Iterator
from pathlib import Path
from onyx.indexing.models import IndexChunk
class ChunkBatchStore:
"""Manages serialization of embedded chunks to a temporary directory.
Owns the temp directory lifetime and provides save/load/stream/scrub
operations.
Use as a context manager to ensure cleanup::
with ChunkBatchStore() as store:
store.save(chunks, batch_idx=0)
for chunk in store.stream():
...
"""
_EXT = ".pkl"
def __init__(self) -> None:
self._tmpdir: Path | None = None
# -- context manager -----------------------------------------------------
def __enter__(self) -> "ChunkBatchStore":
self._tmpdir = Path(tempfile.mkdtemp(prefix="onyx_embeddings_"))
return self
def __exit__(self, *_exc: object) -> None:
if self._tmpdir is not None:
shutil.rmtree(self._tmpdir, ignore_errors=True)
self._tmpdir = None
@property
def _dir(self) -> Path:
assert self._tmpdir is not None, "ChunkBatchStore used outside context manager"
return self._tmpdir
# -- storage primitives --------------------------------------------------
def save(self, chunks: list[IndexChunk], batch_idx: int) -> None:
"""Serialize a batch of embedded chunks to disk."""
with open(self._dir / f"batch_{batch_idx}{self._EXT}", "wb") as f:
pickle.dump(chunks, f)
def _load(self, batch_file: Path) -> list[IndexChunk]:
"""Deserialize a batch of embedded chunks from a file."""
with open(batch_file, "rb") as f:
return pickle.load(f)
def _batch_files(self) -> list[Path]:
"""Return batch files sorted by numeric index."""
return sorted(
self._dir.glob(f"batch_*{self._EXT}"),
key=lambda p: int(p.stem.removeprefix("batch_")),
)
# -- higher-level operations ---------------------------------------------
def stream(self) -> Iterator[IndexChunk]:
"""Yield all chunks across all batch files.
Each call returns a fresh generator, so the data can be iterated
multiple times (e.g. once per document index).
"""
for batch_file in self._batch_files():
yield from self._load(batch_file)
def scrub_failed_docs(self, failed_doc_ids: set[str]) -> None:
"""Remove chunks belonging to *failed_doc_ids* from all batch files.
When a document fails embedding in batch N, earlier batches may
already contain successfully embedded chunks for that document.
This ensures the output is all-or-nothing per document.
"""
for batch_file in self._batch_files():
batch_chunks = self._load(batch_file)
cleaned = [
c for c in batch_chunks if c.source_document.id not in failed_doc_ids
]
if len(cleaned) != len(batch_chunks):
with open(batch_file, "wb") as f:
pickle.dump(cleaned, f)

View File

@@ -1,5 +1,8 @@
from collections import defaultdict
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Protocol
from pydantic import BaseModel
@@ -9,6 +12,7 @@ from sqlalchemy.orm import Session
from onyx.configs.app_configs import DEFAULT_CONTEXTUAL_RAG_LLM_NAME
from onyx.configs.app_configs import DEFAULT_CONTEXTUAL_RAG_LLM_PROVIDER
from onyx.configs.app_configs import ENABLE_CONTEXTUAL_RAG
from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
from onyx.configs.app_configs import MAX_TOKENS_FOR_FULL_INCLUSION
from onyx.configs.app_configs import USE_CHUNK_SUMMARY
@@ -43,10 +47,12 @@ from onyx.document_index.interfaces import DocumentMetadata
from onyx.document_index.interfaces import IndexBatchParams
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
from onyx.file_store.file_store import get_default_file_store
from onyx.indexing.chunk_batch_store import ChunkBatchStore
from onyx.indexing.chunker import Chunker
from onyx.indexing.embedder import embed_chunks_with_failure_handling
from onyx.indexing.embedder import IndexingEmbedder
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexingBatchAdapter
from onyx.indexing.models import UpdatableChunkData
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
@@ -63,6 +69,7 @@ from onyx.natural_language_processing.utils import tokenizer_trim_middle
from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT1
from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT2
from onyx.prompts.contextual_retrieval import DOCUMENT_SUMMARY_PROMPT
from onyx.utils.batching import batch_generator
from onyx.utils.logger import setup_logger
from onyx.utils.postgres_sanitization import sanitize_documents_for_postgres
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -91,6 +98,20 @@ class IndexingPipelineResult(BaseModel):
failures: list[ConnectorFailure]
@classmethod
def empty(cls, total_docs: int) -> "IndexingPipelineResult":
return cls(
new_docs=0,
total_docs=total_docs,
total_chunks=0,
failures=[],
)
class ChunkEmbeddingResult(BaseModel):
successful_chunk_ids: list[tuple[int, str]] # (chunk_id, document_id)
connector_failures: list[ConnectorFailure]
class IndexingPipelineProtocol(Protocol):
def __call__(
@@ -139,6 +160,110 @@ def _upsert_documents_in_db(
)
def _get_failed_doc_ids(failures: list[ConnectorFailure]) -> set[str]:
"""Extract document IDs from a list of connector failures."""
return {f.failed_document.document_id for f in failures if f.failed_document}
def _embed_chunks_to_store(
chunks: list[DocAwareChunk],
embedder: IndexingEmbedder,
tenant_id: str,
request_id: str | None,
store: ChunkBatchStore,
) -> ChunkEmbeddingResult:
"""Embed chunks in batches, spilling each batch to *store*.
If a document fails embedding in any batch, its chunks are excluded from
all batches (including earlier ones already written) so that the output
is all-or-nothing per document.
"""
successful_chunk_ids: list[tuple[int, str]] = []
all_embedding_failures: list[ConnectorFailure] = []
# Track failed doc IDs across all batches so that a failure in batch N
# causes chunks for that doc to be skipped in batch N+1 and stripped
# from earlier batches.
all_failed_doc_ids: set[str] = set()
for batch_idx, chunk_batch in enumerate(
batch_generator(chunks, MAX_CHUNKS_PER_DOC_BATCH)
):
# Skip chunks belonging to documents that failed in earlier batches.
chunk_batch = [
c for c in chunk_batch if c.source_document.id not in all_failed_doc_ids
]
if not chunk_batch:
continue
logger.debug(f"Embedding batch {batch_idx}: {len(chunk_batch)} chunks")
chunks_with_embeddings, embedding_failures = embed_chunks_with_failure_handling(
chunks=chunk_batch,
embedder=embedder,
tenant_id=tenant_id,
request_id=request_id,
)
all_embedding_failures.extend(embedding_failures)
all_failed_doc_ids.update(_get_failed_doc_ids(embedding_failures))
# Only keep successfully embedded chunks for non-failed docs.
chunks_with_embeddings = [
c
for c in chunks_with_embeddings
if c.source_document.id not in all_failed_doc_ids
]
successful_chunk_ids.extend(
(c.chunk_id, c.source_document.id) for c in chunks_with_embeddings
)
store.save(chunks_with_embeddings, batch_idx)
del chunks_with_embeddings
# Scrub earlier batches for docs that failed in later batches.
if all_failed_doc_ids:
store.scrub_failed_docs(all_failed_doc_ids)
successful_chunk_ids = [
(chunk_id, doc_id)
for chunk_id, doc_id in successful_chunk_ids
if doc_id not in all_failed_doc_ids
]
return ChunkEmbeddingResult(
successful_chunk_ids=successful_chunk_ids,
connector_failures=all_embedding_failures,
)
@contextmanager
def embed_and_stream(
chunks: list[DocAwareChunk],
embedder: IndexingEmbedder,
tenant_id: str,
request_id: str | None,
) -> Generator[tuple[ChunkEmbeddingResult, ChunkBatchStore], None, None]:
"""Embed chunks to disk and yield a ``(result, store)`` pair.
The store owns the temp directory — files are cleaned up when the context
manager exits.
Usage::
with embed_and_stream(chunks, embedder, tenant_id, req_id) as (result, store):
for chunk in store.stream():
...
"""
with ChunkBatchStore() as store:
result = _embed_chunks_to_store(
chunks=chunks,
embedder=embedder,
tenant_id=tenant_id,
request_id=request_id,
store=store,
)
yield result, store
def get_doc_ids_to_update(
documents: list[Document], db_docs: list[DBDocument]
) -> list[Document]:
@@ -637,6 +762,29 @@ def add_contextual_summaries(
return chunks
def _verify_indexing_completeness(
insertion_records: list[DocumentInsertionRecord],
write_failures: list[ConnectorFailure],
embedding_failed_doc_ids: set[str],
updatable_ids: list[str],
document_index_name: str,
) -> None:
"""Verify that every updatable document was either indexed or reported as failed."""
all_returned_doc_ids = (
{r.document_id for r in insertion_records}
| {f.failed_document.document_id for f in write_failures if f.failed_document}
| embedding_failed_doc_ids
)
if all_returned_doc_ids != set(updatable_ids):
raise RuntimeError(
f"Some documents were not successfully indexed. "
f"Updatable IDs: {updatable_ids}, "
f"Returned IDs: {all_returned_doc_ids}. "
f"This should never happen. "
f"This occured for document index {document_index_name}"
)
@log_function_time(debug_only=True)
def index_doc_batch(
*,
@@ -672,12 +820,7 @@ def index_doc_batch(
filtered_documents = filter_fnc(document_batch)
context = adapter.prepare(filtered_documents, ignore_time_skip)
if not context:
return IndexingPipelineResult(
new_docs=0,
total_docs=len(filtered_documents),
total_chunks=0,
failures=[],
)
return IndexingPipelineResult.empty(len(filtered_documents))
# Convert documents to IndexingDocument objects with processed section
# logger.debug("Processing image sections")
@@ -716,117 +859,99 @@ def index_doc_batch(
)
logger.debug("Starting embedding")
chunks_with_embeddings, embedding_failures = (
embed_chunks_with_failure_handling(
chunks=chunks,
embedder=embedder,
tenant_id=tenant_id,
request_id=request_id,
)
if chunks
else ([], [])
)
with embed_and_stream(chunks, embedder, tenant_id, request_id) as (
embedding_result,
chunk_store,
):
updatable_ids = [doc.id for doc in context.updatable_docs]
updatable_chunk_data = [
UpdatableChunkData(
chunk_id=chunk_id,
document_id=document_id,
boost_score=1.0,
)
for chunk_id, document_id in embedding_result.successful_chunk_ids
]
chunk_content_scores = [1.0] * len(chunks_with_embeddings)
updatable_ids = [doc.id for doc in context.updatable_docs]
updatable_chunk_data = [
UpdatableChunkData(
chunk_id=chunk.chunk_id,
document_id=chunk.source_document.id,
boost_score=score,
)
for chunk, score in zip(chunks_with_embeddings, chunk_content_scores)
]
# Acquires a lock on the documents so that no other process can modify them
# NOTE: don't need to acquire till here, since this is when the actual race condition
# with Vespa can occur.
with adapter.lock_context(context.updatable_docs):
# we're concerned about race conditions where multiple simultaneous indexings might result
# in one set of metadata overwriting another one in vespa.
# we still write data here for the immediate and most likely correct sync, but
# to resolve this, an update of the last modified field at the end of this loop
# always triggers a final metadata sync via the celery queue
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=chunks_with_embeddings,
chunk_content_scores=chunk_content_scores,
tenant_id=tenant_id,
context=context,
embedding_failed_doc_ids = _get_failed_doc_ids(
embedding_result.connector_failures
)
short_descriptor_list = [chunk.to_short_descriptor() for chunk in result.chunks]
short_descriptor_log = str(short_descriptor_list)[:1024]
logger.debug(f"Indexing the following chunks: {short_descriptor_log}")
# Filter to only successfully embedded chunks so
# doc_id_to_new_chunk_cnt reflects what's actually written to Vespa.
embedded_chunks = [
c for c in chunks if c.source_document.id not in embedding_failed_doc_ids
]
primary_doc_idx_insertion_records: list[DocumentInsertionRecord] | None = None
primary_doc_idx_vector_db_write_failures: list[ConnectorFailure] | None = None
for document_index in document_indices:
# A document will not be spread across different batches, so all the
# documents with chunks in this set, are fully represented by the chunks
# in this set
(
insertion_records,
vector_db_write_failures,
) = write_chunks_to_vector_db_with_backoff(
document_index=document_index,
chunks=result.chunks,
index_batch_params=IndexBatchParams(
doc_id_to_previous_chunk_cnt=result.doc_id_to_previous_chunk_cnt,
doc_id_to_new_chunk_cnt=result.doc_id_to_new_chunk_cnt,
tenant_id=tenant_id,
large_chunks_enabled=chunker.enable_large_chunks,
),
# Acquires a lock on the documents so that no other process can modify
# them. Not needed until here, since this is when the actual race
# condition with vector db can occur.
with adapter.lock_context(context.updatable_docs):
enricher = adapter.prepare_enrichment(
context=context,
tenant_id=tenant_id,
chunks=embedded_chunks,
)
all_returned_doc_ids: set[str] = (
{record.document_id for record in insertion_records}
.union(
{
record.failed_document.document_id
for record in vector_db_write_failures
if record.failed_document
}
)
.union(
{
record.failed_document.document_id
for record in embedding_failures
if record.failed_document
}
)
index_batch_params = IndexBatchParams(
doc_id_to_previous_chunk_cnt=enricher.doc_id_to_previous_chunk_cnt,
doc_id_to_new_chunk_cnt=enricher.doc_id_to_new_chunk_cnt,
tenant_id=tenant_id,
large_chunks_enabled=chunker.enable_large_chunks,
)
if all_returned_doc_ids != set(updatable_ids):
raise RuntimeError(
f"Some documents were not successfully indexed. "
f"Updatable IDs: {updatable_ids}, "
f"Returned IDs: {all_returned_doc_ids}. "
"This should never happen."
f"This occured for document index {document_index.__class__.__name__}"
)
# We treat the first document index we got as the primary one used
# for reporting the state of indexing.
if primary_doc_idx_insertion_records is None:
primary_doc_idx_insertion_records = insertion_records
if primary_doc_idx_vector_db_write_failures is None:
primary_doc_idx_vector_db_write_failures = vector_db_write_failures
adapter.post_index(
context=context,
updatable_chunk_data=updatable_chunk_data,
filtered_documents=filtered_documents,
result=result,
)
primary_doc_idx_insertion_records: list[DocumentInsertionRecord] | None = (
None
)
primary_doc_idx_vector_db_write_failures: list[ConnectorFailure] | None = (
None
)
for document_index in document_indices:
def _enriched_stream() -> Iterator[DocMetadataAwareIndexChunk]:
for chunk in chunk_store.stream():
yield enricher.enrich_chunk(chunk, 1.0)
insertion_records, write_failures = (
write_chunks_to_vector_db_with_backoff(
document_index=document_index,
make_chunks=_enriched_stream,
index_batch_params=index_batch_params,
)
)
_verify_indexing_completeness(
insertion_records=insertion_records,
write_failures=write_failures,
embedding_failed_doc_ids=embedding_failed_doc_ids,
updatable_ids=updatable_ids,
document_index_name=document_index.__class__.__name__,
)
# We treat the first document index we got as the primary one used
# for reporting the state of indexing.
if primary_doc_idx_insertion_records is None:
primary_doc_idx_insertion_records = insertion_records
if primary_doc_idx_vector_db_write_failures is None:
primary_doc_idx_vector_db_write_failures = write_failures
adapter.post_index(
context=context,
updatable_chunk_data=updatable_chunk_data,
filtered_documents=filtered_documents,
enrichment=enricher,
)
assert primary_doc_idx_insertion_records is not None
assert primary_doc_idx_vector_db_write_failures is not None
return IndexingPipelineResult(
new_docs=len(
[r for r in primary_doc_idx_insertion_records if not r.already_existed]
new_docs=sum(
1 for r in primary_doc_idx_insertion_records if not r.already_existed
),
total_docs=len(filtered_documents),
total_chunks=len(chunks_with_embeddings),
failures=primary_doc_idx_vector_db_write_failures + embedding_failures,
total_chunks=len(embedding_result.successful_chunk_ids),
failures=primary_doc_idx_vector_db_write_failures
+ embedding_result.connector_failures,
)

View File

@@ -235,12 +235,16 @@ class UpdatableChunkData(BaseModel):
boost_score: float
class BuildMetadataAwareChunksResult(BaseModel):
chunks: list[DocMetadataAwareIndexChunk]
class ChunkEnrichmentContext(Protocol):
"""Returned by prepare_enrichment. Holds pre-computed metadata lookups
and provides per-chunk enrichment."""
doc_id_to_previous_chunk_cnt: dict[str, int]
doc_id_to_new_chunk_cnt: dict[str, int]
user_file_id_to_raw_text: dict[str, str]
user_file_id_to_token_count: dict[str, int | None]
def enrich_chunk(
self, chunk: IndexChunk, score: float
) -> DocMetadataAwareIndexChunk: ...
class IndexingBatchAdapter(Protocol):
@@ -254,18 +258,24 @@ class IndexingBatchAdapter(Protocol):
) -> Generator[TransactionalContext, None, None]:
"""Provide a transaction/row-lock context for critical updates."""
def build_metadata_aware_chunks(
def prepare_enrichment(
self,
chunks_with_embeddings: list[IndexChunk],
chunk_content_scores: list[float],
tenant_id: str,
context: "DocumentBatchPrepareContext",
) -> BuildMetadataAwareChunksResult: ...
tenant_id: str,
chunks: list[DocAwareChunk],
) -> ChunkEnrichmentContext:
"""Prepare per-chunk enrichment data (access, document sets, boost, etc.).
Precondition: ``chunks`` have already been through the embedding step
(i.e. they are ``IndexChunk`` instances with populated embeddings,
passed here as the base ``DocAwareChunk`` type).
"""
...
def post_index(
self,
context: "DocumentBatchPrepareContext",
updatable_chunk_data: list[UpdatableChunkData],
filtered_documents: list[Document],
result: BuildMetadataAwareChunksResult,
enrichment: ChunkEnrichmentContext,
) -> None: ...

View File

@@ -1,6 +1,9 @@
import time
from collections import defaultdict
from collections.abc import Callable
from collections.abc import Iterable
from http import HTTPStatus
from itertools import chain
from itertools import groupby
import httpx
@@ -28,22 +31,22 @@ def _log_insufficient_storage_error(e: Exception) -> None:
def write_chunks_to_vector_db_with_backoff(
document_index: DocumentIndex,
chunks: list[DocMetadataAwareIndexChunk],
make_chunks: Callable[[], Iterable[DocMetadataAwareIndexChunk]],
index_batch_params: IndexBatchParams,
) -> tuple[list[DocumentInsertionRecord], list[ConnectorFailure]]:
"""Tries to insert all chunks in one large batch. If that batch fails for any reason,
goes document by document to isolate the failure(s).
IMPORTANT: must pass in whole documents at a time not individual chunks, since the
vector DB interface assumes that all chunks for a single document are present.
vector DB interface assumes that all chunks for a single document are present. The
chunks must also be in contiguous batches
"""
# first try to write the chunks to the vector db
try:
return (
list(
document_index.index(
chunks=chunks,
chunks=make_chunks(),
index_batch_params=index_batch_params,
)
),
@@ -60,14 +63,23 @@ def write_chunks_to_vector_db_with_backoff(
# wait a couple seconds just to give the vector db a chance to recover
time.sleep(2)
# try writing each doc one by one
chunks_for_docs: dict[str, list[DocMetadataAwareIndexChunk]] = defaultdict(list)
for chunk in chunks:
chunks_for_docs[chunk.source_document.id].append(chunk)
insertion_records: list[DocumentInsertionRecord] = []
failures: list[ConnectorFailure] = []
for doc_id, chunks_for_doc in chunks_for_docs.items():
def key(chunk: DocMetadataAwareIndexChunk) -> str:
return chunk.source_document.id
seen_doc_ids: set[str] = set()
for doc_id, chunks_for_doc in groupby(make_chunks(), key=key):
if doc_id in seen_doc_ids:
raise RuntimeError(
f"Doc chunks are not arriving in order. Current doc_id={doc_id}, seen_doc_ids={list(seen_doc_ids)}"
)
seen_doc_ids.add(doc_id)
first_chunk = next(chunks_for_doc)
chunks_for_doc = chain([first_chunk], chunks_for_doc)
try:
insertion_records.extend(
document_index.index(
@@ -87,9 +99,7 @@ def write_chunks_to_vector_db_with_backoff(
ConnectorFailure(
failed_document=DocumentFailure(
document_id=doc_id,
document_link=(
chunks_for_doc[0].get_link() if chunks_for_doc else None
),
document_link=first_chunk.get_link(),
),
failure_message=str(e),
exception=e,

View File

@@ -185,6 +185,21 @@ def _messages_contain_tool_content(messages: list[dict[str, Any]]) -> bool:
return False
def _prompt_contains_tool_call_history(prompt: LanguageModelInput) -> bool:
"""Check if the prompt contains any assistant messages with tool_calls.
When Anthropic's extended thinking is enabled, the API requires every
assistant message to start with a thinking block before any tool_use
blocks. Since we don't preserve thinking_blocks (they carry
cryptographic signatures that can't be reconstructed), we must skip
the thinking param whenever history contains prior tool-calling turns.
"""
from onyx.llm.models import AssistantMessage
msgs = prompt if isinstance(prompt, list) else [prompt]
return any(isinstance(msg, AssistantMessage) and msg.tool_calls for msg in msgs)
def _is_vertex_model_rejecting_output_config(model_name: str) -> bool:
normalized_model_name = model_name.lower()
return any(
@@ -466,7 +481,20 @@ class LitellmLLM(LLM):
reasoning_effort
)
if budget_tokens is not None:
# Anthropic requires every assistant message with tool_use
# blocks to start with a thinking block that carries a
# cryptographic signature. We don't preserve those blocks
# across turns, so skip thinking when the history already
# contains tool-calling assistant messages. LiteLLM's
# modify_params workaround doesn't cover all providers
# (notably Bedrock).
can_enable_thinking = (
budget_tokens is not None
and not _prompt_contains_tool_call_history(prompt)
)
if can_enable_thinking:
assert budget_tokens is not None # mypy
if max_tokens is not None:
# Anthropic has a weird rule where max token has to be at least as much as budget tokens if set
# and the minimum budget tokens is 1024

View File

@@ -439,6 +439,7 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
dsn=SENTRY_DSN,
integrations=[StarletteIntegration(), FastApiIntegration()],
traces_sample_rate=0.1,
release=__version__,
)
logger.info("Sentry initialized")
else:

View File

@@ -175,6 +175,32 @@ def get_tokenizer(
return _check_tokenizer_cache(provider_type, model_name)
# Max characters per encode() call.
_ENCODE_CHUNK_SIZE = 500_000
def count_tokens(
text: str,
tokenizer: BaseTokenizer,
token_limit: int | None = None,
) -> int:
"""Count tokens, chunking the input to avoid tiktoken stack overflow.
If token_limit is provided and the text is large enough to require
multiple chunks (> 500k chars), stops early once the count exceeds it.
When early-exiting, the returned value exceeds token_limit but may be
less than the true full token count.
"""
if len(text) <= _ENCODE_CHUNK_SIZE:
return len(tokenizer.encode(text))
total = 0
for start in range(0, len(text), _ENCODE_CHUNK_SIZE):
total += len(tokenizer.encode(text[start : start + _ENCODE_CHUNK_SIZE]))
if token_limit is not None and total > token_limit:
return total # Already over — skip remaining chunks
return total
def tokenizer_trim_content(
content: str, desired_length: int, tokenizer: BaseTokenizer
) -> str:

View File

@@ -3844,9 +3844,9 @@
}
},
"node_modules/@ts-morph/common/node_modules/brace-expansion": {
"version": "5.0.3",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-5.0.3.tgz",
"integrity": "sha512-fy6KJm2RawA5RcHkLa1z/ScpBeA762UF9KmZQxwIbDtRJrgLzM10depAiEQ+CXYcoiqW1/m96OAAoke2nE9EeA==",
"version": "5.0.5",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-5.0.5.tgz",
"integrity": "sha512-VZznLgtwhn+Mact9tfiwx64fA9erHH/MCXEUfB/0bX/6Fz6ny5EGTXYltMocqg4xFAQZtnO3DHWWXi8RiuN7cQ==",
"license": "MIT",
"dependencies": {
"balanced-match": "^4.0.2"
@@ -4224,9 +4224,9 @@
}
},
"node_modules/@typescript-eslint/typescript-estree/node_modules/brace-expansion": {
"version": "2.0.2",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz",
"integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==",
"version": "2.0.3",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.3.tgz",
"integrity": "sha512-MCV/fYJEbqx68aE58kv2cA/kiky1G8vux3OR6/jbS+jIMe/6fJWa0DTzJU7dqijOWYwHi1t29FlfYI9uytqlpA==",
"dev": true,
"license": "MIT",
"dependencies": {
@@ -5007,9 +5007,9 @@
}
},
"node_modules/brace-expansion": {
"version": "1.1.12",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz",
"integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==",
"version": "1.1.13",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.13.tgz",
"integrity": "sha512-9ZLprWS6EENmhEOpjCYW2c8VkmOvckIJZfkr7rBW6dObmfgJ/L1GpSYW5Hpo9lDz4D1+n0Ckz8rU7FwHDQiG/w==",
"dev": true,
"license": "MIT",
"dependencies": {

View File

@@ -123,9 +123,8 @@ def _validate_endpoint(
(not reachable — indicates the api_key is invalid).
Timeout handling:
- ConnectTimeout: TCP handshake never completed → cannot_connect.
- ReadTimeout / WriteTimeout: TCP was established, server responded slowly → timeout
(operator should consider increasing timeout_seconds).
- Any httpx.TimeoutException (ConnectTimeout, ReadTimeout, WriteTimeout, PoolTimeout) →
timeout (operator should consider increasing timeout_seconds).
- All other exceptions → cannot_connect.
"""
_check_ssrf_safety(endpoint_url)

View File

@@ -9,20 +9,15 @@ from pydantic import ConfigDict
from pydantic import Field
from sqlalchemy.orm import Session
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_BYTES
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.db.llm import fetch_default_llm_model
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.file_types import OnyxFileExtensions
from onyx.file_processing.password_validation import is_file_password_protected
from onyx.natural_language_processing.utils import count_tokens
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.server.settings.store import load_settings
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import SKIP_USERFILE_THRESHOLD
from shared_configs.configs import SKIP_USERFILE_THRESHOLD_TENANT_LIST
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -161,8 +156,8 @@ def categorize_uploaded_files(
document formats (.pdf, .docx, …) and falls back to a text-detection
heuristic for unknown extensions (.py, .js, .rs, …).
- Uses default tokenizer to compute token length.
- If token length > threshold, reject file (unless threshold skip is enabled).
- If text cannot be extracted, reject file.
- If token length exceeds the admin-configured threshold, reject file.
- If extension unsupported or text cannot be extracted, reject file.
- Otherwise marked as acceptable.
"""
@@ -173,36 +168,33 @@ def categorize_uploaded_files(
provider_type = default_model.llm_provider.provider if default_model else None
tokenizer = get_tokenizer(model_name=model_name, provider_type=provider_type)
# Check if threshold checks should be skipped
skip_threshold = False
# Check global skip flag (works for both single-tenant and multi-tenant)
if SKIP_USERFILE_THRESHOLD:
skip_threshold = True
logger.info("Skipping userfile threshold check (global setting)")
# Check tenant-specific skip list (only applicable in multi-tenant)
elif MULTI_TENANT and SKIP_USERFILE_THRESHOLD_TENANT_LIST:
try:
current_tenant_id = get_current_tenant_id()
skip_threshold = current_tenant_id in SKIP_USERFILE_THRESHOLD_TENANT_LIST
if skip_threshold:
logger.info(
f"Skipping userfile threshold check for tenant: {current_tenant_id}"
)
except RuntimeError as e:
logger.warning(f"Failed to get current tenant ID: {str(e)}")
# Derive limits from admin-configurable settings.
# For upload size: load_settings() resolves 0/None to a positive default.
# For token threshold: 0 means "no limit" (converted to None below).
settings = load_settings()
max_upload_size_mb = (
settings.user_file_max_upload_size_mb
) # always positive after load_settings()
max_upload_size_bytes = (
max_upload_size_mb * 1024 * 1024 if max_upload_size_mb else None
)
token_threshold_k = settings.file_token_count_threshold_k
token_threshold = (
token_threshold_k * 1000 if token_threshold_k else None
) # 0 → None = no limit
for upload in files:
try:
filename = get_safe_filename(upload)
# Size limit is a hard safety cap and is enforced even when token
# threshold checks are skipped via SKIP_USERFILE_THRESHOLD settings.
if is_upload_too_large(upload, USER_FILE_MAX_UPLOAD_SIZE_BYTES):
# Size limit is a hard safety cap.
if max_upload_size_bytes is not None and is_upload_too_large(
upload, max_upload_size_bytes
):
results.rejected.append(
RejectedFile(
filename=filename,
reason=f"Exceeds {USER_FILE_MAX_UPLOAD_SIZE_MB} MB file size limit",
reason=f"Exceeds {max_upload_size_mb} MB file size limit",
)
)
continue
@@ -224,11 +216,11 @@ def categorize_uploaded_files(
)
continue
if not skip_threshold and token_count > FILE_TOKEN_COUNT_THRESHOLD:
if token_threshold is not None and token_count > token_threshold:
results.rejected.append(
RejectedFile(
filename=filename,
reason=f"Exceeds {FILE_TOKEN_COUNT_THRESHOLD} token limit",
reason=f"Exceeds {token_threshold_k}K token limit",
)
)
else:
@@ -269,12 +261,14 @@ def categorize_uploaded_files(
)
continue
token_count = len(tokenizer.encode(text_content))
if not skip_threshold and token_count > FILE_TOKEN_COUNT_THRESHOLD:
token_count = count_tokens(
text_content, tokenizer, token_limit=token_threshold
)
if token_threshold is not None and token_count > token_threshold:
results.rejected.append(
RejectedFile(
filename=filename,
reason=f"Exceeds {FILE_TOKEN_COUNT_THRESHOLD} token limit",
reason=f"Exceeds {token_threshold_k}K token limit",
)
)
else:

View File

@@ -12,7 +12,6 @@ stale, which is fine for monitoring dashboards.
import json
import threading
import time
from collections.abc import Callable
from datetime import datetime
from datetime import timezone
from typing import Any
@@ -104,25 +103,23 @@ class _CachedCollector(Collector):
class QueueDepthCollector(_CachedCollector):
"""Reads Celery queue lengths from the broker Redis on each scrape.
Uses a Redis client factory (callable) rather than a stored client
reference so the connection is always fresh from Celery's pool.
"""
"""Reads Celery queue lengths from the broker Redis on each scrape."""
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
super().__init__(cache_ttl)
self._get_redis: Callable[[], Redis] | None = None
self._celery_app: Any | None = None
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
"""Set a callable that returns a broker Redis client on demand."""
self._get_redis = factory
def set_celery_app(self, app: Any) -> None:
"""Set the Celery app for broker Redis access."""
self._celery_app = app
def _collect_fresh(self) -> list[GaugeMetricFamily]:
if self._get_redis is None:
if self._celery_app is None:
return []
redis_client = self._get_redis()
from onyx.background.celery.celery_redis import celery_get_broker_client
redis_client = celery_get_broker_client(self._celery_app)
depth = GaugeMetricFamily(
"onyx_queue_depth",
@@ -404,17 +401,19 @@ class RedisHealthCollector(_CachedCollector):
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
super().__init__(cache_ttl)
self._get_redis: Callable[[], Redis] | None = None
self._celery_app: Any | None = None
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
"""Set a callable that returns a broker Redis client on demand."""
self._get_redis = factory
def set_celery_app(self, app: Any) -> None:
"""Set the Celery app for broker Redis access."""
self._celery_app = app
def _collect_fresh(self) -> list[GaugeMetricFamily]:
if self._get_redis is None:
if self._celery_app is None:
return []
redis_client = self._get_redis()
from onyx.background.celery.celery_redis import celery_get_broker_client
redis_client = celery_get_broker_client(self._celery_app)
memory_used = GaugeMetricFamily(
"onyx_redis_memory_used_bytes",

View File

@@ -3,12 +3,8 @@
Called once by the monitoring celery worker after Redis and DB are ready.
"""
from collections.abc import Callable
from typing import Any
from celery import Celery
from prometheus_client.registry import REGISTRY
from redis import Redis
from onyx.server.metrics.indexing_pipeline import ConnectorHealthCollector
from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
@@ -21,7 +17,7 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
# Module-level singletons — these are lightweight objects (no connections or DB
# state) until configure() / set_redis_factory() is called. Keeping them at
# state) until configure() / set_celery_app() is called. Keeping them at
# module level ensures they survive the lifetime of the worker process and are
# only registered with the Prometheus registry once.
_queue_collector = QueueDepthCollector()
@@ -32,72 +28,15 @@ _worker_health_collector = WorkerHealthCollector()
_heartbeat_monitor: WorkerHeartbeatMonitor | None = None
def _make_broker_redis_factory(celery_app: Celery) -> Callable[[], Redis]:
"""Create a factory that returns a cached broker Redis client.
Reuses a single connection across scrapes to avoid leaking connections.
Reconnects automatically if the cached connection becomes stale.
"""
_cached_client: list[Redis | None] = [None]
# Keep a reference to the Kombu Connection so we can close it on
# reconnect (the raw Redis client outlives the Kombu wrapper).
_cached_kombu_conn: list[Any] = [None]
def _close_client(client: Redis) -> None:
"""Best-effort close of a Redis client."""
try:
client.close()
except Exception:
logger.debug("Failed to close stale Redis client", exc_info=True)
def _close_kombu_conn() -> None:
"""Best-effort close of the cached Kombu Connection."""
conn = _cached_kombu_conn[0]
if conn is not None:
try:
conn.close()
except Exception:
logger.debug("Failed to close Kombu connection", exc_info=True)
_cached_kombu_conn[0] = None
def _get_broker_redis() -> Redis:
client = _cached_client[0]
if client is not None:
try:
client.ping()
return client
except Exception:
logger.debug("Cached Redis client stale, reconnecting")
_close_client(client)
_cached_client[0] = None
_close_kombu_conn()
# Get a fresh Redis client from the broker connection.
# We hold this client long-term (cached above) rather than using a
# context manager, because we need it to persist across scrapes.
# The caching logic above ensures we only ever hold one connection,
# and we close it explicitly on reconnect.
conn = celery_app.broker_connection()
# kombu's Channel exposes .client at runtime (the underlying Redis
# client) but the type stubs don't declare it.
new_client: Redis = conn.channel().client # type: ignore[attr-defined]
_cached_client[0] = new_client
_cached_kombu_conn[0] = conn
return new_client
return _get_broker_redis
def setup_indexing_pipeline_metrics(celery_app: Celery) -> None:
"""Register all indexing pipeline collectors with the default registry.
Args:
celery_app: The Celery application instance. Used to obtain a fresh
celery_app: The Celery application instance. Used to obtain a
broker Redis client on each scrape for queue depth metrics.
"""
redis_factory = _make_broker_redis_factory(celery_app)
_queue_collector.set_redis_factory(redis_factory)
_redis_health_collector.set_redis_factory(redis_factory)
_queue_collector.set_celery_app(celery_app)
_redis_health_collector.set_celery_app(celery_app)
# Start the heartbeat monitor daemon thread — uses a single persistent
# connection to receive worker-heartbeat events.

View File

@@ -9,7 +9,9 @@ from onyx import __version__ as onyx_version
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.auth.users import is_user_admin
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
from onyx.configs.constants import KV_REINDEX_KEY
from onyx.configs.constants import NotificationType
from onyx.db.engine.sql_engine import get_session
@@ -17,10 +19,16 @@ from onyx.db.models import User
from onyx.db.notification import dismiss_all_notifications
from onyx.db.notification import get_notifications
from onyx.db.notification import update_notification_last_shown
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.utils import HOOKS_AVAILABLE
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.features.build.utils import is_onyx_craft_enabled
from onyx.server.settings.models import (
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB,
)
from onyx.server.settings.models import DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
from onyx.server.settings.models import Notification
from onyx.server.settings.models import Settings
from onyx.server.settings.models import UserSettings
@@ -41,6 +49,15 @@ basic_router = APIRouter(prefix="/settings")
def admin_put_settings(
settings: Settings, _: User = Depends(current_admin_user)
) -> None:
if (
settings.user_file_max_upload_size_mb is not None
and settings.user_file_max_upload_size_mb > 0
and settings.user_file_max_upload_size_mb > MAX_ALLOWED_UPLOAD_SIZE_MB
):
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
f"File upload size limit cannot exceed {MAX_ALLOWED_UPLOAD_SIZE_MB} MB",
)
store_settings(settings)
@@ -83,6 +100,16 @@ def fetch_settings(
vector_db_enabled=not DISABLE_VECTOR_DB,
hooks_enabled=HOOKS_AVAILABLE,
version=onyx_version,
max_allowed_upload_size_mb=MAX_ALLOWED_UPLOAD_SIZE_MB,
default_user_file_max_upload_size_mb=min(
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB,
MAX_ALLOWED_UPLOAD_SIZE_MB,
),
default_file_token_count_threshold_k=(
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
if DISABLE_VECTOR_DB
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
),
)

View File

@@ -2,12 +2,19 @@ from datetime import datetime
from enum import Enum
from pydantic import BaseModel
from pydantic import Field
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
from onyx.configs.constants import NotificationType
from onyx.configs.constants import QueryHistoryType
from onyx.db.models import Notification as NotificationDBModel
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB = 200
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB = 10000
class PageType(str, Enum):
CHAT = "chat"
@@ -78,7 +85,12 @@ class Settings(BaseModel):
# User Knowledge settings
user_knowledge_enabled: bool | None = True
user_file_max_upload_size_mb: int | None = None
user_file_max_upload_size_mb: int | None = Field(
default=DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB, ge=0
)
file_token_count_threshold_k: int | None = Field(
default=None, ge=0 # thousands of tokens; None = context-aware default
)
# Connector settings
show_extra_connectors: bool | None = True
@@ -108,3 +120,14 @@ class UserSettings(Settings):
hooks_enabled: bool = False
# Application version, read from the ONYX_VERSION env var at startup.
version: str | None = None
# Hard ceiling for user_file_max_upload_size_mb, derived from env var.
max_allowed_upload_size_mb: int = MAX_ALLOWED_UPLOAD_SIZE_MB
# Factory defaults so the frontend can show a "restore default" button.
default_user_file_max_upload_size_mb: int = DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
default_file_token_count_threshold_k: int = Field(
default_factory=lambda: (
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
if DISABLE_VECTOR_DB
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
)
)

View File

@@ -1,13 +1,19 @@
from onyx.cache.factory import get_cache_backend
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.configs.app_configs import DISABLE_USER_KNOWLEDGE
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
from onyx.configs.app_configs import SHOW_EXTRA_CONNECTORS
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.configs.constants import KV_SETTINGS_KEY
from onyx.configs.constants import OnyxRedisLocks
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.settings.models import (
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB,
)
from onyx.server.settings.models import DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
from onyx.server.settings.models import Settings
from onyx.utils.logger import setup_logger
@@ -51,9 +57,36 @@ def load_settings() -> Settings:
if DISABLE_USER_KNOWLEDGE:
settings.user_knowledge_enabled = False
settings.user_file_max_upload_size_mb = USER_FILE_MAX_UPLOAD_SIZE_MB
settings.show_extra_connectors = SHOW_EXTRA_CONNECTORS
settings.opensearch_indexing_enabled = ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
# Resolve context-aware defaults for token threshold.
# None = admin hasn't set a value yet → use context-aware default.
# 0 = admin explicitly chose "no limit" → preserve as-is.
if settings.file_token_count_threshold_k is None:
settings.file_token_count_threshold_k = (
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
if DISABLE_VECTOR_DB
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
)
# Upload size: 0 and None are treated as "unset" (not "no limit") →
# fall back to min(configured default, hard ceiling).
if not settings.user_file_max_upload_size_mb:
settings.user_file_max_upload_size_mb = min(
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB,
MAX_ALLOWED_UPLOAD_SIZE_MB,
)
# Clamp to env ceiling so stale KV values are capped even if the
# operator lowered MAX_ALLOWED_UPLOAD_SIZE_MB after a higher value
# was already saved (api.py only guards new writes).
if (
settings.user_file_max_upload_size_mb > 0
and settings.user_file_max_upload_size_mb > MAX_ALLOWED_UPLOAD_SIZE_MB
):
settings.user_file_max_upload_size_mb = MAX_ALLOWED_UPLOAD_SIZE_MB
return settings

View File

@@ -187,7 +187,7 @@ coloredlogs==15.0.1
# via onnxruntime
courlan==1.3.2
# via trafilatura
cryptography==46.0.5
cryptography==46.0.6
# via
# authlib
# google-auth
@@ -449,7 +449,7 @@ kombu==5.5.4
# via celery
kubernetes==31.0.0
# via onyx
langchain-core==1.2.11
langchain-core==1.2.22
# via onyx
langdetect==1.0.9
# via unstructured
@@ -735,7 +735,7 @@ pyee==13.0.0
# via playwright
pygithub==2.5.0
# via onyx
pygments==2.19.2
pygments==2.20.0
# via rich
pyjwt==2.12.0
# via

View File

@@ -97,7 +97,7 @@ comm==0.2.3
# via ipykernel
contourpy==1.3.3
# via matplotlib
cryptography==46.0.5
cryptography==46.0.6
# via
# google-auth
# pyjwt
@@ -263,7 +263,7 @@ oauthlib==3.2.2
# via
# kubernetes
# requests-oauthlib
onyx-devtools==0.7.1
onyx-devtools==0.7.2
# via onyx
openai==2.14.0
# via
@@ -349,7 +349,7 @@ pydantic-core==2.33.2
# via pydantic
pydantic-settings==2.12.0
# via mcp
pygments==2.19.2
pygments==2.20.0
# via
# ipython
# ipython-pygments-lexers

View File

@@ -76,7 +76,7 @@ colorama==0.4.6 ; sys_platform == 'win32'
# via
# click
# tqdm
cryptography==46.0.5
cryptography==46.0.6
# via
# google-auth
# pyjwt

View File

@@ -92,7 +92,7 @@ colorama==0.4.6 ; sys_platform == 'win32'
# via
# click
# tqdm
cryptography==46.0.5
cryptography==46.0.6
# via
# google-auth
# pyjwt

View File

@@ -191,25 +191,6 @@ IGNORED_SYNCING_TENANT_LIST = (
else None
)
# Global flag to skip userfile threshold for all users/tenants
SKIP_USERFILE_THRESHOLD = (
os.environ.get("SKIP_USERFILE_THRESHOLD", "").lower() == "true"
)
# Comma-separated list of specific tenant IDs to skip threshold (multi-tenant only)
SKIP_USERFILE_THRESHOLD_TENANT_IDS = os.environ.get(
"SKIP_USERFILE_THRESHOLD_TENANT_IDS"
)
SKIP_USERFILE_THRESHOLD_TENANT_LIST = (
[
tenant.strip()
for tenant in SKIP_USERFILE_THRESHOLD_TENANT_IDS.split(",")
if tenant.strip()
]
if SKIP_USERFILE_THRESHOLD_TENANT_IDS
else None
)
ENVIRONMENT = os.environ.get("ENVIRONMENT") or "not_explicitly_set"

View File

@@ -1,4 +1,6 @@
import time
from datetime import datetime
from datetime import timezone
import pytest
@@ -17,6 +19,10 @@ PRIVATE_CHANNEL_USERS = [
"test_user_2@onyx-test.com",
]
# Predates any test workspace messages, so the result set should match
# the "no start time" case while exercising the oldest= parameter.
OLDEST_TS_2016 = datetime(2016, 1, 1, tzinfo=timezone.utc).timestamp()
pytestmark = pytest.mark.usefixtures("enable_ee")
@@ -105,15 +111,17 @@ def test_load_from_checkpoint_access__private_channel(
],
indirect=True,
)
@pytest.mark.parametrize("start_ts", [None, OLDEST_TS_2016])
def test_slim_documents_access__public_channel(
slack_connector: SlackConnector,
start_ts: float | None,
) -> None:
"""Test that retrieve_all_slim_docs_perm_sync returns correct access information for slim documents."""
if not slack_connector.client:
raise RuntimeError("Web client must be defined")
slim_docs_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
start=0.0,
start=start_ts,
end=time.time(),
)
@@ -149,7 +157,7 @@ def test_slim_documents_access__private_channel(
raise RuntimeError("Web client must be defined")
slim_docs_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
start=0.0,
start=None,
end=time.time(),
)

View File

@@ -129,6 +129,10 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
return_value=mock_app,
),
patch(_PATCH_QUEUE_DEPTH, return_value=0),
patch(
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
return_value=MagicMock(),
),
):
yield

View File

@@ -88,10 +88,22 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
the actual task instance. We patch ``app`` on that instance's class
(a unique Celery-generated Task subclass) so the mock is scoped to this
task only.
Also patches ``celery_get_broker_client`` so the mock app doesn't need
a real broker URL.
"""
task_instance = task.run.__self__
with patch.object(
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
with (
patch.object(
type(task_instance),
"app",
new_callable=PropertyMock,
return_value=mock_app,
),
patch(
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
return_value=MagicMock(),
),
):
yield

View File

@@ -1,7 +1,7 @@
"""
External dependency unit tests for UserFileIndexingAdapter metadata writing.
Validates that build_metadata_aware_chunks produces DocMetadataAwareIndexChunk
Validates that prepare_enrichment produces DocMetadataAwareIndexChunk
objects with both `user_project` and `personas` fields populated correctly
based on actual DB associations.
@@ -127,7 +127,7 @@ def _make_index_chunk(user_file: UserFile) -> IndexChunk:
class TestAdapterWritesBothMetadataFields:
"""build_metadata_aware_chunks must populate user_project AND personas."""
"""prepare_enrichment must populate user_project AND personas."""
@patch(
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
@@ -153,15 +153,13 @@ class TestAdapterWritesBothMetadataFields:
doc = chunk.source_document
context = DocumentBatchPrepareContext(updatable_docs=[doc], id_to_boost_map={})
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
enricher = adapter.prepare_enrichment(
context=context,
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
assert len(result.chunks) == 1
aware_chunk = result.chunks[0]
assert persona.id in aware_chunk.personas
assert aware_chunk.user_project == []
@@ -190,15 +188,13 @@ class TestAdapterWritesBothMetadataFields:
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
enricher = adapter.prepare_enrichment(
context=context,
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
assert len(result.chunks) == 1
aware_chunk = result.chunks[0]
assert project.id in aware_chunk.user_project
assert aware_chunk.personas == []
@@ -229,14 +225,13 @@ class TestAdapterWritesBothMetadataFields:
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
enricher = adapter.prepare_enrichment(
context=context,
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
aware_chunk = result.chunks[0]
assert persona.id in aware_chunk.personas
assert project.id in aware_chunk.user_project
@@ -261,14 +256,13 @@ class TestAdapterWritesBothMetadataFields:
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
enricher = adapter.prepare_enrichment(
context=context,
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
aware_chunk = result.chunks[0]
assert aware_chunk.personas == []
assert aware_chunk.user_project == []
@@ -300,12 +294,11 @@ class TestAdapterWritesBothMetadataFields:
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
enricher = adapter.prepare_enrichment(
context=context,
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
aware_chunk = result.chunks[0]
assert set(aware_chunk.personas) == {persona_a.id, persona_b.id}

View File

@@ -90,8 +90,17 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
task only.
"""
task_instance = task.run.__self__
with patch.object(
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
with (
patch.object(
type(task_instance),
"app",
new_callable=PropertyMock,
return_value=mock_app,
),
patch(
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
return_value=MagicMock(),
),
):
yield

View File

@@ -6,6 +6,7 @@ These tests assume Vespa and OpenSearch are running.
import time
import uuid
from collections.abc import Generator
from collections.abc import Iterator
import httpx
import pytest
@@ -21,6 +22,7 @@ from onyx.document_index.opensearch.opensearch_document_index import (
)
from onyx.document_index.vespa.index import VespaIndex
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
from onyx.indexing.models import DocMetadataAwareIndexChunk
from tests.external_dependency_unit.constants import TEST_TENANT_ID
from tests.external_dependency_unit.document_index.conftest import EMBEDDING_DIM
from tests.external_dependency_unit.document_index.conftest import make_chunk
@@ -201,3 +203,25 @@ class TestDocumentIndexNew:
assert len(result_map) == 2
assert result_map[existing_doc] is True
assert result_map[new_doc] is False
def test_index_accepts_generator(
self,
document_indices: list[DocumentIndexNew],
tenant_context: None, # noqa: ARG002
) -> None:
"""index() accepts a generator (any iterable), not just a list."""
for document_index in document_indices:
doc_id = f"test_gen_{uuid.uuid4().hex[:8]}"
metadata = make_indexing_metadata([doc_id], old_counts=[0], new_counts=[3])
def chunk_gen() -> Iterator[DocMetadataAwareIndexChunk]:
for i in range(3):
yield make_chunk(doc_id, chunk_id=i)
results = document_index.index(
chunks=chunk_gen(), indexing_metadata=metadata
)
assert len(results) == 1
assert results[0].document_id == doc_id
assert results[0].already_existed is False

View File

@@ -5,6 +5,7 @@ These tests assume Vespa and OpenSearch are running.
import time
from collections.abc import Generator
from collections.abc import Iterator
import pytest
@@ -166,3 +167,29 @@ class TestDocumentIndexOld:
batch_retrieval=True,
)
assert len(inference_chunks) == 0
def test_index_accepts_generator(
self,
document_indices: list[DocumentIndex],
tenant_context: None, # noqa: ARG002
) -> None:
"""index() accepts a generator (any iterable), not just a list."""
for document_index in document_indices:
def chunk_gen() -> Iterator[DocMetadataAwareIndexChunk]:
for i in range(3):
yield make_chunk("test_doc_gen", chunk_id=i)
index_batch_params = IndexBatchParams(
doc_id_to_previous_chunk_cnt={"test_doc_gen": 0},
doc_id_to_new_chunk_cnt={"test_doc_gen": 3},
tenant_id=get_current_tenant_id(),
large_chunks_enabled=False,
)
results = document_index.index(chunk_gen(), index_batch_params)
assert len(results) == 1
record = results.pop()
assert record.document_id == "test_doc_gen"
assert record.already_existed is False

View File

@@ -0,0 +1,109 @@
"""Tests for TTL management task resilience."""
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
_TASK_MODULE = "ee.onyx.background.celery.tasks.ttl_management.tasks"
def _setup_db_session_mock(mock_get_db_session: MagicMock) -> None:
mock_db_session = MagicMock()
mock_get_db_session.return_value.__enter__ = MagicMock(
return_value=mock_db_session
)
mock_get_db_session.return_value.__exit__ = MagicMock(return_value=False)
@patch(f"{_TASK_MODULE}.get_session_with_current_tenant")
@patch(f"{_TASK_MODULE}.delete_chat_session")
@patch(f"{_TASK_MODULE}.get_chat_sessions_older_than")
@patch(f"{_TASK_MODULE}.mark_task_as_finished_with_id")
@patch(f"{_TASK_MODULE}.register_task")
def test_ttl_task_continues_after_session_delete_failure(
_mock_register: Any,
mock_mark_finished: MagicMock,
mock_get_old_sessions: MagicMock,
mock_delete_session: MagicMock,
mock_get_db_session: MagicMock,
) -> None:
"""One failing session should not prevent cleanup of remaining sessions."""
from ee.onyx.background.celery.tasks.ttl_management.tasks import (
perform_ttl_management_task,
)
user1, session1 = uuid4(), uuid4()
user2, session2 = uuid4(), uuid4()
user3, session3 = uuid4(), uuid4()
mock_get_old_sessions.return_value = [
(user1, session1),
(user2, session2),
(user3, session3),
]
# Second session fails
mock_delete_session.side_effect = [
None,
RuntimeError("File does not exist"),
None,
]
_setup_db_session_mock(mock_get_db_session)
mock_task = MagicMock()
mock_task.request.id = "test-task-id"
# Call the underlying function directly, bypassing Celery decorator
perform_ttl_management_task.__wrapped__(
mock_task, retention_limit_days=30, tenant_id="test"
)
# All three sessions should have been attempted
assert mock_delete_session.call_count == 3
# Task marked as finished with success=False (due to the one failure)
mock_mark_finished.assert_called()
finish_call_kwargs = mock_mark_finished.call_args[1]
assert finish_call_kwargs["success"] is False
@patch(f"{_TASK_MODULE}.get_session_with_current_tenant")
@patch(f"{_TASK_MODULE}.delete_chat_session")
@patch(f"{_TASK_MODULE}.get_chat_sessions_older_than")
@patch(f"{_TASK_MODULE}.mark_task_as_finished_with_id")
@patch(f"{_TASK_MODULE}.register_task")
def test_ttl_task_reports_success_when_all_deletions_pass(
_mock_register: Any,
mock_mark_finished: MagicMock,
mock_get_old_sessions: MagicMock,
mock_delete_session: MagicMock,
mock_get_db_session: MagicMock,
) -> None:
"""Task should report success when all sessions are deleted."""
from ee.onyx.background.celery.tasks.ttl_management.tasks import (
perform_ttl_management_task,
)
mock_get_old_sessions.return_value = [
(uuid4(), uuid4()),
(uuid4(), uuid4()),
]
mock_delete_session.side_effect = None
_setup_db_session_mock(mock_get_db_session)
mock_task = MagicMock()
mock_task.request.id = "test-task-id"
perform_ttl_management_task.__wrapped__(
mock_task, retention_limit_days=30, tenant_id="test"
)
assert mock_delete_session.call_count == 2
mock_mark_finished.assert_called()
finish_call_kwargs = mock_mark_finished.call_args[1]
assert finish_call_kwargs["success"] is True

View File

@@ -0,0 +1,87 @@
"""Tests for celery_get_broker_client singleton."""
from collections.abc import Iterator
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.background.celery import celery_redis
@pytest.fixture(autouse=True)
def reset_singleton() -> Iterator[None]:
"""Reset the module-level singleton between tests."""
celery_redis._broker_client = None
celery_redis._broker_url = None
yield
celery_redis._broker_client = None
celery_redis._broker_url = None
def _make_mock_app(broker_url: str = "redis://localhost:6379/15") -> MagicMock:
app = MagicMock()
app.conf.broker_url = broker_url
return app
class TestCeleryGetBrokerClient:
@patch("onyx.background.celery.celery_redis.Redis")
def test_creates_client_on_first_call(self, mock_redis_cls: MagicMock) -> None:
mock_client = MagicMock()
mock_redis_cls.from_url.return_value = mock_client
app = _make_mock_app()
result = celery_redis.celery_get_broker_client(app)
assert result is mock_client
call_args = mock_redis_cls.from_url.call_args
assert call_args[0][0] == "redis://localhost:6379/15"
assert call_args[1]["decode_responses"] is False
assert call_args[1]["socket_keepalive"] is True
assert call_args[1]["retry_on_timeout"] is True
@patch("onyx.background.celery.celery_redis.Redis")
def test_reuses_cached_client(self, mock_redis_cls: MagicMock) -> None:
mock_client = MagicMock()
mock_client.ping.return_value = True
mock_redis_cls.from_url.return_value = mock_client
app = _make_mock_app()
client1 = celery_redis.celery_get_broker_client(app)
client2 = celery_redis.celery_get_broker_client(app)
assert client1 is client2
# from_url called only once
assert mock_redis_cls.from_url.call_count == 1
@patch("onyx.background.celery.celery_redis.Redis")
def test_reconnects_on_ping_failure(self, mock_redis_cls: MagicMock) -> None:
stale_client = MagicMock()
stale_client.ping.side_effect = ConnectionError("disconnected")
fresh_client = MagicMock()
fresh_client.ping.return_value = True
mock_redis_cls.from_url.side_effect = [stale_client, fresh_client]
app = _make_mock_app()
# First call creates stale_client
client1 = celery_redis.celery_get_broker_client(app)
assert client1 is stale_client
# Second call: ping fails, creates fresh_client
client2 = celery_redis.celery_get_broker_client(app)
assert client2 is fresh_client
assert mock_redis_cls.from_url.call_count == 2
@patch("onyx.background.celery.celery_redis.Redis")
def test_uses_broker_url_from_app_config(self, mock_redis_cls: MagicMock) -> None:
mock_redis_cls.from_url.return_value = MagicMock()
app = _make_mock_app("redis://custom-host:6380/3")
celery_redis.celery_get_broker_client(app)
call_args = mock_redis_cls.from_url.call_args
assert call_args[0][0] == "redis://custom-host:6380/3"

View File

@@ -1,15 +1,23 @@
"""Tests for Canvas connector — client (PR1)."""
"""Tests for Canvas connector — client, credentials, conversion."""
from datetime import datetime
from datetime import timezone
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.canvas.client import CanvasApiClient
from onyx.connectors.canvas.connector import CanvasConnector
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.error_handling.exceptions import OnyxError
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
@@ -18,6 +26,77 @@ FAKE_BASE_URL = "https://myschool.instructure.com"
FAKE_TOKEN = "fake-canvas-token"
def _mock_course(
course_id: int = 1,
name: str = "Intro to CS",
course_code: str = "CS101",
) -> dict[str, Any]:
return {
"id": course_id,
"name": name,
"course_code": course_code,
"created_at": "2025-01-01T00:00:00Z",
"workflow_state": "available",
}
def _build_connector(base_url: str = FAKE_BASE_URL) -> CanvasConnector:
"""Build a connector with mocked credential validation."""
with patch("onyx.connectors.canvas.client.rl_requests") as mock_req:
mock_req.get.return_value = _mock_response(json_data=[_mock_course()])
connector = CanvasConnector(canvas_base_url=base_url)
connector.load_credentials({"canvas_access_token": FAKE_TOKEN})
return connector
def _mock_page(
page_id: int = 10,
title: str = "Syllabus",
updated_at: str = "2025-06-01T12:00:00Z",
) -> dict[str, Any]:
return {
"page_id": page_id,
"url": "syllabus",
"title": title,
"body": "<p>Welcome to the course</p>",
"created_at": "2025-01-15T00:00:00Z",
"updated_at": updated_at,
}
def _mock_assignment(
assignment_id: int = 20,
name: str = "Homework 1",
course_id: int = 1,
updated_at: str = "2025-06-01T12:00:00Z",
) -> dict[str, Any]:
return {
"id": assignment_id,
"name": name,
"description": "<p>Solve these problems</p>",
"html_url": f"{FAKE_BASE_URL}/courses/{course_id}/assignments/{assignment_id}",
"course_id": course_id,
"created_at": "2025-01-20T00:00:00Z",
"updated_at": updated_at,
"due_at": "2025-02-01T23:59:00Z",
}
def _mock_announcement(
announcement_id: int = 30,
title: str = "Class Cancelled",
course_id: int = 1,
posted_at: str = "2025-06-01T12:00:00Z",
) -> dict[str, Any]:
return {
"id": announcement_id,
"title": title,
"message": "<p>No class today</p>",
"html_url": f"{FAKE_BASE_URL}/courses/{course_id}/discussion_topics/{announcement_id}",
"posted_at": posted_at,
}
def _mock_response(
status_code: int = 200,
json_data: Any = None,
@@ -325,6 +404,57 @@ class TestGet:
assert result == expected
# ---------------------------------------------------------------------------
# CanvasApiClient.paginate tests
# ---------------------------------------------------------------------------
class TestPaginate:
@patch("onyx.connectors.canvas.client.rl_requests")
def test_single_page(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(
json_data=[{"id": 1}, {"id": 2}]
)
client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url=FAKE_BASE_URL,
)
pages = list(client.paginate("courses"))
assert len(pages) == 1
assert pages[0] == [{"id": 1}, {"id": 2}]
@patch("onyx.connectors.canvas.client.rl_requests")
def test_two_pages(self, mock_requests: MagicMock) -> None:
next_link = f'<{FAKE_BASE_URL}/api/v1/courses?page=2>; rel="next"'
page1 = _mock_response(json_data=[{"id": 1}], link_header=next_link)
page2 = _mock_response(json_data=[{"id": 2}])
mock_requests.get.side_effect = [page1, page2]
client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url=FAKE_BASE_URL,
)
pages = list(client.paginate("courses"))
assert len(pages) == 2
assert pages[0] == [{"id": 1}]
assert pages[1] == [{"id": 2}]
@patch("onyx.connectors.canvas.client.rl_requests")
def test_empty_response(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[])
client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url=FAKE_BASE_URL,
)
pages = list(client.paginate("courses"))
assert pages == []
# ---------------------------------------------------------------------------
# CanvasApiClient._parse_next_link tests
# ---------------------------------------------------------------------------
@@ -379,3 +509,368 @@ class TestParseNextLink:
with pytest.raises(OnyxError, match="must use https"):
self.client._parse_next_link(header)
# ---------------------------------------------------------------------------
# CanvasConnector — credential loading
# ---------------------------------------------------------------------------
class TestLoadCredentials:
def _assert_load_credentials_raises(
self,
status_code: int,
expected_error: type[Exception],
mock_requests: MagicMock,
) -> None:
"""Helper: assert load_credentials raises expected_error for a given status."""
mock_requests.get.return_value = _mock_response(status_code, {})
connector = CanvasConnector(canvas_base_url=FAKE_BASE_URL)
with pytest.raises(expected_error):
connector.load_credentials({"canvas_access_token": FAKE_TOKEN})
@patch("onyx.connectors.canvas.client.rl_requests")
def test_load_credentials_success(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[_mock_course()])
connector = CanvasConnector(canvas_base_url=FAKE_BASE_URL)
result = connector.load_credentials({"canvas_access_token": FAKE_TOKEN})
assert result is None
assert connector._canvas_client is not None
def test_canvas_client_raises_without_credentials(self) -> None:
connector = CanvasConnector(canvas_base_url=FAKE_BASE_URL)
with pytest.raises(ConnectorMissingCredentialError):
_ = connector.canvas_client
@patch("onyx.connectors.canvas.client.rl_requests")
def test_load_credentials_invalid_token(self, mock_requests: MagicMock) -> None:
self._assert_load_credentials_raises(401, CredentialExpiredError, mock_requests)
@patch("onyx.connectors.canvas.client.rl_requests")
def test_load_credentials_insufficient_permissions(
self, mock_requests: MagicMock
) -> None:
self._assert_load_credentials_raises(
403, InsufficientPermissionsError, mock_requests
)
# ---------------------------------------------------------------------------
# CanvasConnector — URL normalization
# ---------------------------------------------------------------------------
class TestConnectorUrlNormalization:
def test_strips_api_v1_suffix(self) -> None:
connector = _build_connector(base_url=f"{FAKE_BASE_URL}/api/v1")
result = connector.canvas_base_url
expected = FAKE_BASE_URL
assert result == expected
def test_strips_trailing_slash(self) -> None:
connector = _build_connector(base_url=f"{FAKE_BASE_URL}/")
result = connector.canvas_base_url
expected = FAKE_BASE_URL
assert result == expected
def test_no_change_for_clean_url(self) -> None:
connector = _build_connector(base_url=FAKE_BASE_URL)
result = connector.canvas_base_url
expected = FAKE_BASE_URL
assert result == expected
# ---------------------------------------------------------------------------
# CanvasConnector — document conversion
# ---------------------------------------------------------------------------
class TestDocumentConversion:
def setup_method(self) -> None:
self.connector = _build_connector()
def test_convert_page_to_document(self) -> None:
from onyx.connectors.canvas.connector import CanvasPage
page = CanvasPage(
page_id=10,
url="syllabus",
title="Syllabus",
body="<p>Welcome</p>",
created_at="2025-01-15T00:00:00Z",
updated_at="2025-06-01T12:00:00Z",
course_id=1,
)
doc = self.connector._convert_page_to_document(page)
expected_id = "canvas-page-1-10"
expected_metadata = {"course_id": "1", "type": "page"}
expected_updated_at = datetime(2025, 6, 1, 12, 0, tzinfo=timezone.utc)
assert doc.id == expected_id
assert doc.source == DocumentSource.CANVAS
assert doc.semantic_identifier == "Syllabus"
assert doc.metadata == expected_metadata
assert doc.sections[0].link is not None
assert f"{FAKE_BASE_URL}/courses/1/pages/syllabus" in doc.sections[0].link
assert doc.doc_updated_at == expected_updated_at
def test_convert_page_without_body(self) -> None:
from onyx.connectors.canvas.connector import CanvasPage
page = CanvasPage(
page_id=11,
url="empty-page",
title="Empty Page",
body=None,
created_at="2025-01-15T00:00:00Z",
updated_at="2025-06-01T12:00:00Z",
course_id=1,
)
doc = self.connector._convert_page_to_document(page)
section_text = doc.sections[0].text
assert section_text is not None
assert "Empty Page" in section_text
assert "<p>" not in section_text
def test_convert_assignment_to_document(self) -> None:
from onyx.connectors.canvas.connector import CanvasAssignment
assignment = CanvasAssignment(
id=20,
name="Homework 1",
description="<p>Solve these</p>",
html_url=f"{FAKE_BASE_URL}/courses/1/assignments/20",
course_id=1,
created_at="2025-01-20T00:00:00Z",
updated_at="2025-06-01T12:00:00Z",
due_at="2025-02-01T23:59:00Z",
)
doc = self.connector._convert_assignment_to_document(assignment)
expected_id = "canvas-assignment-1-20"
expected_due_text = "Due: February 01, 2025 23:59 UTC"
assert doc.id == expected_id
assert doc.source == DocumentSource.CANVAS
assert doc.semantic_identifier == "Homework 1"
assert doc.sections[0].text is not None
assert expected_due_text in doc.sections[0].text
def test_convert_assignment_without_description(self) -> None:
from onyx.connectors.canvas.connector import CanvasAssignment
assignment = CanvasAssignment(
id=21,
name="Quiz 1",
description=None,
html_url=f"{FAKE_BASE_URL}/courses/1/assignments/21",
course_id=1,
created_at="2025-01-20T00:00:00Z",
updated_at="2025-06-01T12:00:00Z",
due_at=None,
)
doc = self.connector._convert_assignment_to_document(assignment)
section_text = doc.sections[0].text
assert section_text is not None
assert "Quiz 1" in section_text
assert "Due:" not in section_text
def test_convert_announcement_to_document(self) -> None:
from onyx.connectors.canvas.connector import CanvasAnnouncement
announcement = CanvasAnnouncement(
id=30,
title="Class Cancelled",
message="<p>No class today</p>",
html_url=f"{FAKE_BASE_URL}/courses/1/discussion_topics/30",
posted_at="2025-06-01T12:00:00Z",
course_id=1,
)
doc = self.connector._convert_announcement_to_document(announcement)
expected_id = "canvas-announcement-1-30"
expected_updated_at = datetime(2025, 6, 1, 12, 0, tzinfo=timezone.utc)
assert doc.id == expected_id
assert doc.source == DocumentSource.CANVAS
assert doc.semantic_identifier == "Class Cancelled"
assert doc.doc_updated_at == expected_updated_at
def test_convert_announcement_without_posted_at(self) -> None:
from onyx.connectors.canvas.connector import CanvasAnnouncement
announcement = CanvasAnnouncement(
id=31,
title="TBD Announcement",
message=None,
html_url=f"{FAKE_BASE_URL}/courses/1/discussion_topics/31",
posted_at=None,
course_id=1,
)
doc = self.connector._convert_announcement_to_document(announcement)
assert doc.doc_updated_at is None
# ---------------------------------------------------------------------------
# CanvasConnector — validate_connector_settings
# ---------------------------------------------------------------------------
class TestValidateConnectorSettings:
def _assert_validate_raises(
self,
status_code: int,
expected_error: type[Exception],
mock_requests: MagicMock,
) -> None:
"""Helper: assert validate_connector_settings raises expected_error."""
success_resp = _mock_response(json_data=[_mock_course()])
fail_resp = _mock_response(status_code, {})
mock_requests.get.side_effect = [success_resp, fail_resp]
connector = CanvasConnector(canvas_base_url=FAKE_BASE_URL)
connector.load_credentials({"canvas_access_token": FAKE_TOKEN})
with pytest.raises(expected_error):
connector.validate_connector_settings()
@patch("onyx.connectors.canvas.client.rl_requests")
def test_validate_success(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[_mock_course()])
connector = _build_connector()
connector.validate_connector_settings() # should not raise
@patch("onyx.connectors.canvas.client.rl_requests")
def test_validate_expired_credential(self, mock_requests: MagicMock) -> None:
self._assert_validate_raises(401, CredentialExpiredError, mock_requests)
@patch("onyx.connectors.canvas.client.rl_requests")
def test_validate_insufficient_permissions(self, mock_requests: MagicMock) -> None:
self._assert_validate_raises(403, InsufficientPermissionsError, mock_requests)
@patch("onyx.connectors.canvas.client.rl_requests")
def test_validate_rate_limited(self, mock_requests: MagicMock) -> None:
self._assert_validate_raises(429, ConnectorValidationError, mock_requests)
@patch("onyx.connectors.canvas.client.rl_requests")
def test_validate_unexpected_error(self, mock_requests: MagicMock) -> None:
self._assert_validate_raises(500, UnexpectedValidationError, mock_requests)
# ---------------------------------------------------------------------------
# _list_* pagination tests
# ---------------------------------------------------------------------------
class TestListCourses:
@patch("onyx.connectors.canvas.client.rl_requests")
def test_single_page(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(
json_data=[_mock_course(1), _mock_course(2, "CS201", "Data Structures")]
)
connector = _build_connector()
result = connector._list_courses()
assert len(result) == 2
assert result[0].id == 1
assert result[1].id == 2
@patch("onyx.connectors.canvas.client.rl_requests")
def test_empty_response(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[])
connector = _build_connector()
result = connector._list_courses()
assert result == []
class TestListPages:
@patch("onyx.connectors.canvas.client.rl_requests")
def test_single_page(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(
json_data=[_mock_page(10), _mock_page(11, "Notes")]
)
connector = _build_connector()
result = connector._list_pages(course_id=1)
assert len(result) == 2
assert result[0].page_id == 10
assert result[1].page_id == 11
@patch("onyx.connectors.canvas.client.rl_requests")
def test_empty_response(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[])
connector = _build_connector()
result = connector._list_pages(course_id=1)
assert result == []
class TestListAssignments:
@patch("onyx.connectors.canvas.client.rl_requests")
def test_single_page(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(
json_data=[_mock_assignment(20), _mock_assignment(21, "Quiz 1")]
)
connector = _build_connector()
result = connector._list_assignments(course_id=1)
assert len(result) == 2
assert result[0].id == 20
assert result[1].id == 21
@patch("onyx.connectors.canvas.client.rl_requests")
def test_empty_response(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[])
connector = _build_connector()
result = connector._list_assignments(course_id=1)
assert result == []
class TestListAnnouncements:
@patch("onyx.connectors.canvas.client.rl_requests")
def test_single_page(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(
json_data=[_mock_announcement(30), _mock_announcement(31, "Update")]
)
connector = _build_connector()
result = connector._list_announcements(course_id=1)
assert len(result) == 2
assert result[0].id == 30
assert result[1].id == 31
@patch("onyx.connectors.canvas.client.rl_requests")
def test_empty_response(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[])
connector = _build_connector()
result = connector._list_announcements(course_id=1)
assert result == []

View File

@@ -1,3 +1,5 @@
from datetime import datetime
from datetime import timezone
from unittest.mock import MagicMock
from unittest.mock import patch
@@ -31,6 +33,7 @@ def mock_jira_cc_pair(
"jira_base_url": jira_base_url,
"project_key": project_key,
}
mock_cc_pair.connector.indexing_start = None
return mock_cc_pair
@@ -65,3 +68,75 @@ def test_jira_permission_sync(
fetch_all_existing_docs_ids_fn=mock_fetch_all_existing_docs_ids_fn,
):
print(doc)
def test_jira_doc_sync_passes_indexing_start(
jira_connector: JiraConnector,
mock_jira_cc_pair: MagicMock,
mock_fetch_all_existing_docs_fn: MagicMock,
mock_fetch_all_existing_docs_ids_fn: MagicMock,
) -> None:
"""Verify that generic_doc_sync derives indexing_start from cc_pair
and forwards it to retrieve_all_slim_docs_perm_sync."""
indexing_start_dt = datetime(2025, 6, 1, tzinfo=timezone.utc)
mock_jira_cc_pair.connector.indexing_start = indexing_start_dt
with patch("onyx.connectors.jira.connector.build_jira_client") as mock_build_client:
mock_build_client.return_value = jira_connector._jira_client
assert jira_connector._jira_client is not None
jira_connector._jira_client._options = MagicMock()
jira_connector._jira_client._options.return_value = {
"rest_api_version": JIRA_SERVER_API_VERSION
}
with patch.object(
type(jira_connector),
"retrieve_all_slim_docs_perm_sync",
return_value=iter([]),
) as mock_retrieve:
list(
jira_doc_sync(
cc_pair=mock_jira_cc_pair,
fetch_all_existing_docs_fn=mock_fetch_all_existing_docs_fn,
fetch_all_existing_docs_ids_fn=mock_fetch_all_existing_docs_ids_fn,
)
)
mock_retrieve.assert_called_once()
call_kwargs = mock_retrieve.call_args
assert call_kwargs.kwargs["start"] == indexing_start_dt.timestamp()
def test_jira_doc_sync_passes_none_when_no_indexing_start(
jira_connector: JiraConnector,
mock_jira_cc_pair: MagicMock,
mock_fetch_all_existing_docs_fn: MagicMock,
mock_fetch_all_existing_docs_ids_fn: MagicMock,
) -> None:
"""Verify that indexing_start is None when the connector has no indexing_start set."""
mock_jira_cc_pair.connector.indexing_start = None
with patch("onyx.connectors.jira.connector.build_jira_client") as mock_build_client:
mock_build_client.return_value = jira_connector._jira_client
assert jira_connector._jira_client is not None
jira_connector._jira_client._options = MagicMock()
jira_connector._jira_client._options.return_value = {
"rest_api_version": JIRA_SERVER_API_VERSION
}
with patch.object(
type(jira_connector),
"retrieve_all_slim_docs_perm_sync",
return_value=iter([]),
) as mock_retrieve:
list(
jira_doc_sync(
cc_pair=mock_jira_cc_pair,
fetch_all_existing_docs_fn=mock_fetch_all_existing_docs_fn,
fetch_all_existing_docs_ids_fn=mock_fetch_all_existing_docs_ids_fn,
)
)
mock_retrieve.assert_called_once()
call_kwargs = mock_retrieve.call_args
assert call_kwargs.kwargs["start"] is None

View File

@@ -0,0 +1,58 @@
"""Tests for chat session and message deletion resilience."""
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
from onyx.db.chat import delete_messages_and_files_from_chat_session
@patch("onyx.db.chat.delete_orphaned_search_docs")
@patch("onyx.db.chat.get_default_file_store")
def test_delete_messages_skips_missing_files(
mock_get_file_store: MagicMock,
_mock_delete_orphaned: Any,
) -> None:
"""Deletion should continue when a referenced file record no longer exists."""
session_id = uuid4()
file_store = MagicMock()
file_store.delete_file.side_effect = [
None, # first file deletes fine
RuntimeError("File by id abc does not exist or was deleted"),
None, # third file deletes fine
]
mock_get_file_store.return_value = file_store
mock_db_session = MagicMock()
mock_db_session.execute.return_value.fetchall.return_value = [
(1, [{"id": "file-ok-1"}, {"id": "file-missing"}, {"id": "file-ok-2"}]),
]
delete_messages_and_files_from_chat_session(session_id, mock_db_session)
assert file_store.delete_file.call_count == 3
mock_db_session.execute.assert_called()
mock_db_session.commit.assert_called()
@patch("onyx.db.chat.delete_orphaned_search_docs")
@patch("onyx.db.chat.get_default_file_store")
def test_delete_messages_succeeds_with_no_files(
mock_get_file_store: MagicMock,
_mock_delete_orphaned: Any,
) -> None:
"""Deletion works when messages have no attached files."""
session_id = uuid4()
mock_db_session = MagicMock()
mock_db_session.execute.return_value.fetchall.return_value = [
(1, None),
(2, []),
]
delete_messages_and_files_from_chat_session(session_id, mock_db_session)
mock_get_file_store.return_value.delete_file.assert_not_called()
mock_db_session.commit.assert_called()

View File

@@ -0,0 +1,223 @@
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.access.models import DocumentAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.document_index.interfaces_new import IndexingMetadata
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.opensearch_document_index import (
OpenSearchDocumentIndex,
)
from onyx.indexing.models import ChunkEmbedding
from onyx.indexing.models import DocMetadataAwareIndexChunk
def _make_chunk(
doc_id: str,
chunk_id: int,
) -> DocMetadataAwareIndexChunk:
"""Creates a minimal DocMetadataAwareIndexChunk for testing."""
doc = Document(
id=doc_id,
sections=[TextSection(text="test", link="http://test.com")],
source=DocumentSource.FILE,
semantic_identifier="test_doc",
metadata={},
)
access = DocumentAccess.build(
user_emails=[],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=True,
)
return DocMetadataAwareIndexChunk(
chunk_id=chunk_id,
blurb="test",
content="test content",
source_links={0: "http://test.com"},
image_file_id=None,
section_continuation=False,
source_document=doc,
title_prefix="",
metadata_suffix_semantic="",
metadata_suffix_keyword="",
mini_chunk_texts=None,
large_chunk_id=None,
doc_summary="",
chunk_context="",
contextual_rag_reserved_tokens=0,
embeddings=ChunkEmbedding(full_embedding=[0.1] * 10, mini_chunk_embeddings=[]),
title_embedding=[0.1] * 10,
tenant_id="test_tenant",
access=access,
document_sets=set(),
user_project=[],
personas=[],
boost=0,
aggregated_chunk_boost_factor=1.0,
ancestor_hierarchy_node_ids=[],
)
def _make_index() -> tuple[OpenSearchDocumentIndex, MagicMock]:
"""Creates an OpenSearchDocumentIndex with a mocked client.
Returns the index and the mock for bulk_index_documents."""
mock_client = MagicMock()
mock_bulk = MagicMock()
mock_client.bulk_index_documents = mock_bulk
tenant_state = TenantState(tenant_id="test_tenant", multitenant=False)
index = OpenSearchDocumentIndex.__new__(OpenSearchDocumentIndex)
index._index_name = "test_index"
index._client = mock_client
index._tenant_state = tenant_state
return index, mock_bulk
def _make_metadata(doc_id: str, chunk_count: int) -> IndexingMetadata:
return IndexingMetadata(
doc_id_to_chunk_cnt_diff={
doc_id: IndexingMetadata.ChunkCounts(
old_chunk_cnt=0,
new_chunk_cnt=chunk_count,
),
},
)
@patch(
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
100,
)
def test_single_doc_under_batch_limit_flushes_once() -> None:
"""A document with fewer chunks than MAX_CHUNKS_PER_DOC_BATCH should flush once."""
index, mock_bulk = _make_index()
doc_id = "doc_1"
num_chunks = 50
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
metadata = _make_metadata(doc_id, num_chunks)
with patch.object(index, "delete", return_value=0):
index.index(chunks, metadata)
assert mock_bulk.call_count == 1
batch_arg = mock_bulk.call_args_list[0]
assert len(batch_arg.kwargs["documents"]) == num_chunks
@patch(
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
100,
)
def test_single_doc_over_batch_limit_flushes_multiple_times() -> None:
"""A document with more chunks than MAX_CHUNKS_PER_DOC_BATCH should flush multiple times."""
index, mock_bulk = _make_index()
doc_id = "doc_1"
num_chunks = 250
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
metadata = _make_metadata(doc_id, num_chunks)
with patch.object(index, "delete", return_value=0):
index.index(chunks, metadata)
# 250 chunks / 100 per batch = 3 flushes (100 + 100 + 50)
assert mock_bulk.call_count == 3
batch_sizes = [len(call.kwargs["documents"]) for call in mock_bulk.call_args_list]
assert batch_sizes == [100, 100, 50]
@patch(
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
100,
)
def test_single_doc_exactly_at_batch_limit() -> None:
"""A document with exactly MAX_CHUNKS_PER_DOC_BATCH chunks should flush once
(the flush happens on the next chunk, not at the boundary)."""
index, mock_bulk = _make_index()
doc_id = "doc_1"
num_chunks = 100
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
metadata = _make_metadata(doc_id, num_chunks)
with patch.object(index, "delete", return_value=0):
index.index(chunks, metadata)
# 100 chunks hit the >= check on chunk 101 which doesn't exist,
# so final flush handles all 100
# Actually: the elif fires when len(current_chunks) >= 100, which happens
# when current_chunks has 100 items and the 101st chunk arrives.
# With exactly 100 chunks, the 100th chunk makes len == 99, then appended -> 100.
# No 101st chunk arrives, so the final flush handles all 100.
assert mock_bulk.call_count == 1
@patch(
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
100,
)
def test_single_doc_one_over_batch_limit() -> None:
"""101 chunks for one doc: first 100 flushed when the 101st arrives, then
the 101st is flushed at the end."""
index, mock_bulk = _make_index()
doc_id = "doc_1"
num_chunks = 101
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
metadata = _make_metadata(doc_id, num_chunks)
with patch.object(index, "delete", return_value=0):
index.index(chunks, metadata)
assert mock_bulk.call_count == 2
batch_sizes = [len(call.kwargs["documents"]) for call in mock_bulk.call_args_list]
assert batch_sizes == [100, 1]
@patch(
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
100,
)
def test_multiple_docs_each_under_limit_flush_per_doc() -> None:
"""Multiple documents each under the batch limit should flush once per document."""
index, mock_bulk = _make_index()
chunks = []
for doc_idx in range(3):
doc_id = f"doc_{doc_idx}"
for chunk_idx in range(50):
chunks.append(_make_chunk(doc_id, chunk_idx))
metadata = IndexingMetadata(
doc_id_to_chunk_cnt_diff={
f"doc_{i}": IndexingMetadata.ChunkCounts(old_chunk_cnt=0, new_chunk_cnt=50)
for i in range(3)
},
)
with patch.object(index, "delete", return_value=0):
index.index(chunks, metadata)
# 3 documents = 3 flushes (one per doc boundary + final)
assert mock_bulk.call_count == 3
@patch(
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
100,
)
def test_delete_called_once_per_document() -> None:
"""Even with multiple flushes for a single document, delete should only be
called once per document."""
index, _mock_bulk = _make_index()
doc_id = "doc_1"
num_chunks = 250
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
metadata = _make_metadata(doc_id, num_chunks)
with patch.object(index, "delete", return_value=0) as mock_delete:
index.index(chunks, metadata)
mock_delete.assert_called_once_with(doc_id, None)

View File

@@ -0,0 +1,152 @@
"""Unit tests for VespaDocumentIndex.index().
These tests mock all external I/O (HTTP calls, thread pools) and verify
the streaming logic, ID cleaning/mapping, and DocumentInsertionRecord
construction.
"""
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.access.models import DocumentAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
from onyx.document_index.interfaces_new import IndexingMetadata
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
from onyx.indexing.models import ChunkEmbedding
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexChunk
def _make_chunk(
doc_id: str,
chunk_id: int = 0,
content: str = "test content",
) -> DocMetadataAwareIndexChunk:
doc = Document(
id=doc_id,
semantic_identifier="test_doc",
sections=[TextSection(text=content, link=None)],
source=DocumentSource.NOT_APPLICABLE,
metadata={},
)
index_chunk = IndexChunk(
chunk_id=chunk_id,
blurb=content[:50],
content=content,
source_links=None,
image_file_id=None,
section_continuation=False,
source_document=doc,
title_prefix="",
metadata_suffix_semantic="",
metadata_suffix_keyword="",
contextual_rag_reserved_tokens=0,
doc_summary="",
chunk_context="",
mini_chunk_texts=None,
large_chunk_id=None,
embeddings=ChunkEmbedding(
full_embedding=[0.1] * 10,
mini_chunk_embeddings=[],
),
title_embedding=None,
)
access = DocumentAccess.build(
user_emails=[],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=True,
)
return DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=index_chunk,
access=access,
document_sets=set(),
user_project=[],
personas=[],
boost=0,
aggregated_chunk_boost_factor=1.0,
tenant_id="test_tenant",
)
def _make_indexing_metadata(
doc_ids: list[str],
old_counts: list[int],
new_counts: list[int],
) -> IndexingMetadata:
return IndexingMetadata(
doc_id_to_chunk_cnt_diff={
doc_id: IndexingMetadata.ChunkCounts(
old_chunk_cnt=old,
new_chunk_cnt=new,
)
for doc_id, old, new in zip(doc_ids, old_counts, new_counts)
}
)
def _stub_enrich(
doc_id: str,
old_chunk_cnt: int,
) -> EnrichedDocumentIndexingInfo:
"""Build an EnrichedDocumentIndexingInfo that says 'no chunks to delete'
when old_chunk_cnt == 0, or 'has existing chunks' otherwise."""
return EnrichedDocumentIndexingInfo(
doc_id=doc_id,
chunk_start_index=0,
old_version=False,
chunk_end_index=old_chunk_cnt,
)
@patch("onyx.document_index.vespa.vespa_document_index.batch_index_vespa_chunks")
@patch("onyx.document_index.vespa.vespa_document_index.delete_vespa_chunks")
@patch(
"onyx.document_index.vespa.vespa_document_index.get_document_chunk_ids",
return_value=[],
)
@patch("onyx.document_index.vespa.vespa_document_index._enrich_basic_chunk_info")
@patch(
"onyx.document_index.vespa.vespa_document_index.BATCH_SIZE",
3,
)
def test_index_respects_batch_size(
mock_enrich: MagicMock,
mock_get_chunk_ids: MagicMock, # noqa: ARG001
mock_delete: MagicMock, # noqa: ARG001
mock_batch_index: MagicMock,
) -> None:
"""When chunks exceed BATCH_SIZE, batch_index_vespa_chunks is called
multiple times with correctly sized batches."""
mock_enrich.return_value = _stub_enrich("doc1", old_chunk_cnt=0)
index = VespaDocumentIndex(
index_name="test_index",
tenant_state=TenantState(tenant_id="test_tenant", multitenant=False),
large_chunks_enabled=False,
httpx_client=MagicMock(),
)
chunks = [_make_chunk("doc1", chunk_id=i) for i in range(7)]
metadata = _make_indexing_metadata(["doc1"], old_counts=[0], new_counts=[7])
results = index.index(chunks=chunks, indexing_metadata=metadata)
assert len(results) == 1
# With BATCH_SIZE=3 and 7 chunks: batches of 3, 3, 1
assert mock_batch_index.call_count == 3
batch_sizes = [len(c.kwargs["chunks"]) for c in mock_batch_index.call_args_list]
assert batch_sizes == [3, 3, 1]
# Verify all chunks are accounted for and in order
all_indexed = [
chunk for c in mock_batch_index.call_args_list for chunk in c.kwargs["chunks"]
]
assert len(all_indexed) == 7
assert [c.chunk_id for c in all_indexed] == list(range(7))

View File

@@ -0,0 +1,391 @@
"""Unit tests for _embed_chunks_to_store.
Tests cover:
- Single batch, no failures
- Multiple batches, no failures
- Failure in a single batch
- Cross-batch document failure scrubbing
- Later batches skip already-failed docs
- Empty input
- All chunks fail
"""
from collections.abc import Callable
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import DocumentSource
from onyx.connectors.models import TextSection
from onyx.indexing.chunk_batch_store import ChunkBatchStore
from onyx.indexing.indexing_pipeline import _embed_chunks_to_store
from onyx.indexing.models import ChunkEmbedding
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import IndexChunk
def _make_doc(doc_id: str) -> Document:
return Document(
id=doc_id,
semantic_identifier="test",
source=DocumentSource.FILE,
sections=[TextSection(text="test", link=None)],
metadata={},
)
def _make_chunk(doc_id: str, chunk_id: int) -> DocAwareChunk:
return DocAwareChunk(
chunk_id=chunk_id,
blurb="test",
content="test content",
source_links=None,
image_file_id=None,
section_continuation=False,
source_document=_make_doc(doc_id),
title_prefix="",
metadata_suffix_semantic="",
metadata_suffix_keyword="",
mini_chunk_texts=None,
large_chunk_id=None,
doc_summary="",
chunk_context="",
contextual_rag_reserved_tokens=0,
)
def _make_index_chunk(doc_id: str, chunk_id: int) -> IndexChunk:
"""Create an IndexChunk (a DocAwareChunk with embeddings)."""
return IndexChunk(
chunk_id=chunk_id,
blurb="test",
content="test content",
source_links=None,
image_file_id=None,
section_continuation=False,
source_document=_make_doc(doc_id),
title_prefix="",
metadata_suffix_semantic="",
metadata_suffix_keyword="",
mini_chunk_texts=None,
large_chunk_id=None,
doc_summary="",
chunk_context="",
contextual_rag_reserved_tokens=0,
embeddings=ChunkEmbedding(
full_embedding=[0.1] * 10,
mini_chunk_embeddings=[],
),
title_embedding=None,
)
def _make_failure(doc_id: str) -> ConnectorFailure:
return ConnectorFailure(
failed_document=DocumentFailure(document_id=doc_id, document_link=None),
failure_message="embedding failed",
exception=RuntimeError("embedding failed"),
)
def _mock_embed_success(
chunks: list[DocAwareChunk], **_kwargs: object
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
"""Simulate successful embedding of all chunks."""
return (
[_make_index_chunk(c.source_document.id, c.chunk_id) for c in chunks],
[],
)
def _mock_embed_fail_doc(
fail_doc_id: str,
) -> Callable[..., tuple[list[IndexChunk], list[ConnectorFailure]]]:
"""Return an embed mock that fails all chunks for a specific doc."""
def _embed(
chunks: list[DocAwareChunk], **_kwargs: object
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
successes = [
_make_index_chunk(c.source_document.id, c.chunk_id)
for c in chunks
if c.source_document.id != fail_doc_id
]
failures = (
[_make_failure(fail_doc_id)]
if any(c.source_document.id == fail_doc_id for c in chunks)
else []
)
return successes, failures
return _embed
class TestEmbedChunksInBatches:
@patch(
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
)
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 100)
def test_single_batch_no_failures(self, mock_embed: MagicMock) -> None:
"""All chunks fit in one batch and embed successfully."""
mock_embed.side_effect = _mock_embed_success
with ChunkBatchStore() as store:
chunks = [_make_chunk("doc1", i) for i in range(3)]
result = _embed_chunks_to_store(
chunks=chunks,
embedder=MagicMock(),
tenant_id="test",
request_id=None,
store=store,
)
assert len(result.successful_chunk_ids) == 3
assert len(result.connector_failures) == 0
# Verify stored contents
assert len(store._batch_files()) == 1
stored = list(store.stream())
assert len(stored) == 3
@patch(
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
)
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 3)
def test_multiple_batches_no_failures(self, mock_embed: MagicMock) -> None:
"""Chunks are split across multiple batches, all succeed."""
mock_embed.side_effect = _mock_embed_success
with ChunkBatchStore() as store:
chunks = [_make_chunk("doc1", i) for i in range(7)]
result = _embed_chunks_to_store(
chunks=chunks,
embedder=MagicMock(),
tenant_id="test",
request_id=None,
store=store,
)
assert len(result.successful_chunk_ids) == 7
assert len(result.connector_failures) == 0
assert len(store._batch_files()) == 3 # 3 + 3 + 1
@patch(
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
)
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 100)
def test_single_batch_with_failure(self, mock_embed: MagicMock) -> None:
"""One doc fails embedding, its chunks are excluded from results."""
mock_embed.side_effect = _mock_embed_fail_doc("doc2")
with ChunkBatchStore() as store:
chunks = [
_make_chunk("doc1", 0),
_make_chunk("doc2", 1),
_make_chunk("doc1", 2),
]
result = _embed_chunks_to_store(
chunks=chunks,
embedder=MagicMock(),
tenant_id="test",
request_id=None,
store=store,
)
assert len(result.connector_failures) == 1
successful_doc_ids = {doc_id for _, doc_id in result.successful_chunk_ids}
assert "doc2" not in successful_doc_ids
assert "doc1" in successful_doc_ids
@patch(
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
)
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 3)
def test_cross_batch_failure_scrubs_earlier_batch(
self, mock_embed: MagicMock
) -> None:
"""Doc A spans batches 0 and 1. It succeeds in batch 0 but fails in
batch 1. Its chunks should be scrubbed from batch 0's batch file."""
call_count = 0
def _embed(
chunks: list[DocAwareChunk], **_kwargs: object
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
nonlocal call_count
call_count += 1
if call_count == 1:
return _mock_embed_success(chunks)
else:
return _mock_embed_fail_doc("docA")(chunks)
mock_embed.side_effect = _embed
with ChunkBatchStore() as store:
chunks = [
_make_chunk("docA", 0),
_make_chunk("docA", 1),
_make_chunk("docA", 2),
_make_chunk("docA", 3),
_make_chunk("docB", 0),
_make_chunk("docB", 1),
]
result = _embed_chunks_to_store(
chunks=chunks,
embedder=MagicMock(),
tenant_id="test",
request_id=None,
store=store,
)
# docA should be fully excluded from results
successful_doc_ids = {doc_id for _, doc_id in result.successful_chunk_ids}
assert "docA" not in successful_doc_ids
assert "docB" in successful_doc_ids
assert len(result.connector_failures) == 1
# Verify batch 0 was scrubbed of docA chunks
all_stored = list(store.stream())
stored_doc_ids = {c.source_document.id for c in all_stored}
assert "docA" not in stored_doc_ids
assert "docB" in stored_doc_ids
@patch(
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
)
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 3)
def test_later_batch_skips_already_failed_doc(self, mock_embed: MagicMock) -> None:
"""If docA fails in batch 0, its chunks in batch 1 are skipped
entirely (never sent to the embedder)."""
embedded_doc_ids: list[str] = []
def _embed(
chunks: list[DocAwareChunk], **_kwargs: object
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
for c in chunks:
embedded_doc_ids.append(c.source_document.id)
return _mock_embed_fail_doc("docA")(chunks)
mock_embed.side_effect = _embed
with ChunkBatchStore() as store:
chunks = [
_make_chunk("docA", 0),
_make_chunk("docA", 1),
_make_chunk("docA", 2),
_make_chunk("docA", 3),
_make_chunk("docB", 0),
_make_chunk("docB", 1),
]
_embed_chunks_to_store(
chunks=chunks,
embedder=MagicMock(),
tenant_id="test",
request_id=None,
store=store,
)
# docA should only appear in batch 0, not batch 1
batch_1_doc_ids = embedded_doc_ids[3:]
assert "docA" not in batch_1_doc_ids
@patch(
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
)
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 3)
def test_failed_doc_skipped_in_later_batch_while_other_doc_succeeds(
self, mock_embed: MagicMock
) -> None:
"""doc1 spans batches 0 and 1, doc2 only in batch 1. Batch 0 fails
doc1. In batch 1, doc1 chunks should be skipped but doc2 chunks
should still be embedded successfully."""
embedded_chunks: list[list[str]] = []
def _embed(
chunks: list[DocAwareChunk], **_kwargs: object
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
embedded_chunks.append([c.source_document.id for c in chunks])
return _mock_embed_fail_doc("doc1")(chunks)
mock_embed.side_effect = _embed
with ChunkBatchStore() as store:
chunks = [
_make_chunk("doc1", 0),
_make_chunk("doc1", 1),
_make_chunk("doc1", 2),
_make_chunk("doc1", 3),
_make_chunk("doc2", 0),
_make_chunk("doc2", 1),
]
result = _embed_chunks_to_store(
chunks=chunks,
embedder=MagicMock(),
tenant_id="test",
request_id=None,
store=store,
)
# doc1 should be fully excluded, doc2 fully included
successful_doc_ids = {doc_id for _, doc_id in result.successful_chunk_ids}
assert "doc1" not in successful_doc_ids
assert "doc2" in successful_doc_ids
assert len(result.successful_chunk_ids) == 2 # doc2's 2 chunks
# Batch 1 should only contain doc2 (doc1 was filtered before embedding)
assert len(embedded_chunks) == 2
assert "doc1" not in embedded_chunks[1]
assert embedded_chunks[1] == ["doc2", "doc2"]
# Verify on-disk state has no doc1 chunks
all_stored = list(store.stream())
assert all(c.source_document.id == "doc2" for c in all_stored)
@patch(
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
)
def test_empty_input(self, mock_embed: MagicMock) -> None:
"""Empty chunk list produces empty results."""
mock_embed.side_effect = _mock_embed_success
with ChunkBatchStore() as store:
result = _embed_chunks_to_store(
chunks=[],
embedder=MagicMock(),
tenant_id="test",
request_id=None,
store=store,
)
assert len(result.successful_chunk_ids) == 0
assert len(result.connector_failures) == 0
mock_embed.assert_not_called()
@patch(
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
)
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 100)
def test_all_chunks_fail(self, mock_embed: MagicMock) -> None:
"""When all documents fail, results have no successful chunks."""
def _fail_all(
chunks: list[DocAwareChunk], **_kwargs: object
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
doc_ids = {c.source_document.id for c in chunks}
return [], [_make_failure(doc_id) for doc_id in doc_ids]
mock_embed.side_effect = _fail_all
with ChunkBatchStore() as store:
chunks = [_make_chunk("doc1", 0), _make_chunk("doc2", 1)]
result = _embed_chunks_to_store(
chunks=chunks,
embedder=MagicMock(),
tenant_id="test",
request_id=None,
store=store,
)
assert len(result.successful_chunk_ids) == 0
assert len(result.connector_failures) == 2

View File

@@ -116,7 +116,7 @@ def _run_adapter_build(
project_ids_map: dict[str, list[int]],
persona_ids_map: dict[str, list[int]],
) -> list[DocMetadataAwareIndexChunk]:
"""Helper that runs UserFileIndexingAdapter.build_metadata_aware_chunks
"""Helper that runs UserFileIndexingAdapter.prepare_enrichment + enrich_chunk
with all external dependencies mocked."""
from onyx.indexing.adapters.user_file_indexing_adapter import (
UserFileIndexingAdapter,
@@ -155,18 +155,16 @@ def _run_adapter_build(
side_effect=Exception("no LLM in tests"),
),
):
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id="test_tenant",
enricher = adapter.prepare_enrichment(
context=context,
tenant_id="test_tenant",
chunks=[chunk],
)
return result.chunks
return [enricher.enrich_chunk(chunk, 1.0)]
def test_build_metadata_aware_chunks_includes_persona_ids() -> None:
"""UserFileIndexingAdapter.build_metadata_aware_chunks writes persona IDs
def test_prepare_enrichment_includes_persona_ids() -> None:
"""UserFileIndexingAdapter.prepare_enrichment writes persona IDs
fetched from the DB into each chunk's metadata."""
file_id = str(uuid4())
persona_ids = [5, 12]
@@ -183,7 +181,7 @@ def test_build_metadata_aware_chunks_includes_persona_ids() -> None:
assert chunks[0].user_project == project_ids
def test_build_metadata_aware_chunks_missing_file_defaults_to_empty() -> None:
def test_prepare_enrichment_missing_file_defaults_to_empty() -> None:
"""When a file has no persona or project associations in the DB, the
adapter should default to empty lists (not KeyError or None)."""
file_id = str(uuid4())

View File

@@ -11,6 +11,7 @@ from litellm.types.utils import ChatCompletionDeltaToolCall
from litellm.types.utils import Delta
from litellm.types.utils import Function as LiteLLMFunction
import onyx.llm.models
from onyx.configs.app_configs import MOCK_LLM_RESPONSE
from onyx.llm.constants import LlmProviderNames
from onyx.llm.interfaces import LLMUserIdentity
@@ -1479,6 +1480,147 @@ def test_bifrost_normalizes_api_base_in_model_kwargs() -> None:
assert llm._model_kwargs["api_base"] == "https://bifrost.example.com/v1"
def test_prompt_contains_tool_call_history_true() -> None:
from onyx.llm.multi_llm import _prompt_contains_tool_call_history
messages: LanguageModelInput = [
UserMessage(content="What's the weather?"),
AssistantMessage(
content=None,
tool_calls=[
ToolCall(
id="tc_1",
function=FunctionCall(name="get_weather", arguments="{}"),
)
],
),
]
assert _prompt_contains_tool_call_history(messages) is True
def test_prompt_contains_tool_call_history_false_no_tools() -> None:
from onyx.llm.multi_llm import _prompt_contains_tool_call_history
messages: LanguageModelInput = [
UserMessage(content="Hello"),
AssistantMessage(content="Hi there!"),
]
assert _prompt_contains_tool_call_history(messages) is False
def test_prompt_contains_tool_call_history_false_user_only() -> None:
from onyx.llm.multi_llm import _prompt_contains_tool_call_history
messages: LanguageModelInput = [UserMessage(content="Hello")]
assert _prompt_contains_tool_call_history(messages) is False
def test_bedrock_claude_drops_thinking_when_thinking_blocks_missing() -> None:
"""When thinking is enabled but assistant messages with tool_calls lack
thinking_blocks, the thinking param must be dropped to avoid the Bedrock
BadRequestError about missing thinking blocks."""
llm = LitellmLLM(
api_key=None,
timeout=30,
model_provider=LlmProviderNames.BEDROCK,
model_name="anthropic.claude-sonnet-4-20250514-v1:0",
max_input_tokens=200000,
)
messages: LanguageModelInput = [
UserMessage(content="What's the weather?"),
AssistantMessage(
content=None,
tool_calls=[
ToolCall(
id="tc_1",
function=FunctionCall(
name="get_weather",
arguments='{"city": "Paris"}',
),
)
],
),
onyx.llm.models.ToolMessage(
content="22°C sunny",
tool_call_id="tc_1",
),
]
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
},
},
}
]
with (
patch("litellm.completion") as mock_completion,
patch("onyx.llm.multi_llm.model_is_reasoning_model", return_value=True),
):
mock_completion.return_value = []
list(llm.stream(messages, tools=tools, reasoning_effort=ReasoningEffort.HIGH))
kwargs = mock_completion.call_args.kwargs
assert "thinking" not in kwargs, (
"thinking param should be dropped when thinking_blocks are missing "
"from assistant messages with tool_calls"
)
def test_bedrock_claude_keeps_thinking_when_no_tool_history() -> None:
"""When thinking is enabled and there are no historical assistant messages
with tool_calls, the thinking param should be preserved."""
llm = LitellmLLM(
api_key=None,
timeout=30,
model_provider=LlmProviderNames.BEDROCK,
model_name="anthropic.claude-sonnet-4-20250514-v1:0",
max_input_tokens=200000,
)
messages: LanguageModelInput = [
UserMessage(content="What's the weather?"),
]
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
},
},
}
]
with (
patch("litellm.completion") as mock_completion,
patch("onyx.llm.multi_llm.model_is_reasoning_model", return_value=True),
):
mock_completion.return_value = []
list(llm.stream(messages, tools=tools, reasoning_effort=ReasoningEffort.HIGH))
kwargs = mock_completion.call_args.kwargs
assert "thinking" in kwargs, (
"thinking param should be preserved when no assistant messages "
"with tool_calls exist in history"
)
assert kwargs["thinking"]["type"] == "enabled"
def test_bifrost_claude_includes_allowed_openai_params() -> None:
llm = LitellmLLM(
api_key="test_key",

View File

@@ -4,13 +4,23 @@ from unittest.mock import MagicMock
import pytest
from fastapi import UploadFile
from onyx.natural_language_processing import utils as nlp_utils
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.natural_language_processing.utils import count_tokens
from onyx.server.features.projects import projects_file_utils as utils
from onyx.server.settings.models import Settings
class _Tokenizer:
class _Tokenizer(BaseTokenizer):
def encode(self, text: str) -> list[int]:
return [1] * len(text)
def tokenize(self, text: str) -> list[str]:
return list(text)
def decode(self, _tokens: list[int]) -> str:
return ""
class _NonSeekableFile(BytesIO):
def tell(self) -> int:
@@ -29,10 +39,26 @@ def _make_upload_no_size(filename: str, content: bytes) -> UploadFile:
return UploadFile(filename=filename, file=BytesIO(content), size=None)
def _patch_common_dependencies(monkeypatch: pytest.MonkeyPatch) -> None:
def _make_settings(upload_size_mb: int = 1, token_threshold_k: int = 100) -> Settings:
return Settings(
user_file_max_upload_size_mb=upload_size_mb,
file_token_count_threshold_k=token_threshold_k,
)
def _patch_common_dependencies(
monkeypatch: pytest.MonkeyPatch,
upload_size_mb: int = 1,
token_threshold_k: int = 100,
) -> None:
monkeypatch.setattr(utils, "fetch_default_llm_model", lambda _db: None)
monkeypatch.setattr(utils, "get_tokenizer", lambda **_kwargs: _Tokenizer())
monkeypatch.setattr(utils, "is_file_password_protected", lambda **_kwargs: False)
monkeypatch.setattr(
utils,
"load_settings",
lambda: _make_settings(upload_size_mb, token_threshold_k),
)
def test_get_upload_size_bytes_falls_back_to_stream_size() -> None:
@@ -76,9 +102,8 @@ def test_is_upload_too_large_logs_warning_when_size_unknown(
def test_categorize_uploaded_files_accepts_size_under_limit(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
# upload_size_mb=1 → max_bytes = 1*1024*1024; file size 99 is well under
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
upload = _make_upload("small.png", size=99)
@@ -91,9 +116,7 @@ def test_categorize_uploaded_files_accepts_size_under_limit(
def test_categorize_uploaded_files_uses_seek_fallback_when_upload_size_missing(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
upload = _make_upload_no_size("small.png", content=b"x" * 99)
@@ -106,12 +129,11 @@ def test_categorize_uploaded_files_uses_seek_fallback_when_upload_size_missing(
def test_categorize_uploaded_files_accepts_size_at_limit(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
upload = _make_upload("edge.png", size=100)
# 1 MB = 1048576 bytes; file at exactly that boundary should be accepted
upload = _make_upload("edge.png", size=1048576)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 1
@@ -121,12 +143,10 @@ def test_categorize_uploaded_files_accepts_size_at_limit(
def test_categorize_uploaded_files_rejects_size_over_limit_with_reason(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
upload = _make_upload("large.png", size=101)
upload = _make_upload("large.png", size=1048577) # 1 byte over 1 MB
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 0
@@ -137,13 +157,11 @@ def test_categorize_uploaded_files_rejects_size_over_limit_with_reason(
def test_categorize_uploaded_files_mixed_batch_keeps_valid_and_rejects_oversized(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
small = _make_upload("small.png", size=50)
large = _make_upload("large.png", size=101)
large = _make_upload("large.png", size=1048577)
result = utils.categorize_uploaded_files([small, large], MagicMock())
@@ -153,15 +171,12 @@ def test_categorize_uploaded_files_mixed_batch_keeps_valid_and_rejects_oversized
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
def test_categorize_uploaded_files_enforces_size_limit_even_when_threshold_is_skipped(
def test_categorize_uploaded_files_enforces_size_limit_always(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "SKIP_USERFILE_THRESHOLD", True)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
upload = _make_upload("oversized.pdf", size=101)
upload = _make_upload("oversized.pdf", size=1048577)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 0
@@ -172,14 +187,12 @@ def test_categorize_uploaded_files_enforces_size_limit_even_when_threshold_is_sk
def test_categorize_uploaded_files_checks_size_before_text_extraction(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
extract_mock = MagicMock(return_value="this should not run")
monkeypatch.setattr(utils, "extract_file_text", extract_mock)
oversized_doc = _make_upload("oversized.pdf", size=101)
oversized_doc = _make_upload("oversized.pdf", size=1048577)
result = utils.categorize_uploaded_files([oversized_doc], MagicMock())
extract_mock.assert_not_called()
@@ -188,40 +201,219 @@ def test_categorize_uploaded_files_checks_size_before_text_extraction(
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
def test_categorize_uploaded_files_accepts_python_file(
def test_categorize_enforces_size_limit_when_upload_size_mb_is_positive(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 10_000)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
"""A positive upload_size_mb is always enforced."""
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
py_source = b'def hello():\n print("world")\n'
monkeypatch.setattr(
utils, "extract_file_text", lambda **_kwargs: py_source.decode()
)
upload = _make_upload("script.py", size=len(py_source), content=py_source)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 1
assert result.acceptable[0].filename == "script.py"
assert len(result.rejected) == 0
def test_categorize_uploaded_files_rejects_binary_file(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 10_000)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: "")
binary_content = bytes(range(256)) * 4
upload = _make_upload("data.bin", size=len(binary_content), content=binary_content)
upload = _make_upload("huge.png", size=1048577, content=b"x")
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 0
assert len(result.rejected) == 1
assert result.rejected[0].filename == "data.bin"
assert "Unsupported file type" in result.rejected[0].reason
def test_categorize_enforces_token_limit_when_threshold_k_is_positive(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A positive token_threshold_k is always enforced."""
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=5)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 6000)
upload = _make_upload("big_image.png", size=100)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 0
assert len(result.rejected) == 1
def test_categorize_no_token_limit_when_threshold_k_is_zero(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""token_threshold_k=0 means no token limit; high-token files are accepted."""
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=0)
monkeypatch.setattr(
utils, "estimate_image_tokens_for_upload", lambda _upload: 999_999
)
upload = _make_upload("huge_image.png", size=100)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.rejected) == 0
assert len(result.acceptable) == 1
def test_categorize_both_limits_enforced(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Both positive limits are enforced; file exceeding token limit is rejected."""
_patch_common_dependencies(monkeypatch, upload_size_mb=10, token_threshold_k=5)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 6000)
upload = _make_upload("over_tokens.png", size=100)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 0
assert len(result.rejected) == 1
assert result.rejected[0].reason == "Exceeds 5K token limit"
def test_categorize_rejection_reason_contains_dynamic_values(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Rejection reasons reflect the admin-configured limits, not hardcoded values."""
_patch_common_dependencies(monkeypatch, upload_size_mb=42, token_threshold_k=7)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 8000)
# File within size limit but over token limit
upload = _make_upload("tokens.png", size=100)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert result.rejected[0].reason == "Exceeds 7K token limit"
# File over size limit
_patch_common_dependencies(monkeypatch, upload_size_mb=42, token_threshold_k=7)
oversized = _make_upload("big.png", size=42 * 1024 * 1024 + 1)
result2 = utils.categorize_uploaded_files([oversized], MagicMock())
assert result2.rejected[0].reason == "Exceeds 42 MB file size limit"
# --- count_tokens tests ---
def test_count_tokens_small_text() -> None:
"""Small text should be encoded in a single call and return correct count."""
tokenizer = _Tokenizer()
text = "hello world"
assert count_tokens(text, tokenizer) == len(tokenizer.encode(text))
def test_count_tokens_chunked_matches_single_call() -> None:
"""Chunked encoding should produce the same result as single-call for small text."""
tokenizer = _Tokenizer()
text = "a" * 1000
assert count_tokens(text, tokenizer) == len(tokenizer.encode(text))
def test_count_tokens_large_text_is_chunked(monkeypatch: pytest.MonkeyPatch) -> None:
"""Text exceeding _ENCODE_CHUNK_SIZE should be split into multiple encode calls."""
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
tokenizer = _Tokenizer()
text = "a" * 250
# _Tokenizer returns 1 token per char, so total should be 250
assert count_tokens(text, tokenizer) == 250
def test_count_tokens_with_token_limit_exits_early(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When token_limit is set and exceeded, count_tokens should stop early."""
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
encode_call_count = 0
original_tokenizer = _Tokenizer()
class _CountingTokenizer(BaseTokenizer):
def encode(self, text: str) -> list[int]:
nonlocal encode_call_count
encode_call_count += 1
return original_tokenizer.encode(text)
def tokenize(self, text: str) -> list[str]:
return list(text)
def decode(self, _tokens: list[int]) -> str:
return ""
tokenizer = _CountingTokenizer()
# 500 chars → 5 chunks of 100; limit=150 → should stop after 2 chunks
text = "a" * 500
result = count_tokens(text, tokenizer, token_limit=150)
assert result == 200 # 2 chunks × 100 tokens each
assert encode_call_count == 2, "Should have stopped after 2 chunks"
def test_count_tokens_with_token_limit_not_exceeded(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When token_limit is set but not exceeded, all chunks are encoded."""
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
tokenizer = _Tokenizer()
text = "a" * 250
result = count_tokens(text, tokenizer, token_limit=1000)
assert result == 250
def test_count_tokens_no_limit_encodes_all_chunks(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Without token_limit, all chunks are encoded regardless of count."""
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
tokenizer = _Tokenizer()
text = "a" * 500
result = count_tokens(text, tokenizer)
assert result == 500
# --- early exit via token_limit in categorize tests ---
def test_categorize_early_exits_tokenization_for_large_text(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Large text files should be rejected via early-exit tokenization
without encoding all chunks."""
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=1)
# token_threshold = 1000; _ENCODE_CHUNK_SIZE = 100 → text of 500 chars = 5 chunks
# Should stop after 2nd chunk (200 tokens > 1000? No... need 1 token per char)
# With _Tokenizer: 1 token per char. threshold=1000, chunk=100 → need 11 chunks
# Let's use a bigger text
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
large_text = "x" * 5000 # 5000 tokens, threshold 1000
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: large_text)
encode_call_count = 0
original_tokenizer = _Tokenizer()
class _CountingTokenizer(BaseTokenizer):
def encode(self, text: str) -> list[int]:
nonlocal encode_call_count
encode_call_count += 1
return original_tokenizer.encode(text)
def tokenize(self, text: str) -> list[str]:
return list(text)
def decode(self, _tokens: list[int]) -> str:
return ""
monkeypatch.setattr(utils, "get_tokenizer", lambda **_kwargs: _CountingTokenizer())
upload = _make_upload("big.txt", size=5000, content=large_text.encode())
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.rejected) == 1
assert "token limit" in result.rejected[0].reason
# 5000 chars / 100 chunk_size = 50 chunks total; should stop well before all 50
assert (
encode_call_count < 50
), f"Expected early exit but encoded {encode_call_count} chunks out of 50"
def test_categorize_text_under_token_limit_accepted(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Text files under the token threshold should be accepted with exact count."""
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=1)
small_text = "x" * 500 # 500 tokens < 1000 threshold
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: small_text)
upload = _make_upload("ok.txt", size=500, content=small_text.encode())
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 1
assert result.acceptable_file_to_token_count["ok.txt"] == 500

View File

@@ -1,12 +1,23 @@
import pytest
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.settings import store as settings_store
from onyx.server.settings.models import (
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB,
)
from onyx.server.settings.models import DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
from onyx.server.settings.models import Settings
class _FakeKvStore:
def __init__(self, data: dict | None = None) -> None:
self._data = data
def load(self, _key: str) -> dict:
raise KvKeyNotFoundError()
if self._data is None:
raise KvKeyNotFoundError()
return self._data
class _FakeCache:
@@ -20,13 +31,140 @@ class _FakeCache:
self._vals[key] = value.encode("utf-8")
def test_load_settings_includes_user_file_max_upload_size_mb(
def test_load_settings_uses_model_defaults_when_no_stored_value(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When no settings are stored (vector DB enabled), load_settings() should
resolve the default token threshold to 200."""
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore())
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "USER_FILE_MAX_UPLOAD_SIZE_MB", 77)
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", False)
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == 77
assert settings.user_file_max_upload_size_mb == DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
assert (
settings.file_token_count_threshold_k
== DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
)
def test_load_settings_uses_high_token_default_when_vector_db_disabled(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When vector DB is disabled and no settings are stored, the token
threshold should default to 10000 (10M tokens)."""
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore())
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", True)
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
assert (
settings.file_token_count_threshold_k
== DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
)
def test_load_settings_preserves_explicit_value_when_vector_db_disabled(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When vector DB is disabled but admin explicitly set a token threshold,
that value should be preserved (not overridden by the 10000 default)."""
stored = Settings(file_token_count_threshold_k=500).model_dump()
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", True)
settings = settings_store.load_settings()
assert settings.file_token_count_threshold_k == 500
def test_load_settings_preserves_zero_token_threshold(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A value of 0 means 'no limit' and should be preserved."""
stored = Settings(file_token_count_threshold_k=0).model_dump()
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", True)
settings = settings_store.load_settings()
assert settings.file_token_count_threshold_k == 0
def test_load_settings_resolves_zero_upload_size_to_default(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A value of 0 should be treated as unset and resolved to the default."""
stored = Settings(user_file_max_upload_size_mb=0).model_dump()
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
def test_load_settings_clamps_upload_size_to_env_max(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When the stored upload size exceeds MAX_ALLOWED_UPLOAD_SIZE_MB, it should
be clamped to the env-configured maximum."""
stored = Settings(user_file_max_upload_size_mb=500).model_dump()
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 250)
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == 250
def test_load_settings_preserves_upload_size_within_max(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When the stored upload size is within MAX_ALLOWED_UPLOAD_SIZE_MB, it should
be preserved unchanged."""
stored = Settings(user_file_max_upload_size_mb=150).model_dump()
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 250)
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == 150
def test_load_settings_zero_upload_size_resolves_to_default(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A value of 0 should be treated as unset and resolved to the default,
clamped to MAX_ALLOWED_UPLOAD_SIZE_MB."""
stored = Settings(user_file_max_upload_size_mb=0).model_dump()
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 100)
monkeypatch.setattr(settings_store, "DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB", 100)
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == 100
def test_load_settings_default_clamped_to_max(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB exceeds MAX_ALLOWED_UPLOAD_SIZE_MB,
the effective default should be min(DEFAULT, MAX)."""
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore())
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB", 100)
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 50)
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == 50

View File

@@ -1,5 +1,6 @@
"""Tests for indexing pipeline Prometheus collectors."""
from collections.abc import Iterator
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@@ -13,6 +14,16 @@ from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
@pytest.fixture(autouse=True)
def _mock_broker_client() -> Iterator[None]:
"""Patch celery_get_broker_client for all collector tests."""
with patch(
"onyx.background.celery.celery_redis.celery_get_broker_client",
return_value=MagicMock(),
):
yield
class TestQueueDepthCollector:
def test_returns_empty_when_factory_not_set(self) -> None:
collector = QueueDepthCollector()
@@ -24,8 +35,7 @@ class TestQueueDepthCollector:
def test_collects_queue_depths(self) -> None:
collector = QueueDepthCollector(cache_ttl=0)
mock_redis = MagicMock()
collector.set_redis_factory(lambda: mock_redis)
collector.set_celery_app(MagicMock())
with (
patch(
@@ -60,8 +70,8 @@ class TestQueueDepthCollector:
def test_handles_redis_error_gracefully(self) -> None:
collector = QueueDepthCollector(cache_ttl=0)
mock_redis = MagicMock()
collector.set_redis_factory(lambda: mock_redis)
MagicMock()
collector.set_celery_app(MagicMock())
with patch(
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
@@ -74,8 +84,8 @@ class TestQueueDepthCollector:
def test_caching_returns_stale_within_ttl(self) -> None:
collector = QueueDepthCollector(cache_ttl=60)
mock_redis = MagicMock()
collector.set_redis_factory(lambda: mock_redis)
MagicMock()
collector.set_celery_app(MagicMock())
with (
patch(
@@ -98,31 +108,10 @@ class TestQueueDepthCollector:
assert first is second # Same object, from cache
def test_factory_called_each_scrape(self) -> None:
"""Verify the Redis factory is called on each fresh collect, not cached."""
collector = QueueDepthCollector(cache_ttl=0)
factory = MagicMock(return_value=MagicMock())
collector.set_redis_factory(factory)
with (
patch(
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
return_value=0,
),
patch(
"onyx.server.metrics.indexing_pipeline.celery_get_unacked_task_ids",
return_value=set(),
),
):
collector.collect()
collector.collect()
assert factory.call_count == 2
def test_error_returns_stale_cache(self) -> None:
collector = QueueDepthCollector(cache_ttl=0)
mock_redis = MagicMock()
collector.set_redis_factory(lambda: mock_redis)
MagicMock()
collector.set_celery_app(MagicMock())
# First call succeeds
with (

View File

@@ -1,96 +1,22 @@
"""Tests for indexing pipeline setup (Redis factory caching)."""
"""Tests for indexing pipeline setup."""
from unittest.mock import MagicMock
from onyx.server.metrics.indexing_pipeline_setup import _make_broker_redis_factory
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
from onyx.server.metrics.indexing_pipeline import RedisHealthCollector
def _make_mock_app(client: MagicMock) -> MagicMock:
"""Create a mock Celery app whose broker_connection().channel().client
returns the given client."""
mock_app = MagicMock()
mock_conn = MagicMock()
mock_conn.channel.return_value.client = client
class TestCollectorCeleryAppSetup:
def test_queue_depth_collector_uses_celery_app(self) -> None:
"""QueueDepthCollector.set_celery_app stores the app for broker access."""
collector = QueueDepthCollector()
mock_app = MagicMock()
collector.set_celery_app(mock_app)
assert collector._celery_app is mock_app
mock_app.broker_connection.return_value = mock_conn
return mock_app
class TestMakeBrokerRedisFactory:
def test_caches_redis_client_across_calls(self) -> None:
"""Factory should reuse the same client on subsequent calls."""
mock_client = MagicMock()
mock_client.ping.return_value = True
mock_app = _make_mock_app(mock_client)
factory = _make_broker_redis_factory(mock_app)
client1 = factory()
client2 = factory()
assert client1 is client2
# broker_connection should only be called once
assert mock_app.broker_connection.call_count == 1
def test_reconnects_when_ping_fails(self) -> None:
"""Factory should create a new client if ping fails (stale connection)."""
mock_client_stale = MagicMock()
mock_client_stale.ping.side_effect = ConnectionError("disconnected")
mock_client_fresh = MagicMock()
mock_client_fresh.ping.return_value = True
mock_app = _make_mock_app(mock_client_stale)
factory = _make_broker_redis_factory(mock_app)
# First call — creates and caches
client1 = factory()
assert client1 is mock_client_stale
assert mock_app.broker_connection.call_count == 1
# Switch to fresh client for next connection
mock_conn_fresh = MagicMock()
mock_conn_fresh.channel.return_value.client = mock_client_fresh
mock_app.broker_connection.return_value = mock_conn_fresh
# Second call — ping fails on stale, reconnects
client2 = factory()
assert client2 is mock_client_fresh
assert mock_app.broker_connection.call_count == 2
def test_reconnect_closes_stale_client(self) -> None:
"""When ping fails, the old client should be closed before reconnecting."""
mock_client_stale = MagicMock()
mock_client_stale.ping.side_effect = ConnectionError("disconnected")
mock_client_fresh = MagicMock()
mock_client_fresh.ping.return_value = True
mock_app = _make_mock_app(mock_client_stale)
factory = _make_broker_redis_factory(mock_app)
# First call — creates and caches
factory()
# Switch to fresh client
mock_conn_fresh = MagicMock()
mock_conn_fresh.channel.return_value.client = mock_client_fresh
mock_app.broker_connection.return_value = mock_conn_fresh
# Second call — ping fails, should close stale client
factory()
mock_client_stale.close.assert_called_once()
def test_first_call_creates_connection(self) -> None:
"""First call should always create a new connection."""
mock_client = MagicMock()
mock_app = _make_mock_app(mock_client)
factory = _make_broker_redis_factory(mock_app)
client = factory()
assert client is mock_client
mock_app.broker_connection.assert_called_once()
def test_redis_health_collector_uses_celery_app(self) -> None:
"""RedisHealthCollector.set_celery_app stores the app for broker access."""
collector = RedisHealthCollector()
mock_app = MagicMock()
collector.set_celery_app(mock_app)
assert collector._celery_app is mock_app

1
cli/.gitignore vendored
View File

@@ -1,3 +1,4 @@
onyx-cli
cli
onyx.cli
__pycache__

View File

@@ -63,6 +63,31 @@ onyx-cli agents
onyx-cli agents --json
```
### Serve over SSH
```shell
# Start a public SSH endpoint for the CLI TUI
onyx-cli serve --host 0.0.0.0 --port 2222
# Connect as a client
ssh your-host -p 2222
```
Clients can either:
- paste an API key at the login prompt, or
- skip the prompt by sending `ONYX_API_KEY` over SSH:
```shell
export ONYX_API_KEY=your-key
ssh -o SendEnv=ONYX_API_KEY your-host -p 2222
```
Useful hardening flags:
- `--idle-timeout` (default `15m`)
- `--max-session-timeout` (default `8h`)
- `--rate-limit-per-minute` (default `20`)
- `--rate-limit-burst` (default `40`)
## Commands
| Command | Description |
@@ -70,6 +95,7 @@ onyx-cli agents --json
| `chat` | Launch the interactive chat TUI (default) |
| `ask` | Ask a one-shot question (non-interactive) |
| `agents` | List available agents |
| `serve` | Serve the interactive chat TUI over SSH |
| `configure` | Configure server URL and API key |
| `validate-config` | Validate configuration and test connection |

View File

@@ -1,7 +1,17 @@
// Package cmd implements Cobra CLI commands for the Onyx CLI.
package cmd
import "github.com/spf13/cobra"
import (
"context"
"fmt"
"time"
"github.com/onyx-dot-app/onyx/cli/internal/api"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/onyx-dot-app/onyx/cli/internal/version"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
)
// Version and Commit are set via ldflags at build time.
var (
@@ -16,15 +26,69 @@ func fullVersion() string {
return Version
}
func printVersion(cmd *cobra.Command) {
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Client version: %s\n", fullVersion())
cfg := config.Load()
if !cfg.IsConfigured() {
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Server version: unknown (not configured)\n")
return
}
client := api.NewClient(cfg)
ctx, cancel := context.WithTimeout(cmd.Context(), 5*time.Second)
defer cancel()
log.Debug("fetching backend version from /api/version")
backendVersion, err := client.GetBackendVersion(ctx)
if err != nil {
log.WithError(err).Debug("could not fetch backend version")
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Server version: unknown (could not reach server)\n")
return
}
if backendVersion == "" {
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Server version: unknown (empty response)\n")
return
}
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Server version: %s\n", backendVersion)
min := version.MinServer()
if sv, ok := version.Parse(backendVersion); ok && sv.LessThan(min) {
log.Warnf("Server version %s is below minimum required %d.%d, please upgrade",
backendVersion, min.Major, min.Minor)
}
}
// Execute creates and runs the root command.
func Execute() error {
opts := struct {
Debug bool
}{}
rootCmd := &cobra.Command{
Use: "onyx-cli",
Short: "Terminal UI for chatting with Onyx",
Long: "Onyx CLI — a terminal interface for chatting with your Onyx agent.",
Version: fullVersion(),
Use: "onyx-cli",
Short: "Terminal UI for chatting with Onyx",
Long: "Onyx CLI — a terminal interface for chatting with your Onyx agent.",
PersistentPreRun: func(cmd *cobra.Command, args []string) {
if opts.Debug {
log.SetLevel(log.DebugLevel)
} else {
log.SetLevel(log.InfoLevel)
}
log.SetFormatter(&log.TextFormatter{
DisableTimestamp: true,
})
},
}
rootCmd.PersistentFlags().BoolVar(&opts.Debug, "debug", false, "run in debug mode")
// Custom --version flag instead of Cobra's built-in (which only shows one version string)
var showVersion bool
rootCmd.Flags().BoolVarP(&showVersion, "version", "v", false, "Print client and server version information")
// Register subcommands
chatCmd := newChatCmd()
rootCmd.AddCommand(chatCmd)
@@ -32,9 +96,16 @@ func Execute() error {
rootCmd.AddCommand(newAgentsCmd())
rootCmd.AddCommand(newConfigureCmd())
rootCmd.AddCommand(newValidateConfigCmd())
rootCmd.AddCommand(newServeCmd())
// Default command is chat
rootCmd.RunE = chatCmd.RunE
// Default command is chat, but intercept --version first
rootCmd.RunE = func(cmd *cobra.Command, args []string) error {
if showVersion {
printVersion(cmd)
return nil
}
return chatCmd.RunE(cmd, args)
}
return rootCmd.Execute()
}

Some files were not shown because too many files have changed in this diff Show More