Compare commits

..

5 Commits

Author SHA1 Message Date
Nik
cf403d9c89 feat(chat): carousel peek animation and interface cleanup for multi-model UI
- Extract MultiModelResponse interface to message/interfaces.ts (co-locate
  per CLAUDE.md convention; update multi-model-preview/page.tsx import)
- Carousel/fade in selection mode: non-preferred panels animate from adjacent
  to preferred → peek position (64px) at container edges, with mask-image
  gradient fade so they dissolve naturally at the viewport boundary
- Entry animation: panels slide out via CSS cubic-bezier transition triggered
  by requestAnimationFrame after mount, preventing initial-position flicker
- Fix test import: useMultiModelChat.test.tsx now imports renderHook+act from
  @tests/setup/test-utils (project wrapper) instead of @testing-library/react
- Revert FrostedDiv.tsx: remove unnecessary wrapperClassName prop (unused by
  any multi-model component)
- Add multi-model-preview dev page (was untracked)
2026-03-25 23:47:17 -07:00
Nik
ce7e68f671 fix(chat): fix hover state, dedup MAX_MODELS, clean up ModelSelector
- Replace Hoverable.Root/Item overlay with hover:bg-background-tint-02
  directly on panel container — correct Figma hover state (full-panel
  tint, no badge)
- Export MAX_MODELS from ModelSelector and import it in useMultiModelChat
  to eliminate the duplicate constant
- Replace @/components/ui/accordion import with @radix-ui/react-accordion
  (removes legacy directory dep)
- Remove dead "Compare Model" button with no onClick handler
- Remove SvgColumn import that was only used by the dead button
- Remove wrapper div around hidden MultiModelPanel (panel already
  self-sizes to w-[220px])
- Tighten panel max-width from 720px to 640px to match chat column width
2026-03-25 23:19:12 -07:00
Nik
65c5d5d5d9 feat(chat): add multi-model UI components and hook 2026-03-25 21:33:03 -07:00
Nik
ebe558e04f feat(chat): add frontend types and API helpers for multi-model streaming 2026-03-25 20:41:13 -07:00
Nik
a49edf3e18 feat(chat): add multi-model parallel streaming backend 2026-03-25 20:17:38 -07:00
301 changed files with 5585 additions and 16698 deletions

View File

@@ -704,9 +704,6 @@ jobs:
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
NODE_OPTIONS=--max-old-space-size=8192
SENTRY_RELEASE=${{ github.sha }}
secrets: |
sentry_auth_token=${{ secrets.SENTRY_AUTH_TOKEN }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
@@ -789,9 +786,6 @@ jobs:
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
NODE_OPTIONS=--max-old-space-size=8192
SENTRY_RELEASE=${{ github.sha }}
secrets: |
sentry_auth_token=${{ secrets.SENTRY_AUTH_TOKEN }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest

View File

@@ -47,8 +47,7 @@ jobs:
done
- name: Publish Helm charts to gh-pages
# NOTE: HEAD of https://github.com/stefanprodan/helm-gh-pages/pull/43
uses: stefanprodan/helm-gh-pages@ad32ad3b8720abfeaac83532fd1e9bdfca5bbe27 # zizmor: ignore[impostor-commit]
uses: stefanprodan/helm-gh-pages@0ad2bb377311d61ac04ad9eb6f252fb68e207260 # ratchet:stefanprodan/helm-gh-pages@v1.7.0
with:
token: ${{ secrets.GITHUB_TOKEN }}
charts_dir: deployment/helm/charts

View File

@@ -35,7 +35,6 @@ jobs:
needs: [provider-chat-test]
if: failure() && github.event_name == 'schedule'
runs-on: ubuntu-slim
environment: ci-protected
timeout-minutes: 5
steps:
- name: Checkout

View File

@@ -183,7 +183,6 @@ jobs:
- cherry-pick-to-latest-release
if: needs.resolve-cherry-pick-request.outputs.should_cherrypick == 'true' && needs.resolve-cherry-pick-request.result == 'success' && needs.cherry-pick-to-latest-release.result == 'success'
runs-on: ubuntu-slim
environment: ci-protected
timeout-minutes: 10
steps:
- name: Checkout
@@ -233,7 +232,6 @@ jobs:
- cherry-pick-to-latest-release
if: always() && needs.resolve-cherry-pick-request.outputs.should_cherrypick == 'true' && (needs.resolve-cherry-pick-request.result == 'failure' || needs.cherry-pick-to-latest-release.result == 'failure')
runs-on: ubuntu-slim
environment: ci-protected
timeout-minutes: 10
steps:
- name: Checkout

View File

@@ -63,7 +63,7 @@ jobs:
targets: ${{ matrix.target }}
- name: Cache Cargo registry and build
uses: actions/cache@668228422ae6a00e4ad889ee87cd7109ec5666a7 # zizmor: ignore[cache-poisoning]
uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # zizmor: ignore[cache-poisoning]
with:
path: |
~/.cargo/bin/

View File

@@ -41,7 +41,7 @@ jobs:
version: v3.19.0
- name: Set up chart-testing
uses: helm/chart-testing-action@2e2940618cb426dce2999631d543b53cdcfc8527
uses: helm/chart-testing-action@b5eebdd9998021f29756c53432f48dab66394810
with:
uv_version: "0.9.9"

View File

@@ -284,7 +284,7 @@ jobs:
- name: Cache playwright cache
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
uses: runs-on/cache@a5f51d6f3fece787d03b7b4e981c82538a0654ed # ratchet:runs-on/cache@v4
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
with:
path: ~/.cache/ms-playwright
key: ${{ runner.os }}-playwright-npm-${{ hashFiles('web/package-lock.json') }}
@@ -626,7 +626,7 @@ jobs:
- name: Cache playwright cache
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
uses: runs-on/cache@a5f51d6f3fece787d03b7b4e981c82538a0654ed # ratchet:runs-on/cache@v4
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
with:
path: ~/.cache/ms-playwright
key: ${{ runner.os }}-playwright-npm-${{ hashFiles('web/package-lock.json') }}

View File

@@ -56,7 +56,7 @@ jobs:
- name: Cache mypy cache
if: ${{ vars.DISABLE_MYPY_CACHE != 'true' }}
uses: runs-on/cache@a5f51d6f3fece787d03b7b4e981c82538a0654ed # ratchet:runs-on/cache@v4
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
with:
path: .mypy_cache
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'pyproject.toml') }}

View File

@@ -31,7 +31,6 @@ jobs:
- runner=4cpu-linux-arm64
- "run-id=${{ github.run_id }}-model-check"
- "extras=ecr-cache"
environment: ci-protected
timeout-minutes: 45
env:

View File

@@ -15,7 +15,6 @@ permissions:
jobs:
Deploy-Preview:
runs-on: ubuntu-latest
environment: ci-protected
timeout-minutes: 30
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd

View File

@@ -13,6 +13,15 @@ jobs:
permissions:
id-token: write
timeout-minutes: 10
strategy:
matrix:
os-arch:
- { goos: "linux", goarch: "amd64" }
- { goos: "linux", goarch: "arm64" }
- { goos: "windows", goarch: "amd64" }
- { goos: "windows", goarch: "arm64" }
- { goos: "darwin", goarch: "amd64" }
- { goos: "darwin", goarch: "arm64" }
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
@@ -22,11 +31,9 @@ jobs:
enable-cache: false
version: "0.9.9"
- run: |
for goos in linux windows darwin; do
for goarch in amd64 arm64; do
GOOS="$goos" GOARCH="$goarch" uv build --wheel
done
done
GOOS="${{ matrix.os-arch.goos }}" \
GOARCH="${{ matrix.os-arch.goarch }}" \
uv build --wheel
working-directory: cli
- run: uv publish
working-directory: cli

View File

@@ -25,7 +25,6 @@ permissions:
jobs:
Deploy-Storybook:
runs-on: ubuntu-latest
environment: ci-protected
timeout-minutes: 30
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v4
@@ -55,7 +54,6 @@ jobs:
needs: Deploy-Storybook
if: always() && needs.Deploy-Storybook.result == 'failure'
runs-on: ubuntu-latest
environment: ci-protected
timeout-minutes: 10
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v4

View File

@@ -9,7 +9,6 @@ on:
jobs:
sync-foss:
runs-on: ubuntu-latest
environment: ci-protected
timeout-minutes: 45
permissions:
contents: read

View File

@@ -11,7 +11,6 @@ permissions:
jobs:
create-and-push-tag:
runs-on: ubuntu-slim
environment: ci-protected
timeout-minutes: 45
steps:

View File

@@ -1,64 +0,0 @@
{
"labels": [],
"comment": "",
"fixWithAI": true,
"hideFooter": false,
"strictness": 3,
"statusCheck": true,
"commentTypes": [
"logic",
"syntax",
"style"
],
"instructions": "",
"disabledLabels": [],
"excludeAuthors": [
"dependabot[bot]",
"renovate[bot]"
],
"ignoreKeywords": "",
"ignorePatterns": "",
"includeAuthors": [],
"summarySection": {
"included": true,
"collapsible": false,
"defaultOpen": false
},
"excludeBranches": [],
"fileChangeLimit": 300,
"includeBranches": [],
"includeKeywords": "",
"triggerOnUpdates": true,
"updateExistingSummaryComment": true,
"updateSummaryOnly": false,
"issuesTableSection": {
"included": true,
"collapsible": false,
"defaultOpen": false
},
"statusCommentsEnabled": true,
"confidenceScoreSection": {
"included": true,
"collapsible": false
},
"sequenceDiagramSection": {
"included": true,
"collapsible": false,
"defaultOpen": false
},
"shouldUpdateDescription": false,
"rules": [
{
"scope": ["web/**"],
"rule": "In Onyx's Next.js app, the `app/ee/admin/` directory is a filesystem convention for Enterprise Edition route overrides — it does NOT add an `/ee/` prefix to the URL. Both `app/admin/groups/page.tsx` and `app/ee/admin/groups/page.tsx` serve the same URL `/admin/groups`. Hardcoded `/admin/...` paths in router.push() calls are correct and do NOT break EE deployments. Do not flag hardcoded admin paths as bugs."
},
{
"scope": ["web/**"],
"rule": "In Onyx, each API key creates a unique user row in the database with a unique `user_id` (UUID). There is a 1:1 mapping between API keys and their backing user records. Multiple API keys do NOT share the same `user_id`. Do not flag potential duplicate row IDs when using `user_id` from API key descriptors."
},
{
"scope": ["backend/**/*.py"],
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"detail\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
}
]
}

View File

@@ -1,57 +0,0 @@
[
{
"scope": [],
"path": "contributing_guides/best_practices.md",
"description": "Best practices for contributing to the codebase"
},
{
"scope": ["web/**"],
"path": "web/AGENTS.md",
"description": "Frontend coding standards for the web directory"
},
{
"scope": ["web/**"],
"path": "web/tests/README.md",
"description": "Frontend testing guide and conventions"
},
{
"scope": ["web/**"],
"path": "web/CLAUDE.md",
"description": "Single source of truth for frontend coding standards"
},
{
"scope": ["web/**"],
"path": "web/lib/opal/README.md",
"description": "Opal component library usage guide"
},
{
"scope": ["backend/**"],
"path": "backend/tests/README.md",
"description": "Backend testing guide covering all 4 test types, fixtures, and conventions"
},
{
"scope": ["backend/onyx/connectors/**"],
"path": "backend/onyx/connectors/README.md",
"description": "Connector development guide covering design, interfaces, and required changes"
},
{
"scope": [],
"path": "CLAUDE.md",
"description": "Project instructions and coding standards"
},
{
"scope": [],
"path": "backend/alembic/README.md",
"description": "Migration guidance, including multi-tenant migration behavior"
},
{
"scope": [],
"path": "deployment/helm/charts/onyx/values-lite.yaml",
"description": "Lite deployment Helm values and service assumptions"
},
{
"scope": [],
"path": "deployment/docker_compose/docker-compose.onyx-lite.yml",
"description": "Lite deployment Docker Compose overlay and disabled service behavior"
}
]

View File

@@ -1,39 +0,0 @@
# Greptile Review Rules
## Type Annotations
Use explicit type annotations for variables to enhance code clarity, especially when moving type hints around in the code.
## Best Practices
Use `contributing_guides/best_practices.md` as core review context. Prefer consistency with existing patterns, fix issues in code you touch, avoid tacking new features onto muddy interfaces, fail loudly instead of silently swallowing errors, keep code strictly typed, preserve clear state boundaries, remove duplicate or dead logic, break up overly long functions, avoid hidden import-time side effects, respect module boundaries, and favor correctness-by-construction over relying on callers to use an API correctly.
## TODOs
Whenever a TODO is added, there must always be an associated name or ticket with that TODO in the style of `TODO(name): ...` or `TODO(1234): ...`
## Debugging Code
Remove temporary debugging code before merging to production, especially tenant-specific debugging logs.
## Hardcoded Booleans
When hardcoding a boolean variable to a constant value, remove the variable entirely and clean up all places where it's used rather than just setting it to a constant.
## Multi-tenant vs Single-tenant
Code changes must consider both multi-tenant and single-tenant deployments. In multi-tenant mode, preserve tenant isolation, ensure tenant context is propagated correctly, and avoid assumptions that only hold for a single shared schema or globally shared state. In single-tenant mode, avoid introducing unnecessary tenant-specific requirements or cloud-only control-plane dependencies.
## Nginx Routing — New Backend Routes
Whenever a new backend route is added that does NOT start with `/api`, it must also be explicitly added to ALL nginx configs:
- `deployment/helm/charts/onyx/templates/nginx-conf.yaml` (Helm/k8s)
- `deployment/data/nginx/app.conf.template` (docker-compose dev)
- `deployment/data/nginx/app.conf.template.prod` (docker-compose prod)
- `deployment/data/nginx/app.conf.template.no-letsencrypt` (docker-compose no-letsencrypt)
Routes not starting with `/api` are not caught by the existing `^/(api|openapi\.json)` location block and will fall through to `location /`, which proxies to the Next.js web server and returns an HTML 404. The new location block must be placed before the `/api` block. Examples of routes that need this treatment: `/scim`, `/mcp`.
## Full vs Lite Deployments
Code changes must consider both regular Onyx deployments and Onyx lite deployments. Lite deployments disable the vector DB, Redis, model servers, and background workers by default, use PostgreSQL-backed cache/auth/file storage, and rely on the API server to handle background work. Do not assume those services are available unless the code path is explicitly limited to full deployments.

View File

@@ -122,7 +122,7 @@ repos:
rev: 5d1e709b7be35cb2025444e19de266b056b7b7ee # frozen: v2.10.1
hooks:
- id: golangci-lint
language_version: "1.26.1"
language_version: "1.26.0"
entry: bash -c "find . -name go.mod -not -path './.venv/*' -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
- repo: https://github.com/astral-sh/ruff-pre-commit

View File

@@ -35,7 +35,7 @@ Onyx comes loaded with advanced features like Agents, Web Search, RAG, MCP, Deep
> [!TIP]
> Run Onyx with one command (or see deployment section below):
> ```
> curl -fsSL https://onyx.app/install_onyx.sh | bash
> curl -fsSL https://raw.githubusercontent.com/onyx-dot-app/onyx/main/deployment/docker_compose/install.sh > install.sh && chmod +x install.sh && ./install.sh
> ```
****

View File

@@ -1,35 +0,0 @@
"""remove voice_provider deleted column
Revision ID: 1d78c0ca7853
Revises: a3f8b2c1d4e5
Create Date: 2026-03-26 11:30:53.883127
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "1d78c0ca7853"
down_revision = "a3f8b2c1d4e5"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Hard-delete any soft-deleted rows before dropping the column
op.execute("DELETE FROM voice_provider WHERE deleted = true")
op.drop_column("voice_provider", "deleted")
def downgrade() -> None:
op.add_column(
"voice_provider",
sa.Column(
"deleted",
sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
),
)

View File

@@ -28,7 +28,6 @@ from onyx.access.models import DocExternalAccess
from onyx.access.models import ElementExternalAccess
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
@@ -188,6 +187,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
# (which lives on a different db number)
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
@@ -227,7 +227,6 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# or be currently executing
try:
r_celery = celery_get_broker_client(self.app)
validate_permission_sync_fences(
tenant_id, r, r_replica, r_celery, lock_beat
)
@@ -474,8 +473,6 @@ def connector_permission_sync_generator_task(
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
eager_load_connector=True,
eager_load_credential=True,
)
if cc_pair is None:
raise ValueError(

View File

@@ -29,7 +29,6 @@ from ee.onyx.external_permissions.sync_params import (
from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.background.error_logging import emit_background_error
@@ -163,6 +162,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
# (which lives on a different db number)
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
@@ -221,7 +221,6 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# or be currently executing
try:
r_celery = celery_get_broker_client(self.app)
validate_external_group_sync_fences(
tenant_id, self.app, r, r_replica, r_celery, lock_beat
)

View File

@@ -13,7 +13,6 @@ from redis.lock import Lock as RedisLock
from ee.onyx.server.tenants.provisioning import setup_tenant
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
from ee.onyx.server.tenants.schema_management import get_current_alembic_version
from ee.onyx.server.tenants.schema_management import run_alembic_migrations
from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import TARGET_AVAILABLE_TENANTS
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
@@ -30,10 +29,9 @@ from shared_configs.configs import TENANT_ID_PREFIX
# Each tenant takes ~80s (alembic migrations), so 5 tenants ≈ 7 minutes.
_MAX_TENANTS_PER_RUN = 5
# Time limits sized for worst-case: provisioning up to _MAX_TENANTS_PER_RUN new tenants
# (~90s each) plus migrating up to TARGET_AVAILABLE_TENANTS pool tenants (~90s each).
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 20 # 20 minutes
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 25 # 25 minutes
# Time limits sized for worst-case batch: _MAX_TENANTS_PER_RUN × ~90s + buffer.
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 10 # 10 minutes
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 15 # 15 minutes
@shared_task(
@@ -93,7 +91,8 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
batch_size = min(tenants_to_provision, _MAX_TENANTS_PER_RUN)
if batch_size < tenants_to_provision:
task_logger.info(
f"Capping batch to {batch_size} (need {tenants_to_provision}, will catch up next cycle)"
f"Capping batch to {batch_size} "
f"(need {tenants_to_provision}, will catch up next cycle)"
)
provisioned = 0
@@ -104,14 +103,12 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
provisioned += 1
except Exception:
task_logger.exception(
f"Failed to provision tenant {i + 1}/{batch_size}, continuing with remaining tenants"
f"Failed to provision tenant {i + 1}/{batch_size}, "
"continuing with remaining tenants"
)
task_logger.info(f"Provisioning complete: {provisioned}/{batch_size} succeeded")
# Migrate any pool tenants that were provisioned before a new migration was deployed
_migrate_stale_pool_tenants()
except Exception:
task_logger.exception("Error in check_available_tenants task")
@@ -124,46 +121,6 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
)
def _migrate_stale_pool_tenants() -> None:
"""
Run alembic upgrade head on all pool tenants. Since alembic upgrade head is
idempotent, tenants already at head are a fast no-op. This ensures pool
tenants are always current so that signup doesn't hit schema mismatches
(e.g. missing columns added after the tenant was pre-provisioned).
"""
with get_session_with_shared_schema() as db_session:
pool_tenants = db_session.query(AvailableTenant).all()
tenant_ids = [t.tenant_id for t in pool_tenants]
if not tenant_ids:
return
task_logger.info(
f"Checking {len(tenant_ids)} pool tenant(s) for pending migrations"
)
for tenant_id in tenant_ids:
try:
run_alembic_migrations(tenant_id)
new_version = get_current_alembic_version(tenant_id)
with get_session_with_shared_schema() as db_session:
tenant = (
db_session.query(AvailableTenant)
.filter_by(tenant_id=tenant_id)
.first()
)
if tenant and tenant.alembic_version != new_version:
task_logger.info(
f"Migrated pool tenant {tenant_id}: {tenant.alembic_version} -> {new_version}"
)
tenant.alembic_version = new_version
db_session.commit()
except Exception:
task_logger.exception(
f"Failed to migrate pool tenant {tenant_id}, skipping"
)
def pre_provision_tenant() -> bool:
"""
Pre-provision a new tenant and store it in the NewAvailableTenant table.

View File

@@ -8,7 +8,6 @@ from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import HierarchyNode
from onyx.connectors.slack.connector import get_channels
from onyx.connectors.slack.connector import make_paginated_slack_api_call
@@ -106,11 +105,9 @@ def _get_slack_document_access(
slack_connector: SlackConnector,
channel_permissions: dict[str, ExternalAccess], # noqa: ARG001
callback: IndexingHeartbeatInterface | None,
indexing_start: SecondsSinceUnixEpoch | None = None,
) -> Generator[DocExternalAccess, None, None]:
slim_doc_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
callback=callback,
start=indexing_start,
callback=callback
)
for doc_metadata_batch in slim_doc_generator:
@@ -183,15 +180,9 @@ def slack_doc_sync(
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
slack_connector.set_credentials_provider(provider)
indexing_start_ts: SecondsSinceUnixEpoch | None = (
cc_pair.connector.indexing_start.timestamp()
if cc_pair.connector.indexing_start is not None
else None
)
yield from _get_slack_document_access(
slack_connector=slack_connector,
slack_connector,
channel_permissions=channel_permissions,
callback=callback,
indexing_start=indexing_start_ts,
)

View File

@@ -6,7 +6,6 @@ from onyx.access.models import ElementExternalAccess
from onyx.access.models import ExternalAccess
from onyx.access.models import NodeExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import HierarchyNode
from onyx.db.models import ConnectorCredentialPair
@@ -41,19 +40,10 @@ def generic_doc_sync(
logger.info(f"Starting {doc_source} doc sync for CC Pair ID: {cc_pair.id}")
indexing_start: SecondsSinceUnixEpoch | None = (
cc_pair.connector.indexing_start.timestamp()
if cc_pair.connector.indexing_start is not None
else None
)
newly_fetched_doc_ids: set[str] = set()
logger.info(f"Fetching all slim documents from {doc_source}")
for doc_batch in slim_connector.retrieve_all_slim_docs_perm_sync(
start=indexing_start,
callback=callback,
):
for doc_batch in slim_connector.retrieve_all_slim_docs_perm_sync(callback=callback):
logger.info(f"Got {len(doc_batch)} slim documents from {doc_source}")
if callback:

View File

@@ -99,26 +99,6 @@ async def get_or_provision_tenant(
tenant_id = await get_available_tenant()
if tenant_id:
# Run migrations to ensure the pre-provisioned tenant schema is current.
# Pool tenants may have been created before a new migration was deployed.
# Capture as a non-optional local so mypy can type the lambda correctly.
_tenant_id: str = tenant_id
loop = asyncio.get_running_loop()
try:
await loop.run_in_executor(
None, lambda: run_alembic_migrations(_tenant_id)
)
except Exception:
# The tenant was already dequeued from the pool — roll it back so
# it doesn't end up orphaned (schema exists, but not assigned to anyone).
logger.exception(
f"Migration failed for pre-provisioned tenant {_tenant_id}; rolling back"
)
try:
await rollback_tenant_provisioning(_tenant_id)
except Exception:
logger.exception(f"Failed to rollback orphaned tenant {_tenant_id}")
raise
# If we have a pre-provisioned tenant, assign it to the user
await assign_tenant_to_user(tenant_id, email, referral_source)
logger.info(f"Assigned pre-provisioned tenant {tenant_id} to user {email}")

View File

@@ -100,7 +100,6 @@ def get_model_app() -> FastAPI:
dsn=SENTRY_DSN,
integrations=[StarletteIntegration(), FastApiIntegration()],
traces_sample_rate=0.1,
release=__version__,
)
logger.info("Sentry initialized")
else:

View File

@@ -20,7 +20,6 @@ from sentry_sdk.integrations.celery import CeleryIntegration
from sqlalchemy import text
from sqlalchemy.orm import Session
from onyx import __version__
from onyx.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
from onyx.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
from onyx.background.celery.celery_utils import celery_is_worker_primary
@@ -66,7 +65,6 @@ if SENTRY_DSN:
dsn=SENTRY_DSN,
integrations=[CeleryIntegration()],
traces_sample_rate=0.1,
release=__version__,
)
logger.info("Sentry initialized")
else:
@@ -517,8 +515,7 @@ def reset_tenant_id(
def wait_for_vespa_or_shutdown(
sender: Any, # noqa: ARG001
**kwargs: Any, # noqa: ARG001
sender: Any, **kwargs: Any # noqa: ARG001
) -> None: # noqa: ARG001
"""Waits for Vespa to become ready subject to a timeout.
Raises WorkerShutdown if the timeout is reached."""

View File

@@ -1,6 +1,5 @@
# These are helper objects for tracking the keys we need to write in redis
import json
import threading
from typing import Any
from typing import cast
@@ -8,59 +7,7 @@ from celery import Celery
from redis import Redis
from onyx.background.celery.configs.base import CELERY_SEPARATOR
from onyx.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
_broker_client: Redis | None = None
_broker_url: str | None = None
_broker_client_lock = threading.Lock()
def celery_get_broker_client(app: Celery) -> Redis:
"""Return a shared Redis client connected to the Celery broker DB.
Uses a module-level singleton so all tasks on a worker share one
connection instead of creating a new one per call. The client
connects directly to the broker Redis DB (parsed from the broker URL).
Thread-safe via lock — safe for use in Celery thread-pool workers.
Usage:
r_celery = celery_get_broker_client(self.app)
length = celery_get_queue_length(queue, r_celery)
"""
global _broker_client, _broker_url
with _broker_client_lock:
url = app.conf.broker_url
if _broker_client is not None and _broker_url == url:
try:
_broker_client.ping()
return _broker_client
except Exception:
try:
_broker_client.close()
except Exception:
pass
_broker_client = None
elif _broker_client is not None:
try:
_broker_client.close()
except Exception:
pass
_broker_client = None
_broker_url = url
_broker_client = Redis.from_url(
url,
decode_responses=False,
health_check_interval=REDIS_HEALTH_CHECK_INTERVAL,
socket_keepalive=True,
socket_keepalive_options=REDIS_SOCKET_KEEPALIVE_OPTIONS,
retry_on_timeout=True,
)
return _broker_client
def celery_get_unacked_length(r: Redis) -> int:

View File

@@ -14,7 +14,6 @@ from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.configs.app_configs import JOB_TIMEOUT
@@ -133,6 +132,7 @@ def revoke_tasks_blocking_deletion(
def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | None:
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
@@ -149,7 +149,6 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | N
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES):
# clear fences that don't have associated celery tasks in progress
try:
r_celery = celery_get_broker_client(self.app)
validate_connector_deletion_fences(
tenant_id, r, r_replica, r_celery, lock_beat
)

View File

@@ -9,7 +9,6 @@ from celery import Celery
from celery import shared_task
from celery import Task
from onyx import __version__
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.memory_monitoring import emit_process_memory
from onyx.background.celery.tasks.docprocessing.heartbeat import start_heartbeat
@@ -138,7 +137,6 @@ def _docfetching_task(
sentry_sdk.init(
dsn=SENTRY_DSN,
traces_sample_rate=0.1,
release=__version__,
)
logger.info("Sentry initialized")
else:

View File

@@ -22,7 +22,6 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.memory_monitoring import emit_process_memory
@@ -450,7 +449,7 @@ def check_indexing_completion(
):
# Check if the task exists in the celery queue
# This handles the case where Redis dies after task creation but before task execution
redis_celery = celery_get_broker_client(task.app)
redis_celery = task.app.broker_connection().channel().client # type: ignore
task_exists = celery_find_task(
attempt.celery_task_id,
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,

View File

@@ -1,5 +1,6 @@
import json
import time
from collections.abc import Callable
from datetime import timedelta
from itertools import islice
from typing import Any
@@ -18,7 +19,6 @@ from sqlalchemy import text
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.memory_monitoring import emit_process_memory
@@ -698,27 +698,31 @@ def monitor_background_processes(self: Task, *, tenant_id: str) -> None:
return None
try:
# Get Redis client for Celery broker
redis_celery = self.app.broker_connection().channel().client # type: ignore
redis_std = get_redis_client()
# Collect queue metrics with broker connection
r_celery = celery_get_broker_client(self.app)
queue_metrics = _collect_queue_metrics(r_celery)
# Define metric collection functions and their dependencies
metric_functions: list[Callable[[], list[Metric]]] = [
lambda: _collect_queue_metrics(redis_celery),
lambda: _collect_connector_metrics(db_session, redis_std),
lambda: _collect_sync_metrics(db_session, redis_std),
]
# Collect remaining metrics (no broker connection needed)
# Collect and log each metric
with get_session_with_current_tenant() as db_session:
all_metrics: list[Metric] = queue_metrics
all_metrics.extend(_collect_connector_metrics(db_session, redis_std))
all_metrics.extend(_collect_sync_metrics(db_session, redis_std))
for metric_fn in metric_functions:
metrics = metric_fn()
for metric in metrics:
# double check to make sure we aren't double-emitting metrics
if metric.key is None or not _has_metric_been_emitted(
redis_std, metric.key
):
metric.log()
metric.emit(tenant_id)
for metric in all_metrics:
if metric.key is None or not _has_metric_been_emitted(
redis_std, metric.key
):
metric.log()
metric.emit(tenant_id)
if metric.key is not None:
_mark_metric_as_emitted(redis_std, metric.key)
if metric.key is not None:
_mark_metric_as_emitted(redis_std, metric.key)
task_logger.info("Successfully collected background metrics")
except SoftTimeLimitExceeded:
@@ -886,7 +890,7 @@ def monitor_celery_queues_helper(
) -> None:
"""A task to monitor all celery queue lengths."""
r_celery = celery_get_broker_client(task.app)
r_celery = task.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length(OnyxCeleryQueues.PRIMARY, r_celery)
n_docfetching = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
@@ -1076,7 +1080,7 @@ def cloud_monitor_celery_pidbox(
num_deleted = 0
MAX_PIDBOX_IDLE = 24 * 3600 # 1 day in seconds
r_celery = celery_get_broker_client(self.app)
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
for key in r_celery.scan_iter("*.reply.celery.pidbox"):
key_bytes = cast(bytes, key)
key_str = key_bytes.decode("utf-8")

View File

@@ -17,7 +17,6 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
@@ -204,6 +203,7 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK,
@@ -261,7 +261,6 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# or be currently executing
try:
r_celery = celery_get_broker_client(self.app)
validate_pruning_fences(tenant_id, r, r_replica, r_celery, lock_beat)
except Exception:
task_logger.exception("Exception while validating pruning fences")

View File

@@ -16,7 +16,6 @@ from sqlalchemy.orm import Session
from onyx.access.access import build_access_for_user_files
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
@@ -106,7 +105,7 @@ def _user_file_delete_queued_key(user_file_id: str | UUID) -> str:
def get_user_file_project_sync_queue_depth(celery_app: Celery) -> int:
redis_celery = celery_get_broker_client(celery_app)
redis_celery: Redis = celery_app.broker_connection().channel().client # type: ignore
return celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC, redis_celery
)
@@ -239,7 +238,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
skipped_guard = 0
try:
# --- Protection 1: queue depth backpressure ---
r_celery = celery_get_broker_client(self.app)
r_celery = self.app.broker_connection().channel().client # type: ignore
queue_len = celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
)
@@ -592,7 +591,7 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
# --- Protection 1: queue depth backpressure ---
# NOTE: must use the broker's Redis client (not redis_client) because
# Celery queues live on a separate Redis DB with CELERY_SEPARATOR keys.
r_celery = celery_get_broker_client(self.app)
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
queue_len = celery_get_queue_length(OnyxCeleryQueues.USER_FILE_DELETE, r_celery)
if queue_len > USER_FILE_DELETE_MAX_QUEUE_DEPTH:
task_logger.warning(

View File

@@ -4,9 +4,11 @@ An overview can be found in the README.md file in this directory.
"""
import io
import queue
import re
import traceback
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from contextvars import Token
from uuid import UUID
@@ -28,6 +30,7 @@ from onyx.chat.compression import calculate_total_history_tokens
from onyx.chat.compression import compress_chat_history
from onyx.chat.compression import find_summary_for_branch
from onyx.chat.compression import get_compression_params
from onyx.chat.emitter import Emitter
from onyx.chat.emitter import get_default_emitter
from onyx.chat.llm_loop import EmptyLLMResponseError
from onyx.chat.llm_loop import run_llm_loop
@@ -59,6 +62,8 @@ from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_or_create_root_message
from onyx.db.chat import reserve_message_id
from onyx.db.chat import reserve_multi_model_message_ids
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import HookPoint
from onyx.db.memory import get_memories
from onyx.db.models import ChatMessage
@@ -86,16 +91,21 @@ from onyx.llm.factory import get_llm_for_persona
from onyx.llm.factory import get_llm_token_counter
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMUserIdentity
from onyx.llm.override_models import LLMOverride
from onyx.llm.request_context import reset_llm_mock_response
from onyx.llm.request_context import set_llm_mock_response
from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.onyxbot.slack.models import SlackContext
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
from onyx.server.query_and_chat.models import MessageResponseIDInfo
from onyx.server.query_and_chat.models import ModelResponseSlot
from onyx.server.query_and_chat.models import MultiModelMessageResponseIDInfo
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.usage_limits import check_llm_cost_limit_for_provider
from onyx.tools.constants import SEARCH_TOOL_ID
@@ -1069,6 +1079,583 @@ def handle_stream_message_objects(
logger.exception("Error in setting processing status")
def _build_model_display_name(override: LLMOverride) -> str:
"""Build a human-readable display name from an LLM override."""
if override.display_name:
return override.display_name
if override.model_version:
return override.model_version
if override.model_provider:
return override.model_provider
return "unknown"
# Sentinel placed on the merged queue when a model thread finishes.
_MODEL_DONE = object()
class _ModelIndexEmitter(Emitter):
"""Emitter that tags packets with model_index and forwards directly to a shared queue.
Unlike the standard Emitter (which accumulates in a local bus), this puts
packets into the shared merged_queue in real-time as they're emitted. This
enables true parallel streaming — packets from multiple models interleave
on the wire instead of arriving in bursts after each model completes.
"""
def __init__(self, model_idx: int, merged_queue: queue.Queue) -> None:
super().__init__(queue.Queue()) # bus exists for compat, unused
self._model_idx = model_idx
self._merged_queue = merged_queue
def emit(self, packet: Packet) -> None:
tagged_placement = Placement(
turn_index=packet.placement.turn_index if packet.placement else 0,
tab_index=packet.placement.tab_index if packet.placement else 0,
sub_turn_index=(
packet.placement.sub_turn_index if packet.placement else None
),
model_index=self._model_idx,
)
tagged_packet = Packet(placement=tagged_placement, obj=packet.obj)
self._merged_queue.put((self._model_idx, tagged_packet))
def run_multi_model_stream(
new_msg_req: SendMessageRequest,
user: User,
db_session: Session,
llm_overrides: list[LLMOverride],
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
mcp_headers: dict[str, str] | None = None,
) -> AnswerStream:
# TODO(ENG-3888): The setup logic below (session resolution through tool construction)
# is duplicated from handle_stream_message_objects. Extract into a shared
# _ChatStreamContext dataclass + _prepare_chat_stream_context() factory so
# both paths call the same setup code.
# https://linear.app/onyx-app/issue/ENG-3888
"""Run 2-3 LLMs in parallel and yield their packets tagged with model_index.
Resource management:
- Each model thread gets its OWN db_session (SQLAlchemy sessions are not thread-safe)
- The caller's db_session is used only for setup (before threads launch) and
completion callbacks (after threads finish)
- ThreadPoolExecutor is bounded to len(overrides) workers
- All threads are joined in the finally block regardless of success/failure
- Queue-based merging avoids busy-waiting
"""
n_models = len(llm_overrides)
if n_models < 2 or n_models > 3:
raise ValueError(f"Multi-model requires 2-3 overrides, got {n_models}")
if new_msg_req.deep_research:
raise ValueError("Multi-model is not supported with deep research")
tenant_id = get_current_tenant_id()
cache: CacheBackend | None = None
chat_session: ChatSession | None = None
user_id = user.id
if user.is_anonymous:
llm_user_identifier = "anonymous_user"
else:
llm_user_identifier = user.email or str(user_id)
try:
# ── Session setup (same as single-model path) ──────────────────
if not new_msg_req.chat_session_id:
if not new_msg_req.chat_session_info:
raise RuntimeError(
"Must specify a chat session id or chat session info"
)
chat_session = create_chat_session_from_request(
chat_session_request=new_msg_req.chat_session_info,
user_id=user_id,
db_session=db_session,
)
yield CreateChatSessionID(chat_session_id=chat_session.id)
else:
chat_session = get_chat_session_by_id(
chat_session_id=new_msg_req.chat_session_id,
user_id=user_id,
db_session=db_session,
)
persona = chat_session.persona
message_text = new_msg_req.message
# ── Build N LLM instances and validate costs ───────────────────
llms: list[LLM] = []
model_display_names: list[str] = []
for override in llm_overrides:
llm = get_llm_for_persona(
persona=persona,
user=user,
llm_override=override,
additional_headers=litellm_additional_headers,
)
check_llm_cost_limit_for_provider(
db_session=db_session,
tenant_id=tenant_id,
llm_provider_api_key=llm.config.api_key,
)
llms.append(llm)
model_display_names.append(_build_model_display_name(override))
# Use first LLM for token counting (context window is checked per-model
# but token counting is model-agnostic enough for setup purposes)
token_counter = get_llm_token_counter(llms[0])
verify_user_files(
user_files=new_msg_req.file_descriptors,
user_id=user_id,
db_session=db_session,
project_id=chat_session.project_id,
)
# ── Chat history chain (shared across all models) ──────────────
chat_history = create_chat_history_chain(
chat_session_id=chat_session.id, db_session=db_session
)
root_message = get_or_create_root_message(
chat_session_id=chat_session.id, db_session=db_session
)
if new_msg_req.parent_message_id == AUTO_PLACE_AFTER_LATEST_MESSAGE:
parent_message = chat_history[-1] if chat_history else root_message
elif (
new_msg_req.parent_message_id is None
or new_msg_req.parent_message_id == root_message.id
):
parent_message = root_message
chat_history = []
else:
parent_message = None
for i in range(len(chat_history) - 1, -1, -1):
if chat_history[i].id == new_msg_req.parent_message_id:
parent_message = chat_history[i]
chat_history = chat_history[: i + 1]
break
if parent_message is None:
raise ValueError(
"The new message sent is not on the latest mainline of messages"
)
if parent_message.message_type == MessageType.USER:
user_message = parent_message
else:
user_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=parent_message,
message=message_text,
token_count=token_counter(message_text),
message_type=MessageType.USER,
files=new_msg_req.file_descriptors,
db_session=db_session,
commit=True,
)
chat_history.append(user_message)
available_files = _collect_available_file_ids(
chat_history=chat_history,
project_id=chat_session.project_id,
user_id=user_id,
db_session=db_session,
)
summary_message = find_summary_for_branch(db_session, chat_history)
summarized_file_metadata: dict[str, FileToolMetadata] = {}
if summary_message and summary_message.last_summarized_message_id:
cutoff_id = summary_message.last_summarized_message_id
for msg in chat_history:
if msg.id > cutoff_id or not msg.files:
continue
for fd in msg.files:
file_id = fd.get("id")
if not file_id:
continue
summarized_file_metadata[file_id] = FileToolMetadata(
file_id=file_id,
filename=fd.get("name") or "unknown",
approx_char_count=0,
)
chat_history = [m for m in chat_history if m.id > cutoff_id]
user_memory_context = get_memories(user, db_session)
custom_agent_prompt = get_custom_agent_prompt(persona, chat_session)
prompt_memory_context = (
user_memory_context
if user.use_memories
else user_memory_context.without_memories()
)
max_reserved_system_prompt_tokens_str = (persona.system_prompt or "") + (
custom_agent_prompt or ""
)
reserved_token_count = calculate_reserved_tokens(
db_session=db_session,
persona_system_prompt=max_reserved_system_prompt_tokens_str,
token_counter=token_counter,
files=new_msg_req.file_descriptors,
user_memory_context=prompt_memory_context,
)
context_user_files = resolve_context_user_files(
persona=persona,
project_id=chat_session.project_id,
user_id=user_id,
db_session=db_session,
)
# Use the smallest context window across all models for safety
min_context_window = min(llm.config.max_input_tokens for llm in llms)
extracted_context_files = extract_context_files(
user_files=context_user_files,
llm_max_context_window=min_context_window,
reserved_token_count=reserved_token_count,
db_session=db_session,
)
search_params = determine_search_params(
persona_id=persona.id,
project_id=chat_session.project_id,
extracted_context_files=extracted_context_files,
)
if persona.user_files:
existing = set(available_files.user_file_ids)
for uf in persona.user_files:
if uf.id not in existing:
available_files.user_file_ids.append(uf.id)
all_tools = get_tools(db_session)
tool_id_to_name_map = {tool.id: tool.name for tool in all_tools}
search_tool_id = next(
(tool.id for tool in all_tools if tool.in_code_tool_id == SEARCH_TOOL_ID),
None,
)
forced_tool_id = new_msg_req.forced_tool_id
if (
search_params.search_usage == SearchToolUsage.DISABLED
and forced_tool_id is not None
and search_tool_id is not None
and forced_tool_id == search_tool_id
):
forced_tool_id = None
files = load_all_chat_files(chat_history, db_session)
chat_files_for_tools = _convert_loaded_files_to_chat_files(files)
# ── Reserve N assistant message IDs ────────────────────────────
reserved_messages = reserve_multi_model_message_ids(
db_session=db_session,
chat_session_id=chat_session.id,
parent_message_id=user_message.id,
model_display_names=model_display_names,
)
yield MultiModelMessageResponseIDInfo(
user_message_id=user_message.id,
responses=[
ModelResponseSlot(message_id=m.id, model_name=name)
for m, name in zip(reserved_messages, model_display_names)
],
)
has_file_reader_tool = any(
tool.in_code_tool_id == "file_reader" for tool in all_tools
)
chat_history_result = convert_chat_history(
chat_history=chat_history,
files=files,
context_image_files=extracted_context_files.image_files,
additional_context=new_msg_req.additional_context,
token_counter=token_counter,
tool_id_to_name_map=tool_id_to_name_map,
)
simple_chat_history = chat_history_result.simple_messages
all_injected_file_metadata: dict[str, FileToolMetadata] = (
chat_history_result.all_injected_file_metadata
if has_file_reader_tool
else {}
)
if summarized_file_metadata:
for fid, meta in summarized_file_metadata.items():
all_injected_file_metadata.setdefault(fid, meta)
if summary_message is not None:
summary_simple = ChatMessageSimple(
message=summary_message.message,
token_count=summary_message.token_count,
message_type=MessageType.ASSISTANT,
)
simple_chat_history.insert(0, summary_simple)
# ── Stop signal and processing status ──────────────────────────
cache = get_cache_backend()
reset_cancel_status(chat_session.id, cache)
def check_is_connected() -> bool:
return check_stop_signal(chat_session.id, cache)
set_processing_status(
chat_session_id=chat_session.id,
cache=cache,
value=True,
)
# Release the main session's read transaction before the long stream
db_session.commit()
# ── Parallel model execution ───────────────────────────────────
# Each model thread writes tagged packets to this shared queue.
# Sentinel _MODEL_DONE signals that a thread finished.
merged_queue: queue.Queue[tuple[int, Packet | Exception | object]] = (
queue.Queue()
)
# Track per-model state containers for completion callbacks
state_containers: list[ChatStateContainer] = [
ChatStateContainer() for _ in range(n_models)
]
# Track which models completed successfully (for completion callbacks)
model_succeeded: list[bool] = [False] * n_models
user_identity = LLMUserIdentity(
user_id=llm_user_identifier,
session_id=str(chat_session.id),
)
def _run_model(model_idx: int) -> None:
"""Run a single model in a worker thread.
Uses _ModelIndexEmitter so packets flow directly to merged_queue
in real-time (not batched after completion). This enables true
parallel streaming where both models' tokens interleave on the wire.
DB access: tools may need a session during execution (e.g., search
tool). Each thread creates its own session via context manager.
"""
model_emitter = _ModelIndexEmitter(model_idx, merged_queue)
sc = state_containers[model_idx]
model_llm = llms[model_idx]
try:
# Each model thread gets its own DB session for tool execution.
# The session is scoped to the thread and closed when done.
with get_session_with_current_tenant() as thread_db_session:
# Construct tools per-thread with thread-local DB session
thread_tool_dict = construct_tools(
persona=persona,
db_session=thread_db_session,
emitter=model_emitter,
user=user,
llm=model_llm,
search_tool_config=SearchToolConfig(
user_selected_filters=new_msg_req.internal_search_filters,
project_id_filter=search_params.project_id_filter,
persona_id_filter=search_params.persona_id_filter,
bypass_acl=False,
enable_slack_search=_should_enable_slack_search(
persona, new_msg_req.internal_search_filters
),
),
custom_tool_config=CustomToolConfig(
chat_session_id=chat_session.id,
message_id=user_message.id,
additional_headers=custom_tool_additional_headers,
mcp_headers=mcp_headers,
),
file_reader_tool_config=FileReaderToolConfig(
user_file_ids=available_files.user_file_ids,
chat_file_ids=available_files.chat_file_ids,
),
allowed_tool_ids=new_msg_req.allowed_tool_ids,
search_usage_forcing_setting=search_params.search_usage,
)
model_tools: list[Tool] = []
for tool_list in thread_tool_dict.values():
model_tools.extend(tool_list)
# Run the LLM loop — this blocks until the model finishes.
# Packets flow to merged_queue in real-time via the emitter.
run_llm_loop(
emitter=model_emitter,
state_container=sc,
simple_chat_history=list(simple_chat_history),
tools=model_tools,
custom_agent_prompt=custom_agent_prompt,
context_files=extracted_context_files,
persona=persona,
user_memory_context=user_memory_context,
llm=model_llm,
token_counter=get_llm_token_counter(model_llm),
db_session=thread_db_session,
forced_tool_id=forced_tool_id,
user_identity=user_identity,
chat_session_id=str(chat_session.id),
chat_files=chat_files_for_tools,
include_citations=new_msg_req.include_citations,
all_injected_file_metadata=all_injected_file_metadata,
inject_memories_in_prompt=user.use_memories,
)
model_succeeded[model_idx] = True
except Exception as e:
merged_queue.put((model_idx, e))
finally:
merged_queue.put((model_idx, _MODEL_DONE))
# Launch model threads via ThreadPoolExecutor (bounded, context-propagating)
executor = ThreadPoolExecutor(
max_workers=n_models,
thread_name_prefix="multi-model",
)
futures = []
try:
for i in range(n_models):
futures.append(executor.submit(_run_model, i))
# ── Main thread: merge and yield packets ───────────────────
models_remaining = n_models
while models_remaining > 0:
try:
model_idx, item = merged_queue.get(timeout=0.3)
except queue.Empty:
# Check cancellation during idle periods
if not check_is_connected():
yield Packet(
placement=Placement(turn_index=0),
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
)
return
continue
else:
if item is _MODEL_DONE:
models_remaining -= 1
continue
if isinstance(item, Exception):
# Yield error as a tagged StreamingError packet.
# Do NOT decrement models_remaining here — the finally block
# in _run_model always posts _MODEL_DONE, which is the sole
# completion signal. Decrementing here too would double-count
# and cause the loop to exit early, silently dropping the
# surviving models' responses.
error_msg = str(item)
stack_trace = "".join(
traceback.format_exception(
type(item), item, item.__traceback__
)
)
# Redact API keys from error messages
model_llm = llms[model_idx]
if (
model_llm.config.api_key
and len(model_llm.config.api_key) > 2
):
error_msg = error_msg.replace(
model_llm.config.api_key, "[REDACTED_API_KEY]"
)
stack_trace = stack_trace.replace(
model_llm.config.api_key, "[REDACTED_API_KEY]"
)
yield StreamingError(
error=error_msg,
stack_trace=stack_trace,
error_code="MODEL_ERROR",
is_retryable=True,
details={
"model": model_llm.config.model_name,
"provider": model_llm.config.model_provider,
"model_index": model_idx,
},
)
continue
if isinstance(item, Packet):
# Packet is already tagged with model_index by _ModelIndexEmitter
yield item
# ── Completion: save each successful model's response ──────
# Run completion callbacks on the main thread using the main
# session. This is safe because all worker threads have exited
# by this point (merged_queue fully drained).
for i in range(n_models):
if not model_succeeded[i]:
continue
try:
llm_loop_completion_handle(
state_container=state_containers[i],
is_connected=check_is_connected,
db_session=db_session,
assistant_message=reserved_messages[i],
llm=llms[i],
reserved_tokens=reserved_token_count,
)
except Exception:
logger.exception(
f"Failed completion for model {i} "
f"({model_display_names[i]})"
)
yield Packet(
placement=Placement(turn_index=0),
obj=OverallStop(type="stop", stop_reason="complete"),
)
finally:
# Don't block on shutdown — futures making live LLM API calls
# cannot be cancelled once started, so wait=True would block
# the generator (and the HTTP response) until all calls finish.
# wait=False lets threads complete in the background.
executor.shutdown(wait=False)
except ValueError as e:
logger.exception("Failed to process multi-model chat message.")
yield StreamingError(
error=str(e),
error_code="VALIDATION_ERROR",
is_retryable=True,
)
db_session.rollback()
return
except Exception as e:
logger.exception(f"Failed multi-model chat: {e}")
stack_trace = traceback.format_exc()
yield StreamingError(
error=str(e),
stack_trace=stack_trace,
error_code="MULTI_MODEL_ERROR",
is_retryable=True,
)
db_session.rollback()
finally:
try:
if cache is not None and chat_session is not None:
set_processing_status(
chat_session_id=chat_session.id,
cache=cache,
value=False,
)
except Exception:
logger.exception("Error clearing processing status")
def llm_loop_completion_handle(
state_container: ChatStateContainer,
is_connected: Callable[[], bool],

View File

@@ -44,31 +44,6 @@ SEND_USER_METADATA_TO_LLM_PROVIDER = (
# User Facing Features Configs
#####
BLURB_SIZE = 128 # Number Encoder Tokens included in the chunk blurb
# Hard ceiling for the admin-configurable file upload size (in MB).
# Self-hosted customers can raise or lower this via the environment variable.
_raw_max_upload_size_mb = int(os.environ.get("MAX_ALLOWED_UPLOAD_SIZE_MB", "250"))
if _raw_max_upload_size_mb < 0:
logger.warning(
"MAX_ALLOWED_UPLOAD_SIZE_MB=%d is negative; falling back to 250",
_raw_max_upload_size_mb,
)
_raw_max_upload_size_mb = 250
MAX_ALLOWED_UPLOAD_SIZE_MB = _raw_max_upload_size_mb
# Default fallback for the per-user file upload size limit (in MB) when no
# admin-configured value exists. Clamped to MAX_ALLOWED_UPLOAD_SIZE_MB at
# runtime so this never silently exceeds the hard ceiling.
_raw_default_upload_size_mb = int(
os.environ.get("DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB", "100")
)
if _raw_default_upload_size_mb < 0:
logger.warning(
"DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB=%d is negative; falling back to 100",
_raw_default_upload_size_mb,
)
_raw_default_upload_size_mb = 100
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB = _raw_default_upload_size_mb
GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
os.environ.get("GENERATIVE_MODEL_ACCESS_CHECK_FREQ") or 86400
) # 1 day
@@ -86,6 +61,17 @@ CACHE_BACKEND = CacheBackendType(
os.environ.get("CACHE_BACKEND", CacheBackendType.REDIS)
)
# Maximum token count for a single uploaded file. Files exceeding this are rejected.
# Defaults to 100k tokens (or 10M when vector DB is disabled).
_DEFAULT_FILE_TOKEN_LIMIT = 10_000_000 if DISABLE_VECTOR_DB else 100_000
FILE_TOKEN_COUNT_THRESHOLD = int(
os.environ.get("FILE_TOKEN_COUNT_THRESHOLD", str(_DEFAULT_FILE_TOKEN_LIMIT))
)
# Maximum upload size for a single user file (chat/projects) in MB.
USER_FILE_MAX_UPLOAD_SIZE_MB = int(os.environ.get("USER_FILE_MAX_UPLOAD_SIZE_MB") or 50)
USER_FILE_MAX_UPLOAD_SIZE_BYTES = USER_FILE_MAX_UPLOAD_SIZE_MB * 1024 * 1024
# If set to true, will show extra/uncommon connectors in the "Other" category
SHOW_EXTRA_CONNECTORS = os.environ.get("SHOW_EXTRA_CONNECTORS", "").lower() == "true"
@@ -805,10 +791,6 @@ MINI_CHUNK_SIZE = 150
# This is the number of regular chunks per large chunk
LARGE_CHUNK_RATIO = 4
# The maximum number of chunks that can be held for 1 document processing batch
# The purpose of this is to set an upper bound on memory usage
MAX_CHUNKS_PER_DOC_BATCH = int(os.environ.get("MAX_CHUNKS_PER_DOC_BATCH") or 1000)
# Include the document level metadata in each chunk. If the metadata is too long, then it is thrown out
# We don't want the metadata to overwhelm the actual contents of the chunk
SKIP_METADATA_IN_CHUNK = os.environ.get("SKIP_METADATA_IN_CHUNK", "").lower() == "true"

View File

@@ -212,7 +212,6 @@ class DocumentSource(str, Enum):
PRODUCTBOARD = "productboard"
FILE = "file"
CODA = "coda"
CANVAS = "canvas"
NOTION = "notion"
ZULIP = "zulip"
LINEAR = "linear"
@@ -673,7 +672,6 @@ DocumentSourceDescription: dict[DocumentSource, str] = {
DocumentSource.SLAB: "slab data",
DocumentSource.PRODUCTBOARD: "productboard data (boards, etc.)",
DocumentSource.FILE: "files",
DocumentSource.CANVAS: "canvas lms - courses, pages, assignments, and announcements",
DocumentSource.CODA: "coda - team workspace with docs, tables, and pages",
DocumentSource.NOTION: "notion data - a workspace that combines note-taking, \
project management, and collaboration tools into a single, customizable platform",

View File

@@ -1,32 +0,0 @@
"""
Permissioning / AccessControl logic for Canvas courses.
CE stub — returns None (no permissions). The EE implementation is loaded
at runtime via ``fetch_versioned_implementation``.
"""
from collections.abc import Callable
from typing import cast
from onyx.access.models import ExternalAccess
from onyx.connectors.canvas.client import CanvasApiClient
from onyx.utils.variable_functionality import fetch_versioned_implementation
from onyx.utils.variable_functionality import global_version
def get_course_permissions(
canvas_client: CanvasApiClient,
course_id: int,
) -> ExternalAccess | None:
if not global_version.is_ee_version():
return None
ee_get_course_permissions = cast(
Callable[[CanvasApiClient, int], ExternalAccess | None],
fetch_versioned_implementation(
"onyx.external_permissions.canvas.access",
"get_course_permissions",
),
)
return ee_get_course_permissions(canvas_client, course_id)

View File

@@ -1,212 +0,0 @@
from __future__ import annotations
import logging
import re
from collections.abc import Iterator
from typing import Any
from urllib.parse import urlparse
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rl_requests,
)
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
logger = logging.getLogger(__name__)
# Requests timeout in seconds.
_CANVAS_CALL_TIMEOUT: int = 30
_CANVAS_API_VERSION: str = "/api/v1"
# Matches the "next" URL in a Canvas Link header, e.g.:
# <https://canvas.example.com/api/v1/courses?page=2>; rel="next"
# Captures the URL inside the angle brackets.
_NEXT_LINK_PATTERN: re.Pattern[str] = re.compile(r'<([^>]+)>;\s*rel="next"')
_STATUS_TO_ERROR_CODE: dict[int, OnyxErrorCode] = {
401: OnyxErrorCode.CREDENTIAL_EXPIRED,
403: OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
404: OnyxErrorCode.BAD_GATEWAY,
429: OnyxErrorCode.RATE_LIMITED,
}
def _error_code_for_status(status_code: int) -> OnyxErrorCode:
"""Map an HTTP status code to the appropriate OnyxErrorCode.
Expects a >= 400 status code. Known codes (401, 403, 404, 429) are
mapped to specific error codes; all other codes (unrecognised 4xx
and 5xx) map to BAD_GATEWAY as unexpected upstream errors.
"""
if status_code in _STATUS_TO_ERROR_CODE:
return _STATUS_TO_ERROR_CODE[status_code]
return OnyxErrorCode.BAD_GATEWAY
class CanvasApiClient:
def __init__(
self,
bearer_token: str,
canvas_base_url: str,
) -> None:
parsed_base = urlparse(canvas_base_url)
if not parsed_base.hostname:
raise ValueError("canvas_base_url must include a valid host")
if parsed_base.scheme != "https":
raise ValueError("canvas_base_url must use https")
self._bearer_token = bearer_token
self.base_url = (
canvas_base_url.rstrip("/").removesuffix(_CANVAS_API_VERSION)
+ _CANVAS_API_VERSION
)
# Hostname is already validated above; reuse parsed_base instead
# of re-parsing. Used by _parse_next_link to validate pagination URLs.
self._expected_host: str = parsed_base.hostname
def get(
self,
endpoint: str = "",
params: dict[str, Any] | None = None,
full_url: str | None = None,
) -> tuple[Any, str | None]:
"""Make a GET request to the Canvas API.
Returns a tuple of (json_body, next_url).
next_url is parsed from the Link header and is None if there are no more pages.
If full_url is provided, it is used directly (for following pagination links).
Security note: full_url must only be set to values returned by
``_parse_next_link``, which validates the host against the configured
Canvas base URL. Passing an arbitrary URL would leak the bearer token.
"""
# full_url is used when following pagination (Canvas returns the
# next-page URL in the Link header). For the first request we build
# the URL from the endpoint name instead.
url = full_url if full_url else self._build_url(endpoint)
headers = self._build_headers()
response = rl_requests.get(
url,
headers=headers,
params=params if not full_url else None,
timeout=_CANVAS_CALL_TIMEOUT,
)
try:
response_json = response.json()
except ValueError as e:
if response.status_code < 300:
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
detail=f"Invalid JSON in Canvas response: {e}",
)
logger.warning(
"Failed to parse JSON from Canvas error response (status=%d): %s",
response.status_code,
e,
)
response_json = {}
if response.status_code >= 400:
# Try to extract the most specific error message from the
# Canvas response body. Canvas uses three different shapes
# depending on the endpoint and error type:
default_error: str = response.reason or f"HTTP {response.status_code}"
error = default_error
if isinstance(response_json, dict):
# Shape 1: {"error": {"message": "Not authorized"}}
error_field = response_json.get("error")
if isinstance(error_field, dict):
response_error = error_field.get("message", "")
if response_error:
error = response_error
# Shape 2: {"error": "Invalid access token"}
elif isinstance(error_field, str):
error = error_field
# Shape 3: {"errors": [{"message": "..."}]}
# Used for validation errors. Only use as fallback if
# we didn't already find a more specific message above.
if error == default_error:
errors_list = response_json.get("errors")
if isinstance(errors_list, list) and errors_list:
first_error = errors_list[0]
if isinstance(first_error, dict):
msg = first_error.get("message", "")
if msg:
error = msg
raise OnyxError(
_error_code_for_status(response.status_code),
detail=error,
status_code_override=response.status_code,
)
next_url = self._parse_next_link(response.headers.get("Link", ""))
return response_json, next_url
def _parse_next_link(self, link_header: str) -> str | None:
"""Extract the 'next' URL from a Canvas Link header.
Only returns URLs whose host matches the configured Canvas base URL
to prevent leaking the bearer token to arbitrary hosts.
"""
expected_host = self._expected_host
for match in _NEXT_LINK_PATTERN.finditer(link_header):
url = match.group(1)
parsed_url = urlparse(url)
if parsed_url.hostname != expected_host:
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
detail=(
"Canvas pagination returned an unexpected host "
f"({parsed_url.hostname}); expected {expected_host}"
),
)
if parsed_url.scheme != "https":
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
detail=(
"Canvas pagination link must use https, "
f"got {parsed_url.scheme!r}"
),
)
return url
return None
def _build_headers(self) -> dict[str, str]:
"""Return the Authorization header with the bearer token."""
return {"Authorization": f"Bearer {self._bearer_token}"}
def _build_url(self, endpoint: str) -> str:
"""Build a full Canvas API URL from an endpoint path.
Assumes endpoint is non-empty (e.g. ``"courses"``, ``"announcements"``).
Only called on a first request, endpoint must be set for first request.
Verify endpoint exists in case of future changes where endpoint might be optional.
Leading slashes are stripped to avoid double-slash in the result.
self.base_url is already normalized with no trailing slash.
"""
final_url = self.base_url
clean_endpoint = endpoint.lstrip("/")
if clean_endpoint:
final_url += "/" + clean_endpoint
return final_url
def paginate(
self,
endpoint: str,
params: dict[str, Any] | None = None,
) -> Iterator[list[Any]]:
"""Yield each page of results, following Link-header pagination.
Makes the first request with endpoint + params, then follows
next_url from Link headers for subsequent pages.
"""
response, next_url = self.get(endpoint, params=params)
while True:
if not response:
break
yield response
if not next_url:
break
response, next_url = self.get(full_url=next_url)

View File

@@ -1,458 +0,0 @@
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from typing import Literal
from typing import NoReturn
from typing import TypeAlias
from pydantic import BaseModel
from retry import retry
from typing_extensions import override
from onyx.access.models import ExternalAccess
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.canvas.access import get_course_permissions
from onyx.connectors.canvas.client import CanvasApiClient
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.error_handling.exceptions import OnyxError
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _handle_canvas_api_error(e: OnyxError) -> NoReturn:
"""Map Canvas API errors to connector framework exceptions."""
if e.status_code == 401:
raise CredentialExpiredError(
"Canvas API token is invalid or expired (HTTP 401)."
)
elif e.status_code == 403:
raise InsufficientPermissionsError(
"Canvas API token does not have sufficient permissions (HTTP 403)."
)
elif e.status_code == 429:
raise ConnectorValidationError(
"Canvas rate-limit exceeded (HTTP 429). Please try again later."
)
elif e.status_code >= 500:
raise UnexpectedValidationError(
f"Unexpected Canvas HTTP error (status={e.status_code}): {e}"
)
else:
raise ConnectorValidationError(
f"Canvas API error (status={e.status_code}): {e}"
)
class CanvasCourse(BaseModel):
id: int
name: str | None = None
course_code: str | None = None
created_at: str | None = None
workflow_state: str | None = None
@classmethod
def from_api(cls, payload: dict[str, Any]) -> "CanvasCourse":
return cls(
id=payload["id"],
name=payload.get("name"),
course_code=payload.get("course_code"),
created_at=payload.get("created_at"),
workflow_state=payload.get("workflow_state"),
)
class CanvasPage(BaseModel):
page_id: int
url: str
title: str
body: str | None = None
created_at: str | None = None
updated_at: str | None = None
course_id: int
@classmethod
def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasPage":
return cls(
page_id=payload["page_id"],
url=payload["url"],
title=payload["title"],
body=payload.get("body"),
created_at=payload.get("created_at"),
updated_at=payload.get("updated_at"),
course_id=course_id,
)
class CanvasAssignment(BaseModel):
id: int
name: str
description: str | None = None
html_url: str
course_id: int
created_at: str | None = None
updated_at: str | None = None
due_at: str | None = None
@classmethod
def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasAssignment":
return cls(
id=payload["id"],
name=payload["name"],
description=payload.get("description"),
html_url=payload["html_url"],
course_id=course_id,
created_at=payload.get("created_at"),
updated_at=payload.get("updated_at"),
due_at=payload.get("due_at"),
)
class CanvasAnnouncement(BaseModel):
id: int
title: str
message: str | None = None
html_url: str
posted_at: str | None = None
course_id: int
@classmethod
def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasAnnouncement":
return cls(
id=payload["id"],
title=payload["title"],
message=payload.get("message"),
html_url=payload["html_url"],
posted_at=payload.get("posted_at"),
course_id=course_id,
)
CanvasStage: TypeAlias = Literal["pages", "assignments", "announcements"]
class CanvasConnectorCheckpoint(ConnectorCheckpoint):
"""Checkpoint state for resumable Canvas indexing.
Fields:
course_ids: Materialized list of course IDs to process.
current_course_index: Index into course_ids for current course.
stage: Which item type we're processing for the current course.
next_url: Pagination cursor within the current stage. None means
start from the first page; a URL means resume from that page.
Invariant:
If current_course_index is incremented, stage must be reset to
"pages" and next_url must be reset to None.
"""
course_ids: list[int] = []
current_course_index: int = 0
stage: CanvasStage = "pages"
next_url: str | None = None
def advance_course(self) -> None:
"""Move to the next course and reset within-course state."""
self.current_course_index += 1
self.stage = "pages"
self.next_url = None
class CanvasConnector(
CheckpointedConnectorWithPermSync[CanvasConnectorCheckpoint],
SlimConnectorWithPermSync,
):
def __init__(
self,
canvas_base_url: str,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.canvas_base_url = canvas_base_url.rstrip("/").removesuffix("/api/v1")
self.batch_size = batch_size
self._canvas_client: CanvasApiClient | None = None
self._course_permissions_cache: dict[int, ExternalAccess | None] = {}
@property
def canvas_client(self) -> CanvasApiClient:
if self._canvas_client is None:
raise ConnectorMissingCredentialError("Canvas")
return self._canvas_client
def _get_course_permissions(self, course_id: int) -> ExternalAccess | None:
"""Get course permissions with caching."""
if course_id not in self._course_permissions_cache:
self._course_permissions_cache[course_id] = get_course_permissions(
canvas_client=self.canvas_client,
course_id=course_id,
)
return self._course_permissions_cache[course_id]
@retry(tries=3, delay=1, backoff=2)
def _list_courses(self) -> list[CanvasCourse]:
"""Fetch all courses accessible to the authenticated user."""
logger.debug("Fetching Canvas courses")
courses: list[CanvasCourse] = []
for page in self.canvas_client.paginate(
"courses", params={"per_page": "100", "state[]": "available"}
):
courses.extend(CanvasCourse.from_api(c) for c in page)
return courses
@retry(tries=3, delay=1, backoff=2)
def _list_pages(self, course_id: int) -> list[CanvasPage]:
"""Fetch all pages for a given course."""
logger.debug(f"Fetching pages for course {course_id}")
pages: list[CanvasPage] = []
for page in self.canvas_client.paginate(
f"courses/{course_id}/pages",
params={"per_page": "100", "include[]": "body", "published": "true"},
):
pages.extend(CanvasPage.from_api(p, course_id=course_id) for p in page)
return pages
@retry(tries=3, delay=1, backoff=2)
def _list_assignments(self, course_id: int) -> list[CanvasAssignment]:
"""Fetch all assignments for a given course."""
logger.debug(f"Fetching assignments for course {course_id}")
assignments: list[CanvasAssignment] = []
for page in self.canvas_client.paginate(
f"courses/{course_id}/assignments",
params={"per_page": "100", "published": "true"},
):
assignments.extend(
CanvasAssignment.from_api(a, course_id=course_id) for a in page
)
return assignments
@retry(tries=3, delay=1, backoff=2)
def _list_announcements(self, course_id: int) -> list[CanvasAnnouncement]:
"""Fetch all announcements for a given course."""
logger.debug(f"Fetching announcements for course {course_id}")
announcements: list[CanvasAnnouncement] = []
for page in self.canvas_client.paginate(
"announcements",
params={
"per_page": "100",
"context_codes[]": f"course_{course_id}",
"active_only": "true",
},
):
announcements.extend(
CanvasAnnouncement.from_api(a, course_id=course_id) for a in page
)
return announcements
def _build_document(
self,
doc_id: str,
link: str,
text: str,
semantic_identifier: str,
doc_updated_at: datetime | None,
course_id: int,
doc_type: str,
) -> Document:
"""Build a Document with standard Canvas fields."""
return Document(
id=doc_id,
sections=cast(
list[TextSection | ImageSection],
[TextSection(link=link, text=text)],
),
source=DocumentSource.CANVAS,
semantic_identifier=semantic_identifier,
doc_updated_at=doc_updated_at,
metadata={"course_id": str(course_id), "type": doc_type},
)
def _convert_page_to_document(self, page: CanvasPage) -> Document:
"""Convert a Canvas page to a Document."""
link = f"{self.canvas_base_url}/courses/{page.course_id}/pages/{page.url}"
text_parts = [page.title]
body_text = parse_html_page_basic(page.body) if page.body else ""
if body_text:
text_parts.append(body_text)
doc_updated_at = (
datetime.fromisoformat(page.updated_at.replace("Z", "+00:00")).astimezone(
timezone.utc
)
if page.updated_at
else None
)
document = self._build_document(
doc_id=f"canvas-page-{page.course_id}-{page.page_id}",
link=link,
text="\n\n".join(text_parts),
semantic_identifier=page.title or f"Page {page.page_id}",
doc_updated_at=doc_updated_at,
course_id=page.course_id,
doc_type="page",
)
return document
def _convert_assignment_to_document(self, assignment: CanvasAssignment) -> Document:
"""Convert a Canvas assignment to a Document."""
text_parts = [assignment.name]
desc_text = (
parse_html_page_basic(assignment.description)
if assignment.description
else ""
)
if desc_text:
text_parts.append(desc_text)
if assignment.due_at:
due_dt = datetime.fromisoformat(
assignment.due_at.replace("Z", "+00:00")
).astimezone(timezone.utc)
text_parts.append(f"Due: {due_dt.strftime('%B %d, %Y %H:%M UTC')}")
doc_updated_at = (
datetime.fromisoformat(
assignment.updated_at.replace("Z", "+00:00")
).astimezone(timezone.utc)
if assignment.updated_at
else None
)
document = self._build_document(
doc_id=f"canvas-assignment-{assignment.course_id}-{assignment.id}",
link=assignment.html_url,
text="\n\n".join(text_parts),
semantic_identifier=assignment.name or f"Assignment {assignment.id}",
doc_updated_at=doc_updated_at,
course_id=assignment.course_id,
doc_type="assignment",
)
return document
def _convert_announcement_to_document(
self, announcement: CanvasAnnouncement
) -> Document:
"""Convert a Canvas announcement to a Document."""
text_parts = [announcement.title]
msg_text = (
parse_html_page_basic(announcement.message) if announcement.message else ""
)
if msg_text:
text_parts.append(msg_text)
doc_updated_at = (
datetime.fromisoformat(
announcement.posted_at.replace("Z", "+00:00")
).astimezone(timezone.utc)
if announcement.posted_at
else None
)
document = self._build_document(
doc_id=f"canvas-announcement-{announcement.course_id}-{announcement.id}",
link=announcement.html_url,
text="\n\n".join(text_parts),
semantic_identifier=announcement.title or f"Announcement {announcement.id}",
doc_updated_at=doc_updated_at,
course_id=announcement.course_id,
doc_type="announcement",
)
return document
@override
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
"""Load and validate Canvas credentials."""
access_token = credentials.get("canvas_access_token")
if not access_token:
raise ConnectorMissingCredentialError("Canvas")
try:
client = CanvasApiClient(
bearer_token=access_token,
canvas_base_url=self.canvas_base_url,
)
client.get("courses", params={"per_page": "1"})
except ValueError as e:
raise ConnectorValidationError(f"Invalid Canvas base URL: {e}")
except OnyxError as e:
_handle_canvas_api_error(e)
self._canvas_client = client
return None
@override
def validate_connector_settings(self) -> None:
"""Validate Canvas connector settings by testing API access."""
try:
self.canvas_client.get("courses", params={"per_page": "1"})
logger.info("Canvas connector settings validated successfully")
except OnyxError as e:
_handle_canvas_api_error(e)
except ConnectorMissingCredentialError:
raise
except Exception as exc:
raise UnexpectedValidationError(
f"Unexpected error during Canvas settings validation: {exc}"
)
@override
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: CanvasConnectorCheckpoint,
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: CanvasConnectorCheckpoint,
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def build_dummy_checkpoint(self) -> CanvasConnectorCheckpoint:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def validate_checkpoint_json(
self, checkpoint_json: str
) -> CanvasConnectorCheckpoint:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
# TODO(benwu408): implemented in PR4 (perm sync)
raise NotImplementedError

View File

@@ -890,8 +890,8 @@ class ConfluenceConnector(
def _retrieve_all_slim_docs(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
callback: IndexingHeartbeatInterface | None = None,
include_permissions: bool = True,
) -> GenerateSlimDocumentOutput:
@@ -915,8 +915,8 @@ class ConfluenceConnector(
self.confluence_client, doc_id, restrictions, ancestors
) or space_level_access_info.get(page_space_key)
# Query pages (with optional time filtering for indexing_start)
page_query = self._construct_page_cql_query(start, end)
# Query pages
page_query = self.base_cql_page_query + self.cql_label_filter
for page in self.confluence_client.cql_paginate_all_expansions(
cql=page_query,
expand=restrictions_expand,
@@ -950,9 +950,7 @@ class ConfluenceConnector(
# Query attachments for each page
page_hierarchy_node_yielded = False
attachment_query = self._construct_attachment_query(
_get_page_id(page), start, end
)
attachment_query = self._construct_attachment_query(_get_page_id(page))
for attachment in self.confluence_client.cql_paginate_all_expansions(
cql=attachment_query,
expand=restrictions_expand,

View File

@@ -10,7 +10,6 @@ from datetime import timedelta
from datetime import timezone
from typing import Any
import requests
from jira import JIRA
from jira.exceptions import JIRAError
from jira.resources import Issue
@@ -240,53 +239,29 @@ def enhanced_search_ids(
)
def _bulk_fetch_request(
jira_client: JIRA, issue_ids: list[str], fields: str | None
) -> list[dict[str, Any]]:
"""Raw POST to the bulkfetch endpoint. Returns the list of raw issue dicts."""
def bulk_fetch_issues(
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
) -> list[Issue]:
# TODO: move away from this jira library if they continue to not support
# the endpoints we need. Using private fields is not ideal, but
# is likely fine for now since we pin the library version
bulk_fetch_path = jira_client._get_url("issue/bulkfetch")
# Prepare the payload according to Jira API v3 specification
payload: dict[str, Any] = {"issueIdsOrKeys": issue_ids}
# Only restrict fields if specified, might want to explicitly do this in the future
# to avoid reading unnecessary data
payload["fields"] = fields.split(",") if fields else ["*all"]
resp = jira_client._session.post(bulk_fetch_path, json=payload)
return resp.json()["issues"]
def bulk_fetch_issues(
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
) -> list[Issue]:
# TODO(evan): move away from this jira library if they continue to not support
# the endpoints we need. Using private fields is not ideal, but
# is likely fine for now since we pin the library version
try:
raw_issues = _bulk_fetch_request(jira_client, issue_ids, fields)
except requests.exceptions.JSONDecodeError:
if len(issue_ids) <= 1:
logger.exception(
f"Jira bulk-fetch response for issue(s) {issue_ids} could not "
f"be decoded as JSON (response too large or truncated)."
)
raise
mid = len(issue_ids) // 2
logger.warning(
f"Jira bulk-fetch JSON decode failed for batch of {len(issue_ids)} issues. "
f"Splitting into sub-batches of {mid} and {len(issue_ids) - mid}."
)
left = bulk_fetch_issues(jira_client, issue_ids[:mid], fields)
right = bulk_fetch_issues(jira_client, issue_ids[mid:], fields)
return left + right
response = jira_client._session.post(bulk_fetch_path, json=payload).json()
except Exception as e:
logger.error(f"Error fetching issues: {e}")
raise
raise e
return [
Issue(jira_client._options, jira_client._session, raw=issue)
for issue in raw_issues
for issue in response["issues"]
]

View File

@@ -72,10 +72,6 @@ CONNECTOR_CLASS_MAP = {
module_path="onyx.connectors.coda.connector",
class_name="CodaConnector",
),
DocumentSource.CANVAS: ConnectorMapping(
module_path="onyx.connectors.canvas.connector",
class_name="CanvasConnector",
),
DocumentSource.NOTION: ConnectorMapping(
module_path="onyx.connectors.notion.connector",
class_name="NotionConnector",

View File

@@ -1765,11 +1765,7 @@ class SharepointConnector(
checkpoint.current_drive_delta_next_link = None
checkpoint.seen_document_ids.clear()
def _fetch_slim_documents_from_sharepoint(
self,
start: datetime | None = None,
end: datetime | None = None,
) -> GenerateSlimDocumentOutput:
def _fetch_slim_documents_from_sharepoint(self) -> GenerateSlimDocumentOutput:
site_descriptors = self._filter_excluded_sites(
self.site_descriptors or self.fetch_sites()
)
@@ -1790,9 +1786,7 @@ class SharepointConnector(
# Process site documents if flag is True
if self.include_site_documents:
for driveitem, drive_name, drive_web_url in self._fetch_driveitems(
site_descriptor=site_descriptor,
start=start,
end=end,
site_descriptor=site_descriptor
):
if self._is_driveitem_excluded(driveitem):
logger.debug(f"Excluding by path denylist: {driveitem.web_url}")
@@ -1847,9 +1841,7 @@ class SharepointConnector(
# Process site pages if flag is True
if self.include_site_pages:
site_pages = self._fetch_site_pages(
site_descriptor, start=start, end=end
)
site_pages = self._fetch_site_pages(site_descriptor)
for site_page in site_pages:
logger.debug(
f"Processing site page: {site_page.get('webUrl', site_page.get('name', 'Unknown'))}"
@@ -2573,22 +2565,12 @@ class SharepointConnector(
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
callback: IndexingHeartbeatInterface | None = None, # noqa: ARG002
) -> GenerateSlimDocumentOutput:
start_dt = (
datetime.fromtimestamp(start, tz=timezone.utc)
if start is not None
else None
)
end_dt = (
datetime.fromtimestamp(end, tz=timezone.utc) if end is not None else None
)
yield from self._fetch_slim_documents_from_sharepoint(
start=start_dt,
end=end_dt,
)
yield from self._fetch_slim_documents_from_sharepoint()
if __name__ == "__main__":

View File

@@ -516,8 +516,6 @@ def _get_all_doc_ids(
] = default_msg_filter,
callback: IndexingHeartbeatInterface | None = None,
workspace_url: str | None = None,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
"""
Get all document ids in the workspace, channel by channel
@@ -548,8 +546,6 @@ def _get_all_doc_ids(
client=client,
channel=channel,
callback=callback,
oldest=str(start) if start else None, # 0.0 -> None intentionally
latest=str(end) if end is not None else None,
)
for message_batch in channel_message_batches:
@@ -851,8 +847,8 @@ class SlackConnector(
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
if self.client is None:
@@ -865,8 +861,6 @@ class SlackConnector(
msg_filter_func=self.msg_filter_func,
callback=callback,
workspace_url=self._workspace_url,
start=start,
end=end,
)
def _load_from_checkpoint(

View File

@@ -617,6 +617,80 @@ def reserve_message_id(
return empty_message
def reserve_multi_model_message_ids(
db_session: Session,
chat_session_id: UUID,
parent_message_id: int,
model_display_names: list[str],
) -> list[ChatMessage]:
"""Reserve N assistant message placeholders for multi-model parallel streaming.
All messages share the same parent (the user message). The parent's
latest_child_message_id points to the LAST reserved message so that the
default history-chain walker picks it up.
"""
reserved: list[ChatMessage] = []
for display_name in model_display_names:
msg = ChatMessage(
chat_session_id=chat_session_id,
parent_message_id=parent_message_id,
latest_child_message_id=None,
message="Response was terminated prior to completion, try regenerating.",
token_count=15, # placeholder; updated on completion by llm_loop_completion_handle
message_type=MessageType.ASSISTANT,
model_display_name=display_name,
)
db_session.add(msg)
reserved.append(msg)
# Flush to assign IDs without committing yet
db_session.flush()
# Point parent's latest_child to the last reserved message
parent = (
db_session.query(ChatMessage)
.filter(ChatMessage.id == parent_message_id)
.first()
)
if parent:
parent.latest_child_message_id = reserved[-1].id
db_session.commit()
return reserved
def set_preferred_response(
db_session: Session,
user_message_id: int,
preferred_assistant_message_id: int,
) -> None:
"""Set the preferred assistant response for a multi-model user message.
Validates that the user message is a USER type and that the preferred
assistant message is a direct child of that user message.
"""
user_msg = db_session.get(ChatMessage, user_message_id)
if user_msg is None:
raise ValueError(f"User message {user_message_id} not found")
if user_msg.message_type != MessageType.USER:
raise ValueError(f"Message {user_message_id} is not a user message")
assistant_msg = db_session.get(ChatMessage, preferred_assistant_message_id)
if assistant_msg is None:
raise ValueError(
f"Assistant message {preferred_assistant_message_id} not found"
)
if assistant_msg.parent_message_id != user_message_id:
raise ValueError(
f"Assistant message {preferred_assistant_message_id} is not a child "
f"of user message {user_message_id}"
)
user_msg.preferred_response_id = preferred_assistant_message_id
user_msg.latest_child_message_id = preferred_assistant_message_id
db_session.commit()
def create_new_chat_message(
chat_session_id: UUID,
parent_message: ChatMessage,
@@ -839,6 +913,8 @@ def translate_db_message_to_chat_message_detail(
error=chat_message.error,
current_feedback=current_feedback,
processing_duration_seconds=chat_message.processing_duration_seconds,
preferred_response_id=chat_message.preferred_response_id,
model_display_name=chat_message.model_display_name,
)
return chat_msg_detail

View File

@@ -3135,6 +3135,8 @@ class VoiceProvider(Base):
is_default_stt: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
is_default_tts: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)

View File

@@ -17,30 +17,39 @@ MAX_VOICE_PLAYBACK_SPEED = 2.0
def fetch_voice_providers(db_session: Session) -> list[VoiceProvider]:
"""Fetch all voice providers."""
return list(
db_session.scalars(select(VoiceProvider).order_by(VoiceProvider.name)).all()
db_session.scalars(
select(VoiceProvider)
.where(VoiceProvider.deleted.is_(False))
.order_by(VoiceProvider.name)
).all()
)
def fetch_voice_provider_by_id(
db_session: Session, provider_id: int
db_session: Session, provider_id: int, include_deleted: bool = False
) -> VoiceProvider | None:
"""Fetch a voice provider by ID."""
return db_session.scalar(
select(VoiceProvider).where(VoiceProvider.id == provider_id)
)
stmt = select(VoiceProvider).where(VoiceProvider.id == provider_id)
if not include_deleted:
stmt = stmt.where(VoiceProvider.deleted.is_(False))
return db_session.scalar(stmt)
def fetch_default_stt_provider(db_session: Session) -> VoiceProvider | None:
"""Fetch the default STT provider."""
return db_session.scalar(
select(VoiceProvider).where(VoiceProvider.is_default_stt.is_(True))
select(VoiceProvider)
.where(VoiceProvider.is_default_stt.is_(True))
.where(VoiceProvider.deleted.is_(False))
)
def fetch_default_tts_provider(db_session: Session) -> VoiceProvider | None:
"""Fetch the default TTS provider."""
return db_session.scalar(
select(VoiceProvider).where(VoiceProvider.is_default_tts.is_(True))
select(VoiceProvider)
.where(VoiceProvider.is_default_tts.is_(True))
.where(VoiceProvider.deleted.is_(False))
)
@@ -49,7 +58,9 @@ def fetch_voice_provider_by_type(
) -> VoiceProvider | None:
"""Fetch a voice provider by type."""
return db_session.scalar(
select(VoiceProvider).where(VoiceProvider.provider_type == provider_type)
select(VoiceProvider)
.where(VoiceProvider.provider_type == provider_type)
.where(VoiceProvider.deleted.is_(False))
)
@@ -108,10 +119,10 @@ def upsert_voice_provider(
def delete_voice_provider(db_session: Session, provider_id: int) -> None:
"""Delete a voice provider by ID."""
"""Soft-delete a voice provider by ID."""
provider = fetch_voice_provider_by_id(db_session, provider_id)
if provider:
db_session.delete(provider)
provider.deleted = True
db_session.flush()

View File

@@ -5,7 +5,6 @@ accidentally reaches the vector DB layer will fail loudly instead of timing
out against a nonexistent Vespa/OpenSearch instance.
"""
from collections.abc import Iterable
from typing import Any
from onyx.context.search.models import IndexFilters
@@ -67,7 +66,7 @@ class DisabledDocumentIndex(DocumentIndex):
# ------------------------------------------------------------------
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk], # noqa: ARG002
chunks: list[DocMetadataAwareIndexChunk], # noqa: ARG002
index_batch_params: IndexBatchParams, # noqa: ARG002
) -> set[DocumentInsertionRecord]:
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)

View File

@@ -1,5 +1,4 @@
import abc
from collections.abc import Iterable
from dataclasses import dataclass
from datetime import datetime
from typing import Any
@@ -207,7 +206,7 @@ class Indexable(abc.ABC):
@abc.abstractmethod
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk],
chunks: list[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[DocumentInsertionRecord]:
"""
@@ -227,8 +226,8 @@ class Indexable(abc.ABC):
it is done automatically outside of this code.
Parameters:
- chunks: Document chunks with all of the information needed for
indexing to the document index.
- chunks: Document chunks with all of the information needed for indexing to the document
index.
- tenant_id: The tenant id of the user whose chunks are being indexed
- large_chunks_enabled: Whether large chunks are enabled

View File

@@ -1,5 +1,4 @@
import abc
from collections.abc import Iterable
from typing import Self
from pydantic import BaseModel
@@ -210,10 +209,10 @@ class Indexable(abc.ABC):
@abc.abstractmethod
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk],
chunks: list[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata,
) -> list[DocumentInsertionRecord]:
"""Indexes an iterable of document chunks into the document index.
"""Indexes a list of document chunks into the document index.
This is often a batch operation including chunks from multiple
documents.

View File

@@ -1,12 +1,11 @@
import json
from collections.abc import Iterable
from collections import defaultdict
from typing import Any
import httpx
from opensearchpy import NotFoundError
from onyx.access.models import DocumentAccess
from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH
from onyx.configs.app_configs import VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT
from onyx.configs.chat_configs import NUM_RETURNED_HITS
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
@@ -352,7 +351,7 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk],
chunks: list[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[OldDocumentInsertionRecord]:
"""
@@ -648,10 +647,10 @@ class OpenSearchDocumentIndex(DocumentIndex):
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata,
chunks: list[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata, # noqa: ARG002
) -> list[DocumentInsertionRecord]:
"""Indexes an iterable of document chunks into the document index.
"""Indexes a list of document chunks into the document index.
Groups chunks by document ID and for each document, deletes existing
chunks and indexes the new chunks in bulk.
@@ -674,34 +673,29 @@ class OpenSearchDocumentIndex(DocumentIndex):
document is newly indexed or had already existed and was just
updated.
"""
total_chunks = sum(
cc.new_chunk_cnt
for cc in indexing_metadata.doc_id_to_chunk_cnt_diff.values()
# Group chunks by document ID.
doc_id_to_chunks: dict[str, list[DocMetadataAwareIndexChunk]] = defaultdict(
list
)
for chunk in chunks:
doc_id_to_chunks[chunk.source_document.id].append(chunk)
logger.debug(
f"[OpenSearchDocumentIndex] Indexing {total_chunks} chunks from {len(indexing_metadata.doc_id_to_chunk_cnt_diff)} "
f"[OpenSearchDocumentIndex] Indexing {len(chunks)} chunks from {len(doc_id_to_chunks)} "
f"documents for index {self._index_name}."
)
document_indexing_results: list[DocumentInsertionRecord] = []
deleted_doc_ids: set[str] = set()
# Buffer chunks per document as they arrive from the iterable.
# When the document ID changes flush the buffered chunks.
current_doc_id: str | None = None
current_chunks: list[DocMetadataAwareIndexChunk] = []
def _flush_chunks(doc_chunks: list[DocMetadataAwareIndexChunk]) -> None:
assert len(doc_chunks) > 0, "doc_chunks is empty"
# Try to index per-document.
for _, chunks in doc_id_to_chunks.items():
# Create a batch of OpenSearch-formatted chunks for bulk insertion.
# Since we are doing this in batches, an error occurring midway
# can result in a state where chunks are deleted and not all the
# new chunks have been indexed.
# Do this before deleting existing chunks to reduce the amount of
# time the document index has no content for a given document, and
# to reduce the chance of entering a state where we delete chunks,
# then some error happens, and never successfully index new chunks.
chunk_batch: list[DocumentChunk] = [
_convert_onyx_chunk_to_opensearch_document(chunk)
for chunk in doc_chunks
_convert_onyx_chunk_to_opensearch_document(chunk) for chunk in chunks
]
onyx_document: Document = doc_chunks[0].source_document
onyx_document: Document = chunks[0].source_document
# First delete the doc's chunks from the index. This is so that
# there are no dangling chunks in the index, in the event that the
# new document's content contains fewer chunks than the previous
@@ -710,43 +704,22 @@ class OpenSearchDocumentIndex(DocumentIndex):
# if the chunk count has actually decreased. This assumes that
# overlapping chunks are perfectly overwritten. If we can't
# guarantee that then we need the code as-is.
if onyx_document.id not in deleted_doc_ids:
num_chunks_deleted = self.delete(
onyx_document.id, onyx_document.chunk_count
)
deleted_doc_ids.add(onyx_document.id)
# If we see that chunks were deleted we assume the doc already
# existed. We record the result before bulk_index_documents
# runs. If indexing raises, this entire result list is discarded
# by the caller's retry logic, so early recording is safe.
document_indexing_results.append(
DocumentInsertionRecord(
document_id=onyx_document.id,
already_existed=num_chunks_deleted > 0,
)
)
num_chunks_deleted = self.delete(
onyx_document.id, onyx_document.chunk_count
)
# If we see that chunks were deleted we assume the doc already
# existed.
document_insertion_record = DocumentInsertionRecord(
document_id=onyx_document.id,
already_existed=num_chunks_deleted > 0,
)
# Now index. This will raise if a chunk of the same ID exists, which
# we do not expect because we should have deleted all chunks.
self._client.bulk_index_documents(
documents=chunk_batch,
tenant_state=self._tenant_state,
)
for chunk in chunks:
doc_id = chunk.source_document.id
if doc_id != current_doc_id:
if current_chunks:
_flush_chunks(current_chunks)
current_doc_id = doc_id
current_chunks = [chunk]
elif len(current_chunks) >= MAX_CHUNKS_PER_DOC_BATCH:
_flush_chunks(current_chunks)
current_chunks = [chunk]
else:
current_chunks.append(chunk)
if current_chunks:
_flush_chunks(current_chunks)
document_indexing_results.append(document_insertion_record)
return document_indexing_results

View File

@@ -6,7 +6,6 @@ import re
import time
import urllib
import zipfile
from collections.abc import Iterable
from dataclasses import dataclass
from datetime import datetime
from datetime import timedelta
@@ -462,7 +461,7 @@ class VespaIndex(DocumentIndex):
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk],
chunks: list[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[OldDocumentInsertionRecord]:
"""

View File

@@ -1,8 +1,6 @@
import concurrent.futures
import logging
import random
from collections.abc import Generator
from collections.abc import Iterable
from typing import Any
from uuid import UUID
@@ -10,7 +8,6 @@ import httpx
from pydantic import BaseModel
from retry import retry
from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH
from onyx.configs.app_configs import RECENCY_BIAS_MULTIPLIER
from onyx.configs.app_configs import RERANK_COUNT
from onyx.configs.chat_configs import DOC_TIME_DECAY
@@ -321,7 +318,7 @@ class VespaDocumentIndex(DocumentIndex):
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk],
chunks: list[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata,
) -> list[DocumentInsertionRecord]:
doc_id_to_chunk_cnt_diff = indexing_metadata.doc_id_to_chunk_cnt_diff
@@ -341,31 +338,22 @@ class VespaDocumentIndex(DocumentIndex):
# Vespa has restrictions on valid characters, yet document IDs come from
# external w.r.t. this class. We need to sanitize them.
#
# Instead of materializing all cleaned chunks upfront, we stream them
# through a generator that cleans IDs and builds the original-ID mapping
# incrementally as chunks flow into Vespa.
def _clean_and_track(
chunks_iter: Iterable[DocMetadataAwareIndexChunk],
id_map: dict[str, str],
seen_ids: set[str],
) -> Generator[DocMetadataAwareIndexChunk, None, None]:
"""Cleans chunk IDs and builds the original-ID mapping
incrementally as chunks flow through, avoiding a separate
materialization pass."""
for chunk in chunks_iter:
original_id = chunk.source_document.id
cleaned = clean_chunk_id_copy(chunk)
cleaned_id = cleaned.source_document.id
# Needed so the final DocumentInsertionRecord returned can have
# the original document ID. cleaned_chunks might not contain IDs
# exactly as callers supplied them.
id_map[cleaned_id] = original_id
seen_ids.add(cleaned_id)
yield cleaned
cleaned_chunks: list[DocMetadataAwareIndexChunk] = [
clean_chunk_id_copy(chunk) for chunk in chunks
]
assert len(cleaned_chunks) == len(
chunks
), "Bug: Cleaned chunks and input chunks have different lengths."
new_document_id_to_original_document_id: dict[str, str] = {}
all_cleaned_doc_ids: set[str] = set()
# Needed so the final DocumentInsertionRecord returned can have the
# original document ID. cleaned_chunks might not contain IDs exactly as
# callers supplied them.
new_document_id_to_original_document_id: dict[str, str] = dict()
for i, cleaned_chunk in enumerate(cleaned_chunks):
old_chunk = chunks[i]
new_document_id_to_original_document_id[
cleaned_chunk.source_document.id
] = old_chunk.source_document.id
existing_docs: set[str] = set()
@@ -421,16 +409,8 @@ class VespaDocumentIndex(DocumentIndex):
executor=executor,
)
# Insert new Vespa documents, streaming through the cleaning
# pipeline so chunks are never fully materialized.
cleaned_chunks = _clean_and_track(
chunks,
new_document_id_to_original_document_id,
all_cleaned_doc_ids,
)
for chunk_batch in batch_generator(
cleaned_chunks, min(BATCH_SIZE, MAX_CHUNKS_PER_DOC_BATCH)
):
# Insert new Vespa documents.
for chunk_batch in batch_generator(cleaned_chunks, BATCH_SIZE):
batch_index_vespa_chunks(
chunks=chunk_batch,
index_name=self._index_name,
@@ -439,6 +419,10 @@ class VespaDocumentIndex(DocumentIndex):
executor=executor,
)
all_cleaned_doc_ids: set[str] = {
chunk.source_document.id for chunk in cleaned_chunks
}
return [
DocumentInsertionRecord(
document_id=new_document_id_to_original_document_id[cleaned_doc_id],

View File

@@ -44,7 +44,6 @@ KNOWN_OPENPYXL_BUGS = [
"Value must be either numerical or a string containing a wildcard",
"File contains no valid workbook part",
"Unable to read workbook: could not read stylesheet from None",
"Colors must be aRGB hex values",
]

View File

@@ -19,8 +19,7 @@ from onyx.db.document import update_docs_updated_at__no_commit
from onyx.db.document_set import fetch_document_sets_for_documents
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
from onyx.indexing.models import ChunkEnrichmentContext
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import BuildMetadataAwareChunksResult
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexChunk
from onyx.indexing.models import UpdatableChunkData
@@ -86,21 +85,14 @@ class DocumentIndexingBatchAdapter:
) as transaction:
yield transaction
def prepare_enrichment(
def build_metadata_aware_chunks(
self,
context: DocumentBatchPrepareContext,
chunks_with_embeddings: list[IndexChunk],
chunk_content_scores: list[float],
tenant_id: str,
chunks: list[DocAwareChunk],
) -> "DocumentChunkEnricher":
"""Do all DB lookups once and return a per-chunk enricher."""
updatable_ids = [doc.id for doc in context.updatable_docs]
doc_id_to_new_chunk_cnt: dict[str, int] = {
doc_id: 0 for doc_id in updatable_ids
}
for chunk in chunks:
if chunk.source_document.id in doc_id_to_new_chunk_cnt:
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
context: DocumentBatchPrepareContext,
) -> BuildMetadataAwareChunksResult:
"""Enrich chunks with access, document sets, boosts, token counts, and hierarchy."""
no_access = DocumentAccess.build(
user_emails=[],
@@ -110,30 +102,67 @@ class DocumentIndexingBatchAdapter:
is_public=False,
)
return DocumentChunkEnricher(
doc_id_to_access_info=get_access_for_documents(
updatable_ids = [doc.id for doc in context.updatable_docs]
doc_id_to_access_info = get_access_for_documents(
document_ids=updatable_ids, db_session=self.db_session
)
doc_id_to_document_set = {
document_id: document_sets
for document_id, document_sets in fetch_document_sets_for_documents(
document_ids=updatable_ids, db_session=self.db_session
),
doc_id_to_document_set={
document_id: document_sets
for document_id, document_sets in fetch_document_sets_for_documents(
document_ids=updatable_ids, db_session=self.db_session
)
},
doc_id_to_ancestor_ids=self._get_ancestor_ids_for_documents(
context.updatable_docs, tenant_id
),
id_to_boost_map=context.id_to_boost_map,
doc_id_to_previous_chunk_cnt={
document_id: chunk_count
for document_id, chunk_count in fetch_chunk_counts_for_documents(
document_ids=updatable_ids,
db_session=self.db_session,
)
},
doc_id_to_new_chunk_cnt=dict(doc_id_to_new_chunk_cnt),
no_access=no_access,
tenant_id=tenant_id,
)
}
doc_id_to_previous_chunk_cnt: dict[str, int] = {
document_id: chunk_count
for document_id, chunk_count in fetch_chunk_counts_for_documents(
document_ids=updatable_ids,
db_session=self.db_session,
)
}
doc_id_to_new_chunk_cnt: dict[str, int] = {
doc_id: 0 for doc_id in updatable_ids
}
for chunk in chunks_with_embeddings:
if chunk.source_document.id in doc_id_to_new_chunk_cnt:
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
# Get ancestor hierarchy node IDs for each document
doc_id_to_ancestor_ids = self._get_ancestor_ids_for_documents(
context.updatable_docs, tenant_id
)
access_aware_chunks = [
DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=doc_id_to_access_info.get(chunk.source_document.id, no_access),
document_sets=set(
doc_id_to_document_set.get(chunk.source_document.id, [])
),
user_project=[],
personas=[],
boost=(
context.id_to_boost_map[chunk.source_document.id]
if chunk.source_document.id in context.id_to_boost_map
else DEFAULT_BOOST
),
tenant_id=tenant_id,
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],
ancestor_hierarchy_node_ids=doc_id_to_ancestor_ids[
chunk.source_document.id
],
)
for chunk_num, chunk in enumerate(chunks_with_embeddings)
]
return BuildMetadataAwareChunksResult(
chunks=access_aware_chunks,
doc_id_to_previous_chunk_cnt=doc_id_to_previous_chunk_cnt,
doc_id_to_new_chunk_cnt=doc_id_to_new_chunk_cnt,
user_file_id_to_raw_text={},
user_file_id_to_token_count={},
)
def _get_ancestor_ids_for_documents(
@@ -174,7 +203,7 @@ class DocumentIndexingBatchAdapter:
context: DocumentBatchPrepareContext,
updatable_chunk_data: list[UpdatableChunkData],
filtered_documents: list[Document],
enrichment: ChunkEnrichmentContext,
result: BuildMetadataAwareChunksResult,
) -> None:
"""Finalize DB updates, store plaintext, and mark docs as indexed."""
updatable_ids = [doc.id for doc in context.updatable_docs]
@@ -198,7 +227,7 @@ class DocumentIndexingBatchAdapter:
update_docs_chunk_count__no_commit(
document_ids=updatable_ids,
doc_id_to_chunk_count=enrichment.doc_id_to_new_chunk_cnt,
doc_id_to_chunk_count=result.doc_id_to_new_chunk_cnt,
db_session=self.db_session,
)
@@ -220,52 +249,3 @@ class DocumentIndexingBatchAdapter:
)
self.db_session.commit()
class DocumentChunkEnricher:
"""Pre-computed metadata for per-chunk enrichment of connector documents."""
def __init__(
self,
doc_id_to_access_info: dict[str, DocumentAccess],
doc_id_to_document_set: dict[str, list[str]],
doc_id_to_ancestor_ids: dict[str, list[int]],
id_to_boost_map: dict[str, int],
doc_id_to_previous_chunk_cnt: dict[str, int],
doc_id_to_new_chunk_cnt: dict[str, int],
no_access: DocumentAccess,
tenant_id: str,
) -> None:
self._doc_id_to_access_info = doc_id_to_access_info
self._doc_id_to_document_set = doc_id_to_document_set
self._doc_id_to_ancestor_ids = doc_id_to_ancestor_ids
self._id_to_boost_map = id_to_boost_map
self._no_access = no_access
self._tenant_id = tenant_id
self.doc_id_to_previous_chunk_cnt = doc_id_to_previous_chunk_cnt
self.doc_id_to_new_chunk_cnt = doc_id_to_new_chunk_cnt
def enrich_chunk(
self, chunk: IndexChunk, score: float
) -> DocMetadataAwareIndexChunk:
return DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=self._doc_id_to_access_info.get(
chunk.source_document.id, self._no_access
),
document_sets=set(
self._doc_id_to_document_set.get(chunk.source_document.id, [])
),
user_project=[],
personas=[],
boost=(
self._id_to_boost_map[chunk.source_document.id]
if chunk.source_document.id in self._id_to_boost_map
else DEFAULT_BOOST
),
tenant_id=self._tenant_id,
aggregated_chunk_boost_factor=score,
ancestor_hierarchy_node_ids=self._doc_id_to_ancestor_ids[
chunk.source_document.id
],
)

View File

@@ -1,9 +1,6 @@
from __future__ import annotations
import contextlib
import datetime
import time
from collections import defaultdict
from collections.abc import Generator
from uuid import UUID
@@ -27,13 +24,11 @@ from onyx.db.user_file import fetch_persona_ids_for_user_files
from onyx.db.user_file import fetch_user_project_ids_for_user_files
from onyx.file_store.utils import store_user_file_plaintext
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
from onyx.indexing.models import ChunkEnrichmentContext
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import BuildMetadataAwareChunksResult
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexChunk
from onyx.indexing.models import UpdatableChunkData
from onyx.llm.factory import get_default_llm
from onyx.natural_language_processing.utils import count_tokens
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.utils.logger import setup_logger
@@ -106,20 +101,13 @@ class UserFileIndexingAdapter:
f"Failed to acquire locks after {_NUM_LOCK_ATTEMPTS} attempts for user files: {[doc.id for doc in documents]}"
)
def prepare_enrichment(
def build_metadata_aware_chunks(
self,
context: DocumentBatchPrepareContext,
chunks_with_embeddings: list[IndexChunk],
chunk_content_scores: list[float],
tenant_id: str,
chunks: list[DocAwareChunk],
) -> UserFileChunkEnricher:
"""Do all DB lookups and pre-compute file metadata from chunks."""
updatable_ids = [doc.id for doc in context.updatable_docs]
doc_id_to_new_chunk_cnt: dict[str, int] = defaultdict(int)
content_by_file: dict[str, list[str]] = defaultdict(list)
for chunk in chunks:
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
content_by_file[chunk.source_document.id].append(chunk.content)
context: DocumentBatchPrepareContext,
) -> BuildMetadataAwareChunksResult:
no_access = DocumentAccess.build(
user_emails=[],
@@ -129,6 +117,7 @@ class UserFileIndexingAdapter:
is_public=False,
)
updatable_ids = [doc.id for doc in context.updatable_docs]
user_file_id_to_project_ids = fetch_user_project_ids_for_user_files(
user_file_ids=updatable_ids,
db_session=self.db_session,
@@ -149,6 +138,17 @@ class UserFileIndexingAdapter:
)
}
user_file_id_to_new_chunk_cnt: dict[str, int] = {
user_file_id: len(
[
chunk
for chunk in chunks_with_embeddings
if chunk.source_document.id == user_file_id
]
)
for user_file_id in updatable_ids
}
# Initialize tokenizer used for token count calculation
try:
llm = get_default_llm()
@@ -163,30 +163,46 @@ class UserFileIndexingAdapter:
user_file_id_to_raw_text: dict[str, str] = {}
user_file_id_to_token_count: dict[str, int | None] = {}
for user_file_id in updatable_ids:
contents = content_by_file.get(user_file_id)
if contents:
combined_content = " ".join(contents)
user_file_chunks = [
chunk
for chunk in chunks_with_embeddings
if chunk.source_document.id == user_file_id
]
if user_file_chunks:
combined_content = " ".join(
[chunk.content for chunk in user_file_chunks]
)
user_file_id_to_raw_text[str(user_file_id)] = combined_content
token_count: int = (
count_tokens(combined_content, llm_tokenizer)
if llm_tokenizer
else 0
token_count = (
len(llm_tokenizer.encode(combined_content)) if llm_tokenizer else 0
)
user_file_id_to_token_count[str(user_file_id)] = token_count
else:
user_file_id_to_raw_text[str(user_file_id)] = ""
user_file_id_to_token_count[str(user_file_id)] = None
return UserFileChunkEnricher(
user_file_id_to_access=user_file_id_to_access,
user_file_id_to_project_ids=user_file_id_to_project_ids,
user_file_id_to_persona_ids=user_file_id_to_persona_ids,
access_aware_chunks = [
DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=user_file_id_to_access.get(chunk.source_document.id, no_access),
document_sets=set(),
user_project=user_file_id_to_project_ids.get(
chunk.source_document.id, []
),
personas=user_file_id_to_persona_ids.get(chunk.source_document.id, []),
boost=DEFAULT_BOOST,
tenant_id=tenant_id,
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],
)
for chunk_num, chunk in enumerate(chunks_with_embeddings)
]
return BuildMetadataAwareChunksResult(
chunks=access_aware_chunks,
doc_id_to_previous_chunk_cnt=user_file_id_to_previous_chunk_cnt,
doc_id_to_new_chunk_cnt=dict(doc_id_to_new_chunk_cnt),
doc_id_to_new_chunk_cnt=user_file_id_to_new_chunk_cnt,
user_file_id_to_raw_text=user_file_id_to_raw_text,
user_file_id_to_token_count=user_file_id_to_token_count,
no_access=no_access,
tenant_id=tenant_id,
)
def _notify_assistant_owners_if_files_ready(
@@ -230,9 +246,8 @@ class UserFileIndexingAdapter:
context: DocumentBatchPrepareContext,
updatable_chunk_data: list[UpdatableChunkData], # noqa: ARG002
filtered_documents: list[Document], # noqa: ARG002
enrichment: ChunkEnrichmentContext,
result: BuildMetadataAwareChunksResult,
) -> None:
assert isinstance(enrichment, UserFileChunkEnricher)
user_file_ids = [doc.id for doc in context.updatable_docs]
user_files = (
@@ -248,10 +263,8 @@ class UserFileIndexingAdapter:
user_file.last_project_sync_at = datetime.datetime.now(
datetime.timezone.utc
)
user_file.chunk_count = enrichment.doc_id_to_new_chunk_cnt.get(
str(user_file.id), 0
)
user_file.token_count = enrichment.user_file_id_to_token_count[
user_file.chunk_count = result.doc_id_to_new_chunk_cnt[str(user_file.id)]
user_file.token_count = result.user_file_id_to_token_count[
str(user_file.id)
]
@@ -263,54 +276,8 @@ class UserFileIndexingAdapter:
# Store the plaintext in the file store for faster retrieval
# NOTE: this creates its own session to avoid committing the overall
# transaction.
for user_file_id, raw_text in enrichment.user_file_id_to_raw_text.items():
for user_file_id, raw_text in result.user_file_id_to_raw_text.items():
store_user_file_plaintext(
user_file_id=UUID(user_file_id),
plaintext_content=raw_text,
)
class UserFileChunkEnricher:
"""Pre-computed metadata for per-chunk enrichment of user-uploaded files."""
def __init__(
self,
user_file_id_to_access: dict[str, DocumentAccess],
user_file_id_to_project_ids: dict[str, list[int]],
user_file_id_to_persona_ids: dict[str, list[int]],
doc_id_to_previous_chunk_cnt: dict[str, int],
doc_id_to_new_chunk_cnt: dict[str, int],
user_file_id_to_raw_text: dict[str, str],
user_file_id_to_token_count: dict[str, int | None],
no_access: DocumentAccess,
tenant_id: str,
) -> None:
self._user_file_id_to_access = user_file_id_to_access
self._user_file_id_to_project_ids = user_file_id_to_project_ids
self._user_file_id_to_persona_ids = user_file_id_to_persona_ids
self._no_access = no_access
self._tenant_id = tenant_id
self.doc_id_to_previous_chunk_cnt = doc_id_to_previous_chunk_cnt
self.doc_id_to_new_chunk_cnt = doc_id_to_new_chunk_cnt
self.user_file_id_to_raw_text = user_file_id_to_raw_text
self.user_file_id_to_token_count = user_file_id_to_token_count
def enrich_chunk(
self, chunk: IndexChunk, score: float
) -> DocMetadataAwareIndexChunk:
return DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=self._user_file_id_to_access.get(
chunk.source_document.id, self._no_access
),
document_sets=set(),
user_project=self._user_file_id_to_project_ids.get(
chunk.source_document.id, []
),
personas=self._user_file_id_to_persona_ids.get(
chunk.source_document.id, []
),
boost=DEFAULT_BOOST,
tenant_id=self._tenant_id,
aggregated_chunk_boost_factor=score,
)

View File

@@ -1,89 +0,0 @@
import pickle
import shutil
import tempfile
from collections.abc import Iterator
from pathlib import Path
from onyx.indexing.models import IndexChunk
class ChunkBatchStore:
"""Manages serialization of embedded chunks to a temporary directory.
Owns the temp directory lifetime and provides save/load/stream/scrub
operations.
Use as a context manager to ensure cleanup::
with ChunkBatchStore() as store:
store.save(chunks, batch_idx=0)
for chunk in store.stream():
...
"""
_EXT = ".pkl"
def __init__(self) -> None:
self._tmpdir: Path | None = None
# -- context manager -----------------------------------------------------
def __enter__(self) -> "ChunkBatchStore":
self._tmpdir = Path(tempfile.mkdtemp(prefix="onyx_embeddings_"))
return self
def __exit__(self, *_exc: object) -> None:
if self._tmpdir is not None:
shutil.rmtree(self._tmpdir, ignore_errors=True)
self._tmpdir = None
@property
def _dir(self) -> Path:
assert self._tmpdir is not None, "ChunkBatchStore used outside context manager"
return self._tmpdir
# -- storage primitives --------------------------------------------------
def save(self, chunks: list[IndexChunk], batch_idx: int) -> None:
"""Serialize a batch of embedded chunks to disk."""
with open(self._dir / f"batch_{batch_idx}{self._EXT}", "wb") as f:
pickle.dump(chunks, f)
def _load(self, batch_file: Path) -> list[IndexChunk]:
"""Deserialize a batch of embedded chunks from a file."""
with open(batch_file, "rb") as f:
return pickle.load(f)
def _batch_files(self) -> list[Path]:
"""Return batch files sorted by numeric index."""
return sorted(
self._dir.glob(f"batch_*{self._EXT}"),
key=lambda p: int(p.stem.removeprefix("batch_")),
)
# -- higher-level operations ---------------------------------------------
def stream(self) -> Iterator[IndexChunk]:
"""Yield all chunks across all batch files.
Each call returns a fresh generator, so the data can be iterated
multiple times (e.g. once per document index).
"""
for batch_file in self._batch_files():
yield from self._load(batch_file)
def scrub_failed_docs(self, failed_doc_ids: set[str]) -> None:
"""Remove chunks belonging to *failed_doc_ids* from all batch files.
When a document fails embedding in batch N, earlier batches may
already contain successfully embedded chunks for that document.
This ensures the output is all-or-nothing per document.
"""
for batch_file in self._batch_files():
batch_chunks = self._load(batch_file)
cleaned = [
c for c in batch_chunks if c.source_document.id not in failed_doc_ids
]
if len(cleaned) != len(batch_chunks):
with open(batch_file, "wb") as f:
pickle.dump(cleaned, f)

View File

@@ -1,8 +1,5 @@
from collections import defaultdict
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Protocol
from pydantic import BaseModel
@@ -12,7 +9,6 @@ from sqlalchemy.orm import Session
from onyx.configs.app_configs import DEFAULT_CONTEXTUAL_RAG_LLM_NAME
from onyx.configs.app_configs import DEFAULT_CONTEXTUAL_RAG_LLM_PROVIDER
from onyx.configs.app_configs import ENABLE_CONTEXTUAL_RAG
from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
from onyx.configs.app_configs import MAX_TOKENS_FOR_FULL_INCLUSION
from onyx.configs.app_configs import USE_CHUNK_SUMMARY
@@ -47,12 +43,10 @@ from onyx.document_index.interfaces import DocumentMetadata
from onyx.document_index.interfaces import IndexBatchParams
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
from onyx.file_store.file_store import get_default_file_store
from onyx.indexing.chunk_batch_store import ChunkBatchStore
from onyx.indexing.chunker import Chunker
from onyx.indexing.embedder import embed_chunks_with_failure_handling
from onyx.indexing.embedder import IndexingEmbedder
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexingBatchAdapter
from onyx.indexing.models import UpdatableChunkData
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
@@ -69,7 +63,6 @@ from onyx.natural_language_processing.utils import tokenizer_trim_middle
from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT1
from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT2
from onyx.prompts.contextual_retrieval import DOCUMENT_SUMMARY_PROMPT
from onyx.utils.batching import batch_generator
from onyx.utils.logger import setup_logger
from onyx.utils.postgres_sanitization import sanitize_documents_for_postgres
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -98,20 +91,6 @@ class IndexingPipelineResult(BaseModel):
failures: list[ConnectorFailure]
@classmethod
def empty(cls, total_docs: int) -> "IndexingPipelineResult":
return cls(
new_docs=0,
total_docs=total_docs,
total_chunks=0,
failures=[],
)
class ChunkEmbeddingResult(BaseModel):
successful_chunk_ids: list[tuple[int, str]] # (chunk_id, document_id)
connector_failures: list[ConnectorFailure]
class IndexingPipelineProtocol(Protocol):
def __call__(
@@ -160,110 +139,6 @@ def _upsert_documents_in_db(
)
def _get_failed_doc_ids(failures: list[ConnectorFailure]) -> set[str]:
"""Extract document IDs from a list of connector failures."""
return {f.failed_document.document_id for f in failures if f.failed_document}
def _embed_chunks_to_store(
chunks: list[DocAwareChunk],
embedder: IndexingEmbedder,
tenant_id: str,
request_id: str | None,
store: ChunkBatchStore,
) -> ChunkEmbeddingResult:
"""Embed chunks in batches, spilling each batch to *store*.
If a document fails embedding in any batch, its chunks are excluded from
all batches (including earlier ones already written) so that the output
is all-or-nothing per document.
"""
successful_chunk_ids: list[tuple[int, str]] = []
all_embedding_failures: list[ConnectorFailure] = []
# Track failed doc IDs across all batches so that a failure in batch N
# causes chunks for that doc to be skipped in batch N+1 and stripped
# from earlier batches.
all_failed_doc_ids: set[str] = set()
for batch_idx, chunk_batch in enumerate(
batch_generator(chunks, MAX_CHUNKS_PER_DOC_BATCH)
):
# Skip chunks belonging to documents that failed in earlier batches.
chunk_batch = [
c for c in chunk_batch if c.source_document.id not in all_failed_doc_ids
]
if not chunk_batch:
continue
logger.debug(f"Embedding batch {batch_idx}: {len(chunk_batch)} chunks")
chunks_with_embeddings, embedding_failures = embed_chunks_with_failure_handling(
chunks=chunk_batch,
embedder=embedder,
tenant_id=tenant_id,
request_id=request_id,
)
all_embedding_failures.extend(embedding_failures)
all_failed_doc_ids.update(_get_failed_doc_ids(embedding_failures))
# Only keep successfully embedded chunks for non-failed docs.
chunks_with_embeddings = [
c
for c in chunks_with_embeddings
if c.source_document.id not in all_failed_doc_ids
]
successful_chunk_ids.extend(
(c.chunk_id, c.source_document.id) for c in chunks_with_embeddings
)
store.save(chunks_with_embeddings, batch_idx)
del chunks_with_embeddings
# Scrub earlier batches for docs that failed in later batches.
if all_failed_doc_ids:
store.scrub_failed_docs(all_failed_doc_ids)
successful_chunk_ids = [
(chunk_id, doc_id)
for chunk_id, doc_id in successful_chunk_ids
if doc_id not in all_failed_doc_ids
]
return ChunkEmbeddingResult(
successful_chunk_ids=successful_chunk_ids,
connector_failures=all_embedding_failures,
)
@contextmanager
def embed_and_stream(
chunks: list[DocAwareChunk],
embedder: IndexingEmbedder,
tenant_id: str,
request_id: str | None,
) -> Generator[tuple[ChunkEmbeddingResult, ChunkBatchStore], None, None]:
"""Embed chunks to disk and yield a ``(result, store)`` pair.
The store owns the temp directory — files are cleaned up when the context
manager exits.
Usage::
with embed_and_stream(chunks, embedder, tenant_id, req_id) as (result, store):
for chunk in store.stream():
...
"""
with ChunkBatchStore() as store:
result = _embed_chunks_to_store(
chunks=chunks,
embedder=embedder,
tenant_id=tenant_id,
request_id=request_id,
store=store,
)
yield result, store
def get_doc_ids_to_update(
documents: list[Document], db_docs: list[DBDocument]
) -> list[Document]:
@@ -762,29 +637,6 @@ def add_contextual_summaries(
return chunks
def _verify_indexing_completeness(
insertion_records: list[DocumentInsertionRecord],
write_failures: list[ConnectorFailure],
embedding_failed_doc_ids: set[str],
updatable_ids: list[str],
document_index_name: str,
) -> None:
"""Verify that every updatable document was either indexed or reported as failed."""
all_returned_doc_ids = (
{r.document_id for r in insertion_records}
| {f.failed_document.document_id for f in write_failures if f.failed_document}
| embedding_failed_doc_ids
)
if all_returned_doc_ids != set(updatable_ids):
raise RuntimeError(
f"Some documents were not successfully indexed. "
f"Updatable IDs: {updatable_ids}, "
f"Returned IDs: {all_returned_doc_ids}. "
f"This should never happen. "
f"This occured for document index {document_index_name}"
)
@log_function_time(debug_only=True)
def index_doc_batch(
*,
@@ -820,7 +672,12 @@ def index_doc_batch(
filtered_documents = filter_fnc(document_batch)
context = adapter.prepare(filtered_documents, ignore_time_skip)
if not context:
return IndexingPipelineResult.empty(len(filtered_documents))
return IndexingPipelineResult(
new_docs=0,
total_docs=len(filtered_documents),
total_chunks=0,
failures=[],
)
# Convert documents to IndexingDocument objects with processed section
# logger.debug("Processing image sections")
@@ -859,99 +716,117 @@ def index_doc_batch(
)
logger.debug("Starting embedding")
with embed_and_stream(chunks, embedder, tenant_id, request_id) as (
embedding_result,
chunk_store,
):
updatable_ids = [doc.id for doc in context.updatable_docs]
updatable_chunk_data = [
UpdatableChunkData(
chunk_id=chunk_id,
document_id=document_id,
boost_score=1.0,
)
for chunk_id, document_id in embedding_result.successful_chunk_ids
]
chunks_with_embeddings, embedding_failures = (
embed_chunks_with_failure_handling(
chunks=chunks,
embedder=embedder,
tenant_id=tenant_id,
request_id=request_id,
)
if chunks
else ([], [])
)
embedding_failed_doc_ids = _get_failed_doc_ids(
embedding_result.connector_failures
chunk_content_scores = [1.0] * len(chunks_with_embeddings)
updatable_ids = [doc.id for doc in context.updatable_docs]
updatable_chunk_data = [
UpdatableChunkData(
chunk_id=chunk.chunk_id,
document_id=chunk.source_document.id,
boost_score=score,
)
for chunk, score in zip(chunks_with_embeddings, chunk_content_scores)
]
# Acquires a lock on the documents so that no other process can modify them
# NOTE: don't need to acquire till here, since this is when the actual race condition
# with Vespa can occur.
with adapter.lock_context(context.updatable_docs):
# we're concerned about race conditions where multiple simultaneous indexings might result
# in one set of metadata overwriting another one in vespa.
# we still write data here for the immediate and most likely correct sync, but
# to resolve this, an update of the last modified field at the end of this loop
# always triggers a final metadata sync via the celery queue
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=chunks_with_embeddings,
chunk_content_scores=chunk_content_scores,
tenant_id=tenant_id,
context=context,
)
# Filter to only successfully embedded chunks so
# doc_id_to_new_chunk_cnt reflects what's actually written to Vespa.
embedded_chunks = [
c for c in chunks if c.source_document.id not in embedding_failed_doc_ids
]
short_descriptor_list = [chunk.to_short_descriptor() for chunk in result.chunks]
short_descriptor_log = str(short_descriptor_list)[:1024]
logger.debug(f"Indexing the following chunks: {short_descriptor_log}")
# Acquires a lock on the documents so that no other process can modify
# them. Not needed until here, since this is when the actual race
# condition with vector db can occur.
with adapter.lock_context(context.updatable_docs):
enricher = adapter.prepare_enrichment(
context=context,
tenant_id=tenant_id,
chunks=embedded_chunks,
primary_doc_idx_insertion_records: list[DocumentInsertionRecord] | None = None
primary_doc_idx_vector_db_write_failures: list[ConnectorFailure] | None = None
for document_index in document_indices:
# A document will not be spread across different batches, so all the
# documents with chunks in this set, are fully represented by the chunks
# in this set
(
insertion_records,
vector_db_write_failures,
) = write_chunks_to_vector_db_with_backoff(
document_index=document_index,
chunks=result.chunks,
index_batch_params=IndexBatchParams(
doc_id_to_previous_chunk_cnt=result.doc_id_to_previous_chunk_cnt,
doc_id_to_new_chunk_cnt=result.doc_id_to_new_chunk_cnt,
tenant_id=tenant_id,
large_chunks_enabled=chunker.enable_large_chunks,
),
)
index_batch_params = IndexBatchParams(
doc_id_to_previous_chunk_cnt=enricher.doc_id_to_previous_chunk_cnt,
doc_id_to_new_chunk_cnt=enricher.doc_id_to_new_chunk_cnt,
tenant_id=tenant_id,
large_chunks_enabled=chunker.enable_large_chunks,
)
primary_doc_idx_insertion_records: list[DocumentInsertionRecord] | None = (
None
)
primary_doc_idx_vector_db_write_failures: list[ConnectorFailure] | None = (
None
)
for document_index in document_indices:
def _enriched_stream() -> Iterator[DocMetadataAwareIndexChunk]:
for chunk in chunk_store.stream():
yield enricher.enrich_chunk(chunk, 1.0)
insertion_records, write_failures = (
write_chunks_to_vector_db_with_backoff(
document_index=document_index,
make_chunks=_enriched_stream,
index_batch_params=index_batch_params,
)
all_returned_doc_ids: set[str] = (
{record.document_id for record in insertion_records}
.union(
{
record.failed_document.document_id
for record in vector_db_write_failures
if record.failed_document
}
)
_verify_indexing_completeness(
insertion_records=insertion_records,
write_failures=write_failures,
embedding_failed_doc_ids=embedding_failed_doc_ids,
updatable_ids=updatable_ids,
document_index_name=document_index.__class__.__name__,
.union(
{
record.failed_document.document_id
for record in embedding_failures
if record.failed_document
}
)
# We treat the first document index we got as the primary one used
# for reporting the state of indexing.
if primary_doc_idx_insertion_records is None:
primary_doc_idx_insertion_records = insertion_records
if primary_doc_idx_vector_db_write_failures is None:
primary_doc_idx_vector_db_write_failures = write_failures
adapter.post_index(
context=context,
updatable_chunk_data=updatable_chunk_data,
filtered_documents=filtered_documents,
enrichment=enricher,
)
if all_returned_doc_ids != set(updatable_ids):
raise RuntimeError(
f"Some documents were not successfully indexed. "
f"Updatable IDs: {updatable_ids}, "
f"Returned IDs: {all_returned_doc_ids}. "
"This should never happen."
f"This occured for document index {document_index.__class__.__name__}"
)
# We treat the first document index we got as the primary one used
# for reporting the state of indexing.
if primary_doc_idx_insertion_records is None:
primary_doc_idx_insertion_records = insertion_records
if primary_doc_idx_vector_db_write_failures is None:
primary_doc_idx_vector_db_write_failures = vector_db_write_failures
adapter.post_index(
context=context,
updatable_chunk_data=updatable_chunk_data,
filtered_documents=filtered_documents,
result=result,
)
assert primary_doc_idx_insertion_records is not None
assert primary_doc_idx_vector_db_write_failures is not None
return IndexingPipelineResult(
new_docs=sum(
1 for r in primary_doc_idx_insertion_records if not r.already_existed
new_docs=len(
[r for r in primary_doc_idx_insertion_records if not r.already_existed]
),
total_docs=len(filtered_documents),
total_chunks=len(embedding_result.successful_chunk_ids),
failures=primary_doc_idx_vector_db_write_failures
+ embedding_result.connector_failures,
total_chunks=len(chunks_with_embeddings),
failures=primary_doc_idx_vector_db_write_failures + embedding_failures,
)

View File

@@ -235,16 +235,12 @@ class UpdatableChunkData(BaseModel):
boost_score: float
class ChunkEnrichmentContext(Protocol):
"""Returned by prepare_enrichment. Holds pre-computed metadata lookups
and provides per-chunk enrichment."""
class BuildMetadataAwareChunksResult(BaseModel):
chunks: list[DocMetadataAwareIndexChunk]
doc_id_to_previous_chunk_cnt: dict[str, int]
doc_id_to_new_chunk_cnt: dict[str, int]
def enrich_chunk(
self, chunk: IndexChunk, score: float
) -> DocMetadataAwareIndexChunk: ...
user_file_id_to_raw_text: dict[str, str]
user_file_id_to_token_count: dict[str, int | None]
class IndexingBatchAdapter(Protocol):
@@ -258,24 +254,18 @@ class IndexingBatchAdapter(Protocol):
) -> Generator[TransactionalContext, None, None]:
"""Provide a transaction/row-lock context for critical updates."""
def prepare_enrichment(
def build_metadata_aware_chunks(
self,
context: "DocumentBatchPrepareContext",
chunks_with_embeddings: list[IndexChunk],
chunk_content_scores: list[float],
tenant_id: str,
chunks: list[DocAwareChunk],
) -> ChunkEnrichmentContext:
"""Prepare per-chunk enrichment data (access, document sets, boost, etc.).
Precondition: ``chunks`` have already been through the embedding step
(i.e. they are ``IndexChunk`` instances with populated embeddings,
passed here as the base ``DocAwareChunk`` type).
"""
...
context: "DocumentBatchPrepareContext",
) -> BuildMetadataAwareChunksResult: ...
def post_index(
self,
context: "DocumentBatchPrepareContext",
updatable_chunk_data: list[UpdatableChunkData],
filtered_documents: list[Document],
enrichment: ChunkEnrichmentContext,
result: BuildMetadataAwareChunksResult,
) -> None: ...

View File

@@ -1,9 +1,6 @@
import time
from collections.abc import Callable
from collections.abc import Iterable
from collections import defaultdict
from http import HTTPStatus
from itertools import chain
from itertools import groupby
import httpx
@@ -31,22 +28,22 @@ def _log_insufficient_storage_error(e: Exception) -> None:
def write_chunks_to_vector_db_with_backoff(
document_index: DocumentIndex,
make_chunks: Callable[[], Iterable[DocMetadataAwareIndexChunk]],
chunks: list[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> tuple[list[DocumentInsertionRecord], list[ConnectorFailure]]:
"""Tries to insert all chunks in one large batch. If that batch fails for any reason,
goes document by document to isolate the failure(s).
IMPORTANT: must pass in whole documents at a time not individual chunks, since the
vector DB interface assumes that all chunks for a single document are present. The
chunks must also be in contiguous batches
vector DB interface assumes that all chunks for a single document are present.
"""
# first try to write the chunks to the vector db
try:
return (
list(
document_index.index(
chunks=make_chunks(),
chunks=chunks,
index_batch_params=index_batch_params,
)
),
@@ -63,23 +60,14 @@ def write_chunks_to_vector_db_with_backoff(
# wait a couple seconds just to give the vector db a chance to recover
time.sleep(2)
# try writing each doc one by one
chunks_for_docs: dict[str, list[DocMetadataAwareIndexChunk]] = defaultdict(list)
for chunk in chunks:
chunks_for_docs[chunk.source_document.id].append(chunk)
insertion_records: list[DocumentInsertionRecord] = []
failures: list[ConnectorFailure] = []
def key(chunk: DocMetadataAwareIndexChunk) -> str:
return chunk.source_document.id
seen_doc_ids: set[str] = set()
for doc_id, chunks_for_doc in groupby(make_chunks(), key=key):
if doc_id in seen_doc_ids:
raise RuntimeError(
f"Doc chunks are not arriving in order. Current doc_id={doc_id}, seen_doc_ids={list(seen_doc_ids)}"
)
seen_doc_ids.add(doc_id)
first_chunk = next(chunks_for_doc)
chunks_for_doc = chain([first_chunk], chunks_for_doc)
for doc_id, chunks_for_doc in chunks_for_docs.items():
try:
insertion_records.extend(
document_index.index(
@@ -99,7 +87,9 @@ def write_chunks_to_vector_db_with_backoff(
ConnectorFailure(
failed_document=DocumentFailure(
document_id=doc_id,
document_link=first_chunk.get_link(),
document_link=(
chunks_for_doc[0].get_link() if chunks_for_doc else None
),
),
failure_message=str(e),
exception=e,

View File

@@ -25,7 +25,6 @@ class LlmProviderNames(str, Enum):
LM_STUDIO = "lm_studio"
MISTRAL = "mistral"
LITELLM_PROXY = "litellm_proxy"
BIFROST = "bifrost"
def __str__(self) -> str:
"""Needed so things like:
@@ -45,7 +44,6 @@ WELL_KNOWN_PROVIDER_NAMES = [
LlmProviderNames.OLLAMA_CHAT,
LlmProviderNames.LM_STUDIO,
LlmProviderNames.LITELLM_PROXY,
LlmProviderNames.BIFROST,
]
@@ -63,7 +61,6 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
LlmProviderNames.OLLAMA_CHAT: "Ollama",
LlmProviderNames.LM_STUDIO: "LM Studio",
LlmProviderNames.LITELLM_PROXY: "LiteLLM Proxy",
LlmProviderNames.BIFROST: "Bifrost",
"groq": "Groq",
"anyscale": "Anyscale",
"deepseek": "DeepSeek",
@@ -115,7 +112,6 @@ AGGREGATOR_PROVIDERS: set[str] = {
LlmProviderNames.VERTEX_AI,
LlmProviderNames.AZURE,
LlmProviderNames.LITELLM_PROXY,
LlmProviderNames.BIFROST,
}
# Model family name mappings for display name generation

View File

@@ -185,21 +185,6 @@ def _messages_contain_tool_content(messages: list[dict[str, Any]]) -> bool:
return False
def _prompt_contains_tool_call_history(prompt: LanguageModelInput) -> bool:
"""Check if the prompt contains any assistant messages with tool_calls.
When Anthropic's extended thinking is enabled, the API requires every
assistant message to start with a thinking block before any tool_use
blocks. Since we don't preserve thinking_blocks (they carry
cryptographic signatures that can't be reconstructed), we must skip
the thinking param whenever history contains prior tool-calling turns.
"""
from onyx.llm.models import AssistantMessage
msgs = prompt if isinstance(prompt, list) else [prompt]
return any(isinstance(msg, AssistantMessage) and msg.tool_calls for msg in msgs)
def _is_vertex_model_rejecting_output_config(model_name: str) -> bool:
normalized_model_name = model_name.lower()
return any(
@@ -305,17 +290,6 @@ class LitellmLLM(LLM):
):
model_kwargs[VERTEX_LOCATION_KWARG] = "global"
# Bifrost: OpenAI-compatible proxy that expects model names in
# provider/model format (e.g. "anthropic/claude-sonnet-4-6").
# We route through LiteLLM's openai provider with the Bifrost base URL,
# and ensure /v1 is appended.
if model_provider == LlmProviderNames.BIFROST:
self._custom_llm_provider = "openai"
if self._api_base is not None:
base = self._api_base.rstrip("/")
self._api_base = base if base.endswith("/v1") else f"{base}/v1"
model_kwargs["api_base"] = self._api_base
# This is needed for Ollama to do proper function calling
if model_provider == LlmProviderNames.OLLAMA_CHAT and api_base is not None:
model_kwargs["api_base"] = api_base
@@ -427,20 +401,14 @@ class LitellmLLM(LLM):
optional_kwargs: dict[str, Any] = {}
# Model name
is_bifrost = self._model_provider == LlmProviderNames.BIFROST
model_provider = (
f"{self.config.model_provider}/responses"
if is_openai_model # Uses litellm's completions -> responses bridge
else self.config.model_provider
)
if is_bifrost:
# Bifrost expects model names in provider/model format
# (e.g. "anthropic/claude-sonnet-4-6") sent directly to its
# OpenAI-compatible endpoint. We use custom_llm_provider="openai"
# so LiteLLM doesn't try to route based on the provider prefix.
model = self.config.deployment_name or self.config.model_name
else:
model = f"{model_provider}/{self.config.deployment_name or self.config.model_name}"
model = (
f"{model_provider}/{self.config.deployment_name or self.config.model_name}"
)
# Tool choice
if is_claude_model and tool_choice == ToolChoiceOptions.REQUIRED:
@@ -481,20 +449,7 @@ class LitellmLLM(LLM):
reasoning_effort
)
# Anthropic requires every assistant message with tool_use
# blocks to start with a thinking block that carries a
# cryptographic signature. We don't preserve those blocks
# across turns, so skip thinking when the history already
# contains tool-calling assistant messages. LiteLLM's
# modify_params workaround doesn't cover all providers
# (notably Bedrock).
can_enable_thinking = (
budget_tokens is not None
and not _prompt_contains_tool_call_history(prompt)
)
if can_enable_thinking:
assert budget_tokens is not None # mypy
if budget_tokens is not None:
if max_tokens is not None:
# Anthropic has a weird rule where max token has to be at least as much as budget tokens if set
# and the minimum budget tokens is 1024
@@ -528,11 +483,10 @@ class LitellmLLM(LLM):
if structured_response_format:
optional_kwargs["response_format"] = structured_response_format
if not (is_claude_model or is_ollama or is_mistral) or is_bifrost:
if not (is_claude_model or is_ollama or is_mistral):
# Litellm bug: tool_choice is dropped silently if not specified here for OpenAI
# However, this param breaks Anthropic and Mistral models,
# so it must be conditionally included unless the request is
# routed through Bifrost's OpenAI-compatible endpoint.
# so it must be conditionally included.
# Additionally, tool_choice is not supported by Ollama and causes warnings if included.
# See also, https://github.com/ollama/ollama/issues/11171
optional_kwargs["allowed_openai_params"] = ["tool_choice"]

View File

@@ -13,8 +13,6 @@ LM_STUDIO_API_KEY_CONFIG_KEY = "LM_STUDIO_API_KEY"
LITELLM_PROXY_PROVIDER_NAME = "litellm_proxy"
BIFROST_PROVIDER_NAME = "bifrost"
# Providers that use optional Bearer auth from custom_config
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING: dict[str, str] = {
LlmProviderNames.OLLAMA_CHAT: OLLAMA_API_KEY_CONFIG_KEY,

View File

@@ -15,7 +15,6 @@ from onyx.llm.well_known_providers.auto_update_service import (
from onyx.llm.well_known_providers.constants import ANTHROPIC_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import AZURE_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import BEDROCK_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import BIFROST_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import LITELLM_PROXY_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import LM_STUDIO_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OLLAMA_PROVIDER_NAME
@@ -50,7 +49,6 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
LM_STUDIO_PROVIDER_NAME: [], # Dynamic - fetched from LM Studio API
OPENROUTER_PROVIDER_NAME: [], # Dynamic - fetched from OpenRouter API
LITELLM_PROXY_PROVIDER_NAME: [], # Dynamic - fetched from LiteLLM proxy API
BIFROST_PROVIDER_NAME: [], # Dynamic - fetched from Bifrost API
}

View File

@@ -439,7 +439,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
dsn=SENTRY_DSN,
integrations=[StarletteIntegration(), FastApiIntegration()],
traces_sample_rate=0.1,
release=__version__,
)
logger.info("Sentry initialized")
else:

View File

@@ -175,32 +175,6 @@ def get_tokenizer(
return _check_tokenizer_cache(provider_type, model_name)
# Max characters per encode() call.
_ENCODE_CHUNK_SIZE = 500_000
def count_tokens(
text: str,
tokenizer: BaseTokenizer,
token_limit: int | None = None,
) -> int:
"""Count tokens, chunking the input to avoid tiktoken stack overflow.
If token_limit is provided and the text is large enough to require
multiple chunks (> 500k chars), stops early once the count exceeds it.
When early-exiting, the returned value exceeds token_limit but may be
less than the true full token count.
"""
if len(text) <= _ENCODE_CHUNK_SIZE:
return len(tokenizer.encode(text))
total = 0
for start in range(0, len(text), _ENCODE_CHUNK_SIZE):
total += len(tokenizer.encode(text[start : start + _ENCODE_CHUNK_SIZE]))
if token_limit is not None and total > token_limit:
return total # Already over — skip remaining chunks
return total
def tokenizer_trim_content(
content: str, desired_length: int, tokenizer: BaseTokenizer
) -> str:

View File

@@ -3844,9 +3844,9 @@
}
},
"node_modules/@ts-morph/common/node_modules/brace-expansion": {
"version": "5.0.5",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-5.0.5.tgz",
"integrity": "sha512-VZznLgtwhn+Mact9tfiwx64fA9erHH/MCXEUfB/0bX/6Fz6ny5EGTXYltMocqg4xFAQZtnO3DHWWXi8RiuN7cQ==",
"version": "5.0.3",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-5.0.3.tgz",
"integrity": "sha512-fy6KJm2RawA5RcHkLa1z/ScpBeA762UF9KmZQxwIbDtRJrgLzM10depAiEQ+CXYcoiqW1/m96OAAoke2nE9EeA==",
"license": "MIT",
"dependencies": {
"balanced-match": "^4.0.2"
@@ -4224,9 +4224,9 @@
}
},
"node_modules/@typescript-eslint/typescript-estree/node_modules/brace-expansion": {
"version": "2.0.3",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.3.tgz",
"integrity": "sha512-MCV/fYJEbqx68aE58kv2cA/kiky1G8vux3OR6/jbS+jIMe/6fJWa0DTzJU7dqijOWYwHi1t29FlfYI9uytqlpA==",
"version": "2.0.2",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz",
"integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==",
"dev": true,
"license": "MIT",
"dependencies": {
@@ -5007,9 +5007,9 @@
}
},
"node_modules/brace-expansion": {
"version": "1.1.13",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.13.tgz",
"integrity": "sha512-9ZLprWS6EENmhEOpjCYW2c8VkmOvckIJZfkr7rBW6dObmfgJ/L1GpSYW5Hpo9lDz4D1+n0Ckz8rU7FwHDQiG/w==",
"version": "1.1.12",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz",
"integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==",
"dev": true,
"license": "MIT",
"dependencies": {

View File

@@ -44,12 +44,11 @@ def _check_ssrf_safety(endpoint_url: str) -> None:
"""Raise OnyxError if endpoint_url could be used for SSRF.
Delegates to validate_outbound_http_url with https_only=True.
Uses BAD_GATEWAY so the frontend maps the error to the Endpoint URL field.
"""
try:
validate_outbound_http_url(endpoint_url, https_only=True)
except (SSRFException, ValueError) as e:
raise OnyxError(OnyxErrorCode.BAD_GATEWAY, str(e))
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
# ---------------------------------------------------------------------------
@@ -123,8 +122,9 @@ def _validate_endpoint(
(not reachable — indicates the api_key is invalid).
Timeout handling:
- Any httpx.TimeoutException (ConnectTimeout, ReadTimeout, WriteTimeout, PoolTimeout) →
timeout (operator should consider increasing timeout_seconds).
- ConnectTimeout: TCP handshake never completed → cannot_connect.
- ReadTimeout / WriteTimeout: TCP was established, server responded slowly → timeout
(operator should consider increasing timeout_seconds).
- All other exceptions → cannot_connect.
"""
_check_ssrf_safety(endpoint_url)
@@ -141,11 +141,19 @@ def _validate_endpoint(
)
return HookValidateResponse(status=HookValidateStatus.passed)
except httpx.TimeoutException as exc:
# Any timeout (connect, read, or write) means the configured timeout_seconds
# is too low for this endpoint. Report as timeout so the UI directs the user
# to increase the timeout setting.
# ConnectTimeout: TCP handshake never completed → cannot_connect.
# ReadTimeout / WriteTimeout: TCP was established, server just responded slowly → timeout.
if isinstance(exc, httpx.ConnectTimeout):
logger.warning(
"Hook endpoint validation: connect timeout for %s",
endpoint_url,
exc_info=exc,
)
return HookValidateResponse(
status=HookValidateStatus.cannot_connect, error_message=str(exc)
)
logger.warning(
"Hook endpoint validation: timeout for %s",
"Hook endpoint validation: read/write timeout for %s",
endpoint_url,
exc_info=exc,
)

View File

@@ -9,15 +9,20 @@ from pydantic import ConfigDict
from pydantic import Field
from sqlalchemy.orm import Session
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_BYTES
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.db.llm import fetch_default_llm_model
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.file_types import OnyxFileExtensions
from onyx.file_processing.password_validation import is_file_password_protected
from onyx.natural_language_processing.utils import count_tokens
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.server.settings.store import load_settings
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import SKIP_USERFILE_THRESHOLD
from shared_configs.configs import SKIP_USERFILE_THRESHOLD_TENANT_LIST
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -156,8 +161,8 @@ def categorize_uploaded_files(
document formats (.pdf, .docx, …) and falls back to a text-detection
heuristic for unknown extensions (.py, .js, .rs, …).
- Uses default tokenizer to compute token length.
- If token length exceeds the admin-configured threshold, reject file.
- If extension unsupported or text cannot be extracted, reject file.
- If token length > threshold, reject file (unless threshold skip is enabled).
- If text cannot be extracted, reject file.
- Otherwise marked as acceptable.
"""
@@ -168,33 +173,36 @@ def categorize_uploaded_files(
provider_type = default_model.llm_provider.provider if default_model else None
tokenizer = get_tokenizer(model_name=model_name, provider_type=provider_type)
# Derive limits from admin-configurable settings.
# For upload size: load_settings() resolves 0/None to a positive default.
# For token threshold: 0 means "no limit" (converted to None below).
settings = load_settings()
max_upload_size_mb = (
settings.user_file_max_upload_size_mb
) # always positive after load_settings()
max_upload_size_bytes = (
max_upload_size_mb * 1024 * 1024 if max_upload_size_mb else None
)
token_threshold_k = settings.file_token_count_threshold_k
token_threshold = (
token_threshold_k * 1000 if token_threshold_k else None
) # 0 → None = no limit
# Check if threshold checks should be skipped
skip_threshold = False
# Check global skip flag (works for both single-tenant and multi-tenant)
if SKIP_USERFILE_THRESHOLD:
skip_threshold = True
logger.info("Skipping userfile threshold check (global setting)")
# Check tenant-specific skip list (only applicable in multi-tenant)
elif MULTI_TENANT and SKIP_USERFILE_THRESHOLD_TENANT_LIST:
try:
current_tenant_id = get_current_tenant_id()
skip_threshold = current_tenant_id in SKIP_USERFILE_THRESHOLD_TENANT_LIST
if skip_threshold:
logger.info(
f"Skipping userfile threshold check for tenant: {current_tenant_id}"
)
except RuntimeError as e:
logger.warning(f"Failed to get current tenant ID: {str(e)}")
for upload in files:
try:
filename = get_safe_filename(upload)
# Size limit is a hard safety cap.
if max_upload_size_bytes is not None and is_upload_too_large(
upload, max_upload_size_bytes
):
# Size limit is a hard safety cap and is enforced even when token
# threshold checks are skipped via SKIP_USERFILE_THRESHOLD settings.
if is_upload_too_large(upload, USER_FILE_MAX_UPLOAD_SIZE_BYTES):
results.rejected.append(
RejectedFile(
filename=filename,
reason=f"Exceeds {max_upload_size_mb} MB file size limit",
reason=f"Exceeds {USER_FILE_MAX_UPLOAD_SIZE_MB} MB file size limit",
)
)
continue
@@ -216,11 +224,11 @@ def categorize_uploaded_files(
)
continue
if token_threshold is not None and token_count > token_threshold:
if not skip_threshold and token_count > FILE_TOKEN_COUNT_THRESHOLD:
results.rejected.append(
RejectedFile(
filename=filename,
reason=f"Exceeds {token_threshold_k}K token limit",
reason=f"Exceeds {FILE_TOKEN_COUNT_THRESHOLD} token limit",
)
)
else:
@@ -261,14 +269,12 @@ def categorize_uploaded_files(
)
continue
token_count = count_tokens(
text_content, tokenizer, token_limit=token_threshold
)
if token_threshold is not None and token_count > token_threshold:
token_count = len(tokenizer.encode(text_content))
if not skip_threshold and token_count > FILE_TOKEN_COUNT_THRESHOLD:
results.rejected.append(
RejectedFile(
filename=filename,
reason=f"Exceeds {token_threshold_k}K token limit",
reason=f"Exceeds {FILE_TOKEN_COUNT_THRESHOLD} token limit",
)
)
else:

View File

@@ -57,8 +57,6 @@ from onyx.llm.well_known_providers.llm_provider_options import (
)
from onyx.server.manage.llm.models import BedrockFinalModelResponse
from onyx.server.manage.llm.models import BedrockModelsRequest
from onyx.server.manage.llm.models import BifrostFinalModelResponse
from onyx.server.manage.llm.models import BifrostModelsRequest
from onyx.server.manage.llm.models import DefaultModel
from onyx.server.manage.llm.models import LitellmFinalModelResponse
from onyx.server.manage.llm.models import LitellmModelDetails
@@ -1424,26 +1422,11 @@ def _get_litellm_models_response(api_key: str, api_base: str) -> dict:
cleaned_api_base = api_base.strip().rstrip("/")
url = f"{cleaned_api_base}/v1/models"
return _get_openai_compatible_models_response(
url=url,
source_name="LiteLLM proxy",
api_key=api_key,
)
def _get_openai_compatible_models_response(
url: str,
source_name: str,
api_key: str | None = None,
) -> dict:
"""Fetch model metadata from an OpenAI-compatible `/models` endpoint."""
headers = {
"Authorization": f"Bearer {api_key}",
"HTTP-Referer": "https://onyx.app",
"X-Title": "Onyx",
}
if not api_key:
headers.pop("Authorization")
try:
response = httpx.get(url, headers=headers, timeout=10.0)
@@ -1453,125 +1436,20 @@ def _get_openai_compatible_models_response(
if e.response.status_code == 401:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
f"Authentication failed: invalid or missing API key for {source_name}.",
"Authentication failed: invalid or missing API key for LiteLLM proxy.",
)
elif e.response.status_code == 404:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
f"{source_name} models endpoint not found at {url}. Please verify the API base URL.",
f"LiteLLM models endpoint not found at {url}. Please verify the API base URL.",
)
else:
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
f"Failed to fetch {source_name} models: {e}",
f"Failed to fetch LiteLLM models: {e}",
)
except httpx.RequestError as e:
logger.warning(
"Failed to fetch models from OpenAI-compatible endpoint",
extra={"source": source_name, "url": url, "error": str(e)},
exc_info=True,
)
except Exception as e:
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
f"Failed to fetch {source_name} models: {e}",
f"Failed to fetch LiteLLM models: {e}",
)
except ValueError as e:
logger.warning(
"Received invalid model response from OpenAI-compatible endpoint",
extra={"source": source_name, "url": url, "error": str(e)},
exc_info=True,
)
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
f"Failed to fetch {source_name} models: {e}",
)
@admin_router.post("/bifrost/available-models")
def get_bifrost_available_models(
request: BifrostModelsRequest,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[BifrostFinalModelResponse]:
"""Fetch available models from Bifrost gateway /v1/models endpoint."""
response_json = _get_bifrost_models_response(
api_base=request.api_base, api_key=request.api_key
)
models = response_json.get("data", [])
if not isinstance(models, list) or len(models) == 0:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No models found from your Bifrost endpoint",
)
results: list[BifrostFinalModelResponse] = []
for model in models:
try:
model_id = model.get("id", "")
model_name = model.get("name", model_id)
if not model_id:
continue
# Skip embedding models
if is_embedding_model(model_id):
continue
results.append(
BifrostFinalModelResponse(
name=model_id,
display_name=model_name,
max_input_tokens=model.get("context_length"),
supports_image_input=infer_vision_support(model_id),
supports_reasoning=is_reasoning_model(model_id, model_name),
)
)
except Exception as e:
logger.warning(
"Failed to parse Bifrost model entry",
extra={"error": str(e), "item": str(model)[:1000]},
)
if not results:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No compatible models found from Bifrost",
)
sorted_results = sorted(results, key=lambda m: m.name.lower())
# Sync new models to DB if provider_name is specified
if request.provider_name:
_sync_fetched_models(
db_session=db_session,
provider_name=request.provider_name,
models=[
SyncModelEntry(
name=r.name,
display_name=r.display_name,
max_input_tokens=r.max_input_tokens,
supports_image_input=r.supports_image_input,
)
for r in sorted_results
],
source_label="Bifrost",
)
return sorted_results
def _get_bifrost_models_response(api_base: str, api_key: str | None = None) -> dict:
"""Perform GET to Bifrost /v1/models and return parsed JSON."""
cleaned_api_base = api_base.strip().rstrip("/")
# Ensure we hit /v1/models
if cleaned_api_base.endswith("/v1"):
url = f"{cleaned_api_base}/models"
else:
url = f"{cleaned_api_base}/v1/models"
return _get_openai_compatible_models_response(
url=url,
source_name="Bifrost",
api_key=api_key,
)

View File

@@ -449,18 +449,3 @@ class LitellmModelDetails(BaseModel):
class LitellmFinalModelResponse(BaseModel):
provider_name: str # Provider name (e.g. "openai")
model_name: str # Model ID (e.g. "gpt-4o")
# Bifrost dynamic models fetch
class BifrostModelsRequest(BaseModel):
api_base: str
api_key: str | None = None
provider_name: str | None = None # Optional: to save models to existing provider
class BifrostFinalModelResponse(BaseModel):
name: str # Model ID in provider/model format (e.g. "anthropic/claude-sonnet-4-6")
display_name: str # Human-readable name from Bifrost API
max_input_tokens: int | None
supports_image_input: bool
supports_reasoning: bool

View File

@@ -25,7 +25,6 @@ DYNAMIC_LLM_PROVIDERS = frozenset(
LlmProviderNames.BEDROCK,
LlmProviderNames.OLLAMA_CHAT,
LlmProviderNames.LM_STUDIO,
LlmProviderNames.BIFROST,
}
)
@@ -51,25 +50,6 @@ BEDROCK_VISION_MODELS = frozenset(
}
)
# Known Bifrost/OpenAI-compatible vision-capable model families where the
# source API does not expose this metadata directly.
BIFROST_VISION_MODEL_FAMILIES = frozenset(
{
"anthropic/claude-3",
"anthropic/claude-4",
"amazon/nova-pro",
"amazon/nova-lite",
"amazon/nova-premier",
"openai/gpt-4o",
"openai/gpt-4.1",
"google/gemini",
"meta-llama/llama-3.2",
"mistral/pixtral",
"qwen/qwen2.5-vl",
"qwen/qwen-vl",
}
)
def is_valid_bedrock_model(
model_id: str,
@@ -96,18 +76,11 @@ def is_valid_bedrock_model(
def infer_vision_support(model_id: str) -> bool:
"""Infer vision support from model ID when base model metadata unavailable.
Used for providers like Bedrock and Bifrost where vision support may
need to be inferred from vendor/model naming conventions.
Used for cross-region inference profiles when the base model isn't
available in the user's region.
"""
model_id_lower = model_id.lower()
if any(vision_model in model_id_lower for vision_model in BEDROCK_VISION_MODELS):
return True
normalized_model_id = model_id_lower.replace(".", "/")
return any(
vision_model in normalized_model_id
for vision_model in BIFROST_VISION_MODEL_FAMILIES
)
return any(vision_model in model_id_lower for vision_model in BEDROCK_VISION_MODELS)
def generate_bedrock_display_name(model_id: str) -> str:
@@ -349,7 +322,7 @@ def extract_vendor_from_model_name(model_name: str, provider: str) -> str | None
- Ollama: "llama3:70b""Meta"
- Ollama: "qwen2.5:7b""Alibaba"
"""
if provider in (LlmProviderNames.OPENROUTER, LlmProviderNames.BIFROST):
if provider == LlmProviderNames.OPENROUTER:
# Format: "vendor/model-name" e.g., "anthropic/claude-3-5-sonnet"
if "/" in model_name:
vendor_key = model_name.split("/")[0].lower()

View File

@@ -12,6 +12,7 @@ stale, which is fine for monitoring dashboards.
import json
import threading
import time
from collections.abc import Callable
from datetime import datetime
from datetime import timezone
from typing import Any
@@ -103,23 +104,25 @@ class _CachedCollector(Collector):
class QueueDepthCollector(_CachedCollector):
"""Reads Celery queue lengths from the broker Redis on each scrape."""
"""Reads Celery queue lengths from the broker Redis on each scrape.
Uses a Redis client factory (callable) rather than a stored client
reference so the connection is always fresh from Celery's pool.
"""
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
super().__init__(cache_ttl)
self._celery_app: Any | None = None
self._get_redis: Callable[[], Redis] | None = None
def set_celery_app(self, app: Any) -> None:
"""Set the Celery app for broker Redis access."""
self._celery_app = app
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
"""Set a callable that returns a broker Redis client on demand."""
self._get_redis = factory
def _collect_fresh(self) -> list[GaugeMetricFamily]:
if self._celery_app is None:
if self._get_redis is None:
return []
from onyx.background.celery.celery_redis import celery_get_broker_client
redis_client = celery_get_broker_client(self._celery_app)
redis_client = self._get_redis()
depth = GaugeMetricFamily(
"onyx_queue_depth",
@@ -401,19 +404,17 @@ class RedisHealthCollector(_CachedCollector):
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
super().__init__(cache_ttl)
self._celery_app: Any | None = None
self._get_redis: Callable[[], Redis] | None = None
def set_celery_app(self, app: Any) -> None:
"""Set the Celery app for broker Redis access."""
self._celery_app = app
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
"""Set a callable that returns a broker Redis client on demand."""
self._get_redis = factory
def _collect_fresh(self) -> list[GaugeMetricFamily]:
if self._celery_app is None:
if self._get_redis is None:
return []
from onyx.background.celery.celery_redis import celery_get_broker_client
redis_client = celery_get_broker_client(self._celery_app)
redis_client = self._get_redis()
memory_used = GaugeMetricFamily(
"onyx_redis_memory_used_bytes",
@@ -448,128 +449,40 @@ class RedisHealthCollector(_CachedCollector):
return [memory_used, memory_peak, memory_frag, connected_clients]
class WorkerHeartbeatMonitor:
"""Monitors Celery worker health via the event stream.
Subscribes to ``worker-heartbeat``, ``worker-online``, and
``worker-offline`` events via a single persistent connection.
Runs in a daemon thread started once during worker setup.
"""
# Consider a worker down if no heartbeat received for this long.
_HEARTBEAT_TIMEOUT_SECONDS = 120.0
def __init__(self, celery_app: Any) -> None:
self._app = celery_app
self._worker_last_seen: dict[str, float] = {}
self._lock = threading.Lock()
self._running = False
self._thread: threading.Thread | None = None
def start(self) -> None:
"""Start the background event listener thread.
Safe to call multiple times — only starts one thread.
"""
if self._thread is not None and self._thread.is_alive():
return
self._running = True
self._thread = threading.Thread(target=self._listen, daemon=True)
self._thread.start()
logger.info("WorkerHeartbeatMonitor started")
def stop(self) -> None:
self._running = False
def _listen(self) -> None:
"""Background loop: connect to event stream and process heartbeats."""
while self._running:
try:
with self._app.connection() as conn:
recv = self._app.events.Receiver(
conn,
handlers={
"worker-heartbeat": self._on_heartbeat,
"worker-online": self._on_heartbeat,
"worker-offline": self._on_offline,
},
)
recv.capture(
limit=None, timeout=self._HEARTBEAT_TIMEOUT_SECONDS, wakeup=True
)
except Exception:
if self._running:
logger.debug(
"Heartbeat listener disconnected, reconnecting in 5s",
exc_info=True,
)
time.sleep(5.0)
else:
# capture() returned normally (timeout with no events); reconnect
if self._running:
logger.debug("Heartbeat capture timed out, reconnecting")
time.sleep(5.0)
def _on_heartbeat(self, event: dict[str, Any]) -> None:
hostname = event.get("hostname")
if hostname:
with self._lock:
self._worker_last_seen[hostname] = time.monotonic()
def _on_offline(self, event: dict[str, Any]) -> None:
hostname = event.get("hostname")
if hostname:
with self._lock:
self._worker_last_seen.pop(hostname, None)
def get_worker_status(self) -> dict[str, bool]:
"""Return {hostname: is_alive} for all known workers.
Thread-safe. Called by WorkerHealthCollector on each scrape.
Also prunes workers that have been dead longer than 2x the
heartbeat timeout to prevent unbounded growth.
"""
now = time.monotonic()
prune_threshold = self._HEARTBEAT_TIMEOUT_SECONDS * 2
with self._lock:
# Prune workers that have been gone for 2x the timeout
stale = [
h
for h, ts in self._worker_last_seen.items()
if (now - ts) > prune_threshold
]
for h in stale:
del self._worker_last_seen[h]
result: dict[str, bool] = {}
for hostname, last_seen in self._worker_last_seen.items():
alive = (now - last_seen) < self._HEARTBEAT_TIMEOUT_SECONDS
result[hostname] = alive
return result
class WorkerHealthCollector(_CachedCollector):
"""Collects Celery worker health from the heartbeat monitor.
"""Collects Celery worker count and process count via inspect ping.
Reads worker status from ``WorkerHeartbeatMonitor`` which listens
to the Celery event stream via a single persistent connection.
Uses a longer cache TTL (60s) since inspect.ping() is a broadcast
command that takes a couple seconds to complete.
Maintains a set of known worker short-names so that when a worker
stops responding, we emit ``up=0`` instead of silently dropping the
metric (which would make ``absent()``-style alerts impossible).
"""
def __init__(self, cache_ttl: float = 30.0) -> None:
super().__init__(cache_ttl)
self._monitor: WorkerHeartbeatMonitor | None = None
# Remove a worker from _known_workers after this many consecutive
# missed pings (at 60s TTL ≈ 10 minutes of being unreachable).
_MAX_CONSECUTIVE_MISSES = 10
def set_monitor(self, monitor: WorkerHeartbeatMonitor) -> None:
"""Set the heartbeat monitor instance."""
self._monitor = monitor
def __init__(self, cache_ttl: float = 60.0) -> None:
super().__init__(cache_ttl)
self._celery_app: Any | None = None
# worker short-name → consecutive miss count.
# Workers start at 0 and reset to 0 each time they respond.
# Removed after _MAX_CONSECUTIVE_MISSES missed collects.
self._known_workers: dict[str, int] = {}
def set_celery_app(self, app: Any) -> None:
"""Set the Celery app instance for inspect commands."""
self._celery_app = app
def _collect_fresh(self) -> list[GaugeMetricFamily]:
if self._monitor is None:
if self._celery_app is None:
return []
active_workers = GaugeMetricFamily(
"onyx_celery_active_worker_count",
"Number of active Celery workers with recent heartbeats",
"Number of active Celery workers responding to ping",
)
worker_up = GaugeMetricFamily(
"onyx_celery_worker_up",
@@ -578,15 +491,37 @@ class WorkerHealthCollector(_CachedCollector):
)
try:
status = self._monitor.get_worker_status()
alive_count = sum(1 for alive in status.values() if alive)
active_workers.add_metric([], alive_count)
inspector = self._celery_app.control.inspect(timeout=3.0)
ping_result = inspector.ping()
for hostname in sorted(status):
# Use short name (before @) for single-host deployments,
# full hostname when multiple hosts share a worker type.
label = hostname.split("@")[0]
worker_up.add_metric([label], 1 if status[hostname] else 0)
responding: set[str] = set()
if ping_result:
active_workers.add_metric([], len(ping_result))
for worker_name in ping_result:
# Strip hostname suffix for cleaner labels
short_name = worker_name.split("@")[0]
responding.add(short_name)
else:
active_workers.add_metric([], 0)
# Register newly-seen workers and reset miss count for
# workers that responded.
for short_name in responding:
self._known_workers[short_name] = 0
# Increment miss count for non-responding workers and evict
# those that have been missing too long.
stale = []
for short_name in list(self._known_workers):
if short_name not in responding:
self._known_workers[short_name] += 1
if self._known_workers[short_name] >= self._MAX_CONSECUTIVE_MISSES:
stale.append(short_name)
for short_name in stale:
del self._known_workers[short_name]
for short_name in sorted(self._known_workers):
worker_up.add_metric([short_name], 1 if short_name in responding else 0)
except Exception:
logger.debug("Failed to collect worker health metrics", exc_info=True)

View File

@@ -3,21 +3,24 @@
Called once by the monitoring celery worker after Redis and DB are ready.
"""
from collections.abc import Callable
from typing import Any
from celery import Celery
from prometheus_client.registry import REGISTRY
from redis import Redis
from onyx.server.metrics.indexing_pipeline import ConnectorHealthCollector
from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
from onyx.server.metrics.indexing_pipeline import RedisHealthCollector
from onyx.server.metrics.indexing_pipeline import WorkerHealthCollector
from onyx.server.metrics.indexing_pipeline import WorkerHeartbeatMonitor
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Module-level singletons — these are lightweight objects (no connections or DB
# state) until configure() / set_celery_app() is called. Keeping them at
# state) until configure() / set_redis_factory() is called. Keeping them at
# module level ensures they survive the lifetime of the worker process and are
# only registered with the Prometheus registry once.
_queue_collector = QueueDepthCollector()
@@ -25,28 +28,75 @@ _attempt_collector = IndexAttemptCollector()
_connector_collector = ConnectorHealthCollector()
_redis_health_collector = RedisHealthCollector()
_worker_health_collector = WorkerHealthCollector()
_heartbeat_monitor: WorkerHeartbeatMonitor | None = None
def _make_broker_redis_factory(celery_app: Celery) -> Callable[[], Redis]:
"""Create a factory that returns a cached broker Redis client.
Reuses a single connection across scrapes to avoid leaking connections.
Reconnects automatically if the cached connection becomes stale.
"""
_cached_client: list[Redis | None] = [None]
# Keep a reference to the Kombu Connection so we can close it on
# reconnect (the raw Redis client outlives the Kombu wrapper).
_cached_kombu_conn: list[Any] = [None]
def _close_client(client: Redis) -> None:
"""Best-effort close of a Redis client."""
try:
client.close()
except Exception:
logger.debug("Failed to close stale Redis client", exc_info=True)
def _close_kombu_conn() -> None:
"""Best-effort close of the cached Kombu Connection."""
conn = _cached_kombu_conn[0]
if conn is not None:
try:
conn.close()
except Exception:
logger.debug("Failed to close Kombu connection", exc_info=True)
_cached_kombu_conn[0] = None
def _get_broker_redis() -> Redis:
client = _cached_client[0]
if client is not None:
try:
client.ping()
return client
except Exception:
logger.debug("Cached Redis client stale, reconnecting")
_close_client(client)
_cached_client[0] = None
_close_kombu_conn()
# Get a fresh Redis client from the broker connection.
# We hold this client long-term (cached above) rather than using a
# context manager, because we need it to persist across scrapes.
# The caching logic above ensures we only ever hold one connection,
# and we close it explicitly on reconnect.
conn = celery_app.broker_connection()
# kombu's Channel exposes .client at runtime (the underlying Redis
# client) but the type stubs don't declare it.
new_client: Redis = conn.channel().client # type: ignore[attr-defined]
_cached_client[0] = new_client
_cached_kombu_conn[0] = conn
return new_client
return _get_broker_redis
def setup_indexing_pipeline_metrics(celery_app: Celery) -> None:
"""Register all indexing pipeline collectors with the default registry.
Args:
celery_app: The Celery application instance. Used to obtain a
celery_app: The Celery application instance. Used to obtain a fresh
broker Redis client on each scrape for queue depth metrics.
"""
_queue_collector.set_celery_app(celery_app)
_redis_health_collector.set_celery_app(celery_app)
# Start the heartbeat monitor daemon thread — uses a single persistent
# connection to receive worker-heartbeat events.
# Module-level singleton prevents duplicate threads on re-entry.
global _heartbeat_monitor
if _heartbeat_monitor is None:
_heartbeat_monitor = WorkerHeartbeatMonitor(celery_app)
_heartbeat_monitor.start()
_worker_health_collector.set_monitor(_heartbeat_monitor)
redis_factory = _make_broker_redis_factory(celery_app)
_queue_collector.set_redis_factory(redis_factory)
_redis_health_collector.set_redis_factory(redis_factory)
_worker_health_collector.set_celery_app(celery_app)
_attempt_collector.configure()
_connector_collector.configure()

View File

@@ -29,6 +29,7 @@ from onyx.chat.models import ChatFullResponse
from onyx.chat.models import CreateChatSessionID
from onyx.chat.process_message import gather_stream_full
from onyx.chat.process_message import handle_stream_message_objects
from onyx.chat.process_message import run_multi_model_stream
from onyx.chat.prompt_utils import get_default_base_system_prompt
from onyx.chat.stop_signal_checker import set_fence
from onyx.configs.app_configs import WEB_DOMAIN
@@ -46,6 +47,7 @@ from onyx.db.chat import get_chat_messages_by_session
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_chat_sessions_by_user
from onyx.db.chat import set_as_latest_chat_message
from onyx.db.chat import set_preferred_response
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import update_chat_session
from onyx.db.chat_search import search_chat_sessions
@@ -60,6 +62,8 @@ from onyx.db.persona import get_persona_by_id
from onyx.db.usage import increment_usage
from onyx.db.usage import UsageType
from onyx.db.user_file import get_file_id_by_user_file_id
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.file_store.file_store import get_default_file_store
from onyx.llm.constants import LlmProviderNames
from onyx.llm.factory import get_default_llm
@@ -81,6 +85,7 @@ from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
from onyx.server.query_and_chat.models import MessageOrigin
from onyx.server.query_and_chat.models import RenameChatSessionResponse
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.models import SetPreferredResponseRequest
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
from onyx.server.query_and_chat.session_loading import (
@@ -570,6 +575,46 @@ def handle_send_chat_message(
if get_hashed_api_key_from_request(request) or get_hashed_pat_from_request(request):
chat_message_req.origin = MessageOrigin.API
# Multi-model streaming path: 2-3 LLMs in parallel (streaming only)
is_multi_model = (
chat_message_req.llm_overrides is not None
and len(chat_message_req.llm_overrides) > 1
)
if is_multi_model and chat_message_req.stream:
# Narrowed here; is_multi_model already checked llm_overrides is not None
llm_overrides = chat_message_req.llm_overrides or []
def multi_model_stream_generator() -> Generator[str, None, None]:
try:
with get_session_with_current_tenant() as db_session:
for obj in run_multi_model_stream(
new_msg_req=chat_message_req,
user=user,
db_session=db_session,
llm_overrides=llm_overrides,
litellm_additional_headers=extract_headers(
request.headers, LITELLM_PASS_THROUGH_HEADERS
),
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
request.headers
),
mcp_headers=chat_message_req.mcp_headers,
):
yield get_json_line(obj.model_dump())
except Exception as e:
logger.exception("Error in multi-model streaming")
yield json.dumps({"error": str(e)})
return StreamingResponse(
multi_model_stream_generator(), media_type="text/event-stream"
)
if is_multi_model and not chat_message_req.stream:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
"Multi-model mode (llm_overrides with >1 entry) requires stream=True.",
)
# Non-streaming path: consume all packets and return complete response
if not chat_message_req.stream:
with get_session_with_current_tenant() as db_session:
@@ -660,6 +705,30 @@ def set_message_as_latest(
)
@router.put("/set-preferred-response")
def set_preferred_response_endpoint(
request_body: SetPreferredResponseRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
"""Set the preferred assistant response for a multi-model turn."""
try:
# Ownership check: get_chat_message raises ValueError if the message
# doesn't belong to this user, preventing cross-user mutation.
get_chat_message(
chat_message_id=request_body.user_message_id,
user_id=user.id if user else None,
db_session=db_session,
)
set_preferred_response(
db_session=db_session,
user_message_id=request_body.user_message_id,
preferred_assistant_message_id=request_body.preferred_response_id,
)
except ValueError as e:
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
@router.post("/create-chat-message-feedback")
def create_chat_feedback(
feedback: ChatFeedbackRequest,

View File

@@ -9,9 +9,7 @@ from onyx import __version__ as onyx_version
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.auth.users import is_user_admin
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
from onyx.configs.constants import KV_REINDEX_KEY
from onyx.configs.constants import NotificationType
from onyx.db.engine.sql_engine import get_session
@@ -19,16 +17,10 @@ from onyx.db.models import User
from onyx.db.notification import dismiss_all_notifications
from onyx.db.notification import get_notifications
from onyx.db.notification import update_notification_last_shown
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.utils import HOOKS_AVAILABLE
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.features.build.utils import is_onyx_craft_enabled
from onyx.server.settings.models import (
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB,
)
from onyx.server.settings.models import DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
from onyx.server.settings.models import Notification
from onyx.server.settings.models import Settings
from onyx.server.settings.models import UserSettings
@@ -49,15 +41,6 @@ basic_router = APIRouter(prefix="/settings")
def admin_put_settings(
settings: Settings, _: User = Depends(current_admin_user)
) -> None:
if (
settings.user_file_max_upload_size_mb is not None
and settings.user_file_max_upload_size_mb > 0
and settings.user_file_max_upload_size_mb > MAX_ALLOWED_UPLOAD_SIZE_MB
):
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
f"File upload size limit cannot exceed {MAX_ALLOWED_UPLOAD_SIZE_MB} MB",
)
store_settings(settings)
@@ -100,16 +83,6 @@ def fetch_settings(
vector_db_enabled=not DISABLE_VECTOR_DB,
hooks_enabled=HOOKS_AVAILABLE,
version=onyx_version,
max_allowed_upload_size_mb=MAX_ALLOWED_UPLOAD_SIZE_MB,
default_user_file_max_upload_size_mb=min(
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB,
MAX_ALLOWED_UPLOAD_SIZE_MB,
),
default_file_token_count_threshold_k=(
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
if DISABLE_VECTOR_DB
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
),
)

View File

@@ -2,19 +2,12 @@ from datetime import datetime
from enum import Enum
from pydantic import BaseModel
from pydantic import Field
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
from onyx.configs.constants import NotificationType
from onyx.configs.constants import QueryHistoryType
from onyx.db.models import Notification as NotificationDBModel
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB = 200
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB = 10000
class PageType(str, Enum):
CHAT = "chat"
@@ -85,12 +78,7 @@ class Settings(BaseModel):
# User Knowledge settings
user_knowledge_enabled: bool | None = True
user_file_max_upload_size_mb: int | None = Field(
default=DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB, ge=0
)
file_token_count_threshold_k: int | None = Field(
default=None, ge=0 # thousands of tokens; None = context-aware default
)
user_file_max_upload_size_mb: int | None = None
# Connector settings
show_extra_connectors: bool | None = True
@@ -120,14 +108,3 @@ class UserSettings(Settings):
hooks_enabled: bool = False
# Application version, read from the ONYX_VERSION env var at startup.
version: str | None = None
# Hard ceiling for user_file_max_upload_size_mb, derived from env var.
max_allowed_upload_size_mb: int = MAX_ALLOWED_UPLOAD_SIZE_MB
# Factory defaults so the frontend can show a "restore default" button.
default_user_file_max_upload_size_mb: int = DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
default_file_token_count_threshold_k: int = Field(
default_factory=lambda: (
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
if DISABLE_VECTOR_DB
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
)
)

View File

@@ -1,19 +1,13 @@
from onyx.cache.factory import get_cache_backend
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.configs.app_configs import DISABLE_USER_KNOWLEDGE
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
from onyx.configs.app_configs import SHOW_EXTRA_CONNECTORS
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.configs.constants import KV_SETTINGS_KEY
from onyx.configs.constants import OnyxRedisLocks
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.settings.models import (
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB,
)
from onyx.server.settings.models import DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
from onyx.server.settings.models import Settings
from onyx.utils.logger import setup_logger
@@ -57,36 +51,9 @@ def load_settings() -> Settings:
if DISABLE_USER_KNOWLEDGE:
settings.user_knowledge_enabled = False
settings.user_file_max_upload_size_mb = USER_FILE_MAX_UPLOAD_SIZE_MB
settings.show_extra_connectors = SHOW_EXTRA_CONNECTORS
settings.opensearch_indexing_enabled = ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
# Resolve context-aware defaults for token threshold.
# None = admin hasn't set a value yet → use context-aware default.
# 0 = admin explicitly chose "no limit" → preserve as-is.
if settings.file_token_count_threshold_k is None:
settings.file_token_count_threshold_k = (
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
if DISABLE_VECTOR_DB
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
)
# Upload size: 0 and None are treated as "unset" (not "no limit") →
# fall back to min(configured default, hard ceiling).
if not settings.user_file_max_upload_size_mb:
settings.user_file_max_upload_size_mb = min(
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB,
MAX_ALLOWED_UPLOAD_SIZE_MB,
)
# Clamp to env ceiling so stale KV values are capped even if the
# operator lowered MAX_ALLOWED_UPLOAD_SIZE_MB after a higher value
# was already saved (api.py only guards new writes).
if (
settings.user_file_max_upload_size_mb > 0
and settings.user_file_max_upload_size_mb > MAX_ALLOWED_UPLOAD_SIZE_MB
):
settings.user_file_max_upload_size_mb = MAX_ALLOWED_UPLOAD_SIZE_MB
return settings

View File

@@ -187,7 +187,7 @@ coloredlogs==15.0.1
# via onnxruntime
courlan==1.3.2
# via trafilatura
cryptography==46.0.6
cryptography==46.0.5
# via
# authlib
# google-auth
@@ -449,7 +449,7 @@ kombu==5.5.4
# via celery
kubernetes==31.0.0
# via onyx
langchain-core==1.2.22
langchain-core==1.2.11
# via onyx
langdetect==1.0.9
# via unstructured
@@ -735,7 +735,7 @@ pyee==13.0.0
# via playwright
pygithub==2.5.0
# via onyx
pygments==2.20.0
pygments==2.19.2
# via rich
pyjwt==2.12.0
# via

View File

@@ -97,7 +97,7 @@ comm==0.2.3
# via ipykernel
contourpy==1.3.3
# via matplotlib
cryptography==46.0.6
cryptography==46.0.5
# via
# google-auth
# pyjwt
@@ -263,7 +263,7 @@ oauthlib==3.2.2
# via
# kubernetes
# requests-oauthlib
onyx-devtools==0.7.2
onyx-devtools==0.7.1
# via onyx
openai==2.14.0
# via
@@ -349,7 +349,7 @@ pydantic-core==2.33.2
# via pydantic
pydantic-settings==2.12.0
# via mcp
pygments==2.20.0
pygments==2.19.2
# via
# ipython
# ipython-pygments-lexers

View File

@@ -76,7 +76,7 @@ colorama==0.4.6 ; sys_platform == 'win32'
# via
# click
# tqdm
cryptography==46.0.6
cryptography==46.0.5
# via
# google-auth
# pyjwt

View File

@@ -92,7 +92,7 @@ colorama==0.4.6 ; sys_platform == 'win32'
# via
# click
# tqdm
cryptography==46.0.6
cryptography==46.0.5
# via
# google-auth
# pyjwt

View File

@@ -191,6 +191,25 @@ IGNORED_SYNCING_TENANT_LIST = (
else None
)
# Global flag to skip userfile threshold for all users/tenants
SKIP_USERFILE_THRESHOLD = (
os.environ.get("SKIP_USERFILE_THRESHOLD", "").lower() == "true"
)
# Comma-separated list of specific tenant IDs to skip threshold (multi-tenant only)
SKIP_USERFILE_THRESHOLD_TENANT_IDS = os.environ.get(
"SKIP_USERFILE_THRESHOLD_TENANT_IDS"
)
SKIP_USERFILE_THRESHOLD_TENANT_LIST = (
[
tenant.strip()
for tenant in SKIP_USERFILE_THRESHOLD_TENANT_IDS.split(",")
if tenant.strip()
]
if SKIP_USERFILE_THRESHOLD_TENANT_IDS
else None
)
ENVIRONMENT = os.environ.get("ENVIRONMENT") or "not_explicitly_set"

View File

@@ -1,6 +1,4 @@
import time
from datetime import datetime
from datetime import timezone
import pytest
@@ -19,10 +17,6 @@ PRIVATE_CHANNEL_USERS = [
"test_user_2@onyx-test.com",
]
# Predates any test workspace messages, so the result set should match
# the "no start time" case while exercising the oldest= parameter.
OLDEST_TS_2016 = datetime(2016, 1, 1, tzinfo=timezone.utc).timestamp()
pytestmark = pytest.mark.usefixtures("enable_ee")
@@ -111,17 +105,15 @@ def test_load_from_checkpoint_access__private_channel(
],
indirect=True,
)
@pytest.mark.parametrize("start_ts", [None, OLDEST_TS_2016])
def test_slim_documents_access__public_channel(
slack_connector: SlackConnector,
start_ts: float | None,
) -> None:
"""Test that retrieve_all_slim_docs_perm_sync returns correct access information for slim documents."""
if not slack_connector.client:
raise RuntimeError("Web client must be defined")
slim_docs_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
start=start_ts,
start=0.0,
end=time.time(),
)
@@ -157,7 +149,7 @@ def test_slim_documents_access__private_channel(
raise RuntimeError("Web client must be defined")
slim_docs_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
start=None,
start=0.0,
end=time.time(),
)

View File

@@ -129,10 +129,6 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
return_value=mock_app,
),
patch(_PATCH_QUEUE_DEPTH, return_value=0),
patch(
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
return_value=MagicMock(),
),
):
yield

View File

@@ -88,22 +88,10 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
the actual task instance. We patch ``app`` on that instance's class
(a unique Celery-generated Task subclass) so the mock is scoped to this
task only.
Also patches ``celery_get_broker_client`` so the mock app doesn't need
a real broker URL.
"""
task_instance = task.run.__self__
with (
patch.object(
type(task_instance),
"app",
new_callable=PropertyMock,
return_value=mock_app,
),
patch(
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
return_value=MagicMock(),
),
with patch.object(
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
):
yield

View File

@@ -1,7 +1,7 @@
"""
External dependency unit tests for UserFileIndexingAdapter metadata writing.
Validates that prepare_enrichment produces DocMetadataAwareIndexChunk
Validates that build_metadata_aware_chunks produces DocMetadataAwareIndexChunk
objects with both `user_project` and `personas` fields populated correctly
based on actual DB associations.
@@ -127,7 +127,7 @@ def _make_index_chunk(user_file: UserFile) -> IndexChunk:
class TestAdapterWritesBothMetadataFields:
"""prepare_enrichment must populate user_project AND personas."""
"""build_metadata_aware_chunks must populate user_project AND personas."""
@patch(
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
@@ -153,13 +153,15 @@ class TestAdapterWritesBothMetadataFields:
doc = chunk.source_document
context = DocumentBatchPrepareContext(updatable_docs=[doc], id_to_boost_map={})
enricher = adapter.prepare_enrichment(
context=context,
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
context=context,
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
assert len(result.chunks) == 1
aware_chunk = result.chunks[0]
assert persona.id in aware_chunk.personas
assert aware_chunk.user_project == []
@@ -188,13 +190,15 @@ class TestAdapterWritesBothMetadataFields:
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
enricher = adapter.prepare_enrichment(
context=context,
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
context=context,
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
assert len(result.chunks) == 1
aware_chunk = result.chunks[0]
assert project.id in aware_chunk.user_project
assert aware_chunk.personas == []
@@ -225,13 +229,14 @@ class TestAdapterWritesBothMetadataFields:
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
enricher = adapter.prepare_enrichment(
context=context,
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
context=context,
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
aware_chunk = result.chunks[0]
assert persona.id in aware_chunk.personas
assert project.id in aware_chunk.user_project
@@ -256,13 +261,14 @@ class TestAdapterWritesBothMetadataFields:
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
enricher = adapter.prepare_enrichment(
context=context,
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
context=context,
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
aware_chunk = result.chunks[0]
assert aware_chunk.personas == []
assert aware_chunk.user_project == []
@@ -294,11 +300,12 @@ class TestAdapterWritesBothMetadataFields:
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
enricher = adapter.prepare_enrichment(
context=context,
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
context=context,
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
aware_chunk = result.chunks[0]
assert set(aware_chunk.personas) == {persona_a.id, persona_b.id}

View File

@@ -90,17 +90,8 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
task only.
"""
task_instance = task.run.__self__
with (
patch.object(
type(task_instance),
"app",
new_callable=PropertyMock,
return_value=mock_app,
),
patch(
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
return_value=MagicMock(),
),
with patch.object(
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
):
yield

View File

@@ -103,11 +103,6 @@ _EXPECTED_CONFLUENCE_GROUPS = [
user_emails={"oauth@onyx.app"},
gives_anyone_access=False,
),
ExternalUserGroupSet(
id="no yuhong allowed",
user_emails={"hagen@danswer.ai", "pablo@onyx.app", "chris@onyx.app"},
gives_anyone_access=False,
),
]

View File

@@ -6,7 +6,6 @@ These tests assume Vespa and OpenSearch are running.
import time
import uuid
from collections.abc import Generator
from collections.abc import Iterator
import httpx
import pytest
@@ -22,7 +21,6 @@ from onyx.document_index.opensearch.opensearch_document_index import (
)
from onyx.document_index.vespa.index import VespaIndex
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
from onyx.indexing.models import DocMetadataAwareIndexChunk
from tests.external_dependency_unit.constants import TEST_TENANT_ID
from tests.external_dependency_unit.document_index.conftest import EMBEDDING_DIM
from tests.external_dependency_unit.document_index.conftest import make_chunk
@@ -203,25 +201,3 @@ class TestDocumentIndexNew:
assert len(result_map) == 2
assert result_map[existing_doc] is True
assert result_map[new_doc] is False
def test_index_accepts_generator(
self,
document_indices: list[DocumentIndexNew],
tenant_context: None, # noqa: ARG002
) -> None:
"""index() accepts a generator (any iterable), not just a list."""
for document_index in document_indices:
doc_id = f"test_gen_{uuid.uuid4().hex[:8]}"
metadata = make_indexing_metadata([doc_id], old_counts=[0], new_counts=[3])
def chunk_gen() -> Iterator[DocMetadataAwareIndexChunk]:
for i in range(3):
yield make_chunk(doc_id, chunk_id=i)
results = document_index.index(
chunks=chunk_gen(), indexing_metadata=metadata
)
assert len(results) == 1
assert results[0].document_id == doc_id
assert results[0].already_existed is False

View File

@@ -5,7 +5,6 @@ These tests assume Vespa and OpenSearch are running.
import time
from collections.abc import Generator
from collections.abc import Iterator
import pytest
@@ -167,29 +166,3 @@ class TestDocumentIndexOld:
batch_retrieval=True,
)
assert len(inference_chunks) == 0
def test_index_accepts_generator(
self,
document_indices: list[DocumentIndex],
tenant_context: None, # noqa: ARG002
) -> None:
"""index() accepts a generator (any iterable), not just a list."""
for document_index in document_indices:
def chunk_gen() -> Iterator[DocMetadataAwareIndexChunk]:
for i in range(3):
yield make_chunk("test_doc_gen", chunk_id=i)
index_batch_params = IndexBatchParams(
doc_id_to_previous_chunk_cnt={"test_doc_gen": 0},
doc_id_to_new_chunk_cnt={"test_doc_gen": 3},
tenant_id=get_current_tenant_id(),
large_chunks_enabled=False,
)
results = document_index.index(chunk_gen(), index_batch_params)
assert len(results) == 1
record = results.pop()
assert record.document_id == "test_doc_gen"
assert record.already_existed is False

View File

@@ -1,87 +0,0 @@
"""Tests for celery_get_broker_client singleton."""
from collections.abc import Iterator
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.background.celery import celery_redis
@pytest.fixture(autouse=True)
def reset_singleton() -> Iterator[None]:
"""Reset the module-level singleton between tests."""
celery_redis._broker_client = None
celery_redis._broker_url = None
yield
celery_redis._broker_client = None
celery_redis._broker_url = None
def _make_mock_app(broker_url: str = "redis://localhost:6379/15") -> MagicMock:
app = MagicMock()
app.conf.broker_url = broker_url
return app
class TestCeleryGetBrokerClient:
@patch("onyx.background.celery.celery_redis.Redis")
def test_creates_client_on_first_call(self, mock_redis_cls: MagicMock) -> None:
mock_client = MagicMock()
mock_redis_cls.from_url.return_value = mock_client
app = _make_mock_app()
result = celery_redis.celery_get_broker_client(app)
assert result is mock_client
call_args = mock_redis_cls.from_url.call_args
assert call_args[0][0] == "redis://localhost:6379/15"
assert call_args[1]["decode_responses"] is False
assert call_args[1]["socket_keepalive"] is True
assert call_args[1]["retry_on_timeout"] is True
@patch("onyx.background.celery.celery_redis.Redis")
def test_reuses_cached_client(self, mock_redis_cls: MagicMock) -> None:
mock_client = MagicMock()
mock_client.ping.return_value = True
mock_redis_cls.from_url.return_value = mock_client
app = _make_mock_app()
client1 = celery_redis.celery_get_broker_client(app)
client2 = celery_redis.celery_get_broker_client(app)
assert client1 is client2
# from_url called only once
assert mock_redis_cls.from_url.call_count == 1
@patch("onyx.background.celery.celery_redis.Redis")
def test_reconnects_on_ping_failure(self, mock_redis_cls: MagicMock) -> None:
stale_client = MagicMock()
stale_client.ping.side_effect = ConnectionError("disconnected")
fresh_client = MagicMock()
fresh_client.ping.return_value = True
mock_redis_cls.from_url.side_effect = [stale_client, fresh_client]
app = _make_mock_app()
# First call creates stale_client
client1 = celery_redis.celery_get_broker_client(app)
assert client1 is stale_client
# Second call: ping fails, creates fresh_client
client2 = celery_redis.celery_get_broker_client(app)
assert client2 is fresh_client
assert mock_redis_cls.from_url.call_count == 2
@patch("onyx.background.celery.celery_redis.Redis")
def test_uses_broker_url_from_app_config(self, mock_redis_cls: MagicMock) -> None:
mock_redis_cls.from_url.return_value = MagicMock()
app = _make_mock_app("redis://custom-host:6380/3")
celery_redis.celery_get_broker_client(app)
call_args = mock_redis_cls.from_url.call_args
assert call_args[0][0] == "redis://custom-host:6380/3"

View File

@@ -0,0 +1,207 @@
"""Unit tests for multi-model streaming validation and DB helpers.
These are pure unit tests — no real database or LLM calls required.
The validation logic in run_multi_model_stream fires before any external
calls, so we can trigger it with lightweight mocks.
"""
from typing import Any
from unittest.mock import MagicMock
from uuid import uuid4
import pytest
from onyx.configs.constants import MessageType
from onyx.db.chat import set_preferred_response
from onyx.llm.override_models import LLMOverride
from onyx.server.query_and_chat.models import SendMessageRequest
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_request(**kwargs: Any) -> SendMessageRequest:
defaults: dict[str, Any] = {
"message": "hello",
"chat_session_id": uuid4(),
}
defaults.update(kwargs)
return SendMessageRequest(**defaults)
def _make_override(provider: str = "openai", version: str = "gpt-4") -> LLMOverride:
return LLMOverride(model_provider=provider, model_version=version)
def _start_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> None:
"""Advance the generator one step to trigger early validation."""
from onyx.chat.process_message import run_multi_model_stream
user = MagicMock()
user.is_anonymous = False
user.email = "test@example.com"
db = MagicMock()
gen = run_multi_model_stream(req, user, db, overrides)
# Calling next() executes until the first yield OR raises.
# Validation errors are raised before any yield.
next(gen)
# ---------------------------------------------------------------------------
# run_multi_model_stream — validation
# ---------------------------------------------------------------------------
class TestRunMultiModelStreamValidation:
def test_single_override_raises(self) -> None:
"""Exactly 1 override is not multi-model — must raise."""
req = _make_request()
with pytest.raises(ValueError, match="2-3"):
_start_stream(req, [_make_override()])
def test_four_overrides_raises(self) -> None:
"""4 overrides exceeds maximum — must raise."""
req = _make_request()
with pytest.raises(ValueError, match="2-3"):
_start_stream(
req,
[
_make_override("openai", "gpt-4"),
_make_override("anthropic", "claude-3"),
_make_override("google", "gemini-pro"),
_make_override("cohere", "command-r"),
],
)
def test_zero_overrides_raises(self) -> None:
"""Empty override list raises."""
req = _make_request()
with pytest.raises(ValueError, match="2-3"):
_start_stream(req, [])
def test_deep_research_raises(self) -> None:
"""deep_research=True is incompatible with multi-model."""
req = _make_request(deep_research=True)
with pytest.raises(ValueError, match="not supported"):
_start_stream(
req, [_make_override(), _make_override("anthropic", "claude-3")]
)
def test_exactly_two_overrides_is_minimum(self) -> None:
"""Boundary: 1 override fails, 2 passes — ensures fence-post is correct."""
req = _make_request()
# 1 override must fail
with pytest.raises(ValueError, match="2-3"):
_start_stream(req, [_make_override()])
# 2 overrides must NOT raise ValueError (may raise later due to missing session, that's OK)
try:
_start_stream(
req, [_make_override(), _make_override("anthropic", "claude-3")]
)
except ValueError as exc:
pytest.fail(f"2 overrides should pass validation, got ValueError: {exc}")
except Exception:
pass # Any other error means validation passed
# ---------------------------------------------------------------------------
# set_preferred_response — validation (mocked db)
# ---------------------------------------------------------------------------
class TestSetPreferredResponseValidation:
def test_user_message_not_found(self) -> None:
db = MagicMock()
db.get.return_value = None
with pytest.raises(ValueError, match="not found"):
set_preferred_response(
db, user_message_id=999, preferred_assistant_message_id=1
)
def test_wrong_message_type(self) -> None:
"""Cannot set preferred response on a non-USER message."""
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.ASSISTANT # wrong type
db.get.return_value = user_msg
with pytest.raises(ValueError, match="not a user message"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_assistant_message_not_found(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
# First call returns user_msg, second call (for assistant) returns None
db.get.side_effect = [user_msg, None]
with pytest.raises(ValueError, match="not found"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_assistant_not_child_of_user(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
assistant_msg = MagicMock()
assistant_msg.parent_message_id = 999 # different parent
db.get.side_effect = [user_msg, assistant_msg]
with pytest.raises(ValueError, match="not a child"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_valid_call_sets_preferred_response_id(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
assistant_msg = MagicMock()
assistant_msg.parent_message_id = 1 # correct parent
db.get.side_effect = [user_msg, assistant_msg]
set_preferred_response(db, user_message_id=1, preferred_assistant_message_id=2)
assert user_msg.preferred_response_id == 2
assert user_msg.latest_child_message_id == 2
# ---------------------------------------------------------------------------
# LLMOverride — display_name field
# ---------------------------------------------------------------------------
class TestLLMOverrideDisplayName:
def test_display_name_defaults_none(self) -> None:
override = LLMOverride(model_provider="openai", model_version="gpt-4")
assert override.display_name is None
def test_display_name_set(self) -> None:
override = LLMOverride(
model_provider="openai",
model_version="gpt-4",
display_name="GPT-4 Turbo",
)
assert override.display_name == "GPT-4 Turbo"
def test_display_name_serializes(self) -> None:
override = LLMOverride(
model_provider="anthropic",
model_version="claude-opus-4-6",
display_name="Claude Opus",
)
d = override.model_dump()
assert d["display_name"] == "Claude Opus"

View File

@@ -1,876 +0,0 @@
"""Tests for Canvas connector — client, credentials, conversion."""
from datetime import datetime
from datetime import timezone
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.canvas.client import CanvasApiClient
from onyx.connectors.canvas.connector import CanvasConnector
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.error_handling.exceptions import OnyxError
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
FAKE_BASE_URL = "https://myschool.instructure.com"
FAKE_TOKEN = "fake-canvas-token"
def _mock_course(
course_id: int = 1,
name: str = "Intro to CS",
course_code: str = "CS101",
) -> dict[str, Any]:
return {
"id": course_id,
"name": name,
"course_code": course_code,
"created_at": "2025-01-01T00:00:00Z",
"workflow_state": "available",
}
def _build_connector(base_url: str = FAKE_BASE_URL) -> CanvasConnector:
"""Build a connector with mocked credential validation."""
with patch("onyx.connectors.canvas.client.rl_requests") as mock_req:
mock_req.get.return_value = _mock_response(json_data=[_mock_course()])
connector = CanvasConnector(canvas_base_url=base_url)
connector.load_credentials({"canvas_access_token": FAKE_TOKEN})
return connector
def _mock_page(
page_id: int = 10,
title: str = "Syllabus",
updated_at: str = "2025-06-01T12:00:00Z",
) -> dict[str, Any]:
return {
"page_id": page_id,
"url": "syllabus",
"title": title,
"body": "<p>Welcome to the course</p>",
"created_at": "2025-01-15T00:00:00Z",
"updated_at": updated_at,
}
def _mock_assignment(
assignment_id: int = 20,
name: str = "Homework 1",
course_id: int = 1,
updated_at: str = "2025-06-01T12:00:00Z",
) -> dict[str, Any]:
return {
"id": assignment_id,
"name": name,
"description": "<p>Solve these problems</p>",
"html_url": f"{FAKE_BASE_URL}/courses/{course_id}/assignments/{assignment_id}",
"course_id": course_id,
"created_at": "2025-01-20T00:00:00Z",
"updated_at": updated_at,
"due_at": "2025-02-01T23:59:00Z",
}
def _mock_announcement(
announcement_id: int = 30,
title: str = "Class Cancelled",
course_id: int = 1,
posted_at: str = "2025-06-01T12:00:00Z",
) -> dict[str, Any]:
return {
"id": announcement_id,
"title": title,
"message": "<p>No class today</p>",
"html_url": f"{FAKE_BASE_URL}/courses/{course_id}/discussion_topics/{announcement_id}",
"posted_at": posted_at,
}
def _mock_response(
status_code: int = 200,
json_data: Any = None,
link_header: str = "",
) -> MagicMock:
"""Create a mock HTTP response with status, json, and Link header."""
resp = MagicMock()
resp.status_code = status_code
resp.reason = "OK" if status_code < 300 else "Error"
resp.json.return_value = json_data if json_data is not None else []
resp.headers = {"Link": link_header}
return resp
# ---------------------------------------------------------------------------
# CanvasApiClient.__init__ tests
# ---------------------------------------------------------------------------
class TestCanvasApiClientInit:
def test_success(self) -> None:
client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url=FAKE_BASE_URL,
)
expected_base_url = f"{FAKE_BASE_URL}/api/v1"
expected_host = "myschool.instructure.com"
assert client.base_url == expected_base_url
assert client._expected_host == expected_host
def test_normalizes_trailing_slash(self) -> None:
client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url=f"{FAKE_BASE_URL}/",
)
expected_base_url = f"{FAKE_BASE_URL}/api/v1"
assert client.base_url == expected_base_url
def test_normalizes_existing_api_v1(self) -> None:
client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url=f"{FAKE_BASE_URL}/api/v1",
)
expected_base_url = f"{FAKE_BASE_URL}/api/v1"
assert client.base_url == expected_base_url
def test_rejects_non_https_scheme(self) -> None:
with pytest.raises(ValueError, match="must use https"):
CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url="ftp://myschool.instructure.com",
)
def test_rejects_http(self) -> None:
with pytest.raises(ValueError, match="must use https"):
CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url="http://myschool.instructure.com",
)
def test_rejects_missing_host(self) -> None:
with pytest.raises(ValueError, match="must include a valid host"):
CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url="https://",
)
# ---------------------------------------------------------------------------
# CanvasApiClient._build_url tests
# ---------------------------------------------------------------------------
class TestBuildUrl:
def setup_method(self) -> None:
self.client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url=FAKE_BASE_URL,
)
def test_appends_endpoint(self) -> None:
result = self.client._build_url("courses")
expected = f"{FAKE_BASE_URL}/api/v1/courses"
assert result == expected
def test_strips_leading_slash_from_endpoint(self) -> None:
result = self.client._build_url("/courses")
expected = f"{FAKE_BASE_URL}/api/v1/courses"
assert result == expected
# ---------------------------------------------------------------------------
# CanvasApiClient._build_headers tests
# ---------------------------------------------------------------------------
class TestBuildHeaders:
def setup_method(self) -> None:
self.client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url=FAKE_BASE_URL,
)
def test_returns_bearer_auth(self) -> None:
result = self.client._build_headers()
expected = {"Authorization": f"Bearer {FAKE_TOKEN}"}
assert result == expected
# ---------------------------------------------------------------------------
# CanvasApiClient.get tests
# ---------------------------------------------------------------------------
class TestGet:
def setup_method(self) -> None:
self.client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url=FAKE_BASE_URL,
)
@patch("onyx.connectors.canvas.client.rl_requests")
def test_success_returns_json_and_next_url(self, mock_requests: MagicMock) -> None:
next_link = f"<{FAKE_BASE_URL}/api/v1/courses?page=2>; " 'rel="next"'
mock_requests.get.return_value = _mock_response(
json_data=[{"id": 1}], link_header=next_link
)
data, next_url = self.client.get("courses")
expected_data = [{"id": 1}]
expected_next = f"{FAKE_BASE_URL}/api/v1/courses?page=2"
assert data == expected_data
assert next_url == expected_next
@patch("onyx.connectors.canvas.client.rl_requests")
def test_success_no_next_page(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[{"id": 1}])
data, next_url = self.client.get("courses")
assert data == [{"id": 1}]
assert next_url is None
@patch("onyx.connectors.canvas.client.rl_requests")
def test_raises_on_error_status(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(403, {})
with pytest.raises(OnyxError) as exc_info:
self.client.get("courses")
assert exc_info.value.status_code == 403
@patch("onyx.connectors.canvas.client.rl_requests")
def test_raises_on_404(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(404, {})
with pytest.raises(OnyxError) as exc_info:
self.client.get("courses")
assert exc_info.value.status_code == 404
@patch("onyx.connectors.canvas.client.rl_requests")
def test_raises_on_429(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(429, {})
with pytest.raises(OnyxError) as exc_info:
self.client.get("courses")
assert exc_info.value.status_code == 429
@patch("onyx.connectors.canvas.client.rl_requests")
def test_skips_params_when_using_full_url(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[])
full = f"{FAKE_BASE_URL}/api/v1/courses?page=2"
self.client.get(params={"per_page": "100"}, full_url=full)
_, kwargs = mock_requests.get.call_args
assert kwargs["params"] is None
@patch("onyx.connectors.canvas.client.rl_requests")
def test_error_extracts_message_from_error_dict(
self, mock_requests: MagicMock
) -> None:
"""Shape 1: {"error": {"message": "Not authorized"}}"""
mock_requests.get.return_value = _mock_response(
403, {"error": {"message": "Not authorized"}}
)
with pytest.raises(OnyxError) as exc_info:
self.client.get("courses")
result = exc_info.value.detail
expected = "Not authorized"
assert result == expected
@patch("onyx.connectors.canvas.client.rl_requests")
def test_error_extracts_message_from_error_string(
self, mock_requests: MagicMock
) -> None:
"""Shape 2: {"error": "Invalid access token"}"""
mock_requests.get.return_value = _mock_response(
401, {"error": "Invalid access token"}
)
with pytest.raises(OnyxError) as exc_info:
self.client.get("courses")
result = exc_info.value.detail
expected = "Invalid access token"
assert result == expected
@patch("onyx.connectors.canvas.client.rl_requests")
def test_error_extracts_message_from_errors_list(
self, mock_requests: MagicMock
) -> None:
"""Shape 3: {"errors": [{"message": "Invalid query"}]}"""
mock_requests.get.return_value = _mock_response(
400, {"errors": [{"message": "Invalid query"}]}
)
with pytest.raises(OnyxError) as exc_info:
self.client.get("courses")
result = exc_info.value.detail
expected = "Invalid query"
assert result == expected
@patch("onyx.connectors.canvas.client.rl_requests")
def test_error_dict_takes_priority_over_errors_list(
self, mock_requests: MagicMock
) -> None:
"""When both error shapes are present, error dict wins."""
mock_requests.get.return_value = _mock_response(
403, {"error": "Specific error", "errors": [{"message": "Generic"}]}
)
with pytest.raises(OnyxError) as exc_info:
self.client.get("courses")
result = exc_info.value.detail
expected = "Specific error"
assert result == expected
@patch("onyx.connectors.canvas.client.rl_requests")
def test_error_falls_back_to_reason_when_no_json_message(
self, mock_requests: MagicMock
) -> None:
"""Empty error body falls back to response.reason."""
mock_requests.get.return_value = _mock_response(500, {})
with pytest.raises(OnyxError) as exc_info:
self.client.get("courses")
result = exc_info.value.detail
expected = "Error" # from _mock_response's reason for >= 300
assert result == expected
@patch("onyx.connectors.canvas.client.rl_requests")
def test_invalid_json_on_success_raises(self, mock_requests: MagicMock) -> None:
"""Invalid JSON on a 2xx response raises OnyxError."""
resp = MagicMock()
resp.status_code = 200
resp.json.side_effect = ValueError("No JSON")
resp.headers = {"Link": ""}
mock_requests.get.return_value = resp
with pytest.raises(OnyxError, match="Invalid JSON"):
self.client.get("courses")
@patch("onyx.connectors.canvas.client.rl_requests")
def test_invalid_json_on_error_falls_back_to_reason(
self, mock_requests: MagicMock
) -> None:
"""Invalid JSON on a 4xx response falls back to response.reason."""
resp = MagicMock()
resp.status_code = 500
resp.reason = "Internal Server Error"
resp.json.side_effect = ValueError("No JSON")
resp.headers = {"Link": ""}
mock_requests.get.return_value = resp
with pytest.raises(OnyxError) as exc_info:
self.client.get("courses")
result = exc_info.value.detail
expected = "Internal Server Error"
assert result == expected
# ---------------------------------------------------------------------------
# CanvasApiClient.paginate tests
# ---------------------------------------------------------------------------
class TestPaginate:
@patch("onyx.connectors.canvas.client.rl_requests")
def test_single_page(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(
json_data=[{"id": 1}, {"id": 2}]
)
client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url=FAKE_BASE_URL,
)
pages = list(client.paginate("courses"))
assert len(pages) == 1
assert pages[0] == [{"id": 1}, {"id": 2}]
@patch("onyx.connectors.canvas.client.rl_requests")
def test_two_pages(self, mock_requests: MagicMock) -> None:
next_link = f'<{FAKE_BASE_URL}/api/v1/courses?page=2>; rel="next"'
page1 = _mock_response(json_data=[{"id": 1}], link_header=next_link)
page2 = _mock_response(json_data=[{"id": 2}])
mock_requests.get.side_effect = [page1, page2]
client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url=FAKE_BASE_URL,
)
pages = list(client.paginate("courses"))
assert len(pages) == 2
assert pages[0] == [{"id": 1}]
assert pages[1] == [{"id": 2}]
@patch("onyx.connectors.canvas.client.rl_requests")
def test_empty_response(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[])
client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url=FAKE_BASE_URL,
)
pages = list(client.paginate("courses"))
assert pages == []
# ---------------------------------------------------------------------------
# CanvasApiClient._parse_next_link tests
# ---------------------------------------------------------------------------
class TestParseNextLink:
def setup_method(self) -> None:
self.client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url="https://canvas.example.com",
)
def test_found(self) -> None:
header = '<https://canvas.example.com/api/v1/courses?page=2>; rel="next"'
result = self.client._parse_next_link(header)
expected = "https://canvas.example.com/api/v1/courses?page=2"
assert result == expected
def test_not_found(self) -> None:
header = '<https://canvas.example.com/api/v1/courses?page=1>; rel="current"'
result = self.client._parse_next_link(header)
assert result is None
def test_empty(self) -> None:
result = self.client._parse_next_link("")
assert result is None
def test_multiple_rels(self) -> None:
header = (
'<https://canvas.example.com/api/v1/courses?page=1>; rel="current", '
'<https://canvas.example.com/api/v1/courses?page=2>; rel="next"'
)
result = self.client._parse_next_link(header)
expected = "https://canvas.example.com/api/v1/courses?page=2"
assert result == expected
def test_rejects_host_mismatch(self) -> None:
header = '<https://evil.example.com/api/v1/courses?page=2>; rel="next"'
with pytest.raises(OnyxError, match="unexpected host"):
self.client._parse_next_link(header)
def test_rejects_non_https_link(self) -> None:
header = '<http://canvas.example.com/api/v1/courses?page=2>; rel="next"'
with pytest.raises(OnyxError, match="must use https"):
self.client._parse_next_link(header)
# ---------------------------------------------------------------------------
# CanvasConnector — credential loading
# ---------------------------------------------------------------------------
class TestLoadCredentials:
def _assert_load_credentials_raises(
self,
status_code: int,
expected_error: type[Exception],
mock_requests: MagicMock,
) -> None:
"""Helper: assert load_credentials raises expected_error for a given status."""
mock_requests.get.return_value = _mock_response(status_code, {})
connector = CanvasConnector(canvas_base_url=FAKE_BASE_URL)
with pytest.raises(expected_error):
connector.load_credentials({"canvas_access_token": FAKE_TOKEN})
@patch("onyx.connectors.canvas.client.rl_requests")
def test_load_credentials_success(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[_mock_course()])
connector = CanvasConnector(canvas_base_url=FAKE_BASE_URL)
result = connector.load_credentials({"canvas_access_token": FAKE_TOKEN})
assert result is None
assert connector._canvas_client is not None
def test_canvas_client_raises_without_credentials(self) -> None:
connector = CanvasConnector(canvas_base_url=FAKE_BASE_URL)
with pytest.raises(ConnectorMissingCredentialError):
_ = connector.canvas_client
@patch("onyx.connectors.canvas.client.rl_requests")
def test_load_credentials_invalid_token(self, mock_requests: MagicMock) -> None:
self._assert_load_credentials_raises(401, CredentialExpiredError, mock_requests)
@patch("onyx.connectors.canvas.client.rl_requests")
def test_load_credentials_insufficient_permissions(
self, mock_requests: MagicMock
) -> None:
self._assert_load_credentials_raises(
403, InsufficientPermissionsError, mock_requests
)
# ---------------------------------------------------------------------------
# CanvasConnector — URL normalization
# ---------------------------------------------------------------------------
class TestConnectorUrlNormalization:
def test_strips_api_v1_suffix(self) -> None:
connector = _build_connector(base_url=f"{FAKE_BASE_URL}/api/v1")
result = connector.canvas_base_url
expected = FAKE_BASE_URL
assert result == expected
def test_strips_trailing_slash(self) -> None:
connector = _build_connector(base_url=f"{FAKE_BASE_URL}/")
result = connector.canvas_base_url
expected = FAKE_BASE_URL
assert result == expected
def test_no_change_for_clean_url(self) -> None:
connector = _build_connector(base_url=FAKE_BASE_URL)
result = connector.canvas_base_url
expected = FAKE_BASE_URL
assert result == expected
# ---------------------------------------------------------------------------
# CanvasConnector — document conversion
# ---------------------------------------------------------------------------
class TestDocumentConversion:
def setup_method(self) -> None:
self.connector = _build_connector()
def test_convert_page_to_document(self) -> None:
from onyx.connectors.canvas.connector import CanvasPage
page = CanvasPage(
page_id=10,
url="syllabus",
title="Syllabus",
body="<p>Welcome</p>",
created_at="2025-01-15T00:00:00Z",
updated_at="2025-06-01T12:00:00Z",
course_id=1,
)
doc = self.connector._convert_page_to_document(page)
expected_id = "canvas-page-1-10"
expected_metadata = {"course_id": "1", "type": "page"}
expected_updated_at = datetime(2025, 6, 1, 12, 0, tzinfo=timezone.utc)
assert doc.id == expected_id
assert doc.source == DocumentSource.CANVAS
assert doc.semantic_identifier == "Syllabus"
assert doc.metadata == expected_metadata
assert doc.sections[0].link is not None
assert f"{FAKE_BASE_URL}/courses/1/pages/syllabus" in doc.sections[0].link
assert doc.doc_updated_at == expected_updated_at
def test_convert_page_without_body(self) -> None:
from onyx.connectors.canvas.connector import CanvasPage
page = CanvasPage(
page_id=11,
url="empty-page",
title="Empty Page",
body=None,
created_at="2025-01-15T00:00:00Z",
updated_at="2025-06-01T12:00:00Z",
course_id=1,
)
doc = self.connector._convert_page_to_document(page)
section_text = doc.sections[0].text
assert section_text is not None
assert "Empty Page" in section_text
assert "<p>" not in section_text
def test_convert_assignment_to_document(self) -> None:
from onyx.connectors.canvas.connector import CanvasAssignment
assignment = CanvasAssignment(
id=20,
name="Homework 1",
description="<p>Solve these</p>",
html_url=f"{FAKE_BASE_URL}/courses/1/assignments/20",
course_id=1,
created_at="2025-01-20T00:00:00Z",
updated_at="2025-06-01T12:00:00Z",
due_at="2025-02-01T23:59:00Z",
)
doc = self.connector._convert_assignment_to_document(assignment)
expected_id = "canvas-assignment-1-20"
expected_due_text = "Due: February 01, 2025 23:59 UTC"
assert doc.id == expected_id
assert doc.source == DocumentSource.CANVAS
assert doc.semantic_identifier == "Homework 1"
assert doc.sections[0].text is not None
assert expected_due_text in doc.sections[0].text
def test_convert_assignment_without_description(self) -> None:
from onyx.connectors.canvas.connector import CanvasAssignment
assignment = CanvasAssignment(
id=21,
name="Quiz 1",
description=None,
html_url=f"{FAKE_BASE_URL}/courses/1/assignments/21",
course_id=1,
created_at="2025-01-20T00:00:00Z",
updated_at="2025-06-01T12:00:00Z",
due_at=None,
)
doc = self.connector._convert_assignment_to_document(assignment)
section_text = doc.sections[0].text
assert section_text is not None
assert "Quiz 1" in section_text
assert "Due:" not in section_text
def test_convert_announcement_to_document(self) -> None:
from onyx.connectors.canvas.connector import CanvasAnnouncement
announcement = CanvasAnnouncement(
id=30,
title="Class Cancelled",
message="<p>No class today</p>",
html_url=f"{FAKE_BASE_URL}/courses/1/discussion_topics/30",
posted_at="2025-06-01T12:00:00Z",
course_id=1,
)
doc = self.connector._convert_announcement_to_document(announcement)
expected_id = "canvas-announcement-1-30"
expected_updated_at = datetime(2025, 6, 1, 12, 0, tzinfo=timezone.utc)
assert doc.id == expected_id
assert doc.source == DocumentSource.CANVAS
assert doc.semantic_identifier == "Class Cancelled"
assert doc.doc_updated_at == expected_updated_at
def test_convert_announcement_without_posted_at(self) -> None:
from onyx.connectors.canvas.connector import CanvasAnnouncement
announcement = CanvasAnnouncement(
id=31,
title="TBD Announcement",
message=None,
html_url=f"{FAKE_BASE_URL}/courses/1/discussion_topics/31",
posted_at=None,
course_id=1,
)
doc = self.connector._convert_announcement_to_document(announcement)
assert doc.doc_updated_at is None
# ---------------------------------------------------------------------------
# CanvasConnector — validate_connector_settings
# ---------------------------------------------------------------------------
class TestValidateConnectorSettings:
def _assert_validate_raises(
self,
status_code: int,
expected_error: type[Exception],
mock_requests: MagicMock,
) -> None:
"""Helper: assert validate_connector_settings raises expected_error."""
success_resp = _mock_response(json_data=[_mock_course()])
fail_resp = _mock_response(status_code, {})
mock_requests.get.side_effect = [success_resp, fail_resp]
connector = CanvasConnector(canvas_base_url=FAKE_BASE_URL)
connector.load_credentials({"canvas_access_token": FAKE_TOKEN})
with pytest.raises(expected_error):
connector.validate_connector_settings()
@patch("onyx.connectors.canvas.client.rl_requests")
def test_validate_success(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[_mock_course()])
connector = _build_connector()
connector.validate_connector_settings() # should not raise
@patch("onyx.connectors.canvas.client.rl_requests")
def test_validate_expired_credential(self, mock_requests: MagicMock) -> None:
self._assert_validate_raises(401, CredentialExpiredError, mock_requests)
@patch("onyx.connectors.canvas.client.rl_requests")
def test_validate_insufficient_permissions(self, mock_requests: MagicMock) -> None:
self._assert_validate_raises(403, InsufficientPermissionsError, mock_requests)
@patch("onyx.connectors.canvas.client.rl_requests")
def test_validate_rate_limited(self, mock_requests: MagicMock) -> None:
self._assert_validate_raises(429, ConnectorValidationError, mock_requests)
@patch("onyx.connectors.canvas.client.rl_requests")
def test_validate_unexpected_error(self, mock_requests: MagicMock) -> None:
self._assert_validate_raises(500, UnexpectedValidationError, mock_requests)
# ---------------------------------------------------------------------------
# _list_* pagination tests
# ---------------------------------------------------------------------------
class TestListCourses:
@patch("onyx.connectors.canvas.client.rl_requests")
def test_single_page(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(
json_data=[_mock_course(1), _mock_course(2, "CS201", "Data Structures")]
)
connector = _build_connector()
result = connector._list_courses()
assert len(result) == 2
assert result[0].id == 1
assert result[1].id == 2
@patch("onyx.connectors.canvas.client.rl_requests")
def test_empty_response(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[])
connector = _build_connector()
result = connector._list_courses()
assert result == []
class TestListPages:
@patch("onyx.connectors.canvas.client.rl_requests")
def test_single_page(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(
json_data=[_mock_page(10), _mock_page(11, "Notes")]
)
connector = _build_connector()
result = connector._list_pages(course_id=1)
assert len(result) == 2
assert result[0].page_id == 10
assert result[1].page_id == 11
@patch("onyx.connectors.canvas.client.rl_requests")
def test_empty_response(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[])
connector = _build_connector()
result = connector._list_pages(course_id=1)
assert result == []
class TestListAssignments:
@patch("onyx.connectors.canvas.client.rl_requests")
def test_single_page(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(
json_data=[_mock_assignment(20), _mock_assignment(21, "Quiz 1")]
)
connector = _build_connector()
result = connector._list_assignments(course_id=1)
assert len(result) == 2
assert result[0].id == 20
assert result[1].id == 21
@patch("onyx.connectors.canvas.client.rl_requests")
def test_empty_response(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[])
connector = _build_connector()
result = connector._list_assignments(course_id=1)
assert result == []
class TestListAnnouncements:
@patch("onyx.connectors.canvas.client.rl_requests")
def test_single_page(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(
json_data=[_mock_announcement(30), _mock_announcement(31, "Update")]
)
connector = _build_connector()
result = connector._list_announcements(course_id=1)
assert len(result) == 2
assert result[0].id == 30
assert result[1].id == 31
@patch("onyx.connectors.canvas.client.rl_requests")
def test_empty_response(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[])
connector = _build_connector()
result = connector._list_announcements(course_id=1)
assert result == []

View File

@@ -1,147 +0,0 @@
from typing import Any
from unittest.mock import MagicMock
import pytest
import requests
from jira import JIRA
from jira.resources import Issue
from onyx.connectors.jira.connector import bulk_fetch_issues
def _make_raw_issue(issue_id: str) -> dict[str, Any]:
return {
"id": issue_id,
"key": f"TEST-{issue_id}",
"fields": {"summary": f"Issue {issue_id}"},
}
def _mock_jira_client() -> MagicMock:
mock = MagicMock(spec=JIRA)
mock._options = {"server": "https://jira.example.com"}
mock._session = MagicMock()
mock._get_url = MagicMock(
return_value="https://jira.example.com/rest/api/3/issue/bulkfetch"
)
return mock
def test_bulk_fetch_success() -> None:
"""Happy path: all issues fetched in one request."""
client = _mock_jira_client()
raw = [_make_raw_issue("1"), _make_raw_issue("2"), _make_raw_issue("3")]
resp = MagicMock()
resp.json.return_value = {"issues": raw}
client._session.post.return_value = resp
result = bulk_fetch_issues(client, ["1", "2", "3"])
assert len(result) == 3
assert all(isinstance(r, Issue) for r in result)
client._session.post.assert_called_once()
def test_bulk_fetch_splits_on_json_error() -> None:
"""When the full batch fails with JSONDecodeError, sub-batches succeed."""
client = _mock_jira_client()
call_count = 0
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
nonlocal call_count
call_count += 1
ids = json["issueIdsOrKeys"]
if len(ids) > 2:
resp = MagicMock()
resp.json.side_effect = requests.exceptions.JSONDecodeError(
"Expecting ',' delimiter", "doc", 2294125
)
return resp
resp = MagicMock()
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
return resp
client._session.post.side_effect = _post_side_effect
result = bulk_fetch_issues(client, ["1", "2", "3", "4"])
assert len(result) == 4
returned_ids = {r.raw["id"] for r in result}
assert returned_ids == {"1", "2", "3", "4"}
assert call_count > 1
def test_bulk_fetch_raises_on_single_unfetchable_issue() -> None:
"""A single issue that always fails JSON decode raises after splitting."""
client = _mock_jira_client()
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
ids = json["issueIdsOrKeys"]
if "bad" in ids:
resp = MagicMock()
resp.json.side_effect = requests.exceptions.JSONDecodeError(
"Expecting ',' delimiter", "doc", 100
)
return resp
resp = MagicMock()
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
return resp
client._session.post.side_effect = _post_side_effect
with pytest.raises(requests.exceptions.JSONDecodeError):
bulk_fetch_issues(client, ["1", "bad", "2"])
def test_bulk_fetch_non_json_error_propagates() -> None:
"""Non-JSONDecodeError exceptions still propagate."""
client = _mock_jira_client()
resp = MagicMock()
resp.json.side_effect = ValueError("something else broke")
client._session.post.return_value = resp
try:
bulk_fetch_issues(client, ["1"])
assert False, "Expected ValueError to propagate"
except ValueError:
pass
def test_bulk_fetch_with_fields() -> None:
"""Fields parameter is forwarded correctly."""
client = _mock_jira_client()
raw = [_make_raw_issue("1")]
resp = MagicMock()
resp.json.return_value = {"issues": raw}
client._session.post.return_value = resp
bulk_fetch_issues(client, ["1"], fields="summary,description")
call_payload = client._session.post.call_args[1]["json"]
assert call_payload["fields"] == ["summary", "description"]
def test_bulk_fetch_recursive_splitting_raises_on_bad_issue() -> None:
"""With a 6-issue batch where one is bad, recursion isolates it and raises."""
client = _mock_jira_client()
bad_id = "BAD"
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
ids = json["issueIdsOrKeys"]
if bad_id in ids:
resp = MagicMock()
resp.json.side_effect = requests.exceptions.JSONDecodeError(
"truncated", "doc", 999
)
return resp
resp = MagicMock()
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
return resp
client._session.post.side_effect = _post_side_effect
with pytest.raises(requests.exceptions.JSONDecodeError):
bulk_fetch_issues(client, ["1", "2", bad_id, "3", "4", "5"])

View File

@@ -1,5 +1,3 @@
from datetime import datetime
from datetime import timezone
from unittest.mock import MagicMock
from unittest.mock import patch
@@ -33,7 +31,6 @@ def mock_jira_cc_pair(
"jira_base_url": jira_base_url,
"project_key": project_key,
}
mock_cc_pair.connector.indexing_start = None
return mock_cc_pair
@@ -68,75 +65,3 @@ def test_jira_permission_sync(
fetch_all_existing_docs_ids_fn=mock_fetch_all_existing_docs_ids_fn,
):
print(doc)
def test_jira_doc_sync_passes_indexing_start(
jira_connector: JiraConnector,
mock_jira_cc_pair: MagicMock,
mock_fetch_all_existing_docs_fn: MagicMock,
mock_fetch_all_existing_docs_ids_fn: MagicMock,
) -> None:
"""Verify that generic_doc_sync derives indexing_start from cc_pair
and forwards it to retrieve_all_slim_docs_perm_sync."""
indexing_start_dt = datetime(2025, 6, 1, tzinfo=timezone.utc)
mock_jira_cc_pair.connector.indexing_start = indexing_start_dt
with patch("onyx.connectors.jira.connector.build_jira_client") as mock_build_client:
mock_build_client.return_value = jira_connector._jira_client
assert jira_connector._jira_client is not None
jira_connector._jira_client._options = MagicMock()
jira_connector._jira_client._options.return_value = {
"rest_api_version": JIRA_SERVER_API_VERSION
}
with patch.object(
type(jira_connector),
"retrieve_all_slim_docs_perm_sync",
return_value=iter([]),
) as mock_retrieve:
list(
jira_doc_sync(
cc_pair=mock_jira_cc_pair,
fetch_all_existing_docs_fn=mock_fetch_all_existing_docs_fn,
fetch_all_existing_docs_ids_fn=mock_fetch_all_existing_docs_ids_fn,
)
)
mock_retrieve.assert_called_once()
call_kwargs = mock_retrieve.call_args
assert call_kwargs.kwargs["start"] == indexing_start_dt.timestamp()
def test_jira_doc_sync_passes_none_when_no_indexing_start(
jira_connector: JiraConnector,
mock_jira_cc_pair: MagicMock,
mock_fetch_all_existing_docs_fn: MagicMock,
mock_fetch_all_existing_docs_ids_fn: MagicMock,
) -> None:
"""Verify that indexing_start is None when the connector has no indexing_start set."""
mock_jira_cc_pair.connector.indexing_start = None
with patch("onyx.connectors.jira.connector.build_jira_client") as mock_build_client:
mock_build_client.return_value = jira_connector._jira_client
assert jira_connector._jira_client is not None
jira_connector._jira_client._options = MagicMock()
jira_connector._jira_client._options.return_value = {
"rest_api_version": JIRA_SERVER_API_VERSION
}
with patch.object(
type(jira_connector),
"retrieve_all_slim_docs_perm_sync",
return_value=iter([]),
) as mock_retrieve:
list(
jira_doc_sync(
cc_pair=mock_jira_cc_pair,
fetch_all_existing_docs_fn=mock_fetch_all_existing_docs_fn,
fetch_all_existing_docs_ids_fn=mock_fetch_all_existing_docs_ids_fn,
)
)
mock_retrieve.assert_called_once()
call_kwargs = mock_retrieve.call_args
assert call_kwargs.kwargs["start"] is None

View File

@@ -272,13 +272,13 @@ class TestUpsertVoiceProvider:
class TestDeleteVoiceProvider:
"""Tests for delete_voice_provider."""
def test_hard_deletes_provider_when_found(self, mock_db_session: MagicMock) -> None:
def test_soft_deletes_provider_when_found(self, mock_db_session: MagicMock) -> None:
provider = _make_voice_provider(id=1)
mock_db_session.scalar.return_value = provider
delete_voice_provider(mock_db_session, 1)
mock_db_session.delete.assert_called_once_with(provider)
assert provider.deleted is True
mock_db_session.flush.assert_called_once()
def test_does_nothing_when_provider_not_found(

Some files were not shown because too many files have changed in this diff Show More