mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-05 06:52:42 +00:00
Compare commits
61 Commits
refactor/r
...
multi-mode
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c8e565fa75 | ||
|
|
bab95d8bf0 | ||
|
|
eb7bc74e1b | ||
|
|
29da0aefb5 | ||
|
|
6c86301c51 | ||
|
|
631146f48f | ||
|
|
f327278506 | ||
|
|
c7cc439862 | ||
|
|
3365a369e2 | ||
|
|
470bda3fb5 | ||
|
|
13f511e209 | ||
|
|
c5e8ba1eab | ||
|
|
0b9d154a73 | ||
|
|
6e65e55bf5 | ||
|
|
3f9e208759 | ||
|
|
fb8edda14a | ||
|
|
58decd8a6b | ||
|
|
e97204d9cc | ||
|
|
44ab02c94f | ||
|
|
a98cc30f25 | ||
|
|
a709dcb8fa | ||
|
|
a3dfe6aa1b | ||
|
|
23e4d55fb1 | ||
|
|
470cc85f83 | ||
|
|
64d9be5a41 | ||
|
|
71a5b469b0 | ||
|
|
462eb0697f | ||
|
|
b708dc8796 | ||
|
|
c9e2c32f55 | ||
|
|
d725df62e7 | ||
|
|
d1460972b6 | ||
|
|
706872f0b7 | ||
|
|
ed3856be2b | ||
|
|
6326c7f0b9 | ||
|
|
40420fc4e6 | ||
|
|
1a2b6a66cc | ||
|
|
d1b1529ccf | ||
|
|
fedd9c76e5 | ||
|
|
0b34b40b79 | ||
|
|
fe82ddb1b9 | ||
|
|
32d3d70525 | ||
|
|
40b9e10890 | ||
|
|
e21b204b8a | ||
|
|
2f672b3a4f | ||
|
|
cf19d0df4f | ||
|
|
86a6a4c134 | ||
|
|
146b5449d2 | ||
|
|
b66991b5c5 | ||
|
|
9cb76dc027 | ||
|
|
f66891d19e | ||
|
|
c07c952ad5 | ||
|
|
be7f40a28a | ||
|
|
26f941b5da | ||
|
|
b9e84c42a8 | ||
|
|
0a1df52c2f | ||
|
|
306b0d452f | ||
|
|
5fdb34ba8e | ||
|
|
2d066631e3 | ||
|
|
5c84f6c61b | ||
|
|
899179d4b6 | ||
|
|
80d6bafc74 |
3
.github/workflows/helm-chart-releases.yml
vendored
3
.github/workflows/helm-chart-releases.yml
vendored
@@ -47,7 +47,8 @@ jobs:
|
||||
done
|
||||
|
||||
- name: Publish Helm charts to gh-pages
|
||||
uses: stefanprodan/helm-gh-pages@0ad2bb377311d61ac04ad9eb6f252fb68e207260 # ratchet:stefanprodan/helm-gh-pages@v1.7.0
|
||||
# NOTE: HEAD of https://github.com/stefanprodan/helm-gh-pages/pull/43
|
||||
uses: stefanprodan/helm-gh-pages@ad32ad3b8720abfeaac83532fd1e9bdfca5bbe27 # zizmor: ignore[impostor-commit]
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
charts_dir: deployment/helm/charts
|
||||
|
||||
@@ -35,6 +35,7 @@ jobs:
|
||||
needs: [provider-chat-test]
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
runs-on: ubuntu-slim
|
||||
environment: ci-protected
|
||||
timeout-minutes: 5
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
@@ -183,6 +183,7 @@ jobs:
|
||||
- cherry-pick-to-latest-release
|
||||
if: needs.resolve-cherry-pick-request.outputs.should_cherrypick == 'true' && needs.resolve-cherry-pick-request.result == 'success' && needs.cherry-pick-to-latest-release.result == 'success'
|
||||
runs-on: ubuntu-slim
|
||||
environment: ci-protected
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -232,6 +233,7 @@ jobs:
|
||||
- cherry-pick-to-latest-release
|
||||
if: always() && needs.resolve-cherry-pick-request.outputs.should_cherrypick == 'true' && (needs.resolve-cherry-pick-request.result == 'failure' || needs.cherry-pick-to-latest-release.result == 'failure')
|
||||
runs-on: ubuntu-slim
|
||||
environment: ci-protected
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
2
.github/workflows/pr-desktop-build.yml
vendored
2
.github/workflows/pr-desktop-build.yml
vendored
@@ -63,7 +63,7 @@ jobs:
|
||||
targets: ${{ matrix.target }}
|
||||
|
||||
- name: Cache Cargo registry and build
|
||||
uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # zizmor: ignore[cache-poisoning]
|
||||
uses: actions/cache@668228422ae6a00e4ad889ee87cd7109ec5666a7 # zizmor: ignore[cache-poisoning]
|
||||
with:
|
||||
path: |
|
||||
~/.cargo/bin/
|
||||
|
||||
4
.github/workflows/pr-playwright-tests.yml
vendored
4
.github/workflows/pr-playwright-tests.yml
vendored
@@ -284,7 +284,7 @@ jobs:
|
||||
|
||||
- name: Cache playwright cache
|
||||
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
uses: runs-on/cache@a5f51d6f3fece787d03b7b4e981c82538a0654ed # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: ~/.cache/ms-playwright
|
||||
key: ${{ runner.os }}-playwright-npm-${{ hashFiles('web/package-lock.json') }}
|
||||
@@ -626,7 +626,7 @@ jobs:
|
||||
|
||||
- name: Cache playwright cache
|
||||
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
uses: runs-on/cache@a5f51d6f3fece787d03b7b4e981c82538a0654ed # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: ~/.cache/ms-playwright
|
||||
key: ${{ runner.os }}-playwright-npm-${{ hashFiles('web/package-lock.json') }}
|
||||
|
||||
2
.github/workflows/pr-python-checks.yml
vendored
2
.github/workflows/pr-python-checks.yml
vendored
@@ -56,7 +56,7 @@ jobs:
|
||||
|
||||
- name: Cache mypy cache
|
||||
if: ${{ vars.DISABLE_MYPY_CACHE != 'true' }}
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
uses: runs-on/cache@a5f51d6f3fece787d03b7b4e981c82538a0654ed # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: .mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'pyproject.toml') }}
|
||||
|
||||
1
.github/workflows/pr-python-model-tests.yml
vendored
1
.github/workflows/pr-python-model-tests.yml
vendored
@@ -31,6 +31,7 @@ jobs:
|
||||
- runner=4cpu-linux-arm64
|
||||
- "run-id=${{ github.run_id }}-model-check"
|
||||
- "extras=ecr-cache"
|
||||
environment: ci-protected
|
||||
timeout-minutes: 45
|
||||
|
||||
env:
|
||||
|
||||
1
.github/workflows/preview.yml
vendored
1
.github/workflows/preview.yml
vendored
@@ -15,6 +15,7 @@ permissions:
|
||||
jobs:
|
||||
Deploy-Preview:
|
||||
runs-on: ubuntu-latest
|
||||
environment: ci-protected
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd
|
||||
|
||||
17
.github/workflows/release-cli.yml
vendored
17
.github/workflows/release-cli.yml
vendored
@@ -13,15 +13,6 @@ jobs:
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 10
|
||||
strategy:
|
||||
matrix:
|
||||
os-arch:
|
||||
- { goos: "linux", goarch: "amd64" }
|
||||
- { goos: "linux", goarch: "arm64" }
|
||||
- { goos: "windows", goarch: "amd64" }
|
||||
- { goos: "windows", goarch: "arm64" }
|
||||
- { goos: "darwin", goarch: "amd64" }
|
||||
- { goos: "darwin", goarch: "arm64" }
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
@@ -31,9 +22,11 @@ jobs:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
- run: |
|
||||
GOOS="${{ matrix.os-arch.goos }}" \
|
||||
GOARCH="${{ matrix.os-arch.goarch }}" \
|
||||
uv build --wheel
|
||||
for goos in linux windows darwin; do
|
||||
for goarch in amd64 arm64; do
|
||||
GOOS="$goos" GOARCH="$goarch" uv build --wheel
|
||||
done
|
||||
done
|
||||
working-directory: cli
|
||||
- run: uv publish
|
||||
working-directory: cli
|
||||
|
||||
2
.github/workflows/storybook-deploy.yml
vendored
2
.github/workflows/storybook-deploy.yml
vendored
@@ -25,6 +25,7 @@ permissions:
|
||||
jobs:
|
||||
Deploy-Storybook:
|
||||
runs-on: ubuntu-latest
|
||||
environment: ci-protected
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v4
|
||||
@@ -54,6 +55,7 @@ jobs:
|
||||
needs: Deploy-Storybook
|
||||
if: always() && needs.Deploy-Storybook.result == 'failure'
|
||||
runs-on: ubuntu-latest
|
||||
environment: ci-protected
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v4
|
||||
|
||||
1
.github/workflows/sync_foss.yml
vendored
1
.github/workflows/sync_foss.yml
vendored
@@ -9,6 +9,7 @@ on:
|
||||
jobs:
|
||||
sync-foss:
|
||||
runs-on: ubuntu-latest
|
||||
environment: ci-protected
|
||||
timeout-minutes: 45
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
1
.github/workflows/tag-nightly.yml
vendored
1
.github/workflows/tag-nightly.yml
vendored
@@ -11,6 +11,7 @@ permissions:
|
||||
jobs:
|
||||
create-and-push-tag:
|
||||
runs-on: ubuntu-slim
|
||||
environment: ci-protected
|
||||
timeout-minutes: 45
|
||||
|
||||
steps:
|
||||
|
||||
@@ -24,6 +24,16 @@ When hardcoding a boolean variable to a constant value, remove the variable enti
|
||||
|
||||
Code changes must consider both multi-tenant and single-tenant deployments. In multi-tenant mode, preserve tenant isolation, ensure tenant context is propagated correctly, and avoid assumptions that only hold for a single shared schema or globally shared state. In single-tenant mode, avoid introducing unnecessary tenant-specific requirements or cloud-only control-plane dependencies.
|
||||
|
||||
## Nginx Routing — New Backend Routes
|
||||
|
||||
Whenever a new backend route is added that does NOT start with `/api`, it must also be explicitly added to ALL nginx configs:
|
||||
- `deployment/helm/charts/onyx/templates/nginx-conf.yaml` (Helm/k8s)
|
||||
- `deployment/data/nginx/app.conf.template` (docker-compose dev)
|
||||
- `deployment/data/nginx/app.conf.template.prod` (docker-compose prod)
|
||||
- `deployment/data/nginx/app.conf.template.no-letsencrypt` (docker-compose no-letsencrypt)
|
||||
|
||||
Routes not starting with `/api` are not caught by the existing `^/(api|openapi\.json)` location block and will fall through to `location /`, which proxies to the Next.js web server and returns an HTML 404. The new location block must be placed before the `/api` block. Examples of routes that need this treatment: `/scim`, `/mcp`.
|
||||
|
||||
## Full vs Lite Deployments
|
||||
|
||||
Code changes must consider both regular Onyx deployments and Onyx lite deployments. Lite deployments disable the vector DB, Redis, model servers, and background workers by default, use PostgreSQL-backed cache/auth/file storage, and rely on the API server to handle background work. Do not assume those services are available unless the code path is explicitly limited to full deployments.
|
||||
|
||||
@@ -122,7 +122,7 @@ repos:
|
||||
rev: 5d1e709b7be35cb2025444e19de266b056b7b7ee # frozen: v2.10.1
|
||||
hooks:
|
||||
- id: golangci-lint
|
||||
language_version: "1.26.0"
|
||||
language_version: "1.26.1"
|
||||
entry: bash -c "find . -name go.mod -not -path './.venv/*' -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
|
||||
@@ -35,7 +35,7 @@ Onyx comes loaded with advanced features like Agents, Web Search, RAG, MCP, Deep
|
||||
> [!TIP]
|
||||
> Run Onyx with one command (or see deployment section below):
|
||||
> ```
|
||||
> curl -fsSL https://raw.githubusercontent.com/onyx-dot-app/onyx/main/deployment/docker_compose/install.sh > install.sh && chmod +x install.sh && ./install.sh
|
||||
> curl -fsSL https://onyx.app/install_onyx.sh | bash
|
||||
> ```
|
||||
|
||||
****
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
"""remove voice_provider deleted column
|
||||
|
||||
Revision ID: 1d78c0ca7853
|
||||
Revises: a3f8b2c1d4e5
|
||||
Create Date: 2026-03-26 11:30:53.883127
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1d78c0ca7853"
|
||||
down_revision = "a3f8b2c1d4e5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Hard-delete any soft-deleted rows before dropping the column
|
||||
op.execute("DELETE FROM voice_provider WHERE deleted = true")
|
||||
op.drop_column("voice_provider", "deleted")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"voice_provider",
|
||||
sa.Column(
|
||||
"deleted",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
)
|
||||
@@ -28,6 +28,7 @@ from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ElementExternalAccess
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
@@ -187,7 +188,6 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
|
||||
# (which lives on a different db number)
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
|
||||
@@ -227,6 +227,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
validate_permission_sync_fences(
|
||||
tenant_id, r, r_replica, r_celery, lock_beat
|
||||
)
|
||||
@@ -473,6 +474,8 @@ def connector_permission_sync_generator_task(
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
eager_load_connector=True,
|
||||
eager_load_credential=True,
|
||||
)
|
||||
if cc_pair is None:
|
||||
raise ValueError(
|
||||
|
||||
@@ -29,6 +29,7 @@ from ee.onyx.external_permissions.sync_params import (
|
||||
from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
from onyx.background.error_logging import emit_background_error
|
||||
@@ -162,7 +163,6 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
# (which lives on a different db number)
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
|
||||
@@ -221,6 +221,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
validate_external_group_sync_fences(
|
||||
tenant_id, self.app, r, r_replica, r_celery, lock_beat
|
||||
)
|
||||
|
||||
@@ -8,6 +8,7 @@ from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.slack.connector import get_channels
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call
|
||||
@@ -105,9 +106,11 @@ def _get_slack_document_access(
|
||||
slack_connector: SlackConnector,
|
||||
channel_permissions: dict[str, ExternalAccess], # noqa: ARG001
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
indexing_start: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
slim_doc_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
|
||||
callback=callback
|
||||
callback=callback,
|
||||
start=indexing_start,
|
||||
)
|
||||
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
@@ -180,9 +183,15 @@ def slack_doc_sync(
|
||||
|
||||
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector.set_credentials_provider(provider)
|
||||
indexing_start_ts: SecondsSinceUnixEpoch | None = (
|
||||
cc_pair.connector.indexing_start.timestamp()
|
||||
if cc_pair.connector.indexing_start is not None
|
||||
else None
|
||||
)
|
||||
|
||||
yield from _get_slack_document_access(
|
||||
slack_connector,
|
||||
slack_connector=slack_connector,
|
||||
channel_permissions=channel_permissions,
|
||||
callback=callback,
|
||||
indexing_start=indexing_start_ts,
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ from onyx.access.models import ElementExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.access.models import NodeExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
@@ -40,10 +41,19 @@ def generic_doc_sync(
|
||||
|
||||
logger.info(f"Starting {doc_source} doc sync for CC Pair ID: {cc_pair.id}")
|
||||
|
||||
indexing_start: SecondsSinceUnixEpoch | None = (
|
||||
cc_pair.connector.indexing_start.timestamp()
|
||||
if cc_pair.connector.indexing_start is not None
|
||||
else None
|
||||
)
|
||||
|
||||
newly_fetched_doc_ids: set[str] = set()
|
||||
|
||||
logger.info(f"Fetching all slim documents from {doc_source}")
|
||||
for doc_batch in slim_connector.retrieve_all_slim_docs_perm_sync(callback=callback):
|
||||
for doc_batch in slim_connector.retrieve_all_slim_docs_perm_sync(
|
||||
start=indexing_start,
|
||||
callback=callback,
|
||||
):
|
||||
logger.info(f"Got {len(doc_batch)} slim documents from {doc_source}")
|
||||
|
||||
if callback:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# These are helper objects for tracking the keys we need to write in redis
|
||||
import json
|
||||
import threading
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -7,7 +8,59 @@ from celery import Celery
|
||||
from redis import Redis
|
||||
|
||||
from onyx.background.celery.configs.base import CELERY_SEPARATOR
|
||||
from onyx.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
|
||||
|
||||
|
||||
_broker_client: Redis | None = None
|
||||
_broker_url: str | None = None
|
||||
_broker_client_lock = threading.Lock()
|
||||
|
||||
|
||||
def celery_get_broker_client(app: Celery) -> Redis:
|
||||
"""Return a shared Redis client connected to the Celery broker DB.
|
||||
|
||||
Uses a module-level singleton so all tasks on a worker share one
|
||||
connection instead of creating a new one per call. The client
|
||||
connects directly to the broker Redis DB (parsed from the broker URL).
|
||||
|
||||
Thread-safe via lock — safe for use in Celery thread-pool workers.
|
||||
|
||||
Usage:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
length = celery_get_queue_length(queue, r_celery)
|
||||
"""
|
||||
global _broker_client, _broker_url
|
||||
with _broker_client_lock:
|
||||
url = app.conf.broker_url
|
||||
if _broker_client is not None and _broker_url == url:
|
||||
try:
|
||||
_broker_client.ping()
|
||||
return _broker_client
|
||||
except Exception:
|
||||
try:
|
||||
_broker_client.close()
|
||||
except Exception:
|
||||
pass
|
||||
_broker_client = None
|
||||
elif _broker_client is not None:
|
||||
try:
|
||||
_broker_client.close()
|
||||
except Exception:
|
||||
pass
|
||||
_broker_client = None
|
||||
|
||||
_broker_url = url
|
||||
_broker_client = Redis.from_url(
|
||||
url,
|
||||
decode_responses=False,
|
||||
health_check_interval=REDIS_HEALTH_CHECK_INTERVAL,
|
||||
socket_keepalive=True,
|
||||
socket_keepalive_options=REDIS_SOCKET_KEEPALIVE_OPTIONS,
|
||||
retry_on_timeout=True,
|
||||
)
|
||||
return _broker_client
|
||||
|
||||
|
||||
def celery_get_unacked_length(r: Redis) -> int:
|
||||
|
||||
@@ -14,6 +14,7 @@ from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
@@ -132,7 +133,6 @@ def revoke_tasks_blocking_deletion(
|
||||
def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
||||
@@ -149,6 +149,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | N
|
||||
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES):
|
||||
# clear fences that don't have associated celery tasks in progress
|
||||
try:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
validate_connector_deletion_fences(
|
||||
tenant_id, r, r_replica, r_celery, lock_beat
|
||||
)
|
||||
|
||||
@@ -22,6 +22,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.memory_monitoring import emit_process_memory
|
||||
@@ -449,7 +450,7 @@ def check_indexing_completion(
|
||||
):
|
||||
# Check if the task exists in the celery queue
|
||||
# This handles the case where Redis dies after task creation but before task execution
|
||||
redis_celery = task.app.broker_connection().channel().client # type: ignore
|
||||
redis_celery = celery_get_broker_client(task.app)
|
||||
task_exists = celery_find_task(
|
||||
attempt.celery_task_id,
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from itertools import islice
|
||||
from typing import Any
|
||||
@@ -19,6 +18,7 @@ from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.memory_monitoring import emit_process_memory
|
||||
@@ -698,31 +698,27 @@ def monitor_background_processes(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Get Redis client for Celery broker
|
||||
redis_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
redis_std = get_redis_client()
|
||||
|
||||
# Define metric collection functions and their dependencies
|
||||
metric_functions: list[Callable[[], list[Metric]]] = [
|
||||
lambda: _collect_queue_metrics(redis_celery),
|
||||
lambda: _collect_connector_metrics(db_session, redis_std),
|
||||
lambda: _collect_sync_metrics(db_session, redis_std),
|
||||
]
|
||||
# Collect queue metrics with broker connection
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
queue_metrics = _collect_queue_metrics(r_celery)
|
||||
|
||||
# Collect and log each metric
|
||||
# Collect remaining metrics (no broker connection needed)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for metric_fn in metric_functions:
|
||||
metrics = metric_fn()
|
||||
for metric in metrics:
|
||||
# double check to make sure we aren't double-emitting metrics
|
||||
if metric.key is None or not _has_metric_been_emitted(
|
||||
redis_std, metric.key
|
||||
):
|
||||
metric.log()
|
||||
metric.emit(tenant_id)
|
||||
all_metrics: list[Metric] = queue_metrics
|
||||
all_metrics.extend(_collect_connector_metrics(db_session, redis_std))
|
||||
all_metrics.extend(_collect_sync_metrics(db_session, redis_std))
|
||||
|
||||
if metric.key is not None:
|
||||
_mark_metric_as_emitted(redis_std, metric.key)
|
||||
for metric in all_metrics:
|
||||
if metric.key is None or not _has_metric_been_emitted(
|
||||
redis_std, metric.key
|
||||
):
|
||||
metric.log()
|
||||
metric.emit(tenant_id)
|
||||
|
||||
if metric.key is not None:
|
||||
_mark_metric_as_emitted(redis_std, metric.key)
|
||||
|
||||
task_logger.info("Successfully collected background metrics")
|
||||
except SoftTimeLimitExceeded:
|
||||
@@ -890,7 +886,7 @@ def monitor_celery_queues_helper(
|
||||
) -> None:
|
||||
"""A task to monitor all celery queue lengths."""
|
||||
|
||||
r_celery = task.app.broker_connection().channel().client # type: ignore
|
||||
r_celery = celery_get_broker_client(task.app)
|
||||
n_celery = celery_get_queue_length(OnyxCeleryQueues.PRIMARY, r_celery)
|
||||
n_docfetching = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
|
||||
@@ -1080,7 +1076,7 @@ def cloud_monitor_celery_pidbox(
|
||||
num_deleted = 0
|
||||
|
||||
MAX_PIDBOX_IDLE = 24 * 3600 # 1 day in seconds
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
for key in r_celery.scan_iter("*.reply.celery.pidbox"):
|
||||
key_bytes = cast(bytes, key)
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
|
||||
@@ -17,6 +17,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
@@ -203,7 +204,6 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK,
|
||||
@@ -261,6 +261,7 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
validate_pruning_fences(tenant_id, r, r_replica, r_celery, lock_beat)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while validating pruning fences")
|
||||
|
||||
@@ -16,6 +16,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.access import build_access_for_user_files
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
@@ -105,7 +106,7 @@ def _user_file_delete_queued_key(user_file_id: str | UUID) -> str:
|
||||
|
||||
|
||||
def get_user_file_project_sync_queue_depth(celery_app: Celery) -> int:
|
||||
redis_celery: Redis = celery_app.broker_connection().channel().client # type: ignore
|
||||
redis_celery = celery_get_broker_client(celery_app)
|
||||
return celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC, redis_celery
|
||||
)
|
||||
@@ -238,7 +239,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
skipped_guard = 0
|
||||
try:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
queue_len = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
|
||||
)
|
||||
@@ -591,7 +592,7 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
# NOTE: must use the broker's Redis client (not redis_client) because
|
||||
# Celery queues live on a separate Redis DB with CELERY_SEPARATOR keys.
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
queue_len = celery_get_queue_length(OnyxCeleryQueues.USER_FILE_DELETE, r_celery)
|
||||
if queue_len > USER_FILE_DELETE_MAX_QUEUE_DEPTH:
|
||||
task_logger.warning(
|
||||
|
||||
@@ -1,19 +1,8 @@
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from queue import Empty
|
||||
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PacketException
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
|
||||
# Type alias for search doc deduplication key
|
||||
# Simple key: just document_id (str)
|
||||
@@ -159,114 +148,3 @@ class ChatStateContainer:
|
||||
"""Thread-safe getter for emitted citations (returns a copy)."""
|
||||
with self._lock:
|
||||
return self._emitted_citations.copy()
|
||||
|
||||
|
||||
def run_chat_loop_with_state_containers(
|
||||
chat_loop_func: Callable[[Emitter, ChatStateContainer], None],
|
||||
completion_callback: Callable[[ChatStateContainer], None],
|
||||
is_connected: Callable[[], bool],
|
||||
emitter: Emitter,
|
||||
state_container: ChatStateContainer,
|
||||
) -> Generator[Packet, None]:
|
||||
"""
|
||||
Explicit wrapper function that runs a function in a background thread
|
||||
with event streaming capabilities.
|
||||
|
||||
The wrapped function should accept emitter as first arg and use it to emit
|
||||
Packet objects. This wrapper polls every 300ms to check if stop signal is set.
|
||||
|
||||
Args:
|
||||
func: The function to wrap (should accept emitter and state_container as first and second args)
|
||||
completion_callback: Callback function to call when the function completes
|
||||
emitter: Emitter instance for sending packets
|
||||
state_container: ChatStateContainer instance for accumulating state
|
||||
is_connected: Callable that returns False when stop signal is set
|
||||
|
||||
Usage:
|
||||
packets = run_chat_loop_with_state_containers(
|
||||
my_func,
|
||||
completion_callback=completion_callback,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
is_connected=check_func,
|
||||
)
|
||||
for packet in packets:
|
||||
# Process packets
|
||||
pass
|
||||
"""
|
||||
|
||||
def run_with_exception_capture() -> None:
|
||||
try:
|
||||
chat_loop_func(emitter, state_container)
|
||||
except Exception as e:
|
||||
# If execution fails, emit an exception packet
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=0),
|
||||
obj=PacketException(type="error", exception=e),
|
||||
)
|
||||
)
|
||||
|
||||
# Run the function in a background thread
|
||||
thread = run_in_background(run_with_exception_capture)
|
||||
|
||||
pkt: Packet | None = None
|
||||
last_turn_index = 0 # Track the highest turn_index seen for stop packet
|
||||
last_cancel_check = time.monotonic()
|
||||
cancel_check_interval = 0.3 # Check for cancellation every 300ms
|
||||
try:
|
||||
while True:
|
||||
# Poll queue with 300ms timeout for natural stop signal checking
|
||||
# the 300ms timeout is to avoid busy-waiting and to allow the stop signal to be checked regularly
|
||||
try:
|
||||
pkt = emitter.bus.get(timeout=0.3)
|
||||
except Empty:
|
||||
if not is_connected():
|
||||
# Stop signal detected
|
||||
yield Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
|
||||
)
|
||||
break
|
||||
last_cancel_check = time.monotonic()
|
||||
continue
|
||||
|
||||
if pkt is not None:
|
||||
# Track the highest turn_index for the stop packet
|
||||
if pkt.placement and pkt.placement.turn_index > last_turn_index:
|
||||
last_turn_index = pkt.placement.turn_index
|
||||
|
||||
if isinstance(pkt.obj, OverallStop):
|
||||
yield pkt
|
||||
break
|
||||
elif isinstance(pkt.obj, PacketException):
|
||||
raise pkt.obj.exception
|
||||
else:
|
||||
yield pkt
|
||||
|
||||
# Check for cancellation periodically even when packets are flowing
|
||||
# This ensures stop signal is checked during active streaming
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_cancel_check >= cancel_check_interval:
|
||||
if not is_connected():
|
||||
# Stop signal detected during streaming
|
||||
yield Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
|
||||
)
|
||||
break
|
||||
last_cancel_check = current_time
|
||||
finally:
|
||||
# Wait for thread to complete on normal exit to propagate exceptions and ensure cleanup.
|
||||
# Skip waiting if user disconnected to exit quickly.
|
||||
if is_connected():
|
||||
wait_on_background(thread)
|
||||
try:
|
||||
completion_callback(state_container)
|
||||
except Exception as e:
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=PacketException(type="error", exception=e),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,19 +1,40 @@
|
||||
import threading
|
||||
from queue import Queue
|
||||
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
|
||||
|
||||
class Emitter:
|
||||
"""Use this inside tools to emit arbitrary UI progress."""
|
||||
"""Routes packets from LLM/tool execution to the ``_run_models`` drain loop.
|
||||
|
||||
def __init__(self, bus: Queue):
|
||||
self.bus = bus
|
||||
Tags every packet with ``model_index`` and places it on ``merged_queue``
|
||||
as a ``(model_idx, packet)`` tuple for ordered consumption downstream.
|
||||
|
||||
Args:
|
||||
merged_queue: Shared queue owned by ``_run_models``.
|
||||
model_idx: Index embedded in packet placements (``0`` for N=1 runs).
|
||||
drain_done: Optional event set by ``_run_models`` when the drain loop
|
||||
exits early (e.g. HTTP disconnect). When set, ``emit`` returns
|
||||
immediately so worker threads can exit fast.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
merged_queue: Queue[tuple[int, Packet | Exception | object]],
|
||||
model_idx: int = 0,
|
||||
drain_done: threading.Event | None = None,
|
||||
) -> None:
|
||||
self._model_idx = model_idx
|
||||
self._merged_queue = merged_queue
|
||||
self._drain_done = drain_done
|
||||
|
||||
def emit(self, packet: Packet) -> None:
|
||||
self.bus.put(packet) # Thread-safe
|
||||
|
||||
|
||||
def get_default_emitter() -> Emitter:
|
||||
bus: Queue[Packet] = Queue()
|
||||
emitter = Emitter(bus)
|
||||
return emitter
|
||||
if self._drain_done is not None and self._drain_done.is_set():
|
||||
return
|
||||
base = packet.placement or Placement(turn_index=0)
|
||||
tagged = Packet(
|
||||
placement=base.model_copy(update={"model_index": self._model_idx}),
|
||||
obj=packet.obj,
|
||||
)
|
||||
self._merged_queue.put((self._model_idx, tagged))
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -44,6 +44,31 @@ SEND_USER_METADATA_TO_LLM_PROVIDER = (
|
||||
# User Facing Features Configs
|
||||
#####
|
||||
BLURB_SIZE = 128 # Number Encoder Tokens included in the chunk blurb
|
||||
|
||||
# Hard ceiling for the admin-configurable file upload size (in MB).
|
||||
# Self-hosted customers can raise or lower this via the environment variable.
|
||||
_raw_max_upload_size_mb = int(os.environ.get("MAX_ALLOWED_UPLOAD_SIZE_MB", "250"))
|
||||
if _raw_max_upload_size_mb < 0:
|
||||
logger.warning(
|
||||
"MAX_ALLOWED_UPLOAD_SIZE_MB=%d is negative; falling back to 250",
|
||||
_raw_max_upload_size_mb,
|
||||
)
|
||||
_raw_max_upload_size_mb = 250
|
||||
MAX_ALLOWED_UPLOAD_SIZE_MB = _raw_max_upload_size_mb
|
||||
|
||||
# Default fallback for the per-user file upload size limit (in MB) when no
|
||||
# admin-configured value exists. Clamped to MAX_ALLOWED_UPLOAD_SIZE_MB at
|
||||
# runtime so this never silently exceeds the hard ceiling.
|
||||
_raw_default_upload_size_mb = int(
|
||||
os.environ.get("DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB", "100")
|
||||
)
|
||||
if _raw_default_upload_size_mb < 0:
|
||||
logger.warning(
|
||||
"DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB=%d is negative; falling back to 100",
|
||||
_raw_default_upload_size_mb,
|
||||
)
|
||||
_raw_default_upload_size_mb = 100
|
||||
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB = _raw_default_upload_size_mb
|
||||
GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
|
||||
os.environ.get("GENERATIVE_MODEL_ACCESS_CHECK_FREQ") or 86400
|
||||
) # 1 day
|
||||
@@ -61,17 +86,6 @@ CACHE_BACKEND = CacheBackendType(
|
||||
os.environ.get("CACHE_BACKEND", CacheBackendType.REDIS)
|
||||
)
|
||||
|
||||
# Maximum token count for a single uploaded file. Files exceeding this are rejected.
|
||||
# Defaults to 100k tokens (or 10M when vector DB is disabled).
|
||||
_DEFAULT_FILE_TOKEN_LIMIT = 10_000_000 if DISABLE_VECTOR_DB else 100_000
|
||||
FILE_TOKEN_COUNT_THRESHOLD = int(
|
||||
os.environ.get("FILE_TOKEN_COUNT_THRESHOLD", str(_DEFAULT_FILE_TOKEN_LIMIT))
|
||||
)
|
||||
|
||||
# Maximum upload size for a single user file (chat/projects) in MB.
|
||||
USER_FILE_MAX_UPLOAD_SIZE_MB = int(os.environ.get("USER_FILE_MAX_UPLOAD_SIZE_MB") or 50)
|
||||
USER_FILE_MAX_UPLOAD_SIZE_BYTES = USER_FILE_MAX_UPLOAD_SIZE_MB * 1024 * 1024
|
||||
|
||||
# If set to true, will show extra/uncommon connectors in the "Other" category
|
||||
SHOW_EXTRA_CONNECTORS = os.environ.get("SHOW_EXTRA_CONNECTORS", "").lower() == "true"
|
||||
|
||||
@@ -791,6 +805,10 @@ MINI_CHUNK_SIZE = 150
|
||||
# This is the number of regular chunks per large chunk
|
||||
LARGE_CHUNK_RATIO = 4
|
||||
|
||||
# The maximum number of chunks that can be held for 1 document processing batch
|
||||
# The purpose of this is to set an upper bound on memory usage
|
||||
MAX_CHUNKS_PER_DOC_BATCH = int(os.environ.get("MAX_CHUNKS_PER_DOC_BATCH") or 1000)
|
||||
|
||||
# Include the document level metadata in each chunk. If the metadata is too long, then it is thrown out
|
||||
# We don't want the metadata to overwhelm the actual contents of the chunk
|
||||
SKIP_METADATA_IN_CHUNK = os.environ.get("SKIP_METADATA_IN_CHUNK", "").lower() == "true"
|
||||
|
||||
@@ -890,8 +890,8 @@ class ConfluenceConnector(
|
||||
|
||||
def _retrieve_all_slim_docs(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
include_permissions: bool = True,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
@@ -915,8 +915,8 @@ class ConfluenceConnector(
|
||||
self.confluence_client, doc_id, restrictions, ancestors
|
||||
) or space_level_access_info.get(page_space_key)
|
||||
|
||||
# Query pages
|
||||
page_query = self.base_cql_page_query + self.cql_label_filter
|
||||
# Query pages (with optional time filtering for indexing_start)
|
||||
page_query = self._construct_page_cql_query(start, end)
|
||||
for page in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=page_query,
|
||||
expand=restrictions_expand,
|
||||
@@ -950,7 +950,9 @@ class ConfluenceConnector(
|
||||
|
||||
# Query attachments for each page
|
||||
page_hierarchy_node_yielded = False
|
||||
attachment_query = self._construct_attachment_query(_get_page_id(page))
|
||||
attachment_query = self._construct_attachment_query(
|
||||
_get_page_id(page), start, end
|
||||
)
|
||||
for attachment in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=attachment_query,
|
||||
expand=restrictions_expand,
|
||||
|
||||
@@ -10,6 +10,7 @@ from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from jira import JIRA
|
||||
from jira.exceptions import JIRAError
|
||||
from jira.resources import Issue
|
||||
@@ -239,29 +240,53 @@ def enhanced_search_ids(
|
||||
)
|
||||
|
||||
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO: move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
def _bulk_fetch_request(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Raw POST to the bulkfetch endpoint. Returns the list of raw issue dicts."""
|
||||
bulk_fetch_path = jira_client._get_url("issue/bulkfetch")
|
||||
|
||||
# Prepare the payload according to Jira API v3 specification
|
||||
payload: dict[str, Any] = {"issueIdsOrKeys": issue_ids}
|
||||
|
||||
# Only restrict fields if specified, might want to explicitly do this in the future
|
||||
# to avoid reading unnecessary data
|
||||
payload["fields"] = fields.split(",") if fields else ["*all"]
|
||||
|
||||
resp = jira_client._session.post(bulk_fetch_path, json=payload)
|
||||
return resp.json()["issues"]
|
||||
|
||||
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO(evan): move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
|
||||
try:
|
||||
response = jira_client._session.post(bulk_fetch_path, json=payload).json()
|
||||
raw_issues = _bulk_fetch_request(jira_client, issue_ids, fields)
|
||||
except requests.exceptions.JSONDecodeError:
|
||||
if len(issue_ids) <= 1:
|
||||
logger.exception(
|
||||
f"Jira bulk-fetch response for issue(s) {issue_ids} could not "
|
||||
f"be decoded as JSON (response too large or truncated)."
|
||||
)
|
||||
raise
|
||||
|
||||
mid = len(issue_ids) // 2
|
||||
logger.warning(
|
||||
f"Jira bulk-fetch JSON decode failed for batch of {len(issue_ids)} issues. "
|
||||
f"Splitting into sub-batches of {mid} and {len(issue_ids) - mid}."
|
||||
)
|
||||
left = bulk_fetch_issues(jira_client, issue_ids[:mid], fields)
|
||||
right = bulk_fetch_issues(jira_client, issue_ids[mid:], fields)
|
||||
return left + right
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching issues: {e}")
|
||||
raise e
|
||||
raise
|
||||
|
||||
return [
|
||||
Issue(jira_client._options, jira_client._session, raw=issue)
|
||||
for issue in response["issues"]
|
||||
for issue in raw_issues
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1765,7 +1765,11 @@ class SharepointConnector(
|
||||
checkpoint.current_drive_delta_next_link = None
|
||||
checkpoint.seen_document_ids.clear()
|
||||
|
||||
def _fetch_slim_documents_from_sharepoint(self) -> GenerateSlimDocumentOutput:
|
||||
def _fetch_slim_documents_from_sharepoint(
|
||||
self,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
site_descriptors = self._filter_excluded_sites(
|
||||
self.site_descriptors or self.fetch_sites()
|
||||
)
|
||||
@@ -1786,7 +1790,9 @@ class SharepointConnector(
|
||||
# Process site documents if flag is True
|
||||
if self.include_site_documents:
|
||||
for driveitem, drive_name, drive_web_url in self._fetch_driveitems(
|
||||
site_descriptor=site_descriptor
|
||||
site_descriptor=site_descriptor,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
if self._is_driveitem_excluded(driveitem):
|
||||
logger.debug(f"Excluding by path denylist: {driveitem.web_url}")
|
||||
@@ -1841,7 +1847,9 @@ class SharepointConnector(
|
||||
|
||||
# Process site pages if flag is True
|
||||
if self.include_site_pages:
|
||||
site_pages = self._fetch_site_pages(site_descriptor)
|
||||
site_pages = self._fetch_site_pages(
|
||||
site_descriptor, start=start, end=end
|
||||
)
|
||||
for site_page in site_pages:
|
||||
logger.debug(
|
||||
f"Processing site page: {site_page.get('webUrl', site_page.get('name', 'Unknown'))}"
|
||||
@@ -2565,12 +2573,22 @@ class SharepointConnector(
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None, # noqa: ARG002
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
|
||||
yield from self._fetch_slim_documents_from_sharepoint()
|
||||
start_dt = (
|
||||
datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
if start is not None
|
||||
else None
|
||||
)
|
||||
end_dt = (
|
||||
datetime.fromtimestamp(end, tz=timezone.utc) if end is not None else None
|
||||
)
|
||||
yield from self._fetch_slim_documents_from_sharepoint(
|
||||
start=start_dt,
|
||||
end=end_dt,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -516,6 +516,8 @@ def _get_all_doc_ids(
|
||||
] = default_msg_filter,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
workspace_url: str | None = None,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
"""
|
||||
Get all document ids in the workspace, channel by channel
|
||||
@@ -546,6 +548,8 @@ def _get_all_doc_ids(
|
||||
client=client,
|
||||
channel=channel,
|
||||
callback=callback,
|
||||
oldest=str(start) if start else None, # 0.0 -> None intentionally
|
||||
latest=str(end) if end is not None else None,
|
||||
)
|
||||
|
||||
for message_batch in channel_message_batches:
|
||||
@@ -847,8 +851,8 @@ class SlackConnector(
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
if self.client is None:
|
||||
@@ -861,6 +865,8 @@ class SlackConnector(
|
||||
msg_filter_func=self.msg_filter_func,
|
||||
callback=callback,
|
||||
workspace_url=self._workspace_url,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
def _load_from_checkpoint(
|
||||
|
||||
@@ -617,6 +617,92 @@ def reserve_message_id(
|
||||
return empty_message
|
||||
|
||||
|
||||
def reserve_multi_model_message_ids(
|
||||
db_session: Session,
|
||||
chat_session_id: UUID,
|
||||
parent_message_id: int,
|
||||
model_display_names: list[str],
|
||||
) -> list[ChatMessage]:
|
||||
"""Reserve N assistant message placeholders for multi-model parallel streaming.
|
||||
|
||||
All messages share the same parent (the user message). The parent's
|
||||
latest_child_message_id points to the LAST reserved message so that the
|
||||
default history-chain walker picks it up.
|
||||
"""
|
||||
reserved: list[ChatMessage] = []
|
||||
for display_name in model_display_names:
|
||||
msg = ChatMessage(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=parent_message_id,
|
||||
latest_child_message_id=None,
|
||||
message="Response was terminated prior to completion, try regenerating.",
|
||||
token_count=15, # placeholder; updated on completion by llm_loop_completion_handle
|
||||
message_type=MessageType.ASSISTANT,
|
||||
model_display_name=display_name,
|
||||
)
|
||||
db_session.add(msg)
|
||||
reserved.append(msg)
|
||||
|
||||
# Flush to assign IDs without committing yet
|
||||
db_session.flush()
|
||||
|
||||
# Point parent's latest_child to the last reserved message
|
||||
parent = (
|
||||
db_session.query(ChatMessage)
|
||||
.filter(ChatMessage.id == parent_message_id)
|
||||
.first()
|
||||
)
|
||||
if parent:
|
||||
parent.latest_child_message_id = reserved[-1].id
|
||||
|
||||
db_session.commit()
|
||||
return reserved
|
||||
|
||||
|
||||
def set_preferred_response(
|
||||
db_session: Session,
|
||||
user_message_id: int,
|
||||
preferred_assistant_message_id: int,
|
||||
) -> None:
|
||||
"""Mark one assistant response as the user's preferred choice in a multi-model turn.
|
||||
|
||||
Also advances ``latest_child_message_id`` so the preferred response becomes
|
||||
the active branch for any subsequent messages in the conversation.
|
||||
|
||||
Args:
|
||||
db_session: Active database session.
|
||||
user_message_id: Primary key of the ``USER``-type ``ChatMessage`` whose
|
||||
preferred response is being set.
|
||||
preferred_assistant_message_id: Primary key of the ``ASSISTANT``-type
|
||||
``ChatMessage`` to prefer. Must be a direct child of ``user_message_id``.
|
||||
|
||||
Raises:
|
||||
ValueError: If either message is not found, if ``user_message_id`` does not
|
||||
refer to a USER message, or if the assistant message is not a direct child
|
||||
of the user message.
|
||||
"""
|
||||
user_msg = db_session.get(ChatMessage, user_message_id)
|
||||
if user_msg is None:
|
||||
raise ValueError(f"User message {user_message_id} not found")
|
||||
if user_msg.message_type != MessageType.USER:
|
||||
raise ValueError(f"Message {user_message_id} is not a user message")
|
||||
|
||||
assistant_msg = db_session.get(ChatMessage, preferred_assistant_message_id)
|
||||
if assistant_msg is None:
|
||||
raise ValueError(
|
||||
f"Assistant message {preferred_assistant_message_id} not found"
|
||||
)
|
||||
if assistant_msg.parent_message_id != user_message_id:
|
||||
raise ValueError(
|
||||
f"Assistant message {preferred_assistant_message_id} is not a child "
|
||||
f"of user message {user_message_id}"
|
||||
)
|
||||
|
||||
user_msg.preferred_response_id = preferred_assistant_message_id
|
||||
user_msg.latest_child_message_id = preferred_assistant_message_id
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def create_new_chat_message(
|
||||
chat_session_id: UUID,
|
||||
parent_message: ChatMessage,
|
||||
@@ -839,6 +925,8 @@ def translate_db_message_to_chat_message_detail(
|
||||
error=chat_message.error,
|
||||
current_feedback=current_feedback,
|
||||
processing_duration_seconds=chat_message.processing_duration_seconds,
|
||||
preferred_response_id=chat_message.preferred_response_id,
|
||||
model_display_name=chat_message.model_display_name,
|
||||
)
|
||||
|
||||
return chat_msg_detail
|
||||
|
||||
@@ -3135,8 +3135,6 @@ class VoiceProvider(Base):
|
||||
is_default_stt: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
is_default_tts: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
@@ -17,39 +17,30 @@ MAX_VOICE_PLAYBACK_SPEED = 2.0
|
||||
def fetch_voice_providers(db_session: Session) -> list[VoiceProvider]:
|
||||
"""Fetch all voice providers."""
|
||||
return list(
|
||||
db_session.scalars(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
.order_by(VoiceProvider.name)
|
||||
).all()
|
||||
db_session.scalars(select(VoiceProvider).order_by(VoiceProvider.name)).all()
|
||||
)
|
||||
|
||||
|
||||
def fetch_voice_provider_by_id(
|
||||
db_session: Session, provider_id: int, include_deleted: bool = False
|
||||
db_session: Session, provider_id: int
|
||||
) -> VoiceProvider | None:
|
||||
"""Fetch a voice provider by ID."""
|
||||
stmt = select(VoiceProvider).where(VoiceProvider.id == provider_id)
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(VoiceProvider.deleted.is_(False))
|
||||
return db_session.scalar(stmt)
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider).where(VoiceProvider.id == provider_id)
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_stt_provider(db_session: Session) -> VoiceProvider | None:
|
||||
"""Fetch the default STT provider."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.is_default_stt.is_(True))
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
select(VoiceProvider).where(VoiceProvider.is_default_stt.is_(True))
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_tts_provider(db_session: Session) -> VoiceProvider | None:
|
||||
"""Fetch the default TTS provider."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.is_default_tts.is_(True))
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
select(VoiceProvider).where(VoiceProvider.is_default_tts.is_(True))
|
||||
)
|
||||
|
||||
|
||||
@@ -58,9 +49,7 @@ def fetch_voice_provider_by_type(
|
||||
) -> VoiceProvider | None:
|
||||
"""Fetch a voice provider by type."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.provider_type == provider_type)
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
select(VoiceProvider).where(VoiceProvider.provider_type == provider_type)
|
||||
)
|
||||
|
||||
|
||||
@@ -119,10 +108,10 @@ def upsert_voice_provider(
|
||||
|
||||
|
||||
def delete_voice_provider(db_session: Session, provider_id: int) -> None:
|
||||
"""Soft-delete a voice provider by ID."""
|
||||
"""Delete a voice provider by ID."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider:
|
||||
provider.deleted = True
|
||||
db_session.delete(provider)
|
||||
db_session.flush()
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ accidentally reaches the vector DB layer will fail loudly instead of timing
|
||||
out against a nonexistent Vespa/OpenSearch instance.
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
from onyx.context.search.models import IndexFilters
|
||||
@@ -66,7 +67,7 @@ class DisabledDocumentIndex(DocumentIndex):
|
||||
# ------------------------------------------------------------------
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk], # noqa: ARG002
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk], # noqa: ARG002
|
||||
index_batch_params: IndexBatchParams, # noqa: ARG002
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import abc
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
@@ -206,7 +207,7 @@ class Indexable(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
"""
|
||||
@@ -226,8 +227,8 @@ class Indexable(abc.ABC):
|
||||
it is done automatically outside of this code.
|
||||
|
||||
Parameters:
|
||||
- chunks: Document chunks with all of the information needed for indexing to the document
|
||||
index.
|
||||
- chunks: Document chunks with all of the information needed for
|
||||
indexing to the document index.
|
||||
- tenant_id: The tenant id of the user whose chunks are being indexed
|
||||
- large_chunks_enabled: Whether large chunks are enabled
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import abc
|
||||
from collections.abc import Iterable
|
||||
from typing import Self
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -209,10 +210,10 @@ class Indexable(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata,
|
||||
) -> list[DocumentInsertionRecord]:
|
||||
"""Indexes a list of document chunks into the document index.
|
||||
"""Indexes an iterable of document chunks into the document index.
|
||||
|
||||
This is often a batch operation including chunks from multiple
|
||||
documents.
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from opensearchpy import NotFoundError
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH
|
||||
from onyx.configs.app_configs import VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT
|
||||
from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
@@ -351,7 +352,7 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[OldDocumentInsertionRecord]:
|
||||
"""
|
||||
@@ -647,10 +648,10 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata, # noqa: ARG002
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata,
|
||||
) -> list[DocumentInsertionRecord]:
|
||||
"""Indexes a list of document chunks into the document index.
|
||||
"""Indexes an iterable of document chunks into the document index.
|
||||
|
||||
Groups chunks by document ID and for each document, deletes existing
|
||||
chunks and indexes the new chunks in bulk.
|
||||
@@ -673,29 +674,34 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
document is newly indexed or had already existed and was just
|
||||
updated.
|
||||
"""
|
||||
# Group chunks by document ID.
|
||||
doc_id_to_chunks: dict[str, list[DocMetadataAwareIndexChunk]] = defaultdict(
|
||||
list
|
||||
total_chunks = sum(
|
||||
cc.new_chunk_cnt
|
||||
for cc in indexing_metadata.doc_id_to_chunk_cnt_diff.values()
|
||||
)
|
||||
for chunk in chunks:
|
||||
doc_id_to_chunks[chunk.source_document.id].append(chunk)
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Indexing {len(chunks)} chunks from {len(doc_id_to_chunks)} "
|
||||
f"[OpenSearchDocumentIndex] Indexing {total_chunks} chunks from {len(indexing_metadata.doc_id_to_chunk_cnt_diff)} "
|
||||
f"documents for index {self._index_name}."
|
||||
)
|
||||
|
||||
document_indexing_results: list[DocumentInsertionRecord] = []
|
||||
# Try to index per-document.
|
||||
for _, chunks in doc_id_to_chunks.items():
|
||||
deleted_doc_ids: set[str] = set()
|
||||
# Buffer chunks per document as they arrive from the iterable.
|
||||
# When the document ID changes flush the buffered chunks.
|
||||
current_doc_id: str | None = None
|
||||
current_chunks: list[DocMetadataAwareIndexChunk] = []
|
||||
|
||||
def _flush_chunks(doc_chunks: list[DocMetadataAwareIndexChunk]) -> None:
|
||||
assert len(doc_chunks) > 0, "doc_chunks is empty"
|
||||
|
||||
# Create a batch of OpenSearch-formatted chunks for bulk insertion.
|
||||
# Do this before deleting existing chunks to reduce the amount of
|
||||
# time the document index has no content for a given document, and
|
||||
# to reduce the chance of entering a state where we delete chunks,
|
||||
# then some error happens, and never successfully index new chunks.
|
||||
# Since we are doing this in batches, an error occurring midway
|
||||
# can result in a state where chunks are deleted and not all the
|
||||
# new chunks have been indexed.
|
||||
chunk_batch: list[DocumentChunk] = [
|
||||
_convert_onyx_chunk_to_opensearch_document(chunk) for chunk in chunks
|
||||
_convert_onyx_chunk_to_opensearch_document(chunk)
|
||||
for chunk in doc_chunks
|
||||
]
|
||||
onyx_document: Document = chunks[0].source_document
|
||||
onyx_document: Document = doc_chunks[0].source_document
|
||||
# First delete the doc's chunks from the index. This is so that
|
||||
# there are no dangling chunks in the index, in the event that the
|
||||
# new document's content contains fewer chunks than the previous
|
||||
@@ -704,22 +710,43 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
# if the chunk count has actually decreased. This assumes that
|
||||
# overlapping chunks are perfectly overwritten. If we can't
|
||||
# guarantee that then we need the code as-is.
|
||||
num_chunks_deleted = self.delete(
|
||||
onyx_document.id, onyx_document.chunk_count
|
||||
)
|
||||
# If we see that chunks were deleted we assume the doc already
|
||||
# existed.
|
||||
document_insertion_record = DocumentInsertionRecord(
|
||||
document_id=onyx_document.id,
|
||||
already_existed=num_chunks_deleted > 0,
|
||||
)
|
||||
if onyx_document.id not in deleted_doc_ids:
|
||||
num_chunks_deleted = self.delete(
|
||||
onyx_document.id, onyx_document.chunk_count
|
||||
)
|
||||
deleted_doc_ids.add(onyx_document.id)
|
||||
# If we see that chunks were deleted we assume the doc already
|
||||
# existed. We record the result before bulk_index_documents
|
||||
# runs. If indexing raises, this entire result list is discarded
|
||||
# by the caller's retry logic, so early recording is safe.
|
||||
document_indexing_results.append(
|
||||
DocumentInsertionRecord(
|
||||
document_id=onyx_document.id,
|
||||
already_existed=num_chunks_deleted > 0,
|
||||
)
|
||||
)
|
||||
# Now index. This will raise if a chunk of the same ID exists, which
|
||||
# we do not expect because we should have deleted all chunks.
|
||||
self._client.bulk_index_documents(
|
||||
documents=chunk_batch,
|
||||
tenant_state=self._tenant_state,
|
||||
)
|
||||
document_indexing_results.append(document_insertion_record)
|
||||
|
||||
for chunk in chunks:
|
||||
doc_id = chunk.source_document.id
|
||||
if doc_id != current_doc_id:
|
||||
if current_chunks:
|
||||
_flush_chunks(current_chunks)
|
||||
current_doc_id = doc_id
|
||||
current_chunks = [chunk]
|
||||
elif len(current_chunks) >= MAX_CHUNKS_PER_DOC_BATCH:
|
||||
_flush_chunks(current_chunks)
|
||||
current_chunks = [chunk]
|
||||
else:
|
||||
current_chunks.append(chunk)
|
||||
|
||||
if current_chunks:
|
||||
_flush_chunks(current_chunks)
|
||||
|
||||
return document_indexing_results
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import re
|
||||
import time
|
||||
import urllib
|
||||
import zipfile
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
@@ -461,7 +462,7 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[OldDocumentInsertionRecord]:
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import random
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
@@ -8,6 +10,7 @@ import httpx
|
||||
from pydantic import BaseModel
|
||||
from retry import retry
|
||||
|
||||
from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH
|
||||
from onyx.configs.app_configs import RECENCY_BIAS_MULTIPLIER
|
||||
from onyx.configs.app_configs import RERANK_COUNT
|
||||
from onyx.configs.chat_configs import DOC_TIME_DECAY
|
||||
@@ -318,7 +321,7 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata,
|
||||
) -> list[DocumentInsertionRecord]:
|
||||
doc_id_to_chunk_cnt_diff = indexing_metadata.doc_id_to_chunk_cnt_diff
|
||||
@@ -338,22 +341,31 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
|
||||
# Vespa has restrictions on valid characters, yet document IDs come from
|
||||
# external w.r.t. this class. We need to sanitize them.
|
||||
cleaned_chunks: list[DocMetadataAwareIndexChunk] = [
|
||||
clean_chunk_id_copy(chunk) for chunk in chunks
|
||||
]
|
||||
assert len(cleaned_chunks) == len(
|
||||
chunks
|
||||
), "Bug: Cleaned chunks and input chunks have different lengths."
|
||||
#
|
||||
# Instead of materializing all cleaned chunks upfront, we stream them
|
||||
# through a generator that cleans IDs and builds the original-ID mapping
|
||||
# incrementally as chunks flow into Vespa.
|
||||
def _clean_and_track(
|
||||
chunks_iter: Iterable[DocMetadataAwareIndexChunk],
|
||||
id_map: dict[str, str],
|
||||
seen_ids: set[str],
|
||||
) -> Generator[DocMetadataAwareIndexChunk, None, None]:
|
||||
"""Cleans chunk IDs and builds the original-ID mapping
|
||||
incrementally as chunks flow through, avoiding a separate
|
||||
materialization pass."""
|
||||
for chunk in chunks_iter:
|
||||
original_id = chunk.source_document.id
|
||||
cleaned = clean_chunk_id_copy(chunk)
|
||||
cleaned_id = cleaned.source_document.id
|
||||
# Needed so the final DocumentInsertionRecord returned can have
|
||||
# the original document ID. cleaned_chunks might not contain IDs
|
||||
# exactly as callers supplied them.
|
||||
id_map[cleaned_id] = original_id
|
||||
seen_ids.add(cleaned_id)
|
||||
yield cleaned
|
||||
|
||||
# Needed so the final DocumentInsertionRecord returned can have the
|
||||
# original document ID. cleaned_chunks might not contain IDs exactly as
|
||||
# callers supplied them.
|
||||
new_document_id_to_original_document_id: dict[str, str] = dict()
|
||||
for i, cleaned_chunk in enumerate(cleaned_chunks):
|
||||
old_chunk = chunks[i]
|
||||
new_document_id_to_original_document_id[
|
||||
cleaned_chunk.source_document.id
|
||||
] = old_chunk.source_document.id
|
||||
new_document_id_to_original_document_id: dict[str, str] = {}
|
||||
all_cleaned_doc_ids: set[str] = set()
|
||||
|
||||
existing_docs: set[str] = set()
|
||||
|
||||
@@ -409,8 +421,16 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
# Insert new Vespa documents.
|
||||
for chunk_batch in batch_generator(cleaned_chunks, BATCH_SIZE):
|
||||
# Insert new Vespa documents, streaming through the cleaning
|
||||
# pipeline so chunks are never fully materialized.
|
||||
cleaned_chunks = _clean_and_track(
|
||||
chunks,
|
||||
new_document_id_to_original_document_id,
|
||||
all_cleaned_doc_ids,
|
||||
)
|
||||
for chunk_batch in batch_generator(
|
||||
cleaned_chunks, min(BATCH_SIZE, MAX_CHUNKS_PER_DOC_BATCH)
|
||||
):
|
||||
batch_index_vespa_chunks(
|
||||
chunks=chunk_batch,
|
||||
index_name=self._index_name,
|
||||
@@ -419,10 +439,6 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
all_cleaned_doc_ids: set[str] = {
|
||||
chunk.source_document.id for chunk in cleaned_chunks
|
||||
}
|
||||
|
||||
return [
|
||||
DocumentInsertionRecord(
|
||||
document_id=new_document_id_to_original_document_id[cleaned_doc_id],
|
||||
|
||||
@@ -44,6 +44,7 @@ KNOWN_OPENPYXL_BUGS = [
|
||||
"Value must be either numerical or a string containing a wildcard",
|
||||
"File contains no valid workbook part",
|
||||
"Unable to read workbook: could not read stylesheet from None",
|
||||
"Colors must be aRGB hex values",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -19,7 +19,8 @@ from onyx.db.document import update_docs_updated_at__no_commit
|
||||
from onyx.db.document_set import fetch_document_sets_for_documents
|
||||
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
|
||||
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
|
||||
from onyx.indexing.models import BuildMetadataAwareChunksResult
|
||||
from onyx.indexing.models import ChunkEnrichmentContext
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
@@ -85,14 +86,21 @@ class DocumentIndexingBatchAdapter:
|
||||
) as transaction:
|
||||
yield transaction
|
||||
|
||||
def build_metadata_aware_chunks(
|
||||
def prepare_enrichment(
|
||||
self,
|
||||
chunks_with_embeddings: list[IndexChunk],
|
||||
chunk_content_scores: list[float],
|
||||
tenant_id: str,
|
||||
context: DocumentBatchPrepareContext,
|
||||
) -> BuildMetadataAwareChunksResult:
|
||||
"""Enrich chunks with access, document sets, boosts, token counts, and hierarchy."""
|
||||
tenant_id: str,
|
||||
chunks: list[DocAwareChunk],
|
||||
) -> "DocumentChunkEnricher":
|
||||
"""Do all DB lookups once and return a per-chunk enricher."""
|
||||
updatable_ids = [doc.id for doc in context.updatable_docs]
|
||||
|
||||
doc_id_to_new_chunk_cnt: dict[str, int] = {
|
||||
doc_id: 0 for doc_id in updatable_ids
|
||||
}
|
||||
for chunk in chunks:
|
||||
if chunk.source_document.id in doc_id_to_new_chunk_cnt:
|
||||
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
|
||||
|
||||
no_access = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
@@ -102,67 +110,30 @@ class DocumentIndexingBatchAdapter:
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
updatable_ids = [doc.id for doc in context.updatable_docs]
|
||||
|
||||
doc_id_to_access_info = get_access_for_documents(
|
||||
document_ids=updatable_ids, db_session=self.db_session
|
||||
)
|
||||
doc_id_to_document_set = {
|
||||
document_id: document_sets
|
||||
for document_id, document_sets in fetch_document_sets_for_documents(
|
||||
return DocumentChunkEnricher(
|
||||
doc_id_to_access_info=get_access_for_documents(
|
||||
document_ids=updatable_ids, db_session=self.db_session
|
||||
)
|
||||
}
|
||||
|
||||
doc_id_to_previous_chunk_cnt: dict[str, int] = {
|
||||
document_id: chunk_count
|
||||
for document_id, chunk_count in fetch_chunk_counts_for_documents(
|
||||
document_ids=updatable_ids,
|
||||
db_session=self.db_session,
|
||||
)
|
||||
}
|
||||
|
||||
doc_id_to_new_chunk_cnt: dict[str, int] = {
|
||||
doc_id: 0 for doc_id in updatable_ids
|
||||
}
|
||||
for chunk in chunks_with_embeddings:
|
||||
if chunk.source_document.id in doc_id_to_new_chunk_cnt:
|
||||
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
|
||||
|
||||
# Get ancestor hierarchy node IDs for each document
|
||||
doc_id_to_ancestor_ids = self._get_ancestor_ids_for_documents(
|
||||
context.updatable_docs, tenant_id
|
||||
)
|
||||
|
||||
access_aware_chunks = [
|
||||
DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
access=doc_id_to_access_info.get(chunk.source_document.id, no_access),
|
||||
document_sets=set(
|
||||
doc_id_to_document_set.get(chunk.source_document.id, [])
|
||||
),
|
||||
user_project=[],
|
||||
personas=[],
|
||||
boost=(
|
||||
context.id_to_boost_map[chunk.source_document.id]
|
||||
if chunk.source_document.id in context.id_to_boost_map
|
||||
else DEFAULT_BOOST
|
||||
),
|
||||
tenant_id=tenant_id,
|
||||
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],
|
||||
ancestor_hierarchy_node_ids=doc_id_to_ancestor_ids[
|
||||
chunk.source_document.id
|
||||
],
|
||||
)
|
||||
for chunk_num, chunk in enumerate(chunks_with_embeddings)
|
||||
]
|
||||
|
||||
return BuildMetadataAwareChunksResult(
|
||||
chunks=access_aware_chunks,
|
||||
doc_id_to_previous_chunk_cnt=doc_id_to_previous_chunk_cnt,
|
||||
doc_id_to_new_chunk_cnt=doc_id_to_new_chunk_cnt,
|
||||
user_file_id_to_raw_text={},
|
||||
user_file_id_to_token_count={},
|
||||
),
|
||||
doc_id_to_document_set={
|
||||
document_id: document_sets
|
||||
for document_id, document_sets in fetch_document_sets_for_documents(
|
||||
document_ids=updatable_ids, db_session=self.db_session
|
||||
)
|
||||
},
|
||||
doc_id_to_ancestor_ids=self._get_ancestor_ids_for_documents(
|
||||
context.updatable_docs, tenant_id
|
||||
),
|
||||
id_to_boost_map=context.id_to_boost_map,
|
||||
doc_id_to_previous_chunk_cnt={
|
||||
document_id: chunk_count
|
||||
for document_id, chunk_count in fetch_chunk_counts_for_documents(
|
||||
document_ids=updatable_ids,
|
||||
db_session=self.db_session,
|
||||
)
|
||||
},
|
||||
doc_id_to_new_chunk_cnt=dict(doc_id_to_new_chunk_cnt),
|
||||
no_access=no_access,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
def _get_ancestor_ids_for_documents(
|
||||
@@ -203,7 +174,7 @@ class DocumentIndexingBatchAdapter:
|
||||
context: DocumentBatchPrepareContext,
|
||||
updatable_chunk_data: list[UpdatableChunkData],
|
||||
filtered_documents: list[Document],
|
||||
result: BuildMetadataAwareChunksResult,
|
||||
enrichment: ChunkEnrichmentContext,
|
||||
) -> None:
|
||||
"""Finalize DB updates, store plaintext, and mark docs as indexed."""
|
||||
updatable_ids = [doc.id for doc in context.updatable_docs]
|
||||
@@ -227,7 +198,7 @@ class DocumentIndexingBatchAdapter:
|
||||
|
||||
update_docs_chunk_count__no_commit(
|
||||
document_ids=updatable_ids,
|
||||
doc_id_to_chunk_count=result.doc_id_to_new_chunk_cnt,
|
||||
doc_id_to_chunk_count=enrichment.doc_id_to_new_chunk_cnt,
|
||||
db_session=self.db_session,
|
||||
)
|
||||
|
||||
@@ -249,3 +220,52 @@ class DocumentIndexingBatchAdapter:
|
||||
)
|
||||
|
||||
self.db_session.commit()
|
||||
|
||||
|
||||
class DocumentChunkEnricher:
|
||||
"""Pre-computed metadata for per-chunk enrichment of connector documents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
doc_id_to_access_info: dict[str, DocumentAccess],
|
||||
doc_id_to_document_set: dict[str, list[str]],
|
||||
doc_id_to_ancestor_ids: dict[str, list[int]],
|
||||
id_to_boost_map: dict[str, int],
|
||||
doc_id_to_previous_chunk_cnt: dict[str, int],
|
||||
doc_id_to_new_chunk_cnt: dict[str, int],
|
||||
no_access: DocumentAccess,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
self._doc_id_to_access_info = doc_id_to_access_info
|
||||
self._doc_id_to_document_set = doc_id_to_document_set
|
||||
self._doc_id_to_ancestor_ids = doc_id_to_ancestor_ids
|
||||
self._id_to_boost_map = id_to_boost_map
|
||||
self._no_access = no_access
|
||||
self._tenant_id = tenant_id
|
||||
self.doc_id_to_previous_chunk_cnt = doc_id_to_previous_chunk_cnt
|
||||
self.doc_id_to_new_chunk_cnt = doc_id_to_new_chunk_cnt
|
||||
|
||||
def enrich_chunk(
|
||||
self, chunk: IndexChunk, score: float
|
||||
) -> DocMetadataAwareIndexChunk:
|
||||
return DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
access=self._doc_id_to_access_info.get(
|
||||
chunk.source_document.id, self._no_access
|
||||
),
|
||||
document_sets=set(
|
||||
self._doc_id_to_document_set.get(chunk.source_document.id, [])
|
||||
),
|
||||
user_project=[],
|
||||
personas=[],
|
||||
boost=(
|
||||
self._id_to_boost_map[chunk.source_document.id]
|
||||
if chunk.source_document.id in self._id_to_boost_map
|
||||
else DEFAULT_BOOST
|
||||
),
|
||||
tenant_id=self._tenant_id,
|
||||
aggregated_chunk_boost_factor=score,
|
||||
ancestor_hierarchy_node_ids=self._doc_id_to_ancestor_ids[
|
||||
chunk.source_document.id
|
||||
],
|
||||
)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import datetime
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Generator
|
||||
from uuid import UUID
|
||||
|
||||
@@ -24,11 +27,13 @@ from onyx.db.user_file import fetch_persona_ids_for_user_files
|
||||
from onyx.db.user_file import fetch_user_project_ids_for_user_files
|
||||
from onyx.file_store.utils import store_user_file_plaintext
|
||||
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
|
||||
from onyx.indexing.models import BuildMetadataAwareChunksResult
|
||||
from onyx.indexing.models import ChunkEnrichmentContext
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.natural_language_processing.utils import count_tokens
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -101,13 +106,20 @@ class UserFileIndexingAdapter:
|
||||
f"Failed to acquire locks after {_NUM_LOCK_ATTEMPTS} attempts for user files: {[doc.id for doc in documents]}"
|
||||
)
|
||||
|
||||
def build_metadata_aware_chunks(
|
||||
def prepare_enrichment(
|
||||
self,
|
||||
chunks_with_embeddings: list[IndexChunk],
|
||||
chunk_content_scores: list[float],
|
||||
tenant_id: str,
|
||||
context: DocumentBatchPrepareContext,
|
||||
) -> BuildMetadataAwareChunksResult:
|
||||
tenant_id: str,
|
||||
chunks: list[DocAwareChunk],
|
||||
) -> UserFileChunkEnricher:
|
||||
"""Do all DB lookups and pre-compute file metadata from chunks."""
|
||||
updatable_ids = [doc.id for doc in context.updatable_docs]
|
||||
|
||||
doc_id_to_new_chunk_cnt: dict[str, int] = defaultdict(int)
|
||||
content_by_file: dict[str, list[str]] = defaultdict(list)
|
||||
for chunk in chunks:
|
||||
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
|
||||
content_by_file[chunk.source_document.id].append(chunk.content)
|
||||
|
||||
no_access = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
@@ -117,7 +129,6 @@ class UserFileIndexingAdapter:
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
updatable_ids = [doc.id for doc in context.updatable_docs]
|
||||
user_file_id_to_project_ids = fetch_user_project_ids_for_user_files(
|
||||
user_file_ids=updatable_ids,
|
||||
db_session=self.db_session,
|
||||
@@ -138,17 +149,6 @@ class UserFileIndexingAdapter:
|
||||
)
|
||||
}
|
||||
|
||||
user_file_id_to_new_chunk_cnt: dict[str, int] = {
|
||||
user_file_id: len(
|
||||
[
|
||||
chunk
|
||||
for chunk in chunks_with_embeddings
|
||||
if chunk.source_document.id == user_file_id
|
||||
]
|
||||
)
|
||||
for user_file_id in updatable_ids
|
||||
}
|
||||
|
||||
# Initialize tokenizer used for token count calculation
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
@@ -163,46 +163,30 @@ class UserFileIndexingAdapter:
|
||||
user_file_id_to_raw_text: dict[str, str] = {}
|
||||
user_file_id_to_token_count: dict[str, int | None] = {}
|
||||
for user_file_id in updatable_ids:
|
||||
user_file_chunks = [
|
||||
chunk
|
||||
for chunk in chunks_with_embeddings
|
||||
if chunk.source_document.id == user_file_id
|
||||
]
|
||||
if user_file_chunks:
|
||||
combined_content = " ".join(
|
||||
[chunk.content for chunk in user_file_chunks]
|
||||
)
|
||||
contents = content_by_file.get(user_file_id)
|
||||
if contents:
|
||||
combined_content = " ".join(contents)
|
||||
user_file_id_to_raw_text[str(user_file_id)] = combined_content
|
||||
token_count = (
|
||||
len(llm_tokenizer.encode(combined_content)) if llm_tokenizer else 0
|
||||
token_count: int = (
|
||||
count_tokens(combined_content, llm_tokenizer)
|
||||
if llm_tokenizer
|
||||
else 0
|
||||
)
|
||||
user_file_id_to_token_count[str(user_file_id)] = token_count
|
||||
else:
|
||||
user_file_id_to_raw_text[str(user_file_id)] = ""
|
||||
user_file_id_to_token_count[str(user_file_id)] = None
|
||||
|
||||
access_aware_chunks = [
|
||||
DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
access=user_file_id_to_access.get(chunk.source_document.id, no_access),
|
||||
document_sets=set(),
|
||||
user_project=user_file_id_to_project_ids.get(
|
||||
chunk.source_document.id, []
|
||||
),
|
||||
personas=user_file_id_to_persona_ids.get(chunk.source_document.id, []),
|
||||
boost=DEFAULT_BOOST,
|
||||
tenant_id=tenant_id,
|
||||
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],
|
||||
)
|
||||
for chunk_num, chunk in enumerate(chunks_with_embeddings)
|
||||
]
|
||||
|
||||
return BuildMetadataAwareChunksResult(
|
||||
chunks=access_aware_chunks,
|
||||
return UserFileChunkEnricher(
|
||||
user_file_id_to_access=user_file_id_to_access,
|
||||
user_file_id_to_project_ids=user_file_id_to_project_ids,
|
||||
user_file_id_to_persona_ids=user_file_id_to_persona_ids,
|
||||
doc_id_to_previous_chunk_cnt=user_file_id_to_previous_chunk_cnt,
|
||||
doc_id_to_new_chunk_cnt=user_file_id_to_new_chunk_cnt,
|
||||
doc_id_to_new_chunk_cnt=dict(doc_id_to_new_chunk_cnt),
|
||||
user_file_id_to_raw_text=user_file_id_to_raw_text,
|
||||
user_file_id_to_token_count=user_file_id_to_token_count,
|
||||
no_access=no_access,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
def _notify_assistant_owners_if_files_ready(
|
||||
@@ -246,8 +230,9 @@ class UserFileIndexingAdapter:
|
||||
context: DocumentBatchPrepareContext,
|
||||
updatable_chunk_data: list[UpdatableChunkData], # noqa: ARG002
|
||||
filtered_documents: list[Document], # noqa: ARG002
|
||||
result: BuildMetadataAwareChunksResult,
|
||||
enrichment: ChunkEnrichmentContext,
|
||||
) -> None:
|
||||
assert isinstance(enrichment, UserFileChunkEnricher)
|
||||
user_file_ids = [doc.id for doc in context.updatable_docs]
|
||||
|
||||
user_files = (
|
||||
@@ -263,8 +248,10 @@ class UserFileIndexingAdapter:
|
||||
user_file.last_project_sync_at = datetime.datetime.now(
|
||||
datetime.timezone.utc
|
||||
)
|
||||
user_file.chunk_count = result.doc_id_to_new_chunk_cnt[str(user_file.id)]
|
||||
user_file.token_count = result.user_file_id_to_token_count[
|
||||
user_file.chunk_count = enrichment.doc_id_to_new_chunk_cnt.get(
|
||||
str(user_file.id), 0
|
||||
)
|
||||
user_file.token_count = enrichment.user_file_id_to_token_count[
|
||||
str(user_file.id)
|
||||
]
|
||||
|
||||
@@ -276,8 +263,54 @@ class UserFileIndexingAdapter:
|
||||
# Store the plaintext in the file store for faster retrieval
|
||||
# NOTE: this creates its own session to avoid committing the overall
|
||||
# transaction.
|
||||
for user_file_id, raw_text in result.user_file_id_to_raw_text.items():
|
||||
for user_file_id, raw_text in enrichment.user_file_id_to_raw_text.items():
|
||||
store_user_file_plaintext(
|
||||
user_file_id=UUID(user_file_id),
|
||||
plaintext_content=raw_text,
|
||||
)
|
||||
|
||||
|
||||
class UserFileChunkEnricher:
|
||||
"""Pre-computed metadata for per-chunk enrichment of user-uploaded files."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_file_id_to_access: dict[str, DocumentAccess],
|
||||
user_file_id_to_project_ids: dict[str, list[int]],
|
||||
user_file_id_to_persona_ids: dict[str, list[int]],
|
||||
doc_id_to_previous_chunk_cnt: dict[str, int],
|
||||
doc_id_to_new_chunk_cnt: dict[str, int],
|
||||
user_file_id_to_raw_text: dict[str, str],
|
||||
user_file_id_to_token_count: dict[str, int | None],
|
||||
no_access: DocumentAccess,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
self._user_file_id_to_access = user_file_id_to_access
|
||||
self._user_file_id_to_project_ids = user_file_id_to_project_ids
|
||||
self._user_file_id_to_persona_ids = user_file_id_to_persona_ids
|
||||
self._no_access = no_access
|
||||
self._tenant_id = tenant_id
|
||||
self.doc_id_to_previous_chunk_cnt = doc_id_to_previous_chunk_cnt
|
||||
self.doc_id_to_new_chunk_cnt = doc_id_to_new_chunk_cnt
|
||||
self.user_file_id_to_raw_text = user_file_id_to_raw_text
|
||||
self.user_file_id_to_token_count = user_file_id_to_token_count
|
||||
|
||||
def enrich_chunk(
|
||||
self, chunk: IndexChunk, score: float
|
||||
) -> DocMetadataAwareIndexChunk:
|
||||
return DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
access=self._user_file_id_to_access.get(
|
||||
chunk.source_document.id, self._no_access
|
||||
),
|
||||
document_sets=set(),
|
||||
user_project=self._user_file_id_to_project_ids.get(
|
||||
chunk.source_document.id, []
|
||||
),
|
||||
personas=self._user_file_id_to_persona_ids.get(
|
||||
chunk.source_document.id, []
|
||||
),
|
||||
boost=DEFAULT_BOOST,
|
||||
tenant_id=self._tenant_id,
|
||||
aggregated_chunk_boost_factor=score,
|
||||
)
|
||||
|
||||
89
backend/onyx/indexing/chunk_batch_store.py
Normal file
89
backend/onyx/indexing/chunk_batch_store.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import pickle
|
||||
import shutil
|
||||
import tempfile
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
|
||||
from onyx.indexing.models import IndexChunk
|
||||
|
||||
|
||||
class ChunkBatchStore:
|
||||
"""Manages serialization of embedded chunks to a temporary directory.
|
||||
|
||||
Owns the temp directory lifetime and provides save/load/stream/scrub
|
||||
operations.
|
||||
|
||||
Use as a context manager to ensure cleanup::
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
store.save(chunks, batch_idx=0)
|
||||
for chunk in store.stream():
|
||||
...
|
||||
"""
|
||||
|
||||
_EXT = ".pkl"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._tmpdir: Path | None = None
|
||||
|
||||
# -- context manager -----------------------------------------------------
|
||||
|
||||
def __enter__(self) -> "ChunkBatchStore":
|
||||
self._tmpdir = Path(tempfile.mkdtemp(prefix="onyx_embeddings_"))
|
||||
return self
|
||||
|
||||
def __exit__(self, *_exc: object) -> None:
|
||||
if self._tmpdir is not None:
|
||||
shutil.rmtree(self._tmpdir, ignore_errors=True)
|
||||
self._tmpdir = None
|
||||
|
||||
@property
|
||||
def _dir(self) -> Path:
|
||||
assert self._tmpdir is not None, "ChunkBatchStore used outside context manager"
|
||||
return self._tmpdir
|
||||
|
||||
# -- storage primitives --------------------------------------------------
|
||||
|
||||
def save(self, chunks: list[IndexChunk], batch_idx: int) -> None:
|
||||
"""Serialize a batch of embedded chunks to disk."""
|
||||
with open(self._dir / f"batch_{batch_idx}{self._EXT}", "wb") as f:
|
||||
pickle.dump(chunks, f)
|
||||
|
||||
def _load(self, batch_file: Path) -> list[IndexChunk]:
|
||||
"""Deserialize a batch of embedded chunks from a file."""
|
||||
with open(batch_file, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
def _batch_files(self) -> list[Path]:
|
||||
"""Return batch files sorted by numeric index."""
|
||||
return sorted(
|
||||
self._dir.glob(f"batch_*{self._EXT}"),
|
||||
key=lambda p: int(p.stem.removeprefix("batch_")),
|
||||
)
|
||||
|
||||
# -- higher-level operations ---------------------------------------------
|
||||
|
||||
def stream(self) -> Iterator[IndexChunk]:
|
||||
"""Yield all chunks across all batch files.
|
||||
|
||||
Each call returns a fresh generator, so the data can be iterated
|
||||
multiple times (e.g. once per document index).
|
||||
"""
|
||||
for batch_file in self._batch_files():
|
||||
yield from self._load(batch_file)
|
||||
|
||||
def scrub_failed_docs(self, failed_doc_ids: set[str]) -> None:
|
||||
"""Remove chunks belonging to *failed_doc_ids* from all batch files.
|
||||
|
||||
When a document fails embedding in batch N, earlier batches may
|
||||
already contain successfully embedded chunks for that document.
|
||||
This ensures the output is all-or-nothing per document.
|
||||
"""
|
||||
for batch_file in self._batch_files():
|
||||
batch_chunks = self._load(batch_file)
|
||||
cleaned = [
|
||||
c for c in batch_chunks if c.source_document.id not in failed_doc_ids
|
||||
]
|
||||
if len(cleaned) != len(batch_chunks):
|
||||
with open(batch_file, "wb") as f:
|
||||
pickle.dump(cleaned, f)
|
||||
@@ -1,5 +1,8 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -9,6 +12,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.app_configs import DEFAULT_CONTEXTUAL_RAG_LLM_NAME
|
||||
from onyx.configs.app_configs import DEFAULT_CONTEXTUAL_RAG_LLM_PROVIDER
|
||||
from onyx.configs.app_configs import ENABLE_CONTEXTUAL_RAG
|
||||
from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH
|
||||
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
|
||||
from onyx.configs.app_configs import MAX_TOKENS_FOR_FULL_INCLUSION
|
||||
from onyx.configs.app_configs import USE_CHUNK_SUMMARY
|
||||
@@ -43,10 +47,12 @@ from onyx.document_index.interfaces import DocumentMetadata
|
||||
from onyx.document_index.interfaces import IndexBatchParams
|
||||
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.indexing.chunk_batch_store import ChunkBatchStore
|
||||
from onyx.indexing.chunker import Chunker
|
||||
from onyx.indexing.embedder import embed_chunks_with_failure_handling
|
||||
from onyx.indexing.embedder import IndexingEmbedder
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexingBatchAdapter
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
|
||||
@@ -63,6 +69,7 @@ from onyx.natural_language_processing.utils import tokenizer_trim_middle
|
||||
from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT1
|
||||
from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT2
|
||||
from onyx.prompts.contextual_retrieval import DOCUMENT_SUMMARY_PROMPT
|
||||
from onyx.utils.batching import batch_generator
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_documents_for_postgres
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
@@ -91,6 +98,20 @@ class IndexingPipelineResult(BaseModel):
|
||||
|
||||
failures: list[ConnectorFailure]
|
||||
|
||||
@classmethod
|
||||
def empty(cls, total_docs: int) -> "IndexingPipelineResult":
|
||||
return cls(
|
||||
new_docs=0,
|
||||
total_docs=total_docs,
|
||||
total_chunks=0,
|
||||
failures=[],
|
||||
)
|
||||
|
||||
|
||||
class ChunkEmbeddingResult(BaseModel):
|
||||
successful_chunk_ids: list[tuple[int, str]] # (chunk_id, document_id)
|
||||
connector_failures: list[ConnectorFailure]
|
||||
|
||||
|
||||
class IndexingPipelineProtocol(Protocol):
|
||||
def __call__(
|
||||
@@ -139,6 +160,110 @@ def _upsert_documents_in_db(
|
||||
)
|
||||
|
||||
|
||||
def _get_failed_doc_ids(failures: list[ConnectorFailure]) -> set[str]:
|
||||
"""Extract document IDs from a list of connector failures."""
|
||||
return {f.failed_document.document_id for f in failures if f.failed_document}
|
||||
|
||||
|
||||
def _embed_chunks_to_store(
|
||||
chunks: list[DocAwareChunk],
|
||||
embedder: IndexingEmbedder,
|
||||
tenant_id: str,
|
||||
request_id: str | None,
|
||||
store: ChunkBatchStore,
|
||||
) -> ChunkEmbeddingResult:
|
||||
"""Embed chunks in batches, spilling each batch to *store*.
|
||||
|
||||
If a document fails embedding in any batch, its chunks are excluded from
|
||||
all batches (including earlier ones already written) so that the output
|
||||
is all-or-nothing per document.
|
||||
"""
|
||||
successful_chunk_ids: list[tuple[int, str]] = []
|
||||
all_embedding_failures: list[ConnectorFailure] = []
|
||||
# Track failed doc IDs across all batches so that a failure in batch N
|
||||
# causes chunks for that doc to be skipped in batch N+1 and stripped
|
||||
# from earlier batches.
|
||||
all_failed_doc_ids: set[str] = set()
|
||||
|
||||
for batch_idx, chunk_batch in enumerate(
|
||||
batch_generator(chunks, MAX_CHUNKS_PER_DOC_BATCH)
|
||||
):
|
||||
# Skip chunks belonging to documents that failed in earlier batches.
|
||||
chunk_batch = [
|
||||
c for c in chunk_batch if c.source_document.id not in all_failed_doc_ids
|
||||
]
|
||||
if not chunk_batch:
|
||||
continue
|
||||
|
||||
logger.debug(f"Embedding batch {batch_idx}: {len(chunk_batch)} chunks")
|
||||
|
||||
chunks_with_embeddings, embedding_failures = embed_chunks_with_failure_handling(
|
||||
chunks=chunk_batch,
|
||||
embedder=embedder,
|
||||
tenant_id=tenant_id,
|
||||
request_id=request_id,
|
||||
)
|
||||
all_embedding_failures.extend(embedding_failures)
|
||||
all_failed_doc_ids.update(_get_failed_doc_ids(embedding_failures))
|
||||
|
||||
# Only keep successfully embedded chunks for non-failed docs.
|
||||
chunks_with_embeddings = [
|
||||
c
|
||||
for c in chunks_with_embeddings
|
||||
if c.source_document.id not in all_failed_doc_ids
|
||||
]
|
||||
|
||||
successful_chunk_ids.extend(
|
||||
(c.chunk_id, c.source_document.id) for c in chunks_with_embeddings
|
||||
)
|
||||
|
||||
store.save(chunks_with_embeddings, batch_idx)
|
||||
del chunks_with_embeddings
|
||||
|
||||
# Scrub earlier batches for docs that failed in later batches.
|
||||
if all_failed_doc_ids:
|
||||
store.scrub_failed_docs(all_failed_doc_ids)
|
||||
successful_chunk_ids = [
|
||||
(chunk_id, doc_id)
|
||||
for chunk_id, doc_id in successful_chunk_ids
|
||||
if doc_id not in all_failed_doc_ids
|
||||
]
|
||||
|
||||
return ChunkEmbeddingResult(
|
||||
successful_chunk_ids=successful_chunk_ids,
|
||||
connector_failures=all_embedding_failures,
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def embed_and_stream(
|
||||
chunks: list[DocAwareChunk],
|
||||
embedder: IndexingEmbedder,
|
||||
tenant_id: str,
|
||||
request_id: str | None,
|
||||
) -> Generator[tuple[ChunkEmbeddingResult, ChunkBatchStore], None, None]:
|
||||
"""Embed chunks to disk and yield a ``(result, store)`` pair.
|
||||
|
||||
The store owns the temp directory — files are cleaned up when the context
|
||||
manager exits.
|
||||
|
||||
Usage::
|
||||
|
||||
with embed_and_stream(chunks, embedder, tenant_id, req_id) as (result, store):
|
||||
for chunk in store.stream():
|
||||
...
|
||||
"""
|
||||
with ChunkBatchStore() as store:
|
||||
result = _embed_chunks_to_store(
|
||||
chunks=chunks,
|
||||
embedder=embedder,
|
||||
tenant_id=tenant_id,
|
||||
request_id=request_id,
|
||||
store=store,
|
||||
)
|
||||
yield result, store
|
||||
|
||||
|
||||
def get_doc_ids_to_update(
|
||||
documents: list[Document], db_docs: list[DBDocument]
|
||||
) -> list[Document]:
|
||||
@@ -637,6 +762,29 @@ def add_contextual_summaries(
|
||||
return chunks
|
||||
|
||||
|
||||
def _verify_indexing_completeness(
|
||||
insertion_records: list[DocumentInsertionRecord],
|
||||
write_failures: list[ConnectorFailure],
|
||||
embedding_failed_doc_ids: set[str],
|
||||
updatable_ids: list[str],
|
||||
document_index_name: str,
|
||||
) -> None:
|
||||
"""Verify that every updatable document was either indexed or reported as failed."""
|
||||
all_returned_doc_ids = (
|
||||
{r.document_id for r in insertion_records}
|
||||
| {f.failed_document.document_id for f in write_failures if f.failed_document}
|
||||
| embedding_failed_doc_ids
|
||||
)
|
||||
if all_returned_doc_ids != set(updatable_ids):
|
||||
raise RuntimeError(
|
||||
f"Some documents were not successfully indexed. "
|
||||
f"Updatable IDs: {updatable_ids}, "
|
||||
f"Returned IDs: {all_returned_doc_ids}. "
|
||||
f"This should never happen. "
|
||||
f"This occured for document index {document_index_name}"
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(debug_only=True)
|
||||
def index_doc_batch(
|
||||
*,
|
||||
@@ -672,12 +820,7 @@ def index_doc_batch(
|
||||
filtered_documents = filter_fnc(document_batch)
|
||||
context = adapter.prepare(filtered_documents, ignore_time_skip)
|
||||
if not context:
|
||||
return IndexingPipelineResult(
|
||||
new_docs=0,
|
||||
total_docs=len(filtered_documents),
|
||||
total_chunks=0,
|
||||
failures=[],
|
||||
)
|
||||
return IndexingPipelineResult.empty(len(filtered_documents))
|
||||
|
||||
# Convert documents to IndexingDocument objects with processed section
|
||||
# logger.debug("Processing image sections")
|
||||
@@ -716,117 +859,99 @@ def index_doc_batch(
|
||||
)
|
||||
|
||||
logger.debug("Starting embedding")
|
||||
chunks_with_embeddings, embedding_failures = (
|
||||
embed_chunks_with_failure_handling(
|
||||
chunks=chunks,
|
||||
embedder=embedder,
|
||||
tenant_id=tenant_id,
|
||||
request_id=request_id,
|
||||
)
|
||||
if chunks
|
||||
else ([], [])
|
||||
)
|
||||
with embed_and_stream(chunks, embedder, tenant_id, request_id) as (
|
||||
embedding_result,
|
||||
chunk_store,
|
||||
):
|
||||
updatable_ids = [doc.id for doc in context.updatable_docs]
|
||||
updatable_chunk_data = [
|
||||
UpdatableChunkData(
|
||||
chunk_id=chunk_id,
|
||||
document_id=document_id,
|
||||
boost_score=1.0,
|
||||
)
|
||||
for chunk_id, document_id in embedding_result.successful_chunk_ids
|
||||
]
|
||||
|
||||
chunk_content_scores = [1.0] * len(chunks_with_embeddings)
|
||||
|
||||
updatable_ids = [doc.id for doc in context.updatable_docs]
|
||||
updatable_chunk_data = [
|
||||
UpdatableChunkData(
|
||||
chunk_id=chunk.chunk_id,
|
||||
document_id=chunk.source_document.id,
|
||||
boost_score=score,
|
||||
)
|
||||
for chunk, score in zip(chunks_with_embeddings, chunk_content_scores)
|
||||
]
|
||||
|
||||
# Acquires a lock on the documents so that no other process can modify them
|
||||
# NOTE: don't need to acquire till here, since this is when the actual race condition
|
||||
# with Vespa can occur.
|
||||
with adapter.lock_context(context.updatable_docs):
|
||||
# we're concerned about race conditions where multiple simultaneous indexings might result
|
||||
# in one set of metadata overwriting another one in vespa.
|
||||
# we still write data here for the immediate and most likely correct sync, but
|
||||
# to resolve this, an update of the last modified field at the end of this loop
|
||||
# always triggers a final metadata sync via the celery queue
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=chunks_with_embeddings,
|
||||
chunk_content_scores=chunk_content_scores,
|
||||
tenant_id=tenant_id,
|
||||
context=context,
|
||||
embedding_failed_doc_ids = _get_failed_doc_ids(
|
||||
embedding_result.connector_failures
|
||||
)
|
||||
|
||||
short_descriptor_list = [chunk.to_short_descriptor() for chunk in result.chunks]
|
||||
short_descriptor_log = str(short_descriptor_list)[:1024]
|
||||
logger.debug(f"Indexing the following chunks: {short_descriptor_log}")
|
||||
# Filter to only successfully embedded chunks so
|
||||
# doc_id_to_new_chunk_cnt reflects what's actually written to Vespa.
|
||||
embedded_chunks = [
|
||||
c for c in chunks if c.source_document.id not in embedding_failed_doc_ids
|
||||
]
|
||||
|
||||
primary_doc_idx_insertion_records: list[DocumentInsertionRecord] | None = None
|
||||
primary_doc_idx_vector_db_write_failures: list[ConnectorFailure] | None = None
|
||||
for document_index in document_indices:
|
||||
# A document will not be spread across different batches, so all the
|
||||
# documents with chunks in this set, are fully represented by the chunks
|
||||
# in this set
|
||||
(
|
||||
insertion_records,
|
||||
vector_db_write_failures,
|
||||
) = write_chunks_to_vector_db_with_backoff(
|
||||
document_index=document_index,
|
||||
chunks=result.chunks,
|
||||
index_batch_params=IndexBatchParams(
|
||||
doc_id_to_previous_chunk_cnt=result.doc_id_to_previous_chunk_cnt,
|
||||
doc_id_to_new_chunk_cnt=result.doc_id_to_new_chunk_cnt,
|
||||
tenant_id=tenant_id,
|
||||
large_chunks_enabled=chunker.enable_large_chunks,
|
||||
),
|
||||
# Acquires a lock on the documents so that no other process can modify
|
||||
# them. Not needed until here, since this is when the actual race
|
||||
# condition with vector db can occur.
|
||||
with adapter.lock_context(context.updatable_docs):
|
||||
enricher = adapter.prepare_enrichment(
|
||||
context=context,
|
||||
tenant_id=tenant_id,
|
||||
chunks=embedded_chunks,
|
||||
)
|
||||
|
||||
all_returned_doc_ids: set[str] = (
|
||||
{record.document_id for record in insertion_records}
|
||||
.union(
|
||||
{
|
||||
record.failed_document.document_id
|
||||
for record in vector_db_write_failures
|
||||
if record.failed_document
|
||||
}
|
||||
)
|
||||
.union(
|
||||
{
|
||||
record.failed_document.document_id
|
||||
for record in embedding_failures
|
||||
if record.failed_document
|
||||
}
|
||||
)
|
||||
index_batch_params = IndexBatchParams(
|
||||
doc_id_to_previous_chunk_cnt=enricher.doc_id_to_previous_chunk_cnt,
|
||||
doc_id_to_new_chunk_cnt=enricher.doc_id_to_new_chunk_cnt,
|
||||
tenant_id=tenant_id,
|
||||
large_chunks_enabled=chunker.enable_large_chunks,
|
||||
)
|
||||
if all_returned_doc_ids != set(updatable_ids):
|
||||
raise RuntimeError(
|
||||
f"Some documents were not successfully indexed. "
|
||||
f"Updatable IDs: {updatable_ids}, "
|
||||
f"Returned IDs: {all_returned_doc_ids}. "
|
||||
"This should never happen."
|
||||
f"This occured for document index {document_index.__class__.__name__}"
|
||||
)
|
||||
# We treat the first document index we got as the primary one used
|
||||
# for reporting the state of indexing.
|
||||
if primary_doc_idx_insertion_records is None:
|
||||
primary_doc_idx_insertion_records = insertion_records
|
||||
if primary_doc_idx_vector_db_write_failures is None:
|
||||
primary_doc_idx_vector_db_write_failures = vector_db_write_failures
|
||||
|
||||
adapter.post_index(
|
||||
context=context,
|
||||
updatable_chunk_data=updatable_chunk_data,
|
||||
filtered_documents=filtered_documents,
|
||||
result=result,
|
||||
)
|
||||
primary_doc_idx_insertion_records: list[DocumentInsertionRecord] | None = (
|
||||
None
|
||||
)
|
||||
primary_doc_idx_vector_db_write_failures: list[ConnectorFailure] | None = (
|
||||
None
|
||||
)
|
||||
|
||||
for document_index in document_indices:
|
||||
|
||||
def _enriched_stream() -> Iterator[DocMetadataAwareIndexChunk]:
|
||||
for chunk in chunk_store.stream():
|
||||
yield enricher.enrich_chunk(chunk, 1.0)
|
||||
|
||||
insertion_records, write_failures = (
|
||||
write_chunks_to_vector_db_with_backoff(
|
||||
document_index=document_index,
|
||||
make_chunks=_enriched_stream,
|
||||
index_batch_params=index_batch_params,
|
||||
)
|
||||
)
|
||||
|
||||
_verify_indexing_completeness(
|
||||
insertion_records=insertion_records,
|
||||
write_failures=write_failures,
|
||||
embedding_failed_doc_ids=embedding_failed_doc_ids,
|
||||
updatable_ids=updatable_ids,
|
||||
document_index_name=document_index.__class__.__name__,
|
||||
)
|
||||
# We treat the first document index we got as the primary one used
|
||||
# for reporting the state of indexing.
|
||||
if primary_doc_idx_insertion_records is None:
|
||||
primary_doc_idx_insertion_records = insertion_records
|
||||
if primary_doc_idx_vector_db_write_failures is None:
|
||||
primary_doc_idx_vector_db_write_failures = write_failures
|
||||
|
||||
adapter.post_index(
|
||||
context=context,
|
||||
updatable_chunk_data=updatable_chunk_data,
|
||||
filtered_documents=filtered_documents,
|
||||
enrichment=enricher,
|
||||
)
|
||||
|
||||
assert primary_doc_idx_insertion_records is not None
|
||||
assert primary_doc_idx_vector_db_write_failures is not None
|
||||
return IndexingPipelineResult(
|
||||
new_docs=len(
|
||||
[r for r in primary_doc_idx_insertion_records if not r.already_existed]
|
||||
new_docs=sum(
|
||||
1 for r in primary_doc_idx_insertion_records if not r.already_existed
|
||||
),
|
||||
total_docs=len(filtered_documents),
|
||||
total_chunks=len(chunks_with_embeddings),
|
||||
failures=primary_doc_idx_vector_db_write_failures + embedding_failures,
|
||||
total_chunks=len(embedding_result.successful_chunk_ids),
|
||||
failures=primary_doc_idx_vector_db_write_failures
|
||||
+ embedding_result.connector_failures,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -235,12 +235,16 @@ class UpdatableChunkData(BaseModel):
|
||||
boost_score: float
|
||||
|
||||
|
||||
class BuildMetadataAwareChunksResult(BaseModel):
|
||||
chunks: list[DocMetadataAwareIndexChunk]
|
||||
class ChunkEnrichmentContext(Protocol):
|
||||
"""Returned by prepare_enrichment. Holds pre-computed metadata lookups
|
||||
and provides per-chunk enrichment."""
|
||||
|
||||
doc_id_to_previous_chunk_cnt: dict[str, int]
|
||||
doc_id_to_new_chunk_cnt: dict[str, int]
|
||||
user_file_id_to_raw_text: dict[str, str]
|
||||
user_file_id_to_token_count: dict[str, int | None]
|
||||
|
||||
def enrich_chunk(
|
||||
self, chunk: IndexChunk, score: float
|
||||
) -> DocMetadataAwareIndexChunk: ...
|
||||
|
||||
|
||||
class IndexingBatchAdapter(Protocol):
|
||||
@@ -254,18 +258,24 @@ class IndexingBatchAdapter(Protocol):
|
||||
) -> Generator[TransactionalContext, None, None]:
|
||||
"""Provide a transaction/row-lock context for critical updates."""
|
||||
|
||||
def build_metadata_aware_chunks(
|
||||
def prepare_enrichment(
|
||||
self,
|
||||
chunks_with_embeddings: list[IndexChunk],
|
||||
chunk_content_scores: list[float],
|
||||
tenant_id: str,
|
||||
context: "DocumentBatchPrepareContext",
|
||||
) -> BuildMetadataAwareChunksResult: ...
|
||||
tenant_id: str,
|
||||
chunks: list[DocAwareChunk],
|
||||
) -> ChunkEnrichmentContext:
|
||||
"""Prepare per-chunk enrichment data (access, document sets, boost, etc.).
|
||||
|
||||
Precondition: ``chunks`` have already been through the embedding step
|
||||
(i.e. they are ``IndexChunk`` instances with populated embeddings,
|
||||
passed here as the base ``DocAwareChunk`` type).
|
||||
"""
|
||||
...
|
||||
|
||||
def post_index(
|
||||
self,
|
||||
context: "DocumentBatchPrepareContext",
|
||||
updatable_chunk_data: list[UpdatableChunkData],
|
||||
filtered_documents: list[Document],
|
||||
result: BuildMetadataAwareChunksResult,
|
||||
enrichment: ChunkEnrichmentContext,
|
||||
) -> None: ...
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterable
|
||||
from http import HTTPStatus
|
||||
from itertools import chain
|
||||
from itertools import groupby
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -28,22 +31,22 @@ def _log_insufficient_storage_error(e: Exception) -> None:
|
||||
|
||||
def write_chunks_to_vector_db_with_backoff(
|
||||
document_index: DocumentIndex,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
make_chunks: Callable[[], Iterable[DocMetadataAwareIndexChunk]],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> tuple[list[DocumentInsertionRecord], list[ConnectorFailure]]:
|
||||
"""Tries to insert all chunks in one large batch. If that batch fails for any reason,
|
||||
goes document by document to isolate the failure(s).
|
||||
|
||||
IMPORTANT: must pass in whole documents at a time not individual chunks, since the
|
||||
vector DB interface assumes that all chunks for a single document are present.
|
||||
vector DB interface assumes that all chunks for a single document are present. The
|
||||
chunks must also be in contiguous batches
|
||||
"""
|
||||
|
||||
# first try to write the chunks to the vector db
|
||||
try:
|
||||
return (
|
||||
list(
|
||||
document_index.index(
|
||||
chunks=chunks,
|
||||
chunks=make_chunks(),
|
||||
index_batch_params=index_batch_params,
|
||||
)
|
||||
),
|
||||
@@ -60,14 +63,23 @@ def write_chunks_to_vector_db_with_backoff(
|
||||
# wait a couple seconds just to give the vector db a chance to recover
|
||||
time.sleep(2)
|
||||
|
||||
# try writing each doc one by one
|
||||
chunks_for_docs: dict[str, list[DocMetadataAwareIndexChunk]] = defaultdict(list)
|
||||
for chunk in chunks:
|
||||
chunks_for_docs[chunk.source_document.id].append(chunk)
|
||||
|
||||
insertion_records: list[DocumentInsertionRecord] = []
|
||||
failures: list[ConnectorFailure] = []
|
||||
for doc_id, chunks_for_doc in chunks_for_docs.items():
|
||||
|
||||
def key(chunk: DocMetadataAwareIndexChunk) -> str:
|
||||
return chunk.source_document.id
|
||||
|
||||
seen_doc_ids: set[str] = set()
|
||||
for doc_id, chunks_for_doc in groupby(make_chunks(), key=key):
|
||||
if doc_id in seen_doc_ids:
|
||||
raise RuntimeError(
|
||||
f"Doc chunks are not arriving in order. Current doc_id={doc_id}, seen_doc_ids={list(seen_doc_ids)}"
|
||||
)
|
||||
seen_doc_ids.add(doc_id)
|
||||
|
||||
first_chunk = next(chunks_for_doc)
|
||||
chunks_for_doc = chain([first_chunk], chunks_for_doc)
|
||||
|
||||
try:
|
||||
insertion_records.extend(
|
||||
document_index.index(
|
||||
@@ -87,9 +99,7 @@ def write_chunks_to_vector_db_with_backoff(
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=doc_id,
|
||||
document_link=(
|
||||
chunks_for_doc[0].get_link() if chunks_for_doc else None
|
||||
),
|
||||
document_link=first_chunk.get_link(),
|
||||
),
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
|
||||
@@ -185,6 +185,21 @@ def _messages_contain_tool_content(messages: list[dict[str, Any]]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _prompt_contains_tool_call_history(prompt: LanguageModelInput) -> bool:
|
||||
"""Check if the prompt contains any assistant messages with tool_calls.
|
||||
|
||||
When Anthropic's extended thinking is enabled, the API requires every
|
||||
assistant message to start with a thinking block before any tool_use
|
||||
blocks. Since we don't preserve thinking_blocks (they carry
|
||||
cryptographic signatures that can't be reconstructed), we must skip
|
||||
the thinking param whenever history contains prior tool-calling turns.
|
||||
"""
|
||||
from onyx.llm.models import AssistantMessage
|
||||
|
||||
msgs = prompt if isinstance(prompt, list) else [prompt]
|
||||
return any(isinstance(msg, AssistantMessage) and msg.tool_calls for msg in msgs)
|
||||
|
||||
|
||||
def _is_vertex_model_rejecting_output_config(model_name: str) -> bool:
|
||||
normalized_model_name = model_name.lower()
|
||||
return any(
|
||||
@@ -466,7 +481,20 @@ class LitellmLLM(LLM):
|
||||
reasoning_effort
|
||||
)
|
||||
|
||||
if budget_tokens is not None:
|
||||
# Anthropic requires every assistant message with tool_use
|
||||
# blocks to start with a thinking block that carries a
|
||||
# cryptographic signature. We don't preserve those blocks
|
||||
# across turns, so skip thinking when the history already
|
||||
# contains tool-calling assistant messages. LiteLLM's
|
||||
# modify_params workaround doesn't cover all providers
|
||||
# (notably Bedrock).
|
||||
can_enable_thinking = (
|
||||
budget_tokens is not None
|
||||
and not _prompt_contains_tool_call_history(prompt)
|
||||
)
|
||||
|
||||
if can_enable_thinking:
|
||||
assert budget_tokens is not None # mypy
|
||||
if max_tokens is not None:
|
||||
# Anthropic has a weird rule where max token has to be at least as much as budget tokens if set
|
||||
# and the minimum budget tokens is 1024
|
||||
|
||||
@@ -8,6 +8,24 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
class LLMOverride(BaseModel):
|
||||
"""Per-request LLM settings that override persona defaults.
|
||||
|
||||
All fields are optional — only the fields that differ from the persona's
|
||||
configured LLM need to be supplied. Used both over the wire (API requests)
|
||||
and for multi-model comparison, where one override is supplied per model.
|
||||
|
||||
Attributes:
|
||||
model_provider: LLM provider slug (e.g. ``"openai"``, ``"anthropic"``).
|
||||
When ``None``, the persona's default provider is used.
|
||||
model_version: Specific model version string (e.g. ``"gpt-4o"``).
|
||||
When ``None``, the persona's default model is used.
|
||||
temperature: Sampling temperature in ``[0, 2]``. When ``None``, the
|
||||
persona's default temperature is used.
|
||||
display_name: Human-readable label shown in the UI for this model,
|
||||
e.g. ``"GPT-4 Turbo"``. Optional; falls back to ``model_version``
|
||||
when not set.
|
||||
"""
|
||||
|
||||
model_provider: str | None = None
|
||||
model_version: str | None = None
|
||||
temperature: float | None = None
|
||||
|
||||
@@ -175,6 +175,32 @@ def get_tokenizer(
|
||||
return _check_tokenizer_cache(provider_type, model_name)
|
||||
|
||||
|
||||
# Max characters per encode() call.
|
||||
_ENCODE_CHUNK_SIZE = 500_000
|
||||
|
||||
|
||||
def count_tokens(
|
||||
text: str,
|
||||
tokenizer: BaseTokenizer,
|
||||
token_limit: int | None = None,
|
||||
) -> int:
|
||||
"""Count tokens, chunking the input to avoid tiktoken stack overflow.
|
||||
|
||||
If token_limit is provided and the text is large enough to require
|
||||
multiple chunks (> 500k chars), stops early once the count exceeds it.
|
||||
When early-exiting, the returned value exceeds token_limit but may be
|
||||
less than the true full token count.
|
||||
"""
|
||||
if len(text) <= _ENCODE_CHUNK_SIZE:
|
||||
return len(tokenizer.encode(text))
|
||||
total = 0
|
||||
for start in range(0, len(text), _ENCODE_CHUNK_SIZE):
|
||||
total += len(tokenizer.encode(text[start : start + _ENCODE_CHUNK_SIZE]))
|
||||
if token_limit is not None and total > token_limit:
|
||||
return total # Already over — skip remaining chunks
|
||||
return total
|
||||
|
||||
|
||||
def tokenizer_trim_content(
|
||||
content: str, desired_length: int, tokenizer: BaseTokenizer
|
||||
) -> str:
|
||||
|
||||
@@ -3844,9 +3844,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@ts-morph/common/node_modules/brace-expansion": {
|
||||
"version": "5.0.3",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-5.0.3.tgz",
|
||||
"integrity": "sha512-fy6KJm2RawA5RcHkLa1z/ScpBeA762UF9KmZQxwIbDtRJrgLzM10depAiEQ+CXYcoiqW1/m96OAAoke2nE9EeA==",
|
||||
"version": "5.0.5",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-5.0.5.tgz",
|
||||
"integrity": "sha512-VZznLgtwhn+Mact9tfiwx64fA9erHH/MCXEUfB/0bX/6Fz6ny5EGTXYltMocqg4xFAQZtnO3DHWWXi8RiuN7cQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"balanced-match": "^4.0.2"
|
||||
@@ -4224,9 +4224,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@typescript-eslint/typescript-estree/node_modules/brace-expansion": {
|
||||
"version": "2.0.2",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz",
|
||||
"integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==",
|
||||
"version": "2.0.3",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.3.tgz",
|
||||
"integrity": "sha512-MCV/fYJEbqx68aE58kv2cA/kiky1G8vux3OR6/jbS+jIMe/6fJWa0DTzJU7dqijOWYwHi1t29FlfYI9uytqlpA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
@@ -5007,9 +5007,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/brace-expansion": {
|
||||
"version": "1.1.12",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz",
|
||||
"integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==",
|
||||
"version": "1.1.13",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.13.tgz",
|
||||
"integrity": "sha512-9ZLprWS6EENmhEOpjCYW2c8VkmOvckIJZfkr7rBW6dObmfgJ/L1GpSYW5Hpo9lDz4D1+n0Ckz8rU7FwHDQiG/w==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
|
||||
@@ -44,11 +44,12 @@ def _check_ssrf_safety(endpoint_url: str) -> None:
|
||||
"""Raise OnyxError if endpoint_url could be used for SSRF.
|
||||
|
||||
Delegates to validate_outbound_http_url with https_only=True.
|
||||
Uses BAD_GATEWAY so the frontend maps the error to the Endpoint URL field.
|
||||
"""
|
||||
try:
|
||||
validate_outbound_http_url(endpoint_url, https_only=True)
|
||||
except (SSRFException, ValueError) as e:
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
|
||||
raise OnyxError(OnyxErrorCode.BAD_GATEWAY, str(e))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -122,9 +123,8 @@ def _validate_endpoint(
|
||||
(not reachable — indicates the api_key is invalid).
|
||||
|
||||
Timeout handling:
|
||||
- ConnectTimeout: TCP handshake never completed → cannot_connect.
|
||||
- ReadTimeout / WriteTimeout: TCP was established, server responded slowly → timeout
|
||||
(operator should consider increasing timeout_seconds).
|
||||
- Any httpx.TimeoutException (ConnectTimeout, ReadTimeout, WriteTimeout, PoolTimeout) →
|
||||
timeout (operator should consider increasing timeout_seconds).
|
||||
- All other exceptions → cannot_connect.
|
||||
"""
|
||||
_check_ssrf_safety(endpoint_url)
|
||||
@@ -141,19 +141,11 @@ def _validate_endpoint(
|
||||
)
|
||||
return HookValidateResponse(status=HookValidateStatus.passed)
|
||||
except httpx.TimeoutException as exc:
|
||||
# ConnectTimeout: TCP handshake never completed → cannot_connect.
|
||||
# ReadTimeout / WriteTimeout: TCP was established, server just responded slowly → timeout.
|
||||
if isinstance(exc, httpx.ConnectTimeout):
|
||||
logger.warning(
|
||||
"Hook endpoint validation: connect timeout for %s",
|
||||
endpoint_url,
|
||||
exc_info=exc,
|
||||
)
|
||||
return HookValidateResponse(
|
||||
status=HookValidateStatus.cannot_connect, error_message=str(exc)
|
||||
)
|
||||
# Any timeout (connect, read, or write) means the configured timeout_seconds
|
||||
# is too low for this endpoint. Report as timeout so the UI directs the user
|
||||
# to increase the timeout setting.
|
||||
logger.warning(
|
||||
"Hook endpoint validation: read/write timeout for %s",
|
||||
"Hook endpoint validation: timeout for %s",
|
||||
endpoint_url,
|
||||
exc_info=exc,
|
||||
)
|
||||
|
||||
@@ -9,20 +9,15 @@ from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_BYTES
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.file_processing.password_validation import is_file_password_protected
|
||||
from onyx.natural_language_processing.utils import count_tokens
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import SKIP_USERFILE_THRESHOLD
|
||||
from shared_configs.configs import SKIP_USERFILE_THRESHOLD_TENANT_LIST
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -161,8 +156,8 @@ def categorize_uploaded_files(
|
||||
document formats (.pdf, .docx, …) and falls back to a text-detection
|
||||
heuristic for unknown extensions (.py, .js, .rs, …).
|
||||
- Uses default tokenizer to compute token length.
|
||||
- If token length > threshold, reject file (unless threshold skip is enabled).
|
||||
- If text cannot be extracted, reject file.
|
||||
- If token length exceeds the admin-configured threshold, reject file.
|
||||
- If extension unsupported or text cannot be extracted, reject file.
|
||||
- Otherwise marked as acceptable.
|
||||
"""
|
||||
|
||||
@@ -173,36 +168,33 @@ def categorize_uploaded_files(
|
||||
provider_type = default_model.llm_provider.provider if default_model else None
|
||||
tokenizer = get_tokenizer(model_name=model_name, provider_type=provider_type)
|
||||
|
||||
# Check if threshold checks should be skipped
|
||||
skip_threshold = False
|
||||
|
||||
# Check global skip flag (works for both single-tenant and multi-tenant)
|
||||
if SKIP_USERFILE_THRESHOLD:
|
||||
skip_threshold = True
|
||||
logger.info("Skipping userfile threshold check (global setting)")
|
||||
# Check tenant-specific skip list (only applicable in multi-tenant)
|
||||
elif MULTI_TENANT and SKIP_USERFILE_THRESHOLD_TENANT_LIST:
|
||||
try:
|
||||
current_tenant_id = get_current_tenant_id()
|
||||
skip_threshold = current_tenant_id in SKIP_USERFILE_THRESHOLD_TENANT_LIST
|
||||
if skip_threshold:
|
||||
logger.info(
|
||||
f"Skipping userfile threshold check for tenant: {current_tenant_id}"
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Failed to get current tenant ID: {str(e)}")
|
||||
# Derive limits from admin-configurable settings.
|
||||
# For upload size: load_settings() resolves 0/None to a positive default.
|
||||
# For token threshold: 0 means "no limit" (converted to None below).
|
||||
settings = load_settings()
|
||||
max_upload_size_mb = (
|
||||
settings.user_file_max_upload_size_mb
|
||||
) # always positive after load_settings()
|
||||
max_upload_size_bytes = (
|
||||
max_upload_size_mb * 1024 * 1024 if max_upload_size_mb else None
|
||||
)
|
||||
token_threshold_k = settings.file_token_count_threshold_k
|
||||
token_threshold = (
|
||||
token_threshold_k * 1000 if token_threshold_k else None
|
||||
) # 0 → None = no limit
|
||||
|
||||
for upload in files:
|
||||
try:
|
||||
filename = get_safe_filename(upload)
|
||||
|
||||
# Size limit is a hard safety cap and is enforced even when token
|
||||
# threshold checks are skipped via SKIP_USERFILE_THRESHOLD settings.
|
||||
if is_upload_too_large(upload, USER_FILE_MAX_UPLOAD_SIZE_BYTES):
|
||||
# Size limit is a hard safety cap.
|
||||
if max_upload_size_bytes is not None and is_upload_too_large(
|
||||
upload, max_upload_size_bytes
|
||||
):
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
reason=f"Exceeds {USER_FILE_MAX_UPLOAD_SIZE_MB} MB file size limit",
|
||||
reason=f"Exceeds {max_upload_size_mb} MB file size limit",
|
||||
)
|
||||
)
|
||||
continue
|
||||
@@ -224,11 +216,11 @@ def categorize_uploaded_files(
|
||||
)
|
||||
continue
|
||||
|
||||
if not skip_threshold and token_count > FILE_TOKEN_COUNT_THRESHOLD:
|
||||
if token_threshold is not None and token_count > token_threshold:
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
reason=f"Exceeds {FILE_TOKEN_COUNT_THRESHOLD} token limit",
|
||||
reason=f"Exceeds {token_threshold_k}K token limit",
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -269,12 +261,14 @@ def categorize_uploaded_files(
|
||||
)
|
||||
continue
|
||||
|
||||
token_count = len(tokenizer.encode(text_content))
|
||||
if not skip_threshold and token_count > FILE_TOKEN_COUNT_THRESHOLD:
|
||||
token_count = count_tokens(
|
||||
text_content, tokenizer, token_limit=token_threshold
|
||||
)
|
||||
if token_threshold is not None and token_count > token_threshold:
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
reason=f"Exceeds {FILE_TOKEN_COUNT_THRESHOLD} token limit",
|
||||
reason=f"Exceeds {token_threshold_k}K token limit",
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -1524,6 +1524,7 @@ def get_bifrost_available_models(
|
||||
display_name=model_name,
|
||||
max_input_tokens=model.get("context_length"),
|
||||
supports_image_input=infer_vision_support(model_id),
|
||||
supports_reasoning=is_reasoning_model(model_id, model_name),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@@ -463,3 +463,4 @@ class BifrostFinalModelResponse(BaseModel):
|
||||
display_name: str # Human-readable name from Bifrost API
|
||||
max_input_tokens: int | None
|
||||
supports_image_input: bool
|
||||
supports_reasoning: bool
|
||||
|
||||
@@ -12,7 +12,6 @@ stale, which is fine for monitoring dashboards.
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
@@ -104,25 +103,23 @@ class _CachedCollector(Collector):
|
||||
|
||||
|
||||
class QueueDepthCollector(_CachedCollector):
|
||||
"""Reads Celery queue lengths from the broker Redis on each scrape.
|
||||
|
||||
Uses a Redis client factory (callable) rather than a stored client
|
||||
reference so the connection is always fresh from Celery's pool.
|
||||
"""
|
||||
"""Reads Celery queue lengths from the broker Redis on each scrape."""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._get_redis: Callable[[], Redis] | None = None
|
||||
self._celery_app: Any | None = None
|
||||
|
||||
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
|
||||
"""Set a callable that returns a broker Redis client on demand."""
|
||||
self._get_redis = factory
|
||||
def set_celery_app(self, app: Any) -> None:
|
||||
"""Set the Celery app for broker Redis access."""
|
||||
self._celery_app = app
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if self._get_redis is None:
|
||||
if self._celery_app is None:
|
||||
return []
|
||||
|
||||
redis_client = self._get_redis()
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
|
||||
redis_client = celery_get_broker_client(self._celery_app)
|
||||
|
||||
depth = GaugeMetricFamily(
|
||||
"onyx_queue_depth",
|
||||
@@ -404,17 +401,19 @@ class RedisHealthCollector(_CachedCollector):
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._get_redis: Callable[[], Redis] | None = None
|
||||
self._celery_app: Any | None = None
|
||||
|
||||
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
|
||||
"""Set a callable that returns a broker Redis client on demand."""
|
||||
self._get_redis = factory
|
||||
def set_celery_app(self, app: Any) -> None:
|
||||
"""Set the Celery app for broker Redis access."""
|
||||
self._celery_app = app
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if self._get_redis is None:
|
||||
if self._celery_app is None:
|
||||
return []
|
||||
|
||||
redis_client = self._get_redis()
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
|
||||
redis_client = celery_get_broker_client(self._celery_app)
|
||||
|
||||
memory_used = GaugeMetricFamily(
|
||||
"onyx_redis_memory_used_bytes",
|
||||
|
||||
@@ -3,12 +3,8 @@
|
||||
Called once by the monitoring celery worker after Redis and DB are ready.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from prometheus_client.registry import REGISTRY
|
||||
from redis import Redis
|
||||
|
||||
from onyx.server.metrics.indexing_pipeline import ConnectorHealthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
|
||||
@@ -21,7 +17,7 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
# Module-level singletons — these are lightweight objects (no connections or DB
|
||||
# state) until configure() / set_redis_factory() is called. Keeping them at
|
||||
# state) until configure() / set_celery_app() is called. Keeping them at
|
||||
# module level ensures they survive the lifetime of the worker process and are
|
||||
# only registered with the Prometheus registry once.
|
||||
_queue_collector = QueueDepthCollector()
|
||||
@@ -32,72 +28,15 @@ _worker_health_collector = WorkerHealthCollector()
|
||||
_heartbeat_monitor: WorkerHeartbeatMonitor | None = None
|
||||
|
||||
|
||||
def _make_broker_redis_factory(celery_app: Celery) -> Callable[[], Redis]:
|
||||
"""Create a factory that returns a cached broker Redis client.
|
||||
|
||||
Reuses a single connection across scrapes to avoid leaking connections.
|
||||
Reconnects automatically if the cached connection becomes stale.
|
||||
"""
|
||||
_cached_client: list[Redis | None] = [None]
|
||||
# Keep a reference to the Kombu Connection so we can close it on
|
||||
# reconnect (the raw Redis client outlives the Kombu wrapper).
|
||||
_cached_kombu_conn: list[Any] = [None]
|
||||
|
||||
def _close_client(client: Redis) -> None:
|
||||
"""Best-effort close of a Redis client."""
|
||||
try:
|
||||
client.close()
|
||||
except Exception:
|
||||
logger.debug("Failed to close stale Redis client", exc_info=True)
|
||||
|
||||
def _close_kombu_conn() -> None:
|
||||
"""Best-effort close of the cached Kombu Connection."""
|
||||
conn = _cached_kombu_conn[0]
|
||||
if conn is not None:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
logger.debug("Failed to close Kombu connection", exc_info=True)
|
||||
_cached_kombu_conn[0] = None
|
||||
|
||||
def _get_broker_redis() -> Redis:
|
||||
client = _cached_client[0]
|
||||
if client is not None:
|
||||
try:
|
||||
client.ping()
|
||||
return client
|
||||
except Exception:
|
||||
logger.debug("Cached Redis client stale, reconnecting")
|
||||
_close_client(client)
|
||||
_cached_client[0] = None
|
||||
_close_kombu_conn()
|
||||
|
||||
# Get a fresh Redis client from the broker connection.
|
||||
# We hold this client long-term (cached above) rather than using a
|
||||
# context manager, because we need it to persist across scrapes.
|
||||
# The caching logic above ensures we only ever hold one connection,
|
||||
# and we close it explicitly on reconnect.
|
||||
conn = celery_app.broker_connection()
|
||||
# kombu's Channel exposes .client at runtime (the underlying Redis
|
||||
# client) but the type stubs don't declare it.
|
||||
new_client: Redis = conn.channel().client # type: ignore[attr-defined]
|
||||
_cached_client[0] = new_client
|
||||
_cached_kombu_conn[0] = conn
|
||||
return new_client
|
||||
|
||||
return _get_broker_redis
|
||||
|
||||
|
||||
def setup_indexing_pipeline_metrics(celery_app: Celery) -> None:
|
||||
"""Register all indexing pipeline collectors with the default registry.
|
||||
|
||||
Args:
|
||||
celery_app: The Celery application instance. Used to obtain a fresh
|
||||
celery_app: The Celery application instance. Used to obtain a
|
||||
broker Redis client on each scrape for queue depth metrics.
|
||||
"""
|
||||
redis_factory = _make_broker_redis_factory(celery_app)
|
||||
_queue_collector.set_redis_factory(redis_factory)
|
||||
_redis_health_collector.set_redis_factory(redis_factory)
|
||||
_queue_collector.set_celery_app(celery_app)
|
||||
_redis_health_collector.set_celery_app(celery_app)
|
||||
|
||||
# Start the heartbeat monitor daemon thread — uses a single persistent
|
||||
# connection to receive worker-heartbeat events.
|
||||
|
||||
@@ -28,6 +28,7 @@ from onyx.chat.chat_utils import extract_headers
|
||||
from onyx.chat.models import ChatFullResponse
|
||||
from onyx.chat.models import CreateChatSessionID
|
||||
from onyx.chat.process_message import gather_stream_full
|
||||
from onyx.chat.process_message import handle_multi_model_stream
|
||||
from onyx.chat.process_message import handle_stream_message_objects
|
||||
from onyx.chat.prompt_utils import get_default_base_system_prompt
|
||||
from onyx.chat.stop_signal_checker import set_fence
|
||||
@@ -46,6 +47,7 @@ from onyx.db.chat import get_chat_messages_by_session
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
from onyx.db.chat import set_as_latest_chat_message
|
||||
from onyx.db.chat import set_preferred_response
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import update_chat_session
|
||||
from onyx.db.chat_search import search_chat_sessions
|
||||
@@ -60,6 +62,8 @@ from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.usage import increment_usage
|
||||
from onyx.db.usage import UsageType
|
||||
from onyx.db.user_file import get_file_id_by_user_file_id
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.factory import get_default_llm
|
||||
@@ -81,6 +85,7 @@ from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
from onyx.server.query_and_chat.models import RenameChatSessionResponse
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.models import SetPreferredResponseRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
|
||||
from onyx.server.query_and_chat.session_loading import (
|
||||
@@ -570,6 +575,46 @@ def handle_send_chat_message(
|
||||
if get_hashed_api_key_from_request(request) or get_hashed_pat_from_request(request):
|
||||
chat_message_req.origin = MessageOrigin.API
|
||||
|
||||
# Multi-model streaming path: 2-3 LLMs in parallel (streaming only)
|
||||
is_multi_model = (
|
||||
chat_message_req.llm_overrides is not None
|
||||
and len(chat_message_req.llm_overrides) > 1
|
||||
)
|
||||
if is_multi_model and chat_message_req.stream:
|
||||
# Narrowed here; is_multi_model already checked llm_overrides is not None
|
||||
llm_overrides = chat_message_req.llm_overrides or []
|
||||
|
||||
def multi_model_stream_generator() -> Generator[str, None, None]:
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for obj in handle_multi_model_stream(
|
||||
new_msg_req=chat_message_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
llm_overrides=llm_overrides,
|
||||
litellm_additional_headers=extract_headers(
|
||||
request.headers, LITELLM_PASS_THROUGH_HEADERS
|
||||
),
|
||||
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
|
||||
request.headers
|
||||
),
|
||||
mcp_headers=chat_message_req.mcp_headers,
|
||||
):
|
||||
yield get_json_line(obj.model_dump())
|
||||
except Exception as e:
|
||||
logger.exception("Error in multi-model streaming")
|
||||
yield json.dumps({"error": str(e)})
|
||||
|
||||
return StreamingResponse(
|
||||
multi_model_stream_generator(), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
if is_multi_model and not chat_message_req.stream:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"Multi-model mode (llm_overrides with >1 entry) requires stream=True.",
|
||||
)
|
||||
|
||||
# Non-streaming path: consume all packets and return complete response
|
||||
if not chat_message_req.stream:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -660,6 +705,30 @@ def set_message_as_latest(
|
||||
)
|
||||
|
||||
|
||||
@router.put("/set-preferred-response")
|
||||
def set_preferred_response_endpoint(
|
||||
request_body: SetPreferredResponseRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
"""Set the preferred assistant response for a multi-model turn."""
|
||||
try:
|
||||
# Ownership check: get_chat_message raises ValueError if the message
|
||||
# doesn't belong to this user, preventing cross-user mutation.
|
||||
get_chat_message(
|
||||
chat_message_id=request_body.user_message_id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
set_preferred_response(
|
||||
db_session=db_session,
|
||||
user_message_id=request_body.user_message_id,
|
||||
preferred_assistant_message_id=request_body.preferred_response_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
|
||||
|
||||
|
||||
@router.post("/create-chat-message-feedback")
|
||||
def create_chat_feedback(
|
||||
feedback: ChatFeedbackRequest,
|
||||
|
||||
@@ -2,11 +2,25 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
class Placement(BaseModel):
|
||||
# Which iterative block in the UI is this part of, these are ordered and smaller ones happened first
|
||||
"""Coordinates that identify where a streaming packet belongs in the UI.
|
||||
|
||||
The frontend uses these fields to route each packet to the correct turn,
|
||||
tool tab, agent sub-turn, and (in multi-model mode) response column.
|
||||
|
||||
Attributes:
|
||||
turn_index: Monotonically increasing index of the iterative reasoning block
|
||||
(e.g. tool call round) within this chat message. Lower values happened first.
|
||||
tab_index: Disambiguates parallel tool calls within the same turn so each
|
||||
tool's output can be displayed in its own tab.
|
||||
sub_turn_index: Nesting level for tools that invoke other tools. ``None`` for
|
||||
top-level packets; an integer for tool-within-tool output.
|
||||
model_index: Which model this packet belongs to. ``0`` for single-model
|
||||
responses; ``0``, ``1``, or ``2`` for multi-model comparison. ``None``
|
||||
for pre-LLM setup packets (e.g. message ID info) that are yielded
|
||||
before any Emitter runs.
|
||||
"""
|
||||
|
||||
turn_index: int
|
||||
# For parallel tool calls to preserve order of execution
|
||||
tab_index: int = 0
|
||||
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
|
||||
sub_turn_index: int | None = None
|
||||
# For multi-model streaming: identifies which model (0, 1, 2) this packet belongs to.
|
||||
model_index: int | None = None
|
||||
|
||||
@@ -9,7 +9,9 @@ from onyx import __version__ as onyx_version
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import is_user_admin
|
||||
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
from onyx.configs.constants import KV_REINDEX_KEY
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
@@ -17,10 +19,16 @@ from onyx.db.models import User
|
||||
from onyx.db.notification import dismiss_all_notifications
|
||||
from onyx.db.notification import get_notifications
|
||||
from onyx.db.notification import update_notification_last_shown
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.utils import HOOKS_AVAILABLE
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.features.build.utils import is_onyx_craft_enabled
|
||||
from onyx.server.settings.models import (
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB,
|
||||
)
|
||||
from onyx.server.settings.models import DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
from onyx.server.settings.models import Notification
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.server.settings.models import UserSettings
|
||||
@@ -41,6 +49,15 @@ basic_router = APIRouter(prefix="/settings")
|
||||
def admin_put_settings(
|
||||
settings: Settings, _: User = Depends(current_admin_user)
|
||||
) -> None:
|
||||
if (
|
||||
settings.user_file_max_upload_size_mb is not None
|
||||
and settings.user_file_max_upload_size_mb > 0
|
||||
and settings.user_file_max_upload_size_mb > MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
f"File upload size limit cannot exceed {MAX_ALLOWED_UPLOAD_SIZE_MB} MB",
|
||||
)
|
||||
store_settings(settings)
|
||||
|
||||
|
||||
@@ -83,6 +100,16 @@ def fetch_settings(
|
||||
vector_db_enabled=not DISABLE_VECTOR_DB,
|
||||
hooks_enabled=HOOKS_AVAILABLE,
|
||||
version=onyx_version,
|
||||
max_allowed_upload_size_mb=MAX_ALLOWED_UPLOAD_SIZE_MB,
|
||||
default_user_file_max_upload_size_mb=min(
|
||||
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB,
|
||||
MAX_ALLOWED_UPLOAD_SIZE_MB,
|
||||
),
|
||||
default_file_token_count_threshold_k=(
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
|
||||
if DISABLE_VECTOR_DB
|
||||
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -2,12 +2,19 @@ from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.db.models import Notification as NotificationDBModel
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB = 200
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB = 10000
|
||||
|
||||
|
||||
class PageType(str, Enum):
|
||||
CHAT = "chat"
|
||||
@@ -78,7 +85,12 @@ class Settings(BaseModel):
|
||||
|
||||
# User Knowledge settings
|
||||
user_knowledge_enabled: bool | None = True
|
||||
user_file_max_upload_size_mb: int | None = None
|
||||
user_file_max_upload_size_mb: int | None = Field(
|
||||
default=DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB, ge=0
|
||||
)
|
||||
file_token_count_threshold_k: int | None = Field(
|
||||
default=None, ge=0 # thousands of tokens; None = context-aware default
|
||||
)
|
||||
|
||||
# Connector settings
|
||||
show_extra_connectors: bool | None = True
|
||||
@@ -108,3 +120,14 @@ class UserSettings(Settings):
|
||||
hooks_enabled: bool = False
|
||||
# Application version, read from the ONYX_VERSION env var at startup.
|
||||
version: str | None = None
|
||||
# Hard ceiling for user_file_max_upload_size_mb, derived from env var.
|
||||
max_allowed_upload_size_mb: int = MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
# Factory defaults so the frontend can show a "restore default" button.
|
||||
default_user_file_max_upload_size_mb: int = DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
default_file_token_count_threshold_k: int = Field(
|
||||
default_factory=lambda: (
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
|
||||
if DISABLE_VECTOR_DB
|
||||
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.configs.app_configs import DISABLE_USER_KNOWLEDGE
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.app_configs import SHOW_EXTRA_CONNECTORS
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.configs.constants import KV_SETTINGS_KEY
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.settings.models import (
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB,
|
||||
)
|
||||
from onyx.server.settings.models import DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -51,9 +57,36 @@ def load_settings() -> Settings:
|
||||
if DISABLE_USER_KNOWLEDGE:
|
||||
settings.user_knowledge_enabled = False
|
||||
|
||||
settings.user_file_max_upload_size_mb = USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
settings.show_extra_connectors = SHOW_EXTRA_CONNECTORS
|
||||
settings.opensearch_indexing_enabled = ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
|
||||
# Resolve context-aware defaults for token threshold.
|
||||
# None = admin hasn't set a value yet → use context-aware default.
|
||||
# 0 = admin explicitly chose "no limit" → preserve as-is.
|
||||
if settings.file_token_count_threshold_k is None:
|
||||
settings.file_token_count_threshold_k = (
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
|
||||
if DISABLE_VECTOR_DB
|
||||
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
)
|
||||
|
||||
# Upload size: 0 and None are treated as "unset" (not "no limit") →
|
||||
# fall back to min(configured default, hard ceiling).
|
||||
if not settings.user_file_max_upload_size_mb:
|
||||
settings.user_file_max_upload_size_mb = min(
|
||||
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB,
|
||||
MAX_ALLOWED_UPLOAD_SIZE_MB,
|
||||
)
|
||||
|
||||
# Clamp to env ceiling so stale KV values are capped even if the
|
||||
# operator lowered MAX_ALLOWED_UPLOAD_SIZE_MB after a higher value
|
||||
# was already saved (api.py only guards new writes).
|
||||
if (
|
||||
settings.user_file_max_upload_size_mb > 0
|
||||
and settings.user_file_max_upload_size_mb > MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
):
|
||||
settings.user_file_max_upload_size_mb = MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import queue
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
@@ -708,7 +709,6 @@ def run_research_agent_calls(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from queue import Queue
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
@@ -744,8 +744,8 @@ if __name__ == "__main__":
|
||||
if user is None:
|
||||
raise ValueError("No users found in database. Please create a user first.")
|
||||
|
||||
bus: Queue[Packet] = Queue()
|
||||
emitter = Emitter(bus)
|
||||
emitter_queue: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(merged_queue=emitter_queue)
|
||||
state_container = ChatStateContainer()
|
||||
|
||||
tool_dict = construct_tools(
|
||||
@@ -792,4 +792,4 @@ if __name__ == "__main__":
|
||||
print(result.intermediate_report)
|
||||
print("=" * 80)
|
||||
print(f"Citations: {result.citation_mapping}")
|
||||
print(f"Total packets emitted: {bus.qsize()}")
|
||||
print(f"Total packets emitted: {emitter_queue.qsize()}")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import csv
|
||||
import json
|
||||
import queue
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from io import StringIO
|
||||
@@ -11,7 +12,6 @@ import requests
|
||||
from requests import JSONDecodeError
|
||||
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
@@ -296,9 +296,9 @@ def build_custom_tools_from_openapi_schema_and_headers(
|
||||
url = openapi_to_url(openapi_schema)
|
||||
method_specs = openapi_to_method_specs(openapi_schema)
|
||||
|
||||
# Use default emitter if none provided
|
||||
# Use a discard emitter if none provided (packets go nowhere)
|
||||
if emitter is None:
|
||||
emitter = get_default_emitter()
|
||||
emitter = Emitter(merged_queue=queue.Queue())
|
||||
|
||||
return [
|
||||
CustomTool(
|
||||
@@ -367,7 +367,7 @@ if __name__ == "__main__":
|
||||
tools = build_custom_tools_from_openapi_schema_and_headers(
|
||||
tool_id=0, # dummy tool id
|
||||
openapi_schema=openapi_schema,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
dynamic_schema_info=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ coloredlogs==15.0.1
|
||||
# via onnxruntime
|
||||
courlan==1.3.2
|
||||
# via trafilatura
|
||||
cryptography==46.0.5
|
||||
cryptography==46.0.6
|
||||
# via
|
||||
# authlib
|
||||
# google-auth
|
||||
@@ -449,7 +449,7 @@ kombu==5.5.4
|
||||
# via celery
|
||||
kubernetes==31.0.0
|
||||
# via onyx
|
||||
langchain-core==1.2.11
|
||||
langchain-core==1.2.22
|
||||
# via onyx
|
||||
langdetect==1.0.9
|
||||
# via unstructured
|
||||
@@ -735,7 +735,7 @@ pyee==13.0.0
|
||||
# via playwright
|
||||
pygithub==2.5.0
|
||||
# via onyx
|
||||
pygments==2.19.2
|
||||
pygments==2.20.0
|
||||
# via rich
|
||||
pyjwt==2.12.0
|
||||
# via
|
||||
|
||||
@@ -97,7 +97,7 @@ comm==0.2.3
|
||||
# via ipykernel
|
||||
contourpy==1.3.3
|
||||
# via matplotlib
|
||||
cryptography==46.0.5
|
||||
cryptography==46.0.6
|
||||
# via
|
||||
# google-auth
|
||||
# pyjwt
|
||||
@@ -263,7 +263,7 @@ oauthlib==3.2.2
|
||||
# via
|
||||
# kubernetes
|
||||
# requests-oauthlib
|
||||
onyx-devtools==0.7.1
|
||||
onyx-devtools==0.7.2
|
||||
# via onyx
|
||||
openai==2.14.0
|
||||
# via
|
||||
@@ -349,7 +349,7 @@ pydantic-core==2.33.2
|
||||
# via pydantic
|
||||
pydantic-settings==2.12.0
|
||||
# via mcp
|
||||
pygments==2.19.2
|
||||
pygments==2.20.0
|
||||
# via
|
||||
# ipython
|
||||
# ipython-pygments-lexers
|
||||
|
||||
@@ -76,7 +76,7 @@ colorama==0.4.6 ; sys_platform == 'win32'
|
||||
# via
|
||||
# click
|
||||
# tqdm
|
||||
cryptography==46.0.5
|
||||
cryptography==46.0.6
|
||||
# via
|
||||
# google-auth
|
||||
# pyjwt
|
||||
|
||||
@@ -92,7 +92,7 @@ colorama==0.4.6 ; sys_platform == 'win32'
|
||||
# via
|
||||
# click
|
||||
# tqdm
|
||||
cryptography==46.0.5
|
||||
cryptography==46.0.6
|
||||
# via
|
||||
# google-auth
|
||||
# pyjwt
|
||||
|
||||
@@ -191,25 +191,6 @@ IGNORED_SYNCING_TENANT_LIST = (
|
||||
else None
|
||||
)
|
||||
|
||||
# Global flag to skip userfile threshold for all users/tenants
|
||||
SKIP_USERFILE_THRESHOLD = (
|
||||
os.environ.get("SKIP_USERFILE_THRESHOLD", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Comma-separated list of specific tenant IDs to skip threshold (multi-tenant only)
|
||||
SKIP_USERFILE_THRESHOLD_TENANT_IDS = os.environ.get(
|
||||
"SKIP_USERFILE_THRESHOLD_TENANT_IDS"
|
||||
)
|
||||
SKIP_USERFILE_THRESHOLD_TENANT_LIST = (
|
||||
[
|
||||
tenant.strip()
|
||||
for tenant in SKIP_USERFILE_THRESHOLD_TENANT_IDS.split(",")
|
||||
if tenant.strip()
|
||||
]
|
||||
if SKIP_USERFILE_THRESHOLD_TENANT_IDS
|
||||
else None
|
||||
)
|
||||
|
||||
ENVIRONMENT = os.environ.get("ENVIRONMENT") or "not_explicitly_set"
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -17,6 +19,10 @@ PRIVATE_CHANNEL_USERS = [
|
||||
"test_user_2@onyx-test.com",
|
||||
]
|
||||
|
||||
# Predates any test workspace messages, so the result set should match
|
||||
# the "no start time" case while exercising the oldest= parameter.
|
||||
OLDEST_TS_2016 = datetime(2016, 1, 1, tzinfo=timezone.utc).timestamp()
|
||||
|
||||
pytestmark = pytest.mark.usefixtures("enable_ee")
|
||||
|
||||
|
||||
@@ -105,15 +111,17 @@ def test_load_from_checkpoint_access__private_channel(
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
@pytest.mark.parametrize("start_ts", [None, OLDEST_TS_2016])
|
||||
def test_slim_documents_access__public_channel(
|
||||
slack_connector: SlackConnector,
|
||||
start_ts: float | None,
|
||||
) -> None:
|
||||
"""Test that retrieve_all_slim_docs_perm_sync returns correct access information for slim documents."""
|
||||
if not slack_connector.client:
|
||||
raise RuntimeError("Web client must be defined")
|
||||
|
||||
slim_docs_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
|
||||
start=0.0,
|
||||
start=start_ts,
|
||||
end=time.time(),
|
||||
)
|
||||
|
||||
@@ -149,7 +157,7 @@ def test_slim_documents_access__private_channel(
|
||||
raise RuntimeError("Web client must be defined")
|
||||
|
||||
slim_docs_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
|
||||
start=0.0,
|
||||
start=None,
|
||||
end=time.time(),
|
||||
)
|
||||
|
||||
|
||||
@@ -27,11 +27,13 @@ def create_placement(
|
||||
turn_index: int,
|
||||
tab_index: int = 0,
|
||||
sub_turn_index: int | None = None,
|
||||
model_index: int | None = 0,
|
||||
) -> Placement:
|
||||
return Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
model_index=model_index,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -129,6 +129,10 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
|
||||
return_value=mock_app,
|
||||
),
|
||||
patch(_PATCH_QUEUE_DEPTH, return_value=0),
|
||||
patch(
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@@ -88,10 +88,22 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
|
||||
the actual task instance. We patch ``app`` on that instance's class
|
||||
(a unique Celery-generated Task subclass) so the mock is scoped to this
|
||||
task only.
|
||||
|
||||
Also patches ``celery_get_broker_client`` so the mock app doesn't need
|
||||
a real broker URL.
|
||||
"""
|
||||
task_instance = task.run.__self__
|
||||
with patch.object(
|
||||
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
|
||||
with (
|
||||
patch.object(
|
||||
type(task_instance),
|
||||
"app",
|
||||
new_callable=PropertyMock,
|
||||
return_value=mock_app,
|
||||
),
|
||||
patch(
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
External dependency unit tests for UserFileIndexingAdapter metadata writing.
|
||||
|
||||
Validates that build_metadata_aware_chunks produces DocMetadataAwareIndexChunk
|
||||
Validates that prepare_enrichment produces DocMetadataAwareIndexChunk
|
||||
objects with both `user_project` and `personas` fields populated correctly
|
||||
based on actual DB associations.
|
||||
|
||||
@@ -127,7 +127,7 @@ def _make_index_chunk(user_file: UserFile) -> IndexChunk:
|
||||
|
||||
|
||||
class TestAdapterWritesBothMetadataFields:
|
||||
"""build_metadata_aware_chunks must populate user_project AND personas."""
|
||||
"""prepare_enrichment must populate user_project AND personas."""
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
|
||||
@@ -153,15 +153,13 @@ class TestAdapterWritesBothMetadataFields:
|
||||
doc = chunk.source_document
|
||||
context = DocumentBatchPrepareContext(updatable_docs=[doc], id_to_boost_map={})
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
enricher = adapter.prepare_enrichment(
|
||||
context=context,
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
chunks=[chunk],
|
||||
)
|
||||
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
|
||||
|
||||
assert len(result.chunks) == 1
|
||||
aware_chunk = result.chunks[0]
|
||||
assert persona.id in aware_chunk.personas
|
||||
assert aware_chunk.user_project == []
|
||||
|
||||
@@ -190,15 +188,13 @@ class TestAdapterWritesBothMetadataFields:
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
enricher = adapter.prepare_enrichment(
|
||||
context=context,
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
chunks=[chunk],
|
||||
)
|
||||
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
|
||||
|
||||
assert len(result.chunks) == 1
|
||||
aware_chunk = result.chunks[0]
|
||||
assert project.id in aware_chunk.user_project
|
||||
assert aware_chunk.personas == []
|
||||
|
||||
@@ -229,14 +225,13 @@ class TestAdapterWritesBothMetadataFields:
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
enricher = adapter.prepare_enrichment(
|
||||
context=context,
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
chunks=[chunk],
|
||||
)
|
||||
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
|
||||
|
||||
aware_chunk = result.chunks[0]
|
||||
assert persona.id in aware_chunk.personas
|
||||
assert project.id in aware_chunk.user_project
|
||||
|
||||
@@ -261,14 +256,13 @@ class TestAdapterWritesBothMetadataFields:
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
enricher = adapter.prepare_enrichment(
|
||||
context=context,
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
chunks=[chunk],
|
||||
)
|
||||
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
|
||||
|
||||
aware_chunk = result.chunks[0]
|
||||
assert aware_chunk.personas == []
|
||||
assert aware_chunk.user_project == []
|
||||
|
||||
@@ -300,12 +294,11 @@ class TestAdapterWritesBothMetadataFields:
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
enricher = adapter.prepare_enrichment(
|
||||
context=context,
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
chunks=[chunk],
|
||||
)
|
||||
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
|
||||
|
||||
aware_chunk = result.chunks[0]
|
||||
assert set(aware_chunk.personas) == {persona_a.id, persona_b.id}
|
||||
|
||||
@@ -90,8 +90,17 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
|
||||
task only.
|
||||
"""
|
||||
task_instance = task.run.__self__
|
||||
with patch.object(
|
||||
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
|
||||
with (
|
||||
patch.object(
|
||||
type(task_instance),
|
||||
"app",
|
||||
new_callable=PropertyMock,
|
||||
return_value=mock_app,
|
||||
),
|
||||
patch(
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@@ -103,6 +103,11 @@ _EXPECTED_CONFLUENCE_GROUPS = [
|
||||
user_emails={"oauth@onyx.app"},
|
||||
gives_anyone_access=False,
|
||||
),
|
||||
ExternalUserGroupSet(
|
||||
id="no yuhong allowed",
|
||||
user_emails={"hagen@danswer.ai", "pablo@onyx.app", "chris@onyx.app"},
|
||||
gives_anyone_access=False,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ These tests assume Vespa and OpenSearch are running.
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
@@ -21,6 +22,7 @@ from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
)
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
from tests.external_dependency_unit.document_index.conftest import EMBEDDING_DIM
|
||||
from tests.external_dependency_unit.document_index.conftest import make_chunk
|
||||
@@ -201,3 +203,25 @@ class TestDocumentIndexNew:
|
||||
assert len(result_map) == 2
|
||||
assert result_map[existing_doc] is True
|
||||
assert result_map[new_doc] is False
|
||||
|
||||
def test_index_accepts_generator(
|
||||
self,
|
||||
document_indices: list[DocumentIndexNew],
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""index() accepts a generator (any iterable), not just a list."""
|
||||
for document_index in document_indices:
|
||||
doc_id = f"test_gen_{uuid.uuid4().hex[:8]}"
|
||||
metadata = make_indexing_metadata([doc_id], old_counts=[0], new_counts=[3])
|
||||
|
||||
def chunk_gen() -> Iterator[DocMetadataAwareIndexChunk]:
|
||||
for i in range(3):
|
||||
yield make_chunk(doc_id, chunk_id=i)
|
||||
|
||||
results = document_index.index(
|
||||
chunks=chunk_gen(), indexing_metadata=metadata
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].document_id == doc_id
|
||||
assert results[0].already_existed is False
|
||||
|
||||
@@ -5,6 +5,7 @@ These tests assume Vespa and OpenSearch are running.
|
||||
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -166,3 +167,29 @@ class TestDocumentIndexOld:
|
||||
batch_retrieval=True,
|
||||
)
|
||||
assert len(inference_chunks) == 0
|
||||
|
||||
def test_index_accepts_generator(
|
||||
self,
|
||||
document_indices: list[DocumentIndex],
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""index() accepts a generator (any iterable), not just a list."""
|
||||
for document_index in document_indices:
|
||||
|
||||
def chunk_gen() -> Iterator[DocMetadataAwareIndexChunk]:
|
||||
for i in range(3):
|
||||
yield make_chunk("test_doc_gen", chunk_id=i)
|
||||
|
||||
index_batch_params = IndexBatchParams(
|
||||
doc_id_to_previous_chunk_cnt={"test_doc_gen": 0},
|
||||
doc_id_to_new_chunk_cnt={"test_doc_gen": 3},
|
||||
tenant_id=get_current_tenant_id(),
|
||||
large_chunks_enabled=False,
|
||||
)
|
||||
|
||||
results = document_index.index(chunk_gen(), index_batch_params)
|
||||
|
||||
assert len(results) == 1
|
||||
record = results.pop()
|
||||
assert record.document_id == "test_doc_gen"
|
||||
assert record.already_existed is False
|
||||
|
||||
@@ -13,6 +13,7 @@ This test:
|
||||
All external HTTP calls are mocked, but Postgres and Redis are running.
|
||||
"""
|
||||
|
||||
import queue
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
@@ -20,7 +21,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.db.enums import MCPAuthenticationPerformer
|
||||
from onyx.db.enums import MCPAuthenticationType
|
||||
from onyx.db.enums import MCPTransport
|
||||
@@ -137,7 +138,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=search_tool_config,
|
||||
@@ -200,7 +201,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
@@ -275,7 +276,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
@@ -350,7 +351,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
@@ -458,7 +459,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
@@ -541,7 +542,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
|
||||
@@ -8,6 +8,7 @@ Tests the priority logic for OAuth tokens when constructing custom tools:
|
||||
All external HTTP calls are mocked, but Postgres and Redis are running.
|
||||
"""
|
||||
|
||||
import queue
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
@@ -16,7 +17,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import OAuthConfig
|
||||
from onyx.db.models import Persona
|
||||
@@ -174,7 +175,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=search_tool_config,
|
||||
@@ -232,7 +233,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -284,7 +285,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -345,7 +346,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -416,7 +417,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -483,7 +484,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -536,7 +537,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
"""Tests for celery_get_broker_client singleton."""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.background.celery import celery_redis
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_singleton() -> Iterator[None]:
|
||||
"""Reset the module-level singleton between tests."""
|
||||
celery_redis._broker_client = None
|
||||
celery_redis._broker_url = None
|
||||
yield
|
||||
celery_redis._broker_client = None
|
||||
celery_redis._broker_url = None
|
||||
|
||||
|
||||
def _make_mock_app(broker_url: str = "redis://localhost:6379/15") -> MagicMock:
|
||||
app = MagicMock()
|
||||
app.conf.broker_url = broker_url
|
||||
return app
|
||||
|
||||
|
||||
class TestCeleryGetBrokerClient:
|
||||
@patch("onyx.background.celery.celery_redis.Redis")
|
||||
def test_creates_client_on_first_call(self, mock_redis_cls: MagicMock) -> None:
|
||||
mock_client = MagicMock()
|
||||
mock_redis_cls.from_url.return_value = mock_client
|
||||
|
||||
app = _make_mock_app()
|
||||
result = celery_redis.celery_get_broker_client(app)
|
||||
|
||||
assert result is mock_client
|
||||
call_args = mock_redis_cls.from_url.call_args
|
||||
assert call_args[0][0] == "redis://localhost:6379/15"
|
||||
assert call_args[1]["decode_responses"] is False
|
||||
assert call_args[1]["socket_keepalive"] is True
|
||||
assert call_args[1]["retry_on_timeout"] is True
|
||||
|
||||
@patch("onyx.background.celery.celery_redis.Redis")
|
||||
def test_reuses_cached_client(self, mock_redis_cls: MagicMock) -> None:
|
||||
mock_client = MagicMock()
|
||||
mock_client.ping.return_value = True
|
||||
mock_redis_cls.from_url.return_value = mock_client
|
||||
|
||||
app = _make_mock_app()
|
||||
client1 = celery_redis.celery_get_broker_client(app)
|
||||
client2 = celery_redis.celery_get_broker_client(app)
|
||||
|
||||
assert client1 is client2
|
||||
# from_url called only once
|
||||
assert mock_redis_cls.from_url.call_count == 1
|
||||
|
||||
@patch("onyx.background.celery.celery_redis.Redis")
|
||||
def test_reconnects_on_ping_failure(self, mock_redis_cls: MagicMock) -> None:
|
||||
stale_client = MagicMock()
|
||||
stale_client.ping.side_effect = ConnectionError("disconnected")
|
||||
|
||||
fresh_client = MagicMock()
|
||||
fresh_client.ping.return_value = True
|
||||
|
||||
mock_redis_cls.from_url.side_effect = [stale_client, fresh_client]
|
||||
|
||||
app = _make_mock_app()
|
||||
|
||||
# First call creates stale_client
|
||||
client1 = celery_redis.celery_get_broker_client(app)
|
||||
assert client1 is stale_client
|
||||
|
||||
# Second call: ping fails, creates fresh_client
|
||||
client2 = celery_redis.celery_get_broker_client(app)
|
||||
assert client2 is fresh_client
|
||||
assert mock_redis_cls.from_url.call_count == 2
|
||||
|
||||
@patch("onyx.background.celery.celery_redis.Redis")
|
||||
def test_uses_broker_url_from_app_config(self, mock_redis_cls: MagicMock) -> None:
|
||||
mock_redis_cls.from_url.return_value = MagicMock()
|
||||
|
||||
app = _make_mock_app("redis://custom-host:6380/3")
|
||||
celery_redis.celery_get_broker_client(app)
|
||||
|
||||
call_args = mock_redis_cls.from_url.call_args
|
||||
assert call_args[0][0] == "redis://custom-host:6380/3"
|
||||
173
backend/tests/unit/onyx/chat/test_emitter.py
Normal file
173
backend/tests/unit/onyx/chat/test_emitter.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""Unit tests for the Emitter class.
|
||||
|
||||
All tests use the streaming mode (merged_queue required). Emitter has a single
|
||||
code path — no standalone bus.
|
||||
"""
|
||||
|
||||
import queue
|
||||
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _placement(
|
||||
turn_index: int = 0,
|
||||
tab_index: int = 0,
|
||||
sub_turn_index: int | None = None,
|
||||
) -> Placement:
|
||||
return Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
)
|
||||
|
||||
|
||||
def _packet(
|
||||
turn_index: int = 0,
|
||||
tab_index: int = 0,
|
||||
sub_turn_index: int | None = None,
|
||||
) -> Packet:
|
||||
"""Build a minimal valid packet with an OverallStop payload."""
|
||||
return Packet(
|
||||
placement=_placement(turn_index, tab_index, sub_turn_index),
|
||||
obj=OverallStop(stop_reason="test"),
|
||||
)
|
||||
|
||||
|
||||
def _make_emitter(model_idx: int = 0) -> tuple["Emitter", "queue.Queue"]:
|
||||
"""Return (emitter, queue) wired together."""
|
||||
mq: queue.Queue = queue.Queue()
|
||||
return Emitter(merged_queue=mq, model_idx=model_idx), mq
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Queue routing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitterQueueRouting:
|
||||
def test_emit_lands_on_merged_queue(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet())
|
||||
assert not mq.empty()
|
||||
|
||||
def test_queue_item_is_tuple_of_key_and_packet(self) -> None:
|
||||
emitter, mq = _make_emitter(model_idx=1)
|
||||
emitter.emit(_packet())
|
||||
item = mq.get_nowait()
|
||||
assert isinstance(item, tuple)
|
||||
assert len(item) == 2
|
||||
|
||||
def test_multiple_packets_delivered_fifo(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
p1 = _packet(turn_index=0)
|
||||
p2 = _packet(turn_index=1)
|
||||
emitter.emit(p1)
|
||||
emitter.emit(p2)
|
||||
_, t1 = mq.get_nowait()
|
||||
_, t2 = mq.get_nowait()
|
||||
assert t1.placement.turn_index == 0
|
||||
assert t2.placement.turn_index == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# model_index tagging
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitterModelIndexTagging:
|
||||
def test_n1_default_model_idx_tags_model_index_zero(self) -> None:
|
||||
"""N=1: default model_idx=0, so packet gets model_index=0."""
|
||||
emitter, mq = _make_emitter(model_idx=0)
|
||||
emitter.emit(_packet())
|
||||
_key, tagged = mq.get_nowait()
|
||||
assert tagged.placement.model_index == 0
|
||||
|
||||
def test_model_idx_one_tags_packet(self) -> None:
|
||||
emitter, mq = _make_emitter(model_idx=1)
|
||||
emitter.emit(_packet())
|
||||
_key, tagged = mq.get_nowait()
|
||||
assert tagged.placement.model_index == 1
|
||||
|
||||
def test_model_idx_two_tags_packet(self) -> None:
|
||||
"""Boundary: third model in a 3-model run."""
|
||||
emitter, mq = _make_emitter(model_idx=2)
|
||||
emitter.emit(_packet())
|
||||
_key, tagged = mq.get_nowait()
|
||||
assert tagged.placement.model_index == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Queue key
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitterQueueKey:
|
||||
def test_key_equals_model_idx(self) -> None:
|
||||
"""Drain loop uses the key to route packets; it must match model_idx."""
|
||||
emitter, mq = _make_emitter(model_idx=2)
|
||||
emitter.emit(_packet())
|
||||
key, _ = mq.get_nowait()
|
||||
assert key == 2
|
||||
|
||||
def test_n1_key_is_zero(self) -> None:
|
||||
emitter, mq = _make_emitter(model_idx=0)
|
||||
emitter.emit(_packet())
|
||||
key, _ = mq.get_nowait()
|
||||
assert key == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Placement field preservation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitterPlacementPreservation:
|
||||
def test_turn_index_is_preserved(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet(turn_index=5))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.turn_index == 5
|
||||
|
||||
def test_tab_index_is_preserved(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet(tab_index=3))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.tab_index == 3
|
||||
|
||||
def test_sub_turn_index_is_preserved(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet(sub_turn_index=2))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.sub_turn_index == 2
|
||||
|
||||
def test_sub_turn_index_none_is_preserved(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet(sub_turn_index=None))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.sub_turn_index is None
|
||||
|
||||
def test_packet_obj_is_not_modified(self) -> None:
|
||||
"""The payload object must survive tagging untouched."""
|
||||
emitter, mq = _make_emitter()
|
||||
original_obj = OverallStop(stop_reason="sentinel")
|
||||
pkt = Packet(placement=_placement(), obj=original_obj)
|
||||
emitter.emit(pkt)
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.obj is original_obj
|
||||
|
||||
def test_different_obj_types_are_handled(self) -> None:
|
||||
"""Any valid PacketObj type passes through correctly."""
|
||||
emitter, mq = _make_emitter()
|
||||
pkt = Packet(placement=_placement(), obj=ReasoningStart())
|
||||
emitter.emit(pkt)
|
||||
_, tagged = mq.get_nowait()
|
||||
assert isinstance(tagged.obj, ReasoningStart)
|
||||
768
backend/tests/unit/onyx/chat/test_multi_model_streaming.py
Normal file
768
backend/tests/unit/onyx/chat/test_multi_model_streaming.py
Normal file
@@ -0,0 +1,768 @@
|
||||
"""Unit tests for multi-model streaming validation and DB helpers.
|
||||
|
||||
These are pure unit tests — no real database or LLM calls required.
|
||||
The validation logic in handle_multi_model_stream fires before any external
|
||||
calls, so we can trigger it with lightweight mocks.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.chat import set_preferred_response
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _restore_ee_version() -> Generator[None, None, None]:
|
||||
"""Reset EE global state after each test.
|
||||
|
||||
Importing onyx.chat.process_message triggers set_is_ee_based_on_env_variable()
|
||||
(via the celery client import chain). Without this fixture, the EE flag stays
|
||||
True for the rest of the session and breaks unrelated tests that mock Confluence
|
||||
or other connectors and assume EE is disabled.
|
||||
"""
|
||||
original = global_version._is_ee
|
||||
yield
|
||||
global_version._is_ee = original
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_request(**kwargs: Any) -> SendMessageRequest:
|
||||
defaults: dict[str, Any] = {
|
||||
"message": "hello",
|
||||
"chat_session_id": uuid4(),
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return SendMessageRequest(**defaults)
|
||||
|
||||
|
||||
def _make_override(provider: str = "openai", version: str = "gpt-4") -> LLMOverride:
|
||||
return LLMOverride(model_provider=provider, model_version=version)
|
||||
|
||||
|
||||
def _first_from_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> Any:
|
||||
"""Return the first item yielded by handle_multi_model_stream."""
|
||||
from onyx.chat.process_message import handle_multi_model_stream
|
||||
|
||||
user = MagicMock()
|
||||
user.is_anonymous = False
|
||||
user.email = "test@example.com"
|
||||
db = MagicMock()
|
||||
|
||||
gen = handle_multi_model_stream(req, user, db, overrides)
|
||||
return next(gen)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# handle_multi_model_stream — validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunMultiModelStreamValidation:
|
||||
def test_single_override_yields_error(self) -> None:
|
||||
"""Exactly 1 override is not multi-model — yields StreamingError."""
|
||||
req = _make_request()
|
||||
result = _first_from_stream(req, [_make_override()])
|
||||
assert isinstance(result, StreamingError)
|
||||
assert "2-3" in result.error
|
||||
|
||||
def test_four_overrides_yields_error(self) -> None:
|
||||
"""4 overrides exceeds maximum — yields StreamingError."""
|
||||
req = _make_request()
|
||||
result = _first_from_stream(
|
||||
req,
|
||||
[
|
||||
_make_override("openai", "gpt-4"),
|
||||
_make_override("anthropic", "claude-3"),
|
||||
_make_override("google", "gemini-pro"),
|
||||
_make_override("cohere", "command-r"),
|
||||
],
|
||||
)
|
||||
assert isinstance(result, StreamingError)
|
||||
assert "2-3" in result.error
|
||||
|
||||
def test_zero_overrides_yields_error(self) -> None:
|
||||
"""Empty override list yields StreamingError."""
|
||||
req = _make_request()
|
||||
result = _first_from_stream(req, [])
|
||||
assert isinstance(result, StreamingError)
|
||||
assert "2-3" in result.error
|
||||
|
||||
def test_deep_research_yields_error(self) -> None:
|
||||
"""deep_research=True is incompatible with multi-model — yields StreamingError."""
|
||||
req = _make_request(deep_research=True)
|
||||
result = _first_from_stream(
|
||||
req, [_make_override(), _make_override("anthropic", "claude-3")]
|
||||
)
|
||||
assert isinstance(result, StreamingError)
|
||||
assert "not supported" in result.error
|
||||
|
||||
def test_exactly_two_overrides_is_minimum(self) -> None:
|
||||
"""Boundary: 1 override yields error, 2 overrides passes validation."""
|
||||
req = _make_request()
|
||||
# 1 override must yield a StreamingError
|
||||
result = _first_from_stream(req, [_make_override()])
|
||||
assert isinstance(
|
||||
result, StreamingError
|
||||
), "1 override should yield StreamingError"
|
||||
# 2 overrides must NOT yield a validation StreamingError (may raise later due to
|
||||
# missing session, that's OK — validation itself passed)
|
||||
try:
|
||||
result2 = _first_from_stream(
|
||||
req, [_make_override(), _make_override("anthropic", "claude-3")]
|
||||
)
|
||||
if isinstance(result2, StreamingError) and "2-3" in result2.error:
|
||||
pytest.fail(
|
||||
f"2 overrides should pass validation, got StreamingError: {result2.error}"
|
||||
)
|
||||
except Exception:
|
||||
pass # Any non-validation error means validation passed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# set_preferred_response — validation (mocked db)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSetPreferredResponseValidation:
|
||||
def test_user_message_not_found(self) -> None:
|
||||
db = MagicMock()
|
||||
db.get.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=999, preferred_assistant_message_id=1
|
||||
)
|
||||
|
||||
def test_wrong_message_type(self) -> None:
|
||||
"""Cannot set preferred response on a non-USER message."""
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.ASSISTANT # wrong type
|
||||
|
||||
db.get.return_value = user_msg
|
||||
|
||||
with pytest.raises(ValueError, match="not a user message"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_assistant_message_not_found(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
# First call returns user_msg, second call (for assistant) returns None
|
||||
db.get.side_effect = [user_msg, None]
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_assistant_not_child_of_user(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
assistant_msg = MagicMock()
|
||||
assistant_msg.parent_message_id = 999 # different parent
|
||||
|
||||
db.get.side_effect = [user_msg, assistant_msg]
|
||||
|
||||
with pytest.raises(ValueError, match="not a child"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_valid_call_sets_preferred_response_id(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
assistant_msg = MagicMock()
|
||||
assistant_msg.parent_message_id = 1 # correct parent
|
||||
|
||||
db.get.side_effect = [user_msg, assistant_msg]
|
||||
|
||||
set_preferred_response(db, user_message_id=1, preferred_assistant_message_id=2)
|
||||
|
||||
assert user_msg.preferred_response_id == 2
|
||||
assert user_msg.latest_child_message_id == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLMOverride — display_name field
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLLMOverrideDisplayName:
|
||||
def test_display_name_defaults_none(self) -> None:
|
||||
override = LLMOverride(model_provider="openai", model_version="gpt-4")
|
||||
assert override.display_name is None
|
||||
|
||||
def test_display_name_set(self) -> None:
|
||||
override = LLMOverride(
|
||||
model_provider="openai",
|
||||
model_version="gpt-4",
|
||||
display_name="GPT-4 Turbo",
|
||||
)
|
||||
assert override.display_name == "GPT-4 Turbo"
|
||||
|
||||
def test_display_name_serializes(self) -> None:
|
||||
override = LLMOverride(
|
||||
model_provider="anthropic",
|
||||
model_version="claude-opus-4-6",
|
||||
display_name="Claude Opus",
|
||||
)
|
||||
d = override.model_dump()
|
||||
assert d["display_name"] == "Claude Opus"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _run_models — drain loop behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_setup(n_models: int = 1) -> MagicMock:
|
||||
"""Minimal ChatTurnSetup mock whose fields pass Pydantic validation in _run_model."""
|
||||
setup = MagicMock()
|
||||
setup.llms = [MagicMock() for _ in range(n_models)]
|
||||
setup.model_display_names = [f"model-{i}" for i in range(n_models)]
|
||||
setup.check_is_connected = MagicMock(return_value=True)
|
||||
setup.reserved_messages = [MagicMock() for _ in range(n_models)]
|
||||
setup.reserved_token_count = 100
|
||||
# Fields consumed by SearchToolConfig / CustomToolConfig / FileReaderToolConfig
|
||||
# constructors inside _run_model — must be typed correctly for Pydantic.
|
||||
setup.new_msg_req.deep_research = False
|
||||
setup.new_msg_req.internal_search_filters = None
|
||||
setup.new_msg_req.allowed_tool_ids = None
|
||||
setup.new_msg_req.include_citations = True
|
||||
setup.search_params.project_id_filter = None
|
||||
setup.search_params.persona_id_filter = None
|
||||
setup.bypass_acl = False
|
||||
setup.slack_context = None
|
||||
setup.available_files.user_file_ids = []
|
||||
setup.available_files.chat_file_ids = []
|
||||
setup.forced_tool_id = None
|
||||
setup.simple_chat_history = []
|
||||
setup.chat_session.id = uuid4()
|
||||
setup.user_message.id = None
|
||||
setup.custom_tool_additional_headers = None
|
||||
setup.mcp_headers = None
|
||||
return setup
|
||||
|
||||
|
||||
def _run_models_collect(setup: MagicMock) -> list:
|
||||
"""Drive _run_models to completion and return all yielded items."""
|
||||
from onyx.chat.process_message import _run_models
|
||||
|
||||
return list(_run_models(setup, MagicMock(), MagicMock()))
|
||||
|
||||
|
||||
class TestRunModels:
|
||||
"""Tests for the _run_models worker-thread drain loop.
|
||||
|
||||
All external dependencies (LLM, DB, tools) are patched out. Worker threads
|
||||
still run but return immediately since run_llm_loop is mocked.
|
||||
"""
|
||||
|
||||
def test_n1_overall_stop_from_llm_loop_passes_through(self) -> None:
|
||||
"""OverallStop emitted by run_llm_loop is passed through the drain loop unchanged."""
|
||||
|
||||
def emit_stop(**kwargs: Any) -> None:
|
||||
kwargs["emitter"].emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=0),
|
||||
obj=OverallStop(stop_reason="complete"),
|
||||
)
|
||||
)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_stop),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
stops = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
|
||||
]
|
||||
assert len(stops) == 1
|
||||
stop_obj = stops[0].obj
|
||||
assert isinstance(stop_obj, OverallStop)
|
||||
assert stop_obj.stop_reason == "complete"
|
||||
|
||||
def test_n1_emitted_packet_has_model_index_zero(self) -> None:
|
||||
"""Single-model path: model_index is 0 (Emitter defaults model_idx=0)."""
|
||||
|
||||
def emit_one(**kwargs: Any) -> None:
|
||||
kwargs["emitter"].emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
reasoning = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
|
||||
]
|
||||
assert len(reasoning) == 1
|
||||
assert reasoning[0].placement.model_index == 0
|
||||
|
||||
def test_n2_each_model_packet_tagged_with_its_index(self) -> None:
|
||||
"""Multi-model path: packets from model 0 get index=0, model 1 gets index=1."""
|
||||
|
||||
def emit_one(**kwargs: Any) -> None:
|
||||
# _model_idx is set by _run_model based on position in setup.llms
|
||||
emitter = kwargs["emitter"]
|
||||
emitter.emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=2))
|
||||
|
||||
reasoning = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
|
||||
]
|
||||
assert len(reasoning) == 2
|
||||
indices = {p.placement.model_index for p in reasoning}
|
||||
assert indices == {0, 1}
|
||||
|
||||
def test_model_error_yields_streaming_error(self) -> None:
|
||||
"""An exception inside a worker thread is surfaced as a StreamingError."""
|
||||
|
||||
def always_fail(**_kwargs: Any) -> None:
|
||||
raise RuntimeError("intentional test failure")
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
errors = [p for p in packets if isinstance(p, StreamingError)]
|
||||
assert len(errors) == 1
|
||||
assert errors[0].error_code == "MODEL_ERROR"
|
||||
assert "intentional test failure" in errors[0].error
|
||||
|
||||
def test_one_model_error_does_not_stop_other_models(self) -> None:
|
||||
"""A failing model yields StreamingError; the surviving model's packets still arrive."""
|
||||
setup = _make_setup(n_models=2)
|
||||
|
||||
def fail_model_0_succeed_model_1(**kwargs: Any) -> None:
|
||||
if kwargs["llm"] is setup.llms[0]:
|
||||
raise RuntimeError("model 0 failed")
|
||||
kwargs["emitter"].emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.chat.process_message.run_llm_loop",
|
||||
side_effect=fail_model_0_succeed_model_1,
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(setup)
|
||||
|
||||
errors = [p for p in packets if isinstance(p, StreamingError)]
|
||||
assert len(errors) == 1
|
||||
|
||||
reasoning = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
|
||||
]
|
||||
assert len(reasoning) == 1
|
||||
assert reasoning[0].placement.model_index == 1
|
||||
|
||||
def test_cancellation_yields_user_cancelled_stop(self) -> None:
|
||||
"""If check_is_connected returns False, drain loop emits user_cancelled."""
|
||||
|
||||
def slow_llm(**_kwargs: Any) -> None:
|
||||
time.sleep(0.3) # Outlasts the 50 ms queue-poll interval
|
||||
|
||||
setup = _make_setup(n_models=1)
|
||||
setup.check_is_connected = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(setup)
|
||||
|
||||
stops = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
|
||||
]
|
||||
assert any(
|
||||
isinstance(s.obj, OverallStop) and s.obj.stop_reason == "user_cancelled"
|
||||
for s in stops
|
||||
)
|
||||
|
||||
def test_stop_button_calls_completion_for_all_models(self) -> None:
|
||||
"""llm_loop_completion_handle must be called for all models when the stop button fires.
|
||||
|
||||
Regression test for the disconnect-cleanup bug: the old
|
||||
run_chat_loop_with_state_containers always called completion_callback in
|
||||
its finally block (even on disconnect) so the DB message was updated from
|
||||
the TERMINATED placeholder to a partial answer. The new _run_models must
|
||||
replicate this — otherwise the integration test
|
||||
test_send_message_disconnect_and_cleanup fails because the message stays
|
||||
as "Response was terminated prior to completion, try regenerating."
|
||||
"""
|
||||
|
||||
def slow_llm(**_kwargs: Any) -> None:
|
||||
time.sleep(0.3)
|
||||
|
||||
setup = _make_setup(n_models=2)
|
||||
setup.check_is_connected = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
_run_models_collect(setup)
|
||||
|
||||
# Must be called once per model, not zero times
|
||||
assert mock_handle.call_count == 2
|
||||
|
||||
def test_completion_handle_called_for_each_successful_model(self) -> None:
|
||||
"""llm_loop_completion_handle must be called once per model that succeeded."""
|
||||
setup = _make_setup(n_models=2)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop"),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
_run_models_collect(setup)
|
||||
|
||||
assert mock_handle.call_count == 2
|
||||
|
||||
def test_completion_handle_not_called_for_failed_model(self) -> None:
|
||||
"""llm_loop_completion_handle must be skipped for a model that raised."""
|
||||
|
||||
def always_fail(**_kwargs: Any) -> None:
|
||||
raise RuntimeError("fail")
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
_run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
mock_handle.assert_not_called()
|
||||
|
||||
def test_http_disconnect_completion_via_generator_exit(self) -> None:
|
||||
"""GeneratorExit from HTTP disconnect triggers worker self-completion.
|
||||
|
||||
When the HTTP client closes the connection, Starlette throws GeneratorExit
|
||||
into the stream generator. The finally block sets drain_done (signalling
|
||||
emitters to stop blocking) and calls executor.shutdown(wait=False) so the
|
||||
server thread is never blocked. Worker threads detect drain_done.is_set()
|
||||
after run_llm_loop completes and self-persist the result via
|
||||
llm_loop_completion_handle using their own DB session.
|
||||
|
||||
This is the primary regression for test_send_message_disconnect_and_cleanup:
|
||||
the integration test disconnects mid-stream and expects the DB message to be
|
||||
updated from the TERMINATED placeholder to the real response.
|
||||
"""
|
||||
import threading
|
||||
|
||||
# Signals the worker to unblock from run_llm_loop after gen.close() returns.
|
||||
# This guarantees drain_done is set BEFORE the worker returns from run_llm_loop,
|
||||
# so the self-completion path (drain_done.is_set() check) is always taken.
|
||||
disconnect_received = threading.Event()
|
||||
# Set by the llm_loop_completion_handle mock when called.
|
||||
completion_called = threading.Event()
|
||||
|
||||
def emit_then_complete(**kwargs: Any) -> None:
|
||||
"""Emit one packet (to give the drain loop a yield point), then block
|
||||
until the main thread signals that gen.close() has been called. This
|
||||
ensures drain_done is set before we return so model_succeeded is checked
|
||||
against a set drain_done — no race condition.
|
||||
"""
|
||||
emitter = kwargs["emitter"]
|
||||
emitter.emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
disconnect_received.wait(timeout=5)
|
||||
|
||||
setup = _make_setup(n_models=1)
|
||||
# is_connected() always True — HTTP disconnect does NOT set the Redis stop fence.
|
||||
setup.check_is_connected = MagicMock(return_value=True)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.chat.process_message.run_llm_loop",
|
||||
side_effect=emit_then_complete,
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle",
|
||||
side_effect=lambda *_, **__: completion_called.set(),
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
from onyx.chat.process_message import _run_models
|
||||
|
||||
# cast to Generator so .close() is available; _run_models returns
|
||||
# AnswerStream (= Iterator) but the actual object is always a generator.
|
||||
gen = cast(Generator, _run_models(setup, MagicMock(), MagicMock()))
|
||||
# Advance to the first yielded packet — generator suspends at `yield item`.
|
||||
first = next(gen)
|
||||
assert isinstance(first, Packet)
|
||||
# Simulate Starlette closing the stream on HTTP client disconnect.
|
||||
# GeneratorExit is thrown at the `yield item` suspension point.
|
||||
gen.close()
|
||||
# Unblock the worker now that drain_done has been set by gen.close().
|
||||
disconnect_received.set()
|
||||
|
||||
# Worker self-completes asynchronously (executor.shutdown(wait=False)).
|
||||
# Wait here, inside the patch context, so that get_session_with_current_tenant
|
||||
# and llm_loop_completion_handle mocks are still active when the worker calls them.
|
||||
assert completion_called.wait(
|
||||
timeout=5
|
||||
), "worker must self-complete via drain_done within 5 seconds"
|
||||
assert (
|
||||
mock_handle.call_count == 1
|
||||
), "completion handle must be called once for the successful model"
|
||||
|
||||
def test_b1_race_disconnect_handler_completes_already_finished_model(self) -> None:
|
||||
"""B1 regression: model finishes BEFORE GeneratorExit fires.
|
||||
|
||||
The worker exits _run_model with drain_done.is_set()=False and skips
|
||||
self-completion. When gen.close() fires afterward, the finally else-branch
|
||||
must detect model_succeeded=True and call llm_loop_completion_handle itself.
|
||||
|
||||
Contrast with test_http_disconnect_completion_via_generator_exit, which
|
||||
tests the opposite ordering (worker finishes AFTER disconnect).
|
||||
"""
|
||||
import threading
|
||||
import time
|
||||
|
||||
completion_called = threading.Event()
|
||||
|
||||
def emit_and_return_immediately(**kwargs: Any) -> None:
|
||||
# Emit one packet so the drain loop has something to yield, then return
|
||||
# immediately — no blocking. The worker will be done in microseconds.
|
||||
kwargs["emitter"].emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
setup = _make_setup(n_models=1)
|
||||
setup.check_is_connected = MagicMock(return_value=True)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.chat.process_message.run_llm_loop",
|
||||
side_effect=emit_and_return_immediately,
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle",
|
||||
side_effect=lambda *_, **__: completion_called.set(),
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
from onyx.chat.process_message import _run_models
|
||||
|
||||
gen = cast(Generator, _run_models(setup, MagicMock(), MagicMock()))
|
||||
first = next(gen)
|
||||
assert isinstance(first, Packet)
|
||||
|
||||
# Give the worker thread time to finish completely (emit + return +
|
||||
# finally + self-completion check). It does almost no work, so 100 ms
|
||||
# is far more than enough while still keeping the test fast.
|
||||
time.sleep(0.1)
|
||||
|
||||
# Now close — worker is already done, so else-branch handles completion.
|
||||
gen.close()
|
||||
|
||||
assert completion_called.wait(
|
||||
timeout=5
|
||||
), "disconnect handler must call completion for a model that already finished"
|
||||
assert mock_handle.call_count == 1, "completion must be called exactly once"
|
||||
|
||||
def test_stop_button_does_not_call_completion_for_errored_model(self) -> None:
|
||||
"""B2 regression: stop-button must NOT call completion for an errored model.
|
||||
|
||||
When model 0 raises an exception, its reserved ChatMessage must not be
|
||||
saved with 'stopped by user' — that message is wrong for a model that
|
||||
errored. llm_loop_completion_handle must only be called for non-errored
|
||||
models when the stop button fires.
|
||||
"""
|
||||
|
||||
def fail_model_0(**kwargs: Any) -> None:
|
||||
if kwargs["llm"] is setup.llms[0]:
|
||||
raise RuntimeError("model 0 errored")
|
||||
# Model 1: run forever (stop button fires before it finishes)
|
||||
time.sleep(10)
|
||||
|
||||
setup = _make_setup(n_models=2)
|
||||
# Return False immediately so the stop-button path fires while model 1
|
||||
# is still sleeping (model 0 has already errored by then).
|
||||
setup.check_is_connected = lambda: False
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=fail_model_0),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
_run_models_collect(setup)
|
||||
|
||||
# Completion must NOT be called for model 0 (it errored).
|
||||
# It MAY be called for model 1 (still in-flight when stop fired).
|
||||
for call in mock_handle.call_args_list:
|
||||
assert (
|
||||
call.kwargs.get("llm") is not setup.llms[0]
|
||||
), "llm_loop_completion_handle must not be called for the errored model"
|
||||
|
||||
def test_external_state_container_used_for_model_zero(self) -> None:
|
||||
"""When provided, external_state_container is used as state_containers[0]."""
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.process_message import _run_models
|
||||
|
||||
external = ChatStateContainer()
|
||||
setup = _make_setup(n_models=1)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop") as mock_llm,
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
list(
|
||||
_run_models(
|
||||
setup, MagicMock(), MagicMock(), external_state_container=external
|
||||
)
|
||||
)
|
||||
|
||||
# The state_container kwarg passed to run_llm_loop must be the external one
|
||||
call_kwargs = mock_llm.call_args.kwargs
|
||||
assert call_kwargs["state_container"] is external
|
||||
147
backend/tests/unit/onyx/connectors/jira/test_jira_bulk_fetch.py
Normal file
147
backend/tests/unit/onyx/connectors/jira/test_jira_bulk_fetch.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from jira import JIRA
|
||||
from jira.resources import Issue
|
||||
|
||||
from onyx.connectors.jira.connector import bulk_fetch_issues
|
||||
|
||||
|
||||
def _make_raw_issue(issue_id: str) -> dict[str, Any]:
|
||||
return {
|
||||
"id": issue_id,
|
||||
"key": f"TEST-{issue_id}",
|
||||
"fields": {"summary": f"Issue {issue_id}"},
|
||||
}
|
||||
|
||||
|
||||
def _mock_jira_client() -> MagicMock:
|
||||
mock = MagicMock(spec=JIRA)
|
||||
mock._options = {"server": "https://jira.example.com"}
|
||||
mock._session = MagicMock()
|
||||
mock._get_url = MagicMock(
|
||||
return_value="https://jira.example.com/rest/api/3/issue/bulkfetch"
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
def test_bulk_fetch_success() -> None:
|
||||
"""Happy path: all issues fetched in one request."""
|
||||
client = _mock_jira_client()
|
||||
raw = [_make_raw_issue("1"), _make_raw_issue("2"), _make_raw_issue("3")]
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": raw}
|
||||
client._session.post.return_value = resp
|
||||
|
||||
result = bulk_fetch_issues(client, ["1", "2", "3"])
|
||||
assert len(result) == 3
|
||||
assert all(isinstance(r, Issue) for r in result)
|
||||
client._session.post.assert_called_once()
|
||||
|
||||
|
||||
def test_bulk_fetch_splits_on_json_error() -> None:
|
||||
"""When the full batch fails with JSONDecodeError, sub-batches succeed."""
|
||||
client = _mock_jira_client()
|
||||
|
||||
call_count = 0
|
||||
|
||||
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
ids = json["issueIdsOrKeys"]
|
||||
if len(ids) > 2:
|
||||
resp = MagicMock()
|
||||
resp.json.side_effect = requests.exceptions.JSONDecodeError(
|
||||
"Expecting ',' delimiter", "doc", 2294125
|
||||
)
|
||||
return resp
|
||||
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
|
||||
return resp
|
||||
|
||||
client._session.post.side_effect = _post_side_effect
|
||||
|
||||
result = bulk_fetch_issues(client, ["1", "2", "3", "4"])
|
||||
assert len(result) == 4
|
||||
returned_ids = {r.raw["id"] for r in result}
|
||||
assert returned_ids == {"1", "2", "3", "4"}
|
||||
assert call_count > 1
|
||||
|
||||
|
||||
def test_bulk_fetch_raises_on_single_unfetchable_issue() -> None:
|
||||
"""A single issue that always fails JSON decode raises after splitting."""
|
||||
client = _mock_jira_client()
|
||||
|
||||
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
|
||||
ids = json["issueIdsOrKeys"]
|
||||
if "bad" in ids:
|
||||
resp = MagicMock()
|
||||
resp.json.side_effect = requests.exceptions.JSONDecodeError(
|
||||
"Expecting ',' delimiter", "doc", 100
|
||||
)
|
||||
return resp
|
||||
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
|
||||
return resp
|
||||
|
||||
client._session.post.side_effect = _post_side_effect
|
||||
|
||||
with pytest.raises(requests.exceptions.JSONDecodeError):
|
||||
bulk_fetch_issues(client, ["1", "bad", "2"])
|
||||
|
||||
|
||||
def test_bulk_fetch_non_json_error_propagates() -> None:
|
||||
"""Non-JSONDecodeError exceptions still propagate."""
|
||||
client = _mock_jira_client()
|
||||
|
||||
resp = MagicMock()
|
||||
resp.json.side_effect = ValueError("something else broke")
|
||||
client._session.post.return_value = resp
|
||||
|
||||
try:
|
||||
bulk_fetch_issues(client, ["1"])
|
||||
assert False, "Expected ValueError to propagate"
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
def test_bulk_fetch_with_fields() -> None:
|
||||
"""Fields parameter is forwarded correctly."""
|
||||
client = _mock_jira_client()
|
||||
raw = [_make_raw_issue("1")]
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": raw}
|
||||
client._session.post.return_value = resp
|
||||
|
||||
bulk_fetch_issues(client, ["1"], fields="summary,description")
|
||||
|
||||
call_payload = client._session.post.call_args[1]["json"]
|
||||
assert call_payload["fields"] == ["summary", "description"]
|
||||
|
||||
|
||||
def test_bulk_fetch_recursive_splitting_raises_on_bad_issue() -> None:
|
||||
"""With a 6-issue batch where one is bad, recursion isolates it and raises."""
|
||||
client = _mock_jira_client()
|
||||
bad_id = "BAD"
|
||||
|
||||
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
|
||||
ids = json["issueIdsOrKeys"]
|
||||
if bad_id in ids:
|
||||
resp = MagicMock()
|
||||
resp.json.side_effect = requests.exceptions.JSONDecodeError(
|
||||
"truncated", "doc", 999
|
||||
)
|
||||
return resp
|
||||
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
|
||||
return resp
|
||||
|
||||
client._session.post.side_effect = _post_side_effect
|
||||
|
||||
with pytest.raises(requests.exceptions.JSONDecodeError):
|
||||
bulk_fetch_issues(client, ["1", "2", bad_id, "3", "4", "5"])
|
||||
@@ -1,3 +1,5 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -31,6 +33,7 @@ def mock_jira_cc_pair(
|
||||
"jira_base_url": jira_base_url,
|
||||
"project_key": project_key,
|
||||
}
|
||||
mock_cc_pair.connector.indexing_start = None
|
||||
|
||||
return mock_cc_pair
|
||||
|
||||
@@ -65,3 +68,75 @@ def test_jira_permission_sync(
|
||||
fetch_all_existing_docs_ids_fn=mock_fetch_all_existing_docs_ids_fn,
|
||||
):
|
||||
print(doc)
|
||||
|
||||
|
||||
def test_jira_doc_sync_passes_indexing_start(
|
||||
jira_connector: JiraConnector,
|
||||
mock_jira_cc_pair: MagicMock,
|
||||
mock_fetch_all_existing_docs_fn: MagicMock,
|
||||
mock_fetch_all_existing_docs_ids_fn: MagicMock,
|
||||
) -> None:
|
||||
"""Verify that generic_doc_sync derives indexing_start from cc_pair
|
||||
and forwards it to retrieve_all_slim_docs_perm_sync."""
|
||||
indexing_start_dt = datetime(2025, 6, 1, tzinfo=timezone.utc)
|
||||
mock_jira_cc_pair.connector.indexing_start = indexing_start_dt
|
||||
|
||||
with patch("onyx.connectors.jira.connector.build_jira_client") as mock_build_client:
|
||||
mock_build_client.return_value = jira_connector._jira_client
|
||||
assert jira_connector._jira_client is not None
|
||||
jira_connector._jira_client._options = MagicMock()
|
||||
jira_connector._jira_client._options.return_value = {
|
||||
"rest_api_version": JIRA_SERVER_API_VERSION
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
type(jira_connector),
|
||||
"retrieve_all_slim_docs_perm_sync",
|
||||
return_value=iter([]),
|
||||
) as mock_retrieve:
|
||||
list(
|
||||
jira_doc_sync(
|
||||
cc_pair=mock_jira_cc_pair,
|
||||
fetch_all_existing_docs_fn=mock_fetch_all_existing_docs_fn,
|
||||
fetch_all_existing_docs_ids_fn=mock_fetch_all_existing_docs_ids_fn,
|
||||
)
|
||||
)
|
||||
|
||||
mock_retrieve.assert_called_once()
|
||||
call_kwargs = mock_retrieve.call_args
|
||||
assert call_kwargs.kwargs["start"] == indexing_start_dt.timestamp()
|
||||
|
||||
|
||||
def test_jira_doc_sync_passes_none_when_no_indexing_start(
|
||||
jira_connector: JiraConnector,
|
||||
mock_jira_cc_pair: MagicMock,
|
||||
mock_fetch_all_existing_docs_fn: MagicMock,
|
||||
mock_fetch_all_existing_docs_ids_fn: MagicMock,
|
||||
) -> None:
|
||||
"""Verify that indexing_start is None when the connector has no indexing_start set."""
|
||||
mock_jira_cc_pair.connector.indexing_start = None
|
||||
|
||||
with patch("onyx.connectors.jira.connector.build_jira_client") as mock_build_client:
|
||||
mock_build_client.return_value = jira_connector._jira_client
|
||||
assert jira_connector._jira_client is not None
|
||||
jira_connector._jira_client._options = MagicMock()
|
||||
jira_connector._jira_client._options.return_value = {
|
||||
"rest_api_version": JIRA_SERVER_API_VERSION
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
type(jira_connector),
|
||||
"retrieve_all_slim_docs_perm_sync",
|
||||
return_value=iter([]),
|
||||
) as mock_retrieve:
|
||||
list(
|
||||
jira_doc_sync(
|
||||
cc_pair=mock_jira_cc_pair,
|
||||
fetch_all_existing_docs_fn=mock_fetch_all_existing_docs_fn,
|
||||
fetch_all_existing_docs_ids_fn=mock_fetch_all_existing_docs_ids_fn,
|
||||
)
|
||||
)
|
||||
|
||||
mock_retrieve.assert_called_once()
|
||||
call_kwargs = mock_retrieve.call_args
|
||||
assert call_kwargs.kwargs["start"] is None
|
||||
|
||||
@@ -272,13 +272,13 @@ class TestUpsertVoiceProvider:
|
||||
class TestDeleteVoiceProvider:
|
||||
"""Tests for delete_voice_provider."""
|
||||
|
||||
def test_soft_deletes_provider_when_found(self, mock_db_session: MagicMock) -> None:
|
||||
def test_hard_deletes_provider_when_found(self, mock_db_session: MagicMock) -> None:
|
||||
provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
|
||||
delete_voice_provider(mock_db_session, 1)
|
||||
|
||||
assert provider.deleted is True
|
||||
mock_db_session.delete.assert_called_once_with(provider)
|
||||
mock_db_session.flush.assert_called_once()
|
||||
|
||||
def test_does_nothing_when_provider_not_found(
|
||||
|
||||
@@ -0,0 +1,223 @@
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.document_index.interfaces_new import IndexingMetadata
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchDocumentIndex,
|
||||
)
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
|
||||
|
||||
def _make_chunk(
|
||||
doc_id: str,
|
||||
chunk_id: int,
|
||||
) -> DocMetadataAwareIndexChunk:
|
||||
"""Creates a minimal DocMetadataAwareIndexChunk for testing."""
|
||||
doc = Document(
|
||||
id=doc_id,
|
||||
sections=[TextSection(text="test", link="http://test.com")],
|
||||
source=DocumentSource.FILE,
|
||||
semantic_identifier="test_doc",
|
||||
metadata={},
|
||||
)
|
||||
access = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=True,
|
||||
)
|
||||
return DocMetadataAwareIndexChunk(
|
||||
chunk_id=chunk_id,
|
||||
blurb="test",
|
||||
content="test content",
|
||||
source_links={0: "http://test.com"},
|
||||
image_file_id=None,
|
||||
section_continuation=False,
|
||||
source_document=doc,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
embeddings=ChunkEmbedding(full_embedding=[0.1] * 10, mini_chunk_embeddings=[]),
|
||||
title_embedding=[0.1] * 10,
|
||||
tenant_id="test_tenant",
|
||||
access=access,
|
||||
document_sets=set(),
|
||||
user_project=[],
|
||||
personas=[],
|
||||
boost=0,
|
||||
aggregated_chunk_boost_factor=1.0,
|
||||
ancestor_hierarchy_node_ids=[],
|
||||
)
|
||||
|
||||
|
||||
def _make_index() -> tuple[OpenSearchDocumentIndex, MagicMock]:
|
||||
"""Creates an OpenSearchDocumentIndex with a mocked client.
|
||||
Returns the index and the mock for bulk_index_documents."""
|
||||
mock_client = MagicMock()
|
||||
mock_bulk = MagicMock()
|
||||
mock_client.bulk_index_documents = mock_bulk
|
||||
|
||||
tenant_state = TenantState(tenant_id="test_tenant", multitenant=False)
|
||||
|
||||
index = OpenSearchDocumentIndex.__new__(OpenSearchDocumentIndex)
|
||||
index._index_name = "test_index"
|
||||
index._client = mock_client
|
||||
index._tenant_state = tenant_state
|
||||
|
||||
return index, mock_bulk
|
||||
|
||||
|
||||
def _make_metadata(doc_id: str, chunk_count: int) -> IndexingMetadata:
|
||||
return IndexingMetadata(
|
||||
doc_id_to_chunk_cnt_diff={
|
||||
doc_id: IndexingMetadata.ChunkCounts(
|
||||
old_chunk_cnt=0,
|
||||
new_chunk_cnt=chunk_count,
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
|
||||
100,
|
||||
)
|
||||
def test_single_doc_under_batch_limit_flushes_once() -> None:
|
||||
"""A document with fewer chunks than MAX_CHUNKS_PER_DOC_BATCH should flush once."""
|
||||
index, mock_bulk = _make_index()
|
||||
doc_id = "doc_1"
|
||||
num_chunks = 50
|
||||
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
|
||||
metadata = _make_metadata(doc_id, num_chunks)
|
||||
|
||||
with patch.object(index, "delete", return_value=0):
|
||||
index.index(chunks, metadata)
|
||||
|
||||
assert mock_bulk.call_count == 1
|
||||
batch_arg = mock_bulk.call_args_list[0]
|
||||
assert len(batch_arg.kwargs["documents"]) == num_chunks
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
|
||||
100,
|
||||
)
|
||||
def test_single_doc_over_batch_limit_flushes_multiple_times() -> None:
|
||||
"""A document with more chunks than MAX_CHUNKS_PER_DOC_BATCH should flush multiple times."""
|
||||
index, mock_bulk = _make_index()
|
||||
doc_id = "doc_1"
|
||||
num_chunks = 250
|
||||
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
|
||||
metadata = _make_metadata(doc_id, num_chunks)
|
||||
|
||||
with patch.object(index, "delete", return_value=0):
|
||||
index.index(chunks, metadata)
|
||||
|
||||
# 250 chunks / 100 per batch = 3 flushes (100 + 100 + 50)
|
||||
assert mock_bulk.call_count == 3
|
||||
batch_sizes = [len(call.kwargs["documents"]) for call in mock_bulk.call_args_list]
|
||||
assert batch_sizes == [100, 100, 50]
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
|
||||
100,
|
||||
)
|
||||
def test_single_doc_exactly_at_batch_limit() -> None:
|
||||
"""A document with exactly MAX_CHUNKS_PER_DOC_BATCH chunks should flush once
|
||||
(the flush happens on the next chunk, not at the boundary)."""
|
||||
index, mock_bulk = _make_index()
|
||||
doc_id = "doc_1"
|
||||
num_chunks = 100
|
||||
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
|
||||
metadata = _make_metadata(doc_id, num_chunks)
|
||||
|
||||
with patch.object(index, "delete", return_value=0):
|
||||
index.index(chunks, metadata)
|
||||
|
||||
# 100 chunks hit the >= check on chunk 101 which doesn't exist,
|
||||
# so final flush handles all 100
|
||||
# Actually: the elif fires when len(current_chunks) >= 100, which happens
|
||||
# when current_chunks has 100 items and the 101st chunk arrives.
|
||||
# With exactly 100 chunks, the 100th chunk makes len == 99, then appended -> 100.
|
||||
# No 101st chunk arrives, so the final flush handles all 100.
|
||||
assert mock_bulk.call_count == 1
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
|
||||
100,
|
||||
)
|
||||
def test_single_doc_one_over_batch_limit() -> None:
|
||||
"""101 chunks for one doc: first 100 flushed when the 101st arrives, then
|
||||
the 101st is flushed at the end."""
|
||||
index, mock_bulk = _make_index()
|
||||
doc_id = "doc_1"
|
||||
num_chunks = 101
|
||||
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
|
||||
metadata = _make_metadata(doc_id, num_chunks)
|
||||
|
||||
with patch.object(index, "delete", return_value=0):
|
||||
index.index(chunks, metadata)
|
||||
|
||||
assert mock_bulk.call_count == 2
|
||||
batch_sizes = [len(call.kwargs["documents"]) for call in mock_bulk.call_args_list]
|
||||
assert batch_sizes == [100, 1]
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
|
||||
100,
|
||||
)
|
||||
def test_multiple_docs_each_under_limit_flush_per_doc() -> None:
|
||||
"""Multiple documents each under the batch limit should flush once per document."""
|
||||
index, mock_bulk = _make_index()
|
||||
chunks = []
|
||||
for doc_idx in range(3):
|
||||
doc_id = f"doc_{doc_idx}"
|
||||
for chunk_idx in range(50):
|
||||
chunks.append(_make_chunk(doc_id, chunk_idx))
|
||||
|
||||
metadata = IndexingMetadata(
|
||||
doc_id_to_chunk_cnt_diff={
|
||||
f"doc_{i}": IndexingMetadata.ChunkCounts(old_chunk_cnt=0, new_chunk_cnt=50)
|
||||
for i in range(3)
|
||||
},
|
||||
)
|
||||
|
||||
with patch.object(index, "delete", return_value=0):
|
||||
index.index(chunks, metadata)
|
||||
|
||||
# 3 documents = 3 flushes (one per doc boundary + final)
|
||||
assert mock_bulk.call_count == 3
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
|
||||
100,
|
||||
)
|
||||
def test_delete_called_once_per_document() -> None:
|
||||
"""Even with multiple flushes for a single document, delete should only be
|
||||
called once per document."""
|
||||
index, _mock_bulk = _make_index()
|
||||
doc_id = "doc_1"
|
||||
num_chunks = 250
|
||||
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
|
||||
metadata = _make_metadata(doc_id, num_chunks)
|
||||
|
||||
with patch.object(index, "delete", return_value=0) as mock_delete:
|
||||
index.index(chunks, metadata)
|
||||
|
||||
mock_delete.assert_called_once_with(doc_id, None)
|
||||
@@ -0,0 +1,152 @@
|
||||
"""Unit tests for VespaDocumentIndex.index().
|
||||
|
||||
These tests mock all external I/O (HTTP calls, thread pools) and verify
|
||||
the streaming logic, ID cleaning/mapping, and DocumentInsertionRecord
|
||||
construction.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
|
||||
from onyx.document_index.interfaces_new import IndexingMetadata
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
|
||||
|
||||
def _make_chunk(
|
||||
doc_id: str,
|
||||
chunk_id: int = 0,
|
||||
content: str = "test content",
|
||||
) -> DocMetadataAwareIndexChunk:
|
||||
doc = Document(
|
||||
id=doc_id,
|
||||
semantic_identifier="test_doc",
|
||||
sections=[TextSection(text=content, link=None)],
|
||||
source=DocumentSource.NOT_APPLICABLE,
|
||||
metadata={},
|
||||
)
|
||||
index_chunk = IndexChunk(
|
||||
chunk_id=chunk_id,
|
||||
blurb=content[:50],
|
||||
content=content,
|
||||
source_links=None,
|
||||
image_file_id=None,
|
||||
section_continuation=False,
|
||||
source_document=doc,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
embeddings=ChunkEmbedding(
|
||||
full_embedding=[0.1] * 10,
|
||||
mini_chunk_embeddings=[],
|
||||
),
|
||||
title_embedding=None,
|
||||
)
|
||||
access = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=True,
|
||||
)
|
||||
return DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=index_chunk,
|
||||
access=access,
|
||||
document_sets=set(),
|
||||
user_project=[],
|
||||
personas=[],
|
||||
boost=0,
|
||||
aggregated_chunk_boost_factor=1.0,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
|
||||
def _make_indexing_metadata(
|
||||
doc_ids: list[str],
|
||||
old_counts: list[int],
|
||||
new_counts: list[int],
|
||||
) -> IndexingMetadata:
|
||||
return IndexingMetadata(
|
||||
doc_id_to_chunk_cnt_diff={
|
||||
doc_id: IndexingMetadata.ChunkCounts(
|
||||
old_chunk_cnt=old,
|
||||
new_chunk_cnt=new,
|
||||
)
|
||||
for doc_id, old, new in zip(doc_ids, old_counts, new_counts)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _stub_enrich(
|
||||
doc_id: str,
|
||||
old_chunk_cnt: int,
|
||||
) -> EnrichedDocumentIndexingInfo:
|
||||
"""Build an EnrichedDocumentIndexingInfo that says 'no chunks to delete'
|
||||
when old_chunk_cnt == 0, or 'has existing chunks' otherwise."""
|
||||
return EnrichedDocumentIndexingInfo(
|
||||
doc_id=doc_id,
|
||||
chunk_start_index=0,
|
||||
old_version=False,
|
||||
chunk_end_index=old_chunk_cnt,
|
||||
)
|
||||
|
||||
|
||||
@patch("onyx.document_index.vespa.vespa_document_index.batch_index_vespa_chunks")
|
||||
@patch("onyx.document_index.vespa.vespa_document_index.delete_vespa_chunks")
|
||||
@patch(
|
||||
"onyx.document_index.vespa.vespa_document_index.get_document_chunk_ids",
|
||||
return_value=[],
|
||||
)
|
||||
@patch("onyx.document_index.vespa.vespa_document_index._enrich_basic_chunk_info")
|
||||
@patch(
|
||||
"onyx.document_index.vespa.vespa_document_index.BATCH_SIZE",
|
||||
3,
|
||||
)
|
||||
def test_index_respects_batch_size(
|
||||
mock_enrich: MagicMock,
|
||||
mock_get_chunk_ids: MagicMock, # noqa: ARG001
|
||||
mock_delete: MagicMock, # noqa: ARG001
|
||||
mock_batch_index: MagicMock,
|
||||
) -> None:
|
||||
"""When chunks exceed BATCH_SIZE, batch_index_vespa_chunks is called
|
||||
multiple times with correctly sized batches."""
|
||||
mock_enrich.return_value = _stub_enrich("doc1", old_chunk_cnt=0)
|
||||
|
||||
index = VespaDocumentIndex(
|
||||
index_name="test_index",
|
||||
tenant_state=TenantState(tenant_id="test_tenant", multitenant=False),
|
||||
large_chunks_enabled=False,
|
||||
httpx_client=MagicMock(),
|
||||
)
|
||||
|
||||
chunks = [_make_chunk("doc1", chunk_id=i) for i in range(7)]
|
||||
metadata = _make_indexing_metadata(["doc1"], old_counts=[0], new_counts=[7])
|
||||
|
||||
results = index.index(chunks=chunks, indexing_metadata=metadata)
|
||||
|
||||
assert len(results) == 1
|
||||
|
||||
# With BATCH_SIZE=3 and 7 chunks: batches of 3, 3, 1
|
||||
assert mock_batch_index.call_count == 3
|
||||
batch_sizes = [len(c.kwargs["chunks"]) for c in mock_batch_index.call_args_list]
|
||||
assert batch_sizes == [3, 3, 1]
|
||||
|
||||
# Verify all chunks are accounted for and in order
|
||||
all_indexed = [
|
||||
chunk for c in mock_batch_index.call_args_list for chunk in c.kwargs["chunks"]
|
||||
]
|
||||
assert len(all_indexed) == 7
|
||||
assert [c.chunk_id for c in all_indexed] == list(range(7))
|
||||
391
backend/tests/unit/onyx/indexing/test_embed_chunks_in_batches.py
Normal file
391
backend/tests/unit/onyx/indexing/test_embed_chunks_in_batches.py
Normal file
@@ -0,0 +1,391 @@
|
||||
"""Unit tests for _embed_chunks_to_store.
|
||||
|
||||
Tests cover:
|
||||
- Single batch, no failures
|
||||
- Multiple batches, no failures
|
||||
- Failure in a single batch
|
||||
- Cross-batch document failure scrubbing
|
||||
- Later batches skip already-failed docs
|
||||
- Empty input
|
||||
- All chunks fail
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import DocumentSource
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.indexing.chunk_batch_store import ChunkBatchStore
|
||||
from onyx.indexing.indexing_pipeline import _embed_chunks_to_store
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
|
||||
|
||||
def _make_doc(doc_id: str) -> Document:
|
||||
return Document(
|
||||
id=doc_id,
|
||||
semantic_identifier="test",
|
||||
source=DocumentSource.FILE,
|
||||
sections=[TextSection(text="test", link=None)],
|
||||
metadata={},
|
||||
)
|
||||
|
||||
|
||||
def _make_chunk(doc_id: str, chunk_id: int) -> DocAwareChunk:
|
||||
return DocAwareChunk(
|
||||
chunk_id=chunk_id,
|
||||
blurb="test",
|
||||
content="test content",
|
||||
source_links=None,
|
||||
image_file_id=None,
|
||||
section_continuation=False,
|
||||
source_document=_make_doc(doc_id),
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
)
|
||||
|
||||
|
||||
def _make_index_chunk(doc_id: str, chunk_id: int) -> IndexChunk:
|
||||
"""Create an IndexChunk (a DocAwareChunk with embeddings)."""
|
||||
return IndexChunk(
|
||||
chunk_id=chunk_id,
|
||||
blurb="test",
|
||||
content="test content",
|
||||
source_links=None,
|
||||
image_file_id=None,
|
||||
section_continuation=False,
|
||||
source_document=_make_doc(doc_id),
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
embeddings=ChunkEmbedding(
|
||||
full_embedding=[0.1] * 10,
|
||||
mini_chunk_embeddings=[],
|
||||
),
|
||||
title_embedding=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_failure(doc_id: str) -> ConnectorFailure:
|
||||
return ConnectorFailure(
|
||||
failed_document=DocumentFailure(document_id=doc_id, document_link=None),
|
||||
failure_message="embedding failed",
|
||||
exception=RuntimeError("embedding failed"),
|
||||
)
|
||||
|
||||
|
||||
def _mock_embed_success(
|
||||
chunks: list[DocAwareChunk], **_kwargs: object
|
||||
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
|
||||
"""Simulate successful embedding of all chunks."""
|
||||
return (
|
||||
[_make_index_chunk(c.source_document.id, c.chunk_id) for c in chunks],
|
||||
[],
|
||||
)
|
||||
|
||||
|
||||
def _mock_embed_fail_doc(
|
||||
fail_doc_id: str,
|
||||
) -> Callable[..., tuple[list[IndexChunk], list[ConnectorFailure]]]:
|
||||
"""Return an embed mock that fails all chunks for a specific doc."""
|
||||
|
||||
def _embed(
|
||||
chunks: list[DocAwareChunk], **_kwargs: object
|
||||
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
|
||||
successes = [
|
||||
_make_index_chunk(c.source_document.id, c.chunk_id)
|
||||
for c in chunks
|
||||
if c.source_document.id != fail_doc_id
|
||||
]
|
||||
failures = (
|
||||
[_make_failure(fail_doc_id)]
|
||||
if any(c.source_document.id == fail_doc_id for c in chunks)
|
||||
else []
|
||||
)
|
||||
return successes, failures
|
||||
|
||||
return _embed
|
||||
|
||||
|
||||
class TestEmbedChunksInBatches:
|
||||
@patch(
|
||||
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
|
||||
)
|
||||
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 100)
|
||||
def test_single_batch_no_failures(self, mock_embed: MagicMock) -> None:
|
||||
"""All chunks fit in one batch and embed successfully."""
|
||||
mock_embed.side_effect = _mock_embed_success
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
chunks = [_make_chunk("doc1", i) for i in range(3)]
|
||||
result = _embed_chunks_to_store(
|
||||
chunks=chunks,
|
||||
embedder=MagicMock(),
|
||||
tenant_id="test",
|
||||
request_id=None,
|
||||
store=store,
|
||||
)
|
||||
|
||||
assert len(result.successful_chunk_ids) == 3
|
||||
assert len(result.connector_failures) == 0
|
||||
|
||||
# Verify stored contents
|
||||
assert len(store._batch_files()) == 1
|
||||
stored = list(store.stream())
|
||||
assert len(stored) == 3
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
|
||||
)
|
||||
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 3)
|
||||
def test_multiple_batches_no_failures(self, mock_embed: MagicMock) -> None:
|
||||
"""Chunks are split across multiple batches, all succeed."""
|
||||
mock_embed.side_effect = _mock_embed_success
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
chunks = [_make_chunk("doc1", i) for i in range(7)]
|
||||
result = _embed_chunks_to_store(
|
||||
chunks=chunks,
|
||||
embedder=MagicMock(),
|
||||
tenant_id="test",
|
||||
request_id=None,
|
||||
store=store,
|
||||
)
|
||||
|
||||
assert len(result.successful_chunk_ids) == 7
|
||||
assert len(result.connector_failures) == 0
|
||||
assert len(store._batch_files()) == 3 # 3 + 3 + 1
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
|
||||
)
|
||||
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 100)
|
||||
def test_single_batch_with_failure(self, mock_embed: MagicMock) -> None:
|
||||
"""One doc fails embedding, its chunks are excluded from results."""
|
||||
mock_embed.side_effect = _mock_embed_fail_doc("doc2")
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
chunks = [
|
||||
_make_chunk("doc1", 0),
|
||||
_make_chunk("doc2", 1),
|
||||
_make_chunk("doc1", 2),
|
||||
]
|
||||
result = _embed_chunks_to_store(
|
||||
chunks=chunks,
|
||||
embedder=MagicMock(),
|
||||
tenant_id="test",
|
||||
request_id=None,
|
||||
store=store,
|
||||
)
|
||||
|
||||
assert len(result.connector_failures) == 1
|
||||
successful_doc_ids = {doc_id for _, doc_id in result.successful_chunk_ids}
|
||||
assert "doc2" not in successful_doc_ids
|
||||
assert "doc1" in successful_doc_ids
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
|
||||
)
|
||||
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 3)
|
||||
def test_cross_batch_failure_scrubs_earlier_batch(
|
||||
self, mock_embed: MagicMock
|
||||
) -> None:
|
||||
"""Doc A spans batches 0 and 1. It succeeds in batch 0 but fails in
|
||||
batch 1. Its chunks should be scrubbed from batch 0's batch file."""
|
||||
call_count = 0
|
||||
|
||||
def _embed(
|
||||
chunks: list[DocAwareChunk], **_kwargs: object
|
||||
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return _mock_embed_success(chunks)
|
||||
else:
|
||||
return _mock_embed_fail_doc("docA")(chunks)
|
||||
|
||||
mock_embed.side_effect = _embed
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
chunks = [
|
||||
_make_chunk("docA", 0),
|
||||
_make_chunk("docA", 1),
|
||||
_make_chunk("docA", 2),
|
||||
_make_chunk("docA", 3),
|
||||
_make_chunk("docB", 0),
|
||||
_make_chunk("docB", 1),
|
||||
]
|
||||
result = _embed_chunks_to_store(
|
||||
chunks=chunks,
|
||||
embedder=MagicMock(),
|
||||
tenant_id="test",
|
||||
request_id=None,
|
||||
store=store,
|
||||
)
|
||||
|
||||
# docA should be fully excluded from results
|
||||
successful_doc_ids = {doc_id for _, doc_id in result.successful_chunk_ids}
|
||||
assert "docA" not in successful_doc_ids
|
||||
assert "docB" in successful_doc_ids
|
||||
assert len(result.connector_failures) == 1
|
||||
|
||||
# Verify batch 0 was scrubbed of docA chunks
|
||||
all_stored = list(store.stream())
|
||||
stored_doc_ids = {c.source_document.id for c in all_stored}
|
||||
assert "docA" not in stored_doc_ids
|
||||
assert "docB" in stored_doc_ids
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
|
||||
)
|
||||
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 3)
|
||||
def test_later_batch_skips_already_failed_doc(self, mock_embed: MagicMock) -> None:
|
||||
"""If docA fails in batch 0, its chunks in batch 1 are skipped
|
||||
entirely (never sent to the embedder)."""
|
||||
embedded_doc_ids: list[str] = []
|
||||
|
||||
def _embed(
|
||||
chunks: list[DocAwareChunk], **_kwargs: object
|
||||
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
|
||||
for c in chunks:
|
||||
embedded_doc_ids.append(c.source_document.id)
|
||||
return _mock_embed_fail_doc("docA")(chunks)
|
||||
|
||||
mock_embed.side_effect = _embed
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
chunks = [
|
||||
_make_chunk("docA", 0),
|
||||
_make_chunk("docA", 1),
|
||||
_make_chunk("docA", 2),
|
||||
_make_chunk("docA", 3),
|
||||
_make_chunk("docB", 0),
|
||||
_make_chunk("docB", 1),
|
||||
]
|
||||
_embed_chunks_to_store(
|
||||
chunks=chunks,
|
||||
embedder=MagicMock(),
|
||||
tenant_id="test",
|
||||
request_id=None,
|
||||
store=store,
|
||||
)
|
||||
|
||||
# docA should only appear in batch 0, not batch 1
|
||||
batch_1_doc_ids = embedded_doc_ids[3:]
|
||||
assert "docA" not in batch_1_doc_ids
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
|
||||
)
|
||||
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 3)
|
||||
def test_failed_doc_skipped_in_later_batch_while_other_doc_succeeds(
|
||||
self, mock_embed: MagicMock
|
||||
) -> None:
|
||||
"""doc1 spans batches 0 and 1, doc2 only in batch 1. Batch 0 fails
|
||||
doc1. In batch 1, doc1 chunks should be skipped but doc2 chunks
|
||||
should still be embedded successfully."""
|
||||
embedded_chunks: list[list[str]] = []
|
||||
|
||||
def _embed(
|
||||
chunks: list[DocAwareChunk], **_kwargs: object
|
||||
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
|
||||
embedded_chunks.append([c.source_document.id for c in chunks])
|
||||
return _mock_embed_fail_doc("doc1")(chunks)
|
||||
|
||||
mock_embed.side_effect = _embed
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
chunks = [
|
||||
_make_chunk("doc1", 0),
|
||||
_make_chunk("doc1", 1),
|
||||
_make_chunk("doc1", 2),
|
||||
_make_chunk("doc1", 3),
|
||||
_make_chunk("doc2", 0),
|
||||
_make_chunk("doc2", 1),
|
||||
]
|
||||
result = _embed_chunks_to_store(
|
||||
chunks=chunks,
|
||||
embedder=MagicMock(),
|
||||
tenant_id="test",
|
||||
request_id=None,
|
||||
store=store,
|
||||
)
|
||||
|
||||
# doc1 should be fully excluded, doc2 fully included
|
||||
successful_doc_ids = {doc_id for _, doc_id in result.successful_chunk_ids}
|
||||
assert "doc1" not in successful_doc_ids
|
||||
assert "doc2" in successful_doc_ids
|
||||
assert len(result.successful_chunk_ids) == 2 # doc2's 2 chunks
|
||||
|
||||
# Batch 1 should only contain doc2 (doc1 was filtered before embedding)
|
||||
assert len(embedded_chunks) == 2
|
||||
assert "doc1" not in embedded_chunks[1]
|
||||
assert embedded_chunks[1] == ["doc2", "doc2"]
|
||||
|
||||
# Verify on-disk state has no doc1 chunks
|
||||
all_stored = list(store.stream())
|
||||
assert all(c.source_document.id == "doc2" for c in all_stored)
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
|
||||
)
|
||||
def test_empty_input(self, mock_embed: MagicMock) -> None:
|
||||
"""Empty chunk list produces empty results."""
|
||||
mock_embed.side_effect = _mock_embed_success
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
result = _embed_chunks_to_store(
|
||||
chunks=[],
|
||||
embedder=MagicMock(),
|
||||
tenant_id="test",
|
||||
request_id=None,
|
||||
store=store,
|
||||
)
|
||||
|
||||
assert len(result.successful_chunk_ids) == 0
|
||||
assert len(result.connector_failures) == 0
|
||||
mock_embed.assert_not_called()
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
|
||||
)
|
||||
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 100)
|
||||
def test_all_chunks_fail(self, mock_embed: MagicMock) -> None:
|
||||
"""When all documents fail, results have no successful chunks."""
|
||||
|
||||
def _fail_all(
|
||||
chunks: list[DocAwareChunk], **_kwargs: object
|
||||
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
|
||||
doc_ids = {c.source_document.id for c in chunks}
|
||||
return [], [_make_failure(doc_id) for doc_id in doc_ids]
|
||||
|
||||
mock_embed.side_effect = _fail_all
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
chunks = [_make_chunk("doc1", 0), _make_chunk("doc2", 1)]
|
||||
result = _embed_chunks_to_store(
|
||||
chunks=chunks,
|
||||
embedder=MagicMock(),
|
||||
tenant_id="test",
|
||||
request_id=None,
|
||||
store=store,
|
||||
)
|
||||
|
||||
assert len(result.successful_chunk_ids) == 0
|
||||
assert len(result.connector_failures) == 2
|
||||
@@ -116,7 +116,7 @@ def _run_adapter_build(
|
||||
project_ids_map: dict[str, list[int]],
|
||||
persona_ids_map: dict[str, list[int]],
|
||||
) -> list[DocMetadataAwareIndexChunk]:
|
||||
"""Helper that runs UserFileIndexingAdapter.build_metadata_aware_chunks
|
||||
"""Helper that runs UserFileIndexingAdapter.prepare_enrichment + enrich_chunk
|
||||
with all external dependencies mocked."""
|
||||
from onyx.indexing.adapters.user_file_indexing_adapter import (
|
||||
UserFileIndexingAdapter,
|
||||
@@ -155,18 +155,16 @@ def _run_adapter_build(
|
||||
side_effect=Exception("no LLM in tests"),
|
||||
),
|
||||
):
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id="test_tenant",
|
||||
enricher = adapter.prepare_enrichment(
|
||||
context=context,
|
||||
tenant_id="test_tenant",
|
||||
chunks=[chunk],
|
||||
)
|
||||
|
||||
return result.chunks
|
||||
return [enricher.enrich_chunk(chunk, 1.0)]
|
||||
|
||||
|
||||
def test_build_metadata_aware_chunks_includes_persona_ids() -> None:
|
||||
"""UserFileIndexingAdapter.build_metadata_aware_chunks writes persona IDs
|
||||
def test_prepare_enrichment_includes_persona_ids() -> None:
|
||||
"""UserFileIndexingAdapter.prepare_enrichment writes persona IDs
|
||||
fetched from the DB into each chunk's metadata."""
|
||||
file_id = str(uuid4())
|
||||
persona_ids = [5, 12]
|
||||
@@ -183,7 +181,7 @@ def test_build_metadata_aware_chunks_includes_persona_ids() -> None:
|
||||
assert chunks[0].user_project == project_ids
|
||||
|
||||
|
||||
def test_build_metadata_aware_chunks_missing_file_defaults_to_empty() -> None:
|
||||
def test_prepare_enrichment_missing_file_defaults_to_empty() -> None:
|
||||
"""When a file has no persona or project associations in the DB, the
|
||||
adapter should default to empty lists (not KeyError or None)."""
|
||||
file_id = str(uuid4())
|
||||
|
||||
@@ -11,6 +11,7 @@ from litellm.types.utils import ChatCompletionDeltaToolCall
|
||||
from litellm.types.utils import Delta
|
||||
from litellm.types.utils import Function as LiteLLMFunction
|
||||
|
||||
import onyx.llm.models
|
||||
from onyx.configs.app_configs import MOCK_LLM_RESPONSE
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
@@ -1479,6 +1480,147 @@ def test_bifrost_normalizes_api_base_in_model_kwargs() -> None:
|
||||
assert llm._model_kwargs["api_base"] == "https://bifrost.example.com/v1"
|
||||
|
||||
|
||||
def test_prompt_contains_tool_call_history_true() -> None:
|
||||
from onyx.llm.multi_llm import _prompt_contains_tool_call_history
|
||||
|
||||
messages: LanguageModelInput = [
|
||||
UserMessage(content="What's the weather?"),
|
||||
AssistantMessage(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
id="tc_1",
|
||||
function=FunctionCall(name="get_weather", arguments="{}"),
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
assert _prompt_contains_tool_call_history(messages) is True
|
||||
|
||||
|
||||
def test_prompt_contains_tool_call_history_false_no_tools() -> None:
|
||||
from onyx.llm.multi_llm import _prompt_contains_tool_call_history
|
||||
|
||||
messages: LanguageModelInput = [
|
||||
UserMessage(content="Hello"),
|
||||
AssistantMessage(content="Hi there!"),
|
||||
]
|
||||
assert _prompt_contains_tool_call_history(messages) is False
|
||||
|
||||
|
||||
def test_prompt_contains_tool_call_history_false_user_only() -> None:
|
||||
from onyx.llm.multi_llm import _prompt_contains_tool_call_history
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Hello")]
|
||||
assert _prompt_contains_tool_call_history(messages) is False
|
||||
|
||||
|
||||
def test_bedrock_claude_drops_thinking_when_thinking_blocks_missing() -> None:
|
||||
"""When thinking is enabled but assistant messages with tool_calls lack
|
||||
thinking_blocks, the thinking param must be dropped to avoid the Bedrock
|
||||
BadRequestError about missing thinking blocks."""
|
||||
llm = LitellmLLM(
|
||||
api_key=None,
|
||||
timeout=30,
|
||||
model_provider=LlmProviderNames.BEDROCK,
|
||||
model_name="anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
max_input_tokens=200000,
|
||||
)
|
||||
|
||||
messages: LanguageModelInput = [
|
||||
UserMessage(content="What's the weather?"),
|
||||
AssistantMessage(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
id="tc_1",
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "Paris"}',
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
onyx.llm.models.ToolMessage(
|
||||
content="22°C sunny",
|
||||
tool_call_id="tc_1",
|
||||
),
|
||||
]
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with (
|
||||
patch("litellm.completion") as mock_completion,
|
||||
patch("onyx.llm.multi_llm.model_is_reasoning_model", return_value=True),
|
||||
):
|
||||
mock_completion.return_value = []
|
||||
|
||||
list(llm.stream(messages, tools=tools, reasoning_effort=ReasoningEffort.HIGH))
|
||||
|
||||
kwargs = mock_completion.call_args.kwargs
|
||||
assert "thinking" not in kwargs, (
|
||||
"thinking param should be dropped when thinking_blocks are missing "
|
||||
"from assistant messages with tool_calls"
|
||||
)
|
||||
|
||||
|
||||
def test_bedrock_claude_keeps_thinking_when_no_tool_history() -> None:
|
||||
"""When thinking is enabled and there are no historical assistant messages
|
||||
with tool_calls, the thinking param should be preserved."""
|
||||
llm = LitellmLLM(
|
||||
api_key=None,
|
||||
timeout=30,
|
||||
model_provider=LlmProviderNames.BEDROCK,
|
||||
model_name="anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
max_input_tokens=200000,
|
||||
)
|
||||
|
||||
messages: LanguageModelInput = [
|
||||
UserMessage(content="What's the weather?"),
|
||||
]
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with (
|
||||
patch("litellm.completion") as mock_completion,
|
||||
patch("onyx.llm.multi_llm.model_is_reasoning_model", return_value=True),
|
||||
):
|
||||
mock_completion.return_value = []
|
||||
|
||||
list(llm.stream(messages, tools=tools, reasoning_effort=ReasoningEffort.HIGH))
|
||||
|
||||
kwargs = mock_completion.call_args.kwargs
|
||||
assert "thinking" in kwargs, (
|
||||
"thinking param should be preserved when no assistant messages "
|
||||
"with tool_calls exist in history"
|
||||
)
|
||||
assert kwargs["thinking"]["type"] == "enabled"
|
||||
|
||||
|
||||
def test_bifrost_claude_includes_allowed_openai_params() -> None:
|
||||
llm = LitellmLLM(
|
||||
api_key="test_key",
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
Covers:
|
||||
- _check_ssrf_safety: scheme enforcement and private-IP blocklist
|
||||
- _validate_endpoint: httpx exception → HookValidateStatus mapping
|
||||
ConnectTimeout → cannot_connect (TCP handshake never completed)
|
||||
ConnectTimeout → timeout (any timeout directs user to increase timeout_seconds)
|
||||
ConnectError → cannot_connect (DNS / TLS failure)
|
||||
ReadTimeout et al. → timeout (TCP connected, server slow)
|
||||
Any other exc → cannot_connect
|
||||
@@ -61,7 +61,7 @@ class TestCheckSsrfSafety:
|
||||
def test_non_https_scheme_rejected(self, url: str) -> None:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
self._call(url)
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
|
||||
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
|
||||
assert "https" in (exc_info.value.detail or "").lower()
|
||||
|
||||
# --- private IP blocklist ---
|
||||
@@ -87,7 +87,7 @@ class TestCheckSsrfSafety:
|
||||
):
|
||||
mock_dns.return_value = [(None, None, None, None, (ip, 0))]
|
||||
self._call("https://internal.example.com/hook")
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
|
||||
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
|
||||
assert ip in (exc_info.value.detail or "")
|
||||
|
||||
def test_public_ip_is_allowed(self) -> None:
|
||||
@@ -106,7 +106,7 @@ class TestCheckSsrfSafety:
|
||||
pytest.raises(OnyxError) as exc_info,
|
||||
):
|
||||
self._call("https://no-such-host.example.com/hook")
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
|
||||
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -158,13 +158,11 @@ class TestValidateEndpoint:
|
||||
assert self._call().status == HookValidateStatus.passed
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_connect_timeout_returns_cannot_connect(
|
||||
self, mock_client_cls: MagicMock
|
||||
) -> None:
|
||||
def test_connect_timeout_returns_timeout(self, mock_client_cls: MagicMock) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
|
||||
httpx.ConnectTimeout("timed out")
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.cannot_connect
|
||||
assert self._call().status == HookValidateStatus.timeout
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -4,13 +4,23 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
from fastapi import UploadFile
|
||||
|
||||
from onyx.natural_language_processing import utils as nlp_utils
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.natural_language_processing.utils import count_tokens
|
||||
from onyx.server.features.projects import projects_file_utils as utils
|
||||
from onyx.server.settings.models import Settings
|
||||
|
||||
|
||||
class _Tokenizer:
|
||||
class _Tokenizer(BaseTokenizer):
|
||||
def encode(self, text: str) -> list[int]:
|
||||
return [1] * len(text)
|
||||
|
||||
def tokenize(self, text: str) -> list[str]:
|
||||
return list(text)
|
||||
|
||||
def decode(self, _tokens: list[int]) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
class _NonSeekableFile(BytesIO):
|
||||
def tell(self) -> int:
|
||||
@@ -29,10 +39,26 @@ def _make_upload_no_size(filename: str, content: bytes) -> UploadFile:
|
||||
return UploadFile(filename=filename, file=BytesIO(content), size=None)
|
||||
|
||||
|
||||
def _patch_common_dependencies(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def _make_settings(upload_size_mb: int = 1, token_threshold_k: int = 100) -> Settings:
|
||||
return Settings(
|
||||
user_file_max_upload_size_mb=upload_size_mb,
|
||||
file_token_count_threshold_k=token_threshold_k,
|
||||
)
|
||||
|
||||
|
||||
def _patch_common_dependencies(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
upload_size_mb: int = 1,
|
||||
token_threshold_k: int = 100,
|
||||
) -> None:
|
||||
monkeypatch.setattr(utils, "fetch_default_llm_model", lambda _db: None)
|
||||
monkeypatch.setattr(utils, "get_tokenizer", lambda **_kwargs: _Tokenizer())
|
||||
monkeypatch.setattr(utils, "is_file_password_protected", lambda **_kwargs: False)
|
||||
monkeypatch.setattr(
|
||||
utils,
|
||||
"load_settings",
|
||||
lambda: _make_settings(upload_size_mb, token_threshold_k),
|
||||
)
|
||||
|
||||
|
||||
def test_get_upload_size_bytes_falls_back_to_stream_size() -> None:
|
||||
@@ -76,9 +102,8 @@ def test_is_upload_too_large_logs_warning_when_size_unknown(
|
||||
def test_categorize_uploaded_files_accepts_size_under_limit(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
# upload_size_mb=1 → max_bytes = 1*1024*1024; file size 99 is well under
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
upload = _make_upload("small.png", size=99)
|
||||
@@ -91,9 +116,7 @@ def test_categorize_uploaded_files_accepts_size_under_limit(
|
||||
def test_categorize_uploaded_files_uses_seek_fallback_when_upload_size_missing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
upload = _make_upload_no_size("small.png", content=b"x" * 99)
|
||||
@@ -106,12 +129,11 @@ def test_categorize_uploaded_files_uses_seek_fallback_when_upload_size_missing(
|
||||
def test_categorize_uploaded_files_accepts_size_at_limit(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
upload = _make_upload("edge.png", size=100)
|
||||
# 1 MB = 1048576 bytes; file at exactly that boundary should be accepted
|
||||
upload = _make_upload("edge.png", size=1048576)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 1
|
||||
@@ -121,12 +143,10 @@ def test_categorize_uploaded_files_accepts_size_at_limit(
|
||||
def test_categorize_uploaded_files_rejects_size_over_limit_with_reason(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
upload = _make_upload("large.png", size=101)
|
||||
upload = _make_upload("large.png", size=1048577) # 1 byte over 1 MB
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 0
|
||||
@@ -137,13 +157,11 @@ def test_categorize_uploaded_files_rejects_size_over_limit_with_reason(
|
||||
def test_categorize_uploaded_files_mixed_batch_keeps_valid_and_rejects_oversized(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
small = _make_upload("small.png", size=50)
|
||||
large = _make_upload("large.png", size=101)
|
||||
large = _make_upload("large.png", size=1048577)
|
||||
|
||||
result = utils.categorize_uploaded_files([small, large], MagicMock())
|
||||
|
||||
@@ -153,15 +171,12 @@ def test_categorize_uploaded_files_mixed_batch_keeps_valid_and_rejects_oversized
|
||||
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_enforces_size_limit_even_when_threshold_is_skipped(
|
||||
def test_categorize_uploaded_files_enforces_size_limit_always(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "SKIP_USERFILE_THRESHOLD", True)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
|
||||
|
||||
upload = _make_upload("oversized.pdf", size=101)
|
||||
upload = _make_upload("oversized.pdf", size=1048577)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 0
|
||||
@@ -172,14 +187,12 @@ def test_categorize_uploaded_files_enforces_size_limit_even_when_threshold_is_sk
|
||||
def test_categorize_uploaded_files_checks_size_before_text_extraction(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
|
||||
|
||||
extract_mock = MagicMock(return_value="this should not run")
|
||||
monkeypatch.setattr(utils, "extract_file_text", extract_mock)
|
||||
|
||||
oversized_doc = _make_upload("oversized.pdf", size=101)
|
||||
oversized_doc = _make_upload("oversized.pdf", size=1048577)
|
||||
result = utils.categorize_uploaded_files([oversized_doc], MagicMock())
|
||||
|
||||
extract_mock.assert_not_called()
|
||||
@@ -188,40 +201,219 @@ def test_categorize_uploaded_files_checks_size_before_text_extraction(
|
||||
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_accepts_python_file(
|
||||
def test_categorize_enforces_size_limit_when_upload_size_mb_is_positive(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 10_000)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
"""A positive upload_size_mb is always enforced."""
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
py_source = b'def hello():\n print("world")\n'
|
||||
monkeypatch.setattr(
|
||||
utils, "extract_file_text", lambda **_kwargs: py_source.decode()
|
||||
)
|
||||
|
||||
upload = _make_upload("script.py", size=len(py_source), content=py_source)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 1
|
||||
assert result.acceptable[0].filename == "script.py"
|
||||
assert len(result.rejected) == 0
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_rejects_binary_file(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 10_000)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
|
||||
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: "")
|
||||
|
||||
binary_content = bytes(range(256)) * 4
|
||||
upload = _make_upload("data.bin", size=len(binary_content), content=binary_content)
|
||||
upload = _make_upload("huge.png", size=1048577, content=b"x")
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 0
|
||||
assert len(result.rejected) == 1
|
||||
assert result.rejected[0].filename == "data.bin"
|
||||
assert "Unsupported file type" in result.rejected[0].reason
|
||||
|
||||
|
||||
def test_categorize_enforces_token_limit_when_threshold_k_is_positive(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""A positive token_threshold_k is always enforced."""
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=5)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 6000)
|
||||
|
||||
upload = _make_upload("big_image.png", size=100)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 0
|
||||
assert len(result.rejected) == 1
|
||||
|
||||
|
||||
def test_categorize_no_token_limit_when_threshold_k_is_zero(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""token_threshold_k=0 means no token limit; high-token files are accepted."""
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=0)
|
||||
monkeypatch.setattr(
|
||||
utils, "estimate_image_tokens_for_upload", lambda _upload: 999_999
|
||||
)
|
||||
|
||||
upload = _make_upload("huge_image.png", size=100)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.rejected) == 0
|
||||
assert len(result.acceptable) == 1
|
||||
|
||||
|
||||
def test_categorize_both_limits_enforced(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Both positive limits are enforced; file exceeding token limit is rejected."""
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=10, token_threshold_k=5)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 6000)
|
||||
|
||||
upload = _make_upload("over_tokens.png", size=100)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 0
|
||||
assert len(result.rejected) == 1
|
||||
assert result.rejected[0].reason == "Exceeds 5K token limit"
|
||||
|
||||
|
||||
def test_categorize_rejection_reason_contains_dynamic_values(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Rejection reasons reflect the admin-configured limits, not hardcoded values."""
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=42, token_threshold_k=7)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 8000)
|
||||
|
||||
# File within size limit but over token limit
|
||||
upload = _make_upload("tokens.png", size=100)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert result.rejected[0].reason == "Exceeds 7K token limit"
|
||||
|
||||
# File over size limit
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=42, token_threshold_k=7)
|
||||
oversized = _make_upload("big.png", size=42 * 1024 * 1024 + 1)
|
||||
result2 = utils.categorize_uploaded_files([oversized], MagicMock())
|
||||
|
||||
assert result2.rejected[0].reason == "Exceeds 42 MB file size limit"
|
||||
|
||||
|
||||
# --- count_tokens tests ---
|
||||
|
||||
|
||||
def test_count_tokens_small_text() -> None:
|
||||
"""Small text should be encoded in a single call and return correct count."""
|
||||
tokenizer = _Tokenizer()
|
||||
text = "hello world"
|
||||
assert count_tokens(text, tokenizer) == len(tokenizer.encode(text))
|
||||
|
||||
|
||||
def test_count_tokens_chunked_matches_single_call() -> None:
|
||||
"""Chunked encoding should produce the same result as single-call for small text."""
|
||||
tokenizer = _Tokenizer()
|
||||
text = "a" * 1000
|
||||
assert count_tokens(text, tokenizer) == len(tokenizer.encode(text))
|
||||
|
||||
|
||||
def test_count_tokens_large_text_is_chunked(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Text exceeding _ENCODE_CHUNK_SIZE should be split into multiple encode calls."""
|
||||
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
|
||||
tokenizer = _Tokenizer()
|
||||
text = "a" * 250
|
||||
# _Tokenizer returns 1 token per char, so total should be 250
|
||||
assert count_tokens(text, tokenizer) == 250
|
||||
|
||||
|
||||
def test_count_tokens_with_token_limit_exits_early(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When token_limit is set and exceeded, count_tokens should stop early."""
|
||||
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
|
||||
|
||||
encode_call_count = 0
|
||||
original_tokenizer = _Tokenizer()
|
||||
|
||||
class _CountingTokenizer(BaseTokenizer):
|
||||
def encode(self, text: str) -> list[int]:
|
||||
nonlocal encode_call_count
|
||||
encode_call_count += 1
|
||||
return original_tokenizer.encode(text)
|
||||
|
||||
def tokenize(self, text: str) -> list[str]:
|
||||
return list(text)
|
||||
|
||||
def decode(self, _tokens: list[int]) -> str:
|
||||
return ""
|
||||
|
||||
tokenizer = _CountingTokenizer()
|
||||
# 500 chars → 5 chunks of 100; limit=150 → should stop after 2 chunks
|
||||
text = "a" * 500
|
||||
result = count_tokens(text, tokenizer, token_limit=150)
|
||||
|
||||
assert result == 200 # 2 chunks × 100 tokens each
|
||||
assert encode_call_count == 2, "Should have stopped after 2 chunks"
|
||||
|
||||
|
||||
def test_count_tokens_with_token_limit_not_exceeded(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When token_limit is set but not exceeded, all chunks are encoded."""
|
||||
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
|
||||
tokenizer = _Tokenizer()
|
||||
text = "a" * 250
|
||||
result = count_tokens(text, tokenizer, token_limit=1000)
|
||||
assert result == 250
|
||||
|
||||
|
||||
def test_count_tokens_no_limit_encodes_all_chunks(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Without token_limit, all chunks are encoded regardless of count."""
|
||||
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
|
||||
tokenizer = _Tokenizer()
|
||||
text = "a" * 500
|
||||
result = count_tokens(text, tokenizer)
|
||||
assert result == 500
|
||||
|
||||
|
||||
# --- early exit via token_limit in categorize tests ---
|
||||
|
||||
|
||||
def test_categorize_early_exits_tokenization_for_large_text(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Large text files should be rejected via early-exit tokenization
|
||||
without encoding all chunks."""
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=1)
|
||||
# token_threshold = 1000; _ENCODE_CHUNK_SIZE = 100 → text of 500 chars = 5 chunks
|
||||
# Should stop after 2nd chunk (200 tokens > 1000? No... need 1 token per char)
|
||||
# With _Tokenizer: 1 token per char. threshold=1000, chunk=100 → need 11 chunks
|
||||
# Let's use a bigger text
|
||||
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
|
||||
large_text = "x" * 5000 # 5000 tokens, threshold 1000
|
||||
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: large_text)
|
||||
|
||||
encode_call_count = 0
|
||||
original_tokenizer = _Tokenizer()
|
||||
|
||||
class _CountingTokenizer(BaseTokenizer):
|
||||
def encode(self, text: str) -> list[int]:
|
||||
nonlocal encode_call_count
|
||||
encode_call_count += 1
|
||||
return original_tokenizer.encode(text)
|
||||
|
||||
def tokenize(self, text: str) -> list[str]:
|
||||
return list(text)
|
||||
|
||||
def decode(self, _tokens: list[int]) -> str:
|
||||
return ""
|
||||
|
||||
monkeypatch.setattr(utils, "get_tokenizer", lambda **_kwargs: _CountingTokenizer())
|
||||
|
||||
upload = _make_upload("big.txt", size=5000, content=large_text.encode())
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.rejected) == 1
|
||||
assert "token limit" in result.rejected[0].reason
|
||||
# 5000 chars / 100 chunk_size = 50 chunks total; should stop well before all 50
|
||||
assert (
|
||||
encode_call_count < 50
|
||||
), f"Expected early exit but encoded {encode_call_count} chunks out of 50"
|
||||
|
||||
|
||||
def test_categorize_text_under_token_limit_accepted(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Text files under the token threshold should be accepted with exact count."""
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=1)
|
||||
small_text = "x" * 500 # 500 tokens < 1000 threshold
|
||||
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: small_text)
|
||||
|
||||
upload = _make_upload("ok.txt", size=500, content=small_text.encode())
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 1
|
||||
assert result.acceptable_file_to_token_count["ok.txt"] == 500
|
||||
|
||||
@@ -1,12 +1,23 @@
|
||||
import pytest
|
||||
|
||||
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.settings import store as settings_store
|
||||
from onyx.server.settings.models import (
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB,
|
||||
)
|
||||
from onyx.server.settings.models import DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
from onyx.server.settings.models import Settings
|
||||
|
||||
|
||||
class _FakeKvStore:
|
||||
def __init__(self, data: dict | None = None) -> None:
|
||||
self._data = data
|
||||
|
||||
def load(self, _key: str) -> dict:
|
||||
raise KvKeyNotFoundError()
|
||||
if self._data is None:
|
||||
raise KvKeyNotFoundError()
|
||||
return self._data
|
||||
|
||||
|
||||
class _FakeCache:
|
||||
@@ -20,13 +31,140 @@ class _FakeCache:
|
||||
self._vals[key] = value.encode("utf-8")
|
||||
|
||||
|
||||
def test_load_settings_includes_user_file_max_upload_size_mb(
|
||||
def test_load_settings_uses_model_defaults_when_no_stored_value(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When no settings are stored (vector DB enabled), load_settings() should
|
||||
resolve the default token threshold to 200."""
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore())
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "USER_FILE_MAX_UPLOAD_SIZE_MB", 77)
|
||||
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", False)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.user_file_max_upload_size_mb == 77
|
||||
assert settings.user_file_max_upload_size_mb == DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
assert (
|
||||
settings.file_token_count_threshold_k
|
||||
== DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
)
|
||||
|
||||
|
||||
def test_load_settings_uses_high_token_default_when_vector_db_disabled(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When vector DB is disabled and no settings are stored, the token
|
||||
threshold should default to 10000 (10M tokens)."""
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore())
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", True)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.user_file_max_upload_size_mb == DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
assert (
|
||||
settings.file_token_count_threshold_k
|
||||
== DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
|
||||
)
|
||||
|
||||
|
||||
def test_load_settings_preserves_explicit_value_when_vector_db_disabled(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When vector DB is disabled but admin explicitly set a token threshold,
|
||||
that value should be preserved (not overridden by the 10000 default)."""
|
||||
stored = Settings(file_token_count_threshold_k=500).model_dump()
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", True)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.file_token_count_threshold_k == 500
|
||||
|
||||
|
||||
def test_load_settings_preserves_zero_token_threshold(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""A value of 0 means 'no limit' and should be preserved."""
|
||||
stored = Settings(file_token_count_threshold_k=0).model_dump()
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", True)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.file_token_count_threshold_k == 0
|
||||
|
||||
|
||||
def test_load_settings_resolves_zero_upload_size_to_default(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""A value of 0 should be treated as unset and resolved to the default."""
|
||||
stored = Settings(user_file_max_upload_size_mb=0).model_dump()
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.user_file_max_upload_size_mb == DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
|
||||
|
||||
def test_load_settings_clamps_upload_size_to_env_max(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When the stored upload size exceeds MAX_ALLOWED_UPLOAD_SIZE_MB, it should
|
||||
be clamped to the env-configured maximum."""
|
||||
stored = Settings(user_file_max_upload_size_mb=500).model_dump()
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 250)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.user_file_max_upload_size_mb == 250
|
||||
|
||||
|
||||
def test_load_settings_preserves_upload_size_within_max(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When the stored upload size is within MAX_ALLOWED_UPLOAD_SIZE_MB, it should
|
||||
be preserved unchanged."""
|
||||
stored = Settings(user_file_max_upload_size_mb=150).model_dump()
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 250)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.user_file_max_upload_size_mb == 150
|
||||
|
||||
|
||||
def test_load_settings_zero_upload_size_resolves_to_default(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""A value of 0 should be treated as unset and resolved to the default,
|
||||
clamped to MAX_ALLOWED_UPLOAD_SIZE_MB."""
|
||||
stored = Settings(user_file_max_upload_size_mb=0).model_dump()
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 100)
|
||||
monkeypatch.setattr(settings_store, "DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB", 100)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.user_file_max_upload_size_mb == 100
|
||||
|
||||
|
||||
def test_load_settings_default_clamped_to_max(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB exceeds MAX_ALLOWED_UPLOAD_SIZE_MB,
|
||||
the effective default should be min(DEFAULT, MAX)."""
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore())
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB", 100)
|
||||
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 50)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.user_file_max_upload_size_mb == 50
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Tests for indexing pipeline Prometheus collectors."""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
@@ -13,6 +14,16 @@ from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
|
||||
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_broker_client() -> Iterator[None]:
|
||||
"""Patch celery_get_broker_client for all collector tests."""
|
||||
with patch(
|
||||
"onyx.background.celery.celery_redis.celery_get_broker_client",
|
||||
return_value=MagicMock(),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
class TestQueueDepthCollector:
|
||||
def test_returns_empty_when_factory_not_set(self) -> None:
|
||||
collector = QueueDepthCollector()
|
||||
@@ -24,8 +35,7 @@ class TestQueueDepthCollector:
|
||||
|
||||
def test_collects_queue_depths(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
collector.set_celery_app(MagicMock())
|
||||
|
||||
with (
|
||||
patch(
|
||||
@@ -60,8 +70,8 @@ class TestQueueDepthCollector:
|
||||
|
||||
def test_handles_redis_error_gracefully(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
MagicMock()
|
||||
collector.set_celery_app(MagicMock())
|
||||
|
||||
with patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
@@ -74,8 +84,8 @@ class TestQueueDepthCollector:
|
||||
|
||||
def test_caching_returns_stale_within_ttl(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=60)
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
MagicMock()
|
||||
collector.set_celery_app(MagicMock())
|
||||
|
||||
with (
|
||||
patch(
|
||||
@@ -98,31 +108,10 @@ class TestQueueDepthCollector:
|
||||
|
||||
assert first is second # Same object, from cache
|
||||
|
||||
def test_factory_called_each_scrape(self) -> None:
|
||||
"""Verify the Redis factory is called on each fresh collect, not cached."""
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
factory = MagicMock(return_value=MagicMock())
|
||||
collector.set_redis_factory(factory)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
return_value=0,
|
||||
),
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_unacked_task_ids",
|
||||
return_value=set(),
|
||||
),
|
||||
):
|
||||
collector.collect()
|
||||
collector.collect()
|
||||
|
||||
assert factory.call_count == 2
|
||||
|
||||
def test_error_returns_stale_cache(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
MagicMock()
|
||||
collector.set_celery_app(MagicMock())
|
||||
|
||||
# First call succeeds
|
||||
with (
|
||||
|
||||
@@ -1,96 +1,22 @@
|
||||
"""Tests for indexing pipeline setup (Redis factory caching)."""
|
||||
"""Tests for indexing pipeline setup."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from onyx.server.metrics.indexing_pipeline_setup import _make_broker_redis_factory
|
||||
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import RedisHealthCollector
|
||||
|
||||
|
||||
def _make_mock_app(client: MagicMock) -> MagicMock:
|
||||
"""Create a mock Celery app whose broker_connection().channel().client
|
||||
returns the given client."""
|
||||
mock_app = MagicMock()
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.channel.return_value.client = client
|
||||
class TestCollectorCeleryAppSetup:
|
||||
def test_queue_depth_collector_uses_celery_app(self) -> None:
|
||||
"""QueueDepthCollector.set_celery_app stores the app for broker access."""
|
||||
collector = QueueDepthCollector()
|
||||
mock_app = MagicMock()
|
||||
collector.set_celery_app(mock_app)
|
||||
assert collector._celery_app is mock_app
|
||||
|
||||
mock_app.broker_connection.return_value = mock_conn
|
||||
|
||||
return mock_app
|
||||
|
||||
|
||||
class TestMakeBrokerRedisFactory:
|
||||
def test_caches_redis_client_across_calls(self) -> None:
|
||||
"""Factory should reuse the same client on subsequent calls."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.ping.return_value = True
|
||||
mock_app = _make_mock_app(mock_client)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
|
||||
client1 = factory()
|
||||
client2 = factory()
|
||||
|
||||
assert client1 is client2
|
||||
# broker_connection should only be called once
|
||||
assert mock_app.broker_connection.call_count == 1
|
||||
|
||||
def test_reconnects_when_ping_fails(self) -> None:
|
||||
"""Factory should create a new client if ping fails (stale connection)."""
|
||||
mock_client_stale = MagicMock()
|
||||
mock_client_stale.ping.side_effect = ConnectionError("disconnected")
|
||||
|
||||
mock_client_fresh = MagicMock()
|
||||
mock_client_fresh.ping.return_value = True
|
||||
|
||||
mock_app = _make_mock_app(mock_client_stale)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
|
||||
# First call — creates and caches
|
||||
client1 = factory()
|
||||
assert client1 is mock_client_stale
|
||||
assert mock_app.broker_connection.call_count == 1
|
||||
|
||||
# Switch to fresh client for next connection
|
||||
mock_conn_fresh = MagicMock()
|
||||
mock_conn_fresh.channel.return_value.client = mock_client_fresh
|
||||
mock_app.broker_connection.return_value = mock_conn_fresh
|
||||
|
||||
# Second call — ping fails on stale, reconnects
|
||||
client2 = factory()
|
||||
assert client2 is mock_client_fresh
|
||||
assert mock_app.broker_connection.call_count == 2
|
||||
|
||||
def test_reconnect_closes_stale_client(self) -> None:
|
||||
"""When ping fails, the old client should be closed before reconnecting."""
|
||||
mock_client_stale = MagicMock()
|
||||
mock_client_stale.ping.side_effect = ConnectionError("disconnected")
|
||||
|
||||
mock_client_fresh = MagicMock()
|
||||
mock_client_fresh.ping.return_value = True
|
||||
|
||||
mock_app = _make_mock_app(mock_client_stale)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
|
||||
# First call — creates and caches
|
||||
factory()
|
||||
|
||||
# Switch to fresh client
|
||||
mock_conn_fresh = MagicMock()
|
||||
mock_conn_fresh.channel.return_value.client = mock_client_fresh
|
||||
mock_app.broker_connection.return_value = mock_conn_fresh
|
||||
|
||||
# Second call — ping fails, should close stale client
|
||||
factory()
|
||||
mock_client_stale.close.assert_called_once()
|
||||
|
||||
def test_first_call_creates_connection(self) -> None:
|
||||
"""First call should always create a new connection."""
|
||||
mock_client = MagicMock()
|
||||
mock_app = _make_mock_app(mock_client)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
client = factory()
|
||||
|
||||
assert client is mock_client
|
||||
mock_app.broker_connection.assert_called_once()
|
||||
def test_redis_health_collector_uses_celery_app(self) -> None:
|
||||
"""RedisHealthCollector.set_celery_app stores the app for broker access."""
|
||||
collector = RedisHealthCollector()
|
||||
mock_app = MagicMock()
|
||||
collector.set_celery_app(mock_app)
|
||||
assert collector._celery_app is mock_app
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Tests for memory tool streaming packet emissions."""
|
||||
|
||||
from queue import Queue
|
||||
import queue
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -18,9 +18,13 @@ from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def emitter() -> Emitter:
|
||||
bus: Queue = Queue()
|
||||
return Emitter(bus)
|
||||
def emitter_queue() -> queue.Queue:
|
||||
return queue.Queue()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def emitter(emitter_queue: queue.Queue) -> Emitter:
|
||||
return Emitter(merged_queue=emitter_queue)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -53,24 +57,27 @@ class TestMemoryToolEmitStart:
|
||||
def test_emit_start_emits_memory_tool_start_packet(
|
||||
self,
|
||||
memory_tool: MemoryTool,
|
||||
emitter: Emitter,
|
||||
emitter_queue: queue.Queue,
|
||||
placement: Placement,
|
||||
) -> None:
|
||||
memory_tool.emit_start(placement)
|
||||
|
||||
packet = emitter.bus.get_nowait()
|
||||
_key, packet = emitter_queue.get_nowait()
|
||||
assert isinstance(packet.obj, MemoryToolStart)
|
||||
assert packet.placement == placement
|
||||
assert packet.placement is not None
|
||||
assert packet.placement.turn_index == placement.turn_index
|
||||
assert packet.placement.tab_index == placement.tab_index
|
||||
assert packet.placement.model_index == 0 # emitter stamps model_index=0
|
||||
|
||||
def test_emit_start_with_different_placement(
|
||||
self,
|
||||
memory_tool: MemoryTool,
|
||||
emitter: Emitter,
|
||||
emitter_queue: queue.Queue,
|
||||
) -> None:
|
||||
placement = Placement(turn_index=2, tab_index=1)
|
||||
memory_tool.emit_start(placement)
|
||||
|
||||
packet = emitter.bus.get_nowait()
|
||||
_key, packet = emitter_queue.get_nowait()
|
||||
assert packet.placement.turn_index == 2
|
||||
assert packet.placement.tab_index == 1
|
||||
|
||||
@@ -81,7 +88,7 @@ class TestMemoryToolRun:
|
||||
self,
|
||||
mock_process: MagicMock,
|
||||
memory_tool: MemoryTool,
|
||||
emitter: Emitter,
|
||||
emitter_queue: queue.Queue,
|
||||
placement: Placement,
|
||||
override_kwargs: MemoryToolOverrideKwargs,
|
||||
) -> None:
|
||||
@@ -93,21 +100,19 @@ class TestMemoryToolRun:
|
||||
memory="User prefers Python",
|
||||
)
|
||||
|
||||
# The delta packet should be in the queue
|
||||
packet = emitter.bus.get_nowait()
|
||||
_key, packet = emitter_queue.get_nowait()
|
||||
assert isinstance(packet.obj, MemoryToolDelta)
|
||||
assert packet.obj.memory_text == "User prefers Python"
|
||||
assert packet.obj.operation == "add"
|
||||
assert packet.obj.memory_id is None
|
||||
assert packet.obj.index is None
|
||||
assert packet.placement == placement
|
||||
|
||||
@patch("onyx.tools.tool_implementations.memory.memory_tool.process_memory_update")
|
||||
def test_run_emits_delta_for_update_operation(
|
||||
self,
|
||||
mock_process: MagicMock,
|
||||
memory_tool: MemoryTool,
|
||||
emitter: Emitter,
|
||||
emitter_queue: queue.Queue,
|
||||
placement: Placement,
|
||||
override_kwargs: MemoryToolOverrideKwargs,
|
||||
) -> None:
|
||||
@@ -119,7 +124,7 @@ class TestMemoryToolRun:
|
||||
memory="User prefers light mode",
|
||||
)
|
||||
|
||||
packet = emitter.bus.get_nowait()
|
||||
_key, packet = emitter_queue.get_nowait()
|
||||
assert isinstance(packet.obj, MemoryToolDelta)
|
||||
assert packet.obj.memory_text == "User prefers light mode"
|
||||
assert packet.obj.operation == "update"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user