Compare commits

..

8 Commits

Author SHA1 Message Date
Nik
fe942daf30 refactor(be): migrate search_settings & image_generation to OnyxError
- Replace all HTTPException raises with OnyxError in:
  - backend/onyx/server/manage/search_settings.py (6 occurrences)
  - backend/onyx/server/manage/image_generation/api.py (17 occurrences)
- Update integration tests for new response shape (detail → message)
- "already exists" now returns 409 (DUPLICATE_RESOURCE) instead of 400

Error code mappings:
- 501 NOT_IMPLEMENTED → OnyxErrorCode.NOT_IMPLEMENTED
- 400 validation → OnyxErrorCode.VALIDATION_ERROR
- 404 not found → OnyxErrorCode.NOT_FOUND
- 400 duplicate → OnyxErrorCode.DUPLICATE_RESOURCE (409)
- 400/401 bad creds → OnyxErrorCode.CREDENTIAL_INVALID
2026-03-04 14:17:27 -08:00
Nik
edad23a7b7 fix(test): update persona access tests for OnyxError response shape 2026-03-04 12:55:33 -08:00
Nik
94ab65e47d fix(test): update integration tests for OnyxError response shape
- "detail" → "message" in response body assertions
- 400 → 409 for duplicate provider (DUPLICATE_RESOURCE)
2026-03-04 12:03:09 -08:00
Nik
cce2a2c2d4 fix(test): update api_base tests to expect OnyxError instead of HTTPException 2026-03-04 11:37:00 -08:00
Nik
628b88740f fix(test): update LLM provider tests to expect OnyxError
Update 3 test assertions from pytest.raises(HTTPException) to
pytest.raises(OnyxError) to match the migrated error handling.
2026-03-04 11:15:30 -08:00
Nik
06cf7c5bdd fix(be): address Greptile review feedback
- Bedrock credential/config errors → CREDENTIAL_INVALID (not BAD_GATEWAY)
- Move guard check outside try block in delete_llm_provider for clarity
2026-03-04 10:52:31 -08:00
Nik
b9c7c1cd3b fix(be): use semantic error codes instead of blanket VALIDATION_ERROR
- "already exists" → DUPLICATE_RESOURCE (409)
- "does not exist" on update → NOT_FOUND (404)
- External service failures (Bedrock/Ollama/OpenRouter) → BAD_GATEWAY (502)
2026-03-04 10:33:02 -08:00
Nik
4dfc64d6cf refactor(be): migrate LLM & embedding management to OnyxError
Replace all HTTPException raises in manage/llm/api.py (25) and
manage/embedding/api.py (2) with OnyxError using standardized error
codes. Part of the ongoing OnyxError rollout.
2026-03-04 10:17:25 -08:00
651 changed files with 8785 additions and 27432 deletions

View File

@@ -1,186 +0,0 @@
---
name: onyx-cli
description: Query the Onyx knowledge base using the onyx-cli command. Use when the user wants to search company documents, ask questions about internal knowledge, query connected data sources, or look up information stored in Onyx.
---
# Onyx CLI — Agent Tool
Onyx is an enterprise search and Gen-AI platform that connects to company documents, apps, and people. The `onyx-cli` CLI provides non-interactive commands to query the Onyx knowledge base and list available agents.
## Prerequisites
### 1. Check if installed
```bash
which onyx-cli
```
### 2. Install (if needed)
**Primary — pip:**
```bash
pip install onyx-cli
```
**From source (Go):**
```bash
cd cli && go build -o onyx-cli . && sudo mv onyx-cli /usr/local/bin/
```
### 3. Check if configured
```bash
onyx-cli validate-config
```
This checks the config file exists, API key is present, and tests the server connection via `/api/me`. Exit code 0 on success, non-zero with a descriptive error on failure.
If unconfigured, you have two options:
**Option A — Interactive setup (requires user input):**
```bash
onyx-cli configure
```
This prompts for the Onyx server URL and API key, tests the connection, and saves config.
**Option B — Environment variables (non-interactive, preferred for agents):**
```bash
export ONYX_SERVER_URL="https://your-onyx-server.com" # default: https://cloud.onyx.app
export ONYX_API_KEY="your-api-key"
```
Environment variables override the config file. If these are set, no config file is needed.
| Variable | Required | Description |
|----------|----------|-------------|
| `ONYX_SERVER_URL` | No | Onyx server base URL (default: `https://cloud.onyx.app`) |
| `ONYX_API_KEY` | Yes | API key for authentication |
| `ONYX_PERSONA_ID` | No | Default agent/persona ID |
If neither the config file nor environment variables are set, tell the user that `onyx-cli` needs to be configured and ask them to either:
- Run `onyx-cli configure` interactively, or
- Set `ONYX_SERVER_URL` and `ONYX_API_KEY` environment variables
## Commands
### Validate configuration
```bash
onyx-cli validate-config
```
Checks config file exists, API key is present, and tests the server connection. Use this before `ask` or `agents` to confirm the CLI is properly set up.
### List available agents
```bash
onyx-cli agents
```
Prints a table of agent IDs, names, and descriptions. Use `--json` for structured output:
```bash
onyx-cli agents --json
```
Use agent IDs with `ask --agent-id` to query a specific agent.
### Basic query (plain text output)
```bash
onyx-cli ask "What is our company's PTO policy?"
```
Streams the answer as plain text to stdout. Exit code 0 on success, non-zero on error.
### JSON output (structured events)
```bash
onyx-cli ask --json "What authentication methods do we support?"
```
Outputs JSON-encoded parsed stream events (one object per line). Key event objects include message deltas, stop, errors, search-start, and citation payloads.
Each line is a JSON object with this envelope:
```json
{"type": "<event_type>", "event": { ... }}
```
| Event Type | Description |
|------------|-------------|
| `message_delta` | Content token — concatenate all `content` fields for the full answer |
| `stop` | Stream complete |
| `error` | Error with `error` message field |
| `search_tool_start` | Onyx started searching documents |
| `citation_info` | Source citation — see shape below |
`citation_info` event shape:
```json
{
"type": "citation_info",
"event": {
"citation_number": 1,
"document_id": "abc123def456",
"placement": {"turn_index": 0, "tab_index": 0, "sub_turn_index": null}
}
}
```
`placement` is metadata about where in the conversation the citation appeared and can be ignored for most use cases.
### Specify an agent
```bash
onyx-cli ask --agent-id 5 "Summarize our Q4 roadmap"
```
Uses a specific Onyx agent/persona instead of the default.
### All flags
| Flag | Type | Description |
|------|------|-------------|
| `--agent-id` | int | Agent ID to use (overrides default) |
| `--json` | bool | Output raw NDJSON events instead of plain text |
## Statelessness
Each `onyx-cli ask` call creates an independent chat session. There is no built-in way to chain context across multiple `ask` invocations — every call starts fresh. If you need multi-turn conversation with memory, use the interactive TUI (`onyx-cli` or `onyx-cli chat`) instead.
## When to Use
Use `onyx-cli ask` when:
- The user asks about company-specific information (policies, docs, processes)
- You need to search internal knowledge bases or connected data sources
- The user references Onyx, asks you to "search Onyx", or wants to query their documents
- You need context from company wikis, Confluence, Google Drive, Slack, or other connected sources
Do NOT use when:
- The question is about general programming knowledge (use your own knowledge)
- The user is asking about code in the current repository (use grep/read tools)
- The user hasn't mentioned Onyx and the question doesn't require internal company data
## Examples
```bash
# Simple question
onyx-cli ask "What are the steps to deploy to production?"
# Get structured output for parsing
onyx-cli ask --json "List all active API integrations"
# Use a specialized agent
onyx-cli ask --agent-id 3 "What were the action items from last week's standup?"
# Pipe the answer into another command
onyx-cli ask "What is the database schema for users?" | head -20
```

View File

@@ -182,52 +182,8 @@ jobs:
title: "🚨 Version Tag Check Failed"
ref-name: ${{ github.ref_name }}
# Create GitHub release first, before desktop builds start.
# This ensures all desktop matrix jobs upload to the same release instead of
# racing to create duplicate releases.
create-release:
needs: determine-builds
if: needs.determine-builds.outputs.build-desktop == 'true'
runs-on: ubuntu-slim
timeout-minutes: 10
permissions:
contents: write
outputs:
release-id: ${{ steps.create-release.outputs.id }}
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Determine release tag
id: release-tag
env:
IS_TEST_RUN: ${{ needs.determine-builds.outputs.is-test-run }}
SHORT_SHA: ${{ needs.determine-builds.outputs.short-sha }}
run: |
if [ "${IS_TEST_RUN}" == "true" ]; then
echo "tag=v0.0.0-dev+${SHORT_SHA}" >> "$GITHUB_OUTPUT"
else
echo "tag=${GITHUB_REF_NAME}" >> "$GITHUB_OUTPUT"
fi
- name: Create GitHub Release
id: create-release
uses: softprops/action-gh-release@da05d552573ad5aba039eaac05058a918a7bf631 # ratchet:softprops/action-gh-release@v2
with:
tag_name: ${{ steps.release-tag.outputs.tag }}
name: ${{ steps.release-tag.outputs.tag }}
body: "See the assets to download this version and install."
draft: true
prerelease: false
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
build-desktop:
needs:
- determine-builds
- create-release
needs: determine-builds
if: needs.determine-builds.outputs.build-desktop == 'true'
permissions:
id-token: write
@@ -252,12 +208,12 @@ jobs:
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6.0.2
with:
# NOTE: persist-credentials is needed for tauri-action to upload assets to GitHub releases.
# NOTE: persist-credentials is needed for tauri-action to create GitHub releases.
persist-credentials: true # zizmor: ignore[artipacked]
- name: Configure AWS credentials
if: startsWith(matrix.platform, 'macos-')
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
@@ -397,9 +353,11 @@ jobs:
APPLE_SIGNING_IDENTITY: ${{ env.CERT_ID }}
APPLE_TEAM_ID: ${{ env.APPLE_TEAM_ID }}
with:
# Use the release created by the create-release job to avoid race conditions
# when multiple matrix jobs try to create/update the same release simultaneously
releaseId: ${{ needs.create-release.outputs.release-id }}
tagName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
releaseName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
releaseBody: "See the assets to download this version and install."
releaseDraft: true
prerelease: false
assetNamePattern: "[name]_[arch][ext]"
args: ${{ matrix.args }}
@@ -426,7 +384,7 @@ jobs:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
@@ -500,7 +458,7 @@ jobs:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
@@ -569,7 +527,7 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
@@ -639,7 +597,7 @@ jobs:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
@@ -721,7 +679,7 @@ jobs:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
@@ -798,7 +756,7 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
@@ -865,7 +823,7 @@ jobs:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
@@ -938,7 +896,7 @@ jobs:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
@@ -1006,7 +964,7 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
@@ -1076,7 +1034,7 @@ jobs:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
@@ -1149,7 +1107,7 @@ jobs:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
@@ -1218,7 +1176,7 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
@@ -1288,7 +1246,7 @@ jobs:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
@@ -1368,7 +1326,7 @@ jobs:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
@@ -1442,7 +1400,7 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
@@ -1507,7 +1465,7 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
@@ -1562,7 +1520,7 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
@@ -1622,7 +1580,7 @@ jobs:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
@@ -1679,7 +1637,7 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2

View File

@@ -15,8 +15,7 @@ permissions:
jobs:
provider-chat-test:
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
secrets:
AWS_OIDC_ROLE_ARN: ${{ secrets.AWS_OIDC_ROLE_ARN }}
secrets: inherit
permissions:
contents: read
id-token: write

View File

@@ -6,13 +6,11 @@ on:
- main
permissions:
contents: read
contents: write
pull-requests: write
jobs:
cherry-pick-to-latest-release:
permissions:
contents: write
pull-requests: write
outputs:
should_cherrypick: ${{ steps.gate.outputs.should_cherrypick }}
pr_number: ${{ steps.gate.outputs.pr_number }}

View File

@@ -57,7 +57,7 @@ jobs:
cache-dependency-path: ./desktop/package-lock.json
- name: Setup Rust
uses: dtolnay/rust-toolchain@efa25f7f19611383d5b0ccf2d1c8914531636bf9
uses: dtolnay/rust-toolchain@4be9e76fd7c4901c61fb841f559994984270fce7
with:
toolchain: stable
targets: ${{ matrix.target }}

View File

@@ -1,56 +0,0 @@
name: Golang Tests
concurrency:
group: Golang-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
merge_group:
pull_request:
branches:
- main
- "release/**"
push:
tags:
- "v*.*.*"
permissions: {}
env:
GO_VERSION: "1.26"
jobs:
detect-modules:
runs-on: ubuntu-latest
timeout-minutes: 10
outputs:
modules: ${{ steps.set-modules.outputs.modules }}
steps:
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8
with:
persist-credentials: false
- id: set-modules
run: echo "modules=$(find . -name 'go.mod' -exec dirname {} \; | jq -Rc '[.,inputs]')" >> "$GITHUB_OUTPUT"
golang:
needs: detect-modules
runs-on: ubuntu-latest
timeout-minutes: 10
strategy:
matrix:
modules: ${{ fromJSON(needs.detect-modules.outputs.modules) }}
steps:
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
with:
persist-credentials: false
- uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # zizmor: ignore[cache-poisoning]
with:
go-version: ${{ env.GO_VERSION }}
cache-dependency-path: "**/go.sum"
- run: go mod tidy
working-directory: ${{ matrix.modules }}
- run: git diff --exit-code go.mod go.sum
working-directory: ${{ matrix.modules }}
- run: go test ./...
working-directory: ${{ matrix.modules }}

View File

@@ -71,7 +71,7 @@ jobs:
- name: Create kind cluster
if: steps.list-changed.outputs.changed == 'true'
uses: helm/kind-action@ef37e7f390d99f746eb8b610417061a60e82a6cc # ratchet:helm/kind-action@v1.14.0
uses: helm/kind-action@92086f6be054225fa813e0a4b13787fc9088faab # ratchet:helm/kind-action@v1.13.0
- name: Pre-install cluster status check
if: steps.list-changed.outputs.changed == 'true'

View File

@@ -316,7 +316,6 @@ jobs:
# Base config shared by both editions
cat <<EOF > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore
OPENSEARCH_FOR_ONYX_ENABLED=false
AUTH_TYPE=basic
POSTGRES_POOL_PRE_PING=true
POSTGRES_USE_NULL_POOL=true
@@ -336,6 +335,7 @@ jobs:
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
LICENSE_ENFORCEMENT_ENABLED=false
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
EOF
fi
@@ -419,7 +419,6 @@ jobs:
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e ENABLE_OPENSEARCH_INDEXING_FOR_ONYX=false \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
@@ -639,7 +638,6 @@ jobs:
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
DEV_MODE=true \
OPENSEARCH_FOR_ONYX_ENABLED=false \
docker compose -f docker-compose.multitenant-dev.yml up \
relational_db \
index \
@@ -694,7 +692,6 @@ jobs:
-e POSTGRES_DB=postgres \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e ENABLE_OPENSEARCH_INDEXING_FOR_ONYX=false \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \

View File

@@ -31,7 +31,7 @@ jobs:
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
with:
node-version: 22
cache: "npm" # zizmor: ignore[cache-poisoning] test-only workflow; no deploy artifacts
cache: "npm"
cache-dependency-path: ./web/package-lock.json
- name: Install node dependencies

View File

@@ -12,9 +12,6 @@ on:
push:
tags:
- "v*.*.*"
# TODO: Remove this if we enable merge-queues for release branches.
branches:
- "release/**"
permissions:
contents: read
@@ -271,11 +268,10 @@ jobs:
persist-credentials: false
- name: Setup node
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
with:
node-version: 22
cache: "npm" # zizmor: ignore[cache-poisoning]
cache: "npm"
cache-dependency-path: ./web/package-lock.json
- name: Install node dependencies
@@ -283,7 +279,6 @@ jobs:
run: npm ci
- name: Cache playwright cache
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
with:
path: ~/.cache/ms-playwright
@@ -464,7 +459,7 @@ jobs:
# --- Visual Regression Diff ---
- name: Configure AWS credentials
if: always()
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
@@ -595,108 +590,6 @@ jobs:
name: docker-logs-${{ matrix.project }}-${{ github.run_id }}
path: ${{ github.workspace }}/docker-compose.log
playwright-tests-lite:
needs: [build-web-image, build-backend-image]
name: Playwright Tests (lite)
runs-on:
- runs-on
- runner=4cpu-linux-arm64
- "run-id=${{ github.run_id }}-playwright-tests-lite"
- "extras=ecr-cache"
timeout-minutes: 30
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Setup node
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
with:
node-version: 22
cache: "npm" # zizmor: ignore[cache-poisoning]
cache-dependency-path: ./web/package-lock.json
- name: Install node dependencies
working-directory: ./web
run: npm ci
- name: Cache playwright cache
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
with:
path: ~/.cache/ms-playwright
key: ${{ runner.os }}-playwright-npm-${{ hashFiles('web/package-lock.json') }}
restore-keys: |
${{ runner.os }}-playwright-npm-
- name: Install playwright browsers
working-directory: ./web
run: npx playwright install --with-deps
- name: Create .env file for Docker Compose
env:
OPENAI_API_KEY_VALUE: ${{ env.OPENAI_API_KEY }}
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
RUN_ID: ${{ github.run_id }}
run: |
cat <<EOF > deployment/docker_compose/.env
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
LICENSE_ENFORCEMENT_ENABLED=false
AUTH_TYPE=basic
INTEGRATION_TESTS_MODE=true
GEN_AI_API_KEY=${OPENAI_API_KEY_VALUE}
MOCK_LLM_RESPONSE=true
REQUIRE_EMAIL_VERIFICATION=false
DISABLE_TELEMETRY=true
ONYX_BACKEND_IMAGE=${ECR_CACHE}:playwright-test-backend-${RUN_ID}
ONYX_WEB_SERVER_IMAGE=${ECR_CACHE}:playwright-test-web-${RUN_ID}
EOF
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Start Docker containers (lite)
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml up -d
id: start_docker
- name: Run Playwright tests (lite)
working-directory: ./web
run: npx playwright test --project lite
- uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
if: always()
with:
name: playwright-test-results-lite-${{ github.run_id }}
path: ./web/output/playwright/
retention-days: 30
- name: Save Docker logs
if: success() || failure()
env:
WORKSPACE: ${{ github.workspace }}
run: |
cd deployment/docker_compose
docker compose logs > docker-compose.log
mv docker-compose.log ${WORKSPACE}/docker-compose.log
- name: Upload logs
if: success() || failure()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: docker-logs-lite-${{ github.run_id }}
path: ${{ github.workspace }}/docker-compose.log
# Post a single combined visual regression comment after all matrix jobs finish
visual-regression-comment:
needs: [playwright-tests]
@@ -793,7 +686,7 @@ jobs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 45
needs: [playwright-tests, playwright-tests-lite]
needs: [playwright-tests]
if: ${{ always() }}
steps:
- name: Check job status

View File

@@ -38,9 +38,9 @@ jobs:
- name: Install node dependencies
working-directory: ./web
run: npm ci
- uses: j178/prek-action@0bb87d7f00b0c99306c8bcb8b8beba1eb581c037 # ratchet:j178/prek-action@v1
- uses: j178/prek-action@9d6a3097e0c1865ecce00cfb89fe80f2ee91b547 # ratchet:j178/prek-action@v1
with:
prek-version: '0.3.4'
prek-version: '0.2.21'
extra-args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || github.event_name == 'merge_group' && format('--from-ref {0} --to-ref {1}', github.event.merge_group.base_sha, github.event.merge_group.head_sha) || github.ref_name == 'main' && '--all-files' || '' }}
- name: Check Actions
uses: giner/check-actions@28d366c7cbbe235f9624a88aa31a628167eee28c # ratchet:giner/check-actions@v1.0.1

View File

@@ -1,214 +0,0 @@
name: Release CLI
on:
push:
tags:
- "cli/v*.*.*"
jobs:
pypi:
runs-on: ubuntu-latest
environment:
name: release-cli
permissions:
id-token: write
timeout-minutes: 10
strategy:
matrix:
os-arch:
- { goos: "linux", goarch: "amd64" }
- { goos: "linux", goarch: "arm64" }
- { goos: "windows", goarch: "amd64" }
- { goos: "windows", goarch: "arm64" }
- { goos: "darwin", goarch: "amd64" }
- { goos: "darwin", goarch: "arm64" }
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"
- run: |
GOOS="${{ matrix.os-arch.goos }}" \
GOARCH="${{ matrix.os-arch.goarch }}" \
uv build --wheel
working-directory: cli
- run: uv publish
working-directory: cli
docker-amd64:
runs-on:
- runs-on
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-cli-amd64
- extras=ecr-cache
environment: deploy
permissions:
id-token: write
timeout-minutes: 30
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
REGISTRY_IMAGE: onyxdotapp/onyx-cli
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
- name: Login to Docker Hub
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
with:
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Build and push AMD64
id: build
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # ratchet:docker/build-push-action@v7
with:
context: ./cli
file: ./cli/Dockerfile
platforms: linux/amd64
cache-from: type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
cache-to: type=inline
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
docker-arm64:
runs-on:
- runs-on
- runner=2cpu-linux-arm64
- run-id=${{ github.run_id }}-cli-arm64
- extras=ecr-cache
environment: deploy
permissions:
id-token: write
timeout-minutes: 30
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
REGISTRY_IMAGE: onyxdotapp/onyx-cli
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
- name: Login to Docker Hub
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
with:
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Build and push ARM64
id: build
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # ratchet:docker/build-push-action@v7
with:
context: ./cli
file: ./cli/Dockerfile
platforms: linux/arm64
cache-from: type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
cache-to: type=inline
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
merge-docker:
needs:
- docker-amd64
- docker-arm64
runs-on:
- runs-on
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-cli-merge
environment: deploy
permissions:
id-token: write
timeout-minutes: 10
env:
REGISTRY_IMAGE: onyxdotapp/onyx-cli
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
- name: Login to Docker Hub
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
with:
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Create and push manifest
env:
AMD64_DIGEST: ${{ needs.docker-amd64.outputs.digest }}
ARM64_DIGEST: ${{ needs.docker-arm64.outputs.digest }}
TAG: ${{ github.ref_name }}
run: |
SANITIZED_TAG="${TAG#cli/}"
IMAGES=(
"${REGISTRY_IMAGE}@${AMD64_DIGEST}"
"${REGISTRY_IMAGE}@${ARM64_DIGEST}"
)
if [[ "$TAG" =~ ^cli/v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
docker buildx imagetools create \
-t "${REGISTRY_IMAGE}:${SANITIZED_TAG}" \
-t "${REGISTRY_IMAGE}:latest" \
"${IMAGES[@]}"
else
docker buildx imagetools create \
-t "${REGISTRY_IMAGE}:${SANITIZED_TAG}" \
"${IMAGES[@]}"
fi

View File

@@ -22,10 +22,12 @@ jobs:
- { goos: "windows", goarch: "arm64" }
- { goos: "darwin", goarch: "amd64" }
- { goos: "darwin", goarch: "arm64" }
- { goos: "", goarch: "" }
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
fetch-depth: 0
- uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false

View File

@@ -48,10 +48,6 @@ on:
required: false
default: true
type: boolean
secrets:
AWS_OIDC_ROLE_ARN:
description: "AWS role ARN for OIDC auth"
required: true
permissions:
contents: read
@@ -77,7 +73,7 @@ jobs:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
@@ -120,7 +116,7 @@ jobs:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
@@ -162,7 +158,7 @@ jobs:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
@@ -268,7 +264,7 @@ jobs:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2

View File

@@ -110,7 +110,7 @@ jobs:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
@@ -180,7 +180,7 @@ jobs:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
@@ -244,7 +244,7 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2

View File

@@ -119,11 +119,10 @@ repos:
]
- repo: https://github.com/golangci/golangci-lint
rev: 5d1e709b7be35cb2025444e19de266b056b7b7ee # frozen: v2.10.1
rev: 9f61b0f53f80672872fced07b6874397c3ed197b # frozen: v2.7.2
hooks:
- id: golangci-lint
language_version: "1.26.0"
entry: bash -c "find . -name go.mod -not -path './.venv/*' -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
entry: bash -c "find tools/ -name go.mod -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.

58
.vscode/launch.json vendored
View File

@@ -40,7 +40,19 @@
}
},
{
"name": "Celery",
"name": "Celery (lightweight mode)",
"configurations": [
"Celery primary",
"Celery background",
"Celery beat"
],
"presentation": {
"group": "1"
},
"stopAll": true
},
{
"name": "Celery (standard mode)",
"configurations": [
"Celery primary",
"Celery light",
@@ -241,6 +253,35 @@
},
"consoleTitle": "Celery light Console"
},
{
"name": "Celery background",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.background",
"worker",
"--pool=threads",
"--concurrency=20",
"--prefetch-multiplier=4",
"--loglevel=INFO",
"--hostname=background@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery background Console"
},
{
"name": "Celery heavy",
"type": "debugpy",
@@ -485,6 +526,21 @@
"group": "3"
}
},
{
"name": "Clear and Restart OpenSearch Container",
// Generic debugger type, required arg but has no bearing on bash.
"type": "node",
"request": "launch",
"runtimeExecutable": "bash",
"runtimeArgs": [
"${workspaceFolder}/backend/scripts/restart_opensearch_container.sh"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "3"
}
},
{
"name": "Eval CLI",
"type": "debugpy",

View File

@@ -86,6 +86,37 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
- Monitoring tasks (every 5 minutes)
- Cleanup tasks (hourly)
#### Worker Deployment Modes
Onyx supports two deployment modes for background workers, controlled by the `USE_LIGHTWEIGHT_BACKGROUND_WORKER` environment variable:
**Lightweight Mode** (default, `USE_LIGHTWEIGHT_BACKGROUND_WORKER=true`):
- Runs a single consolidated `background` worker that handles all background tasks:
- Light worker tasks (Vespa operations, permissions sync, deletion)
- Document processing (indexing pipeline)
- Document fetching (connector data retrieval)
- Pruning operations (from `heavy` worker)
- Knowledge graph processing (from `kg_processing` worker)
- Monitoring tasks (from `monitoring` worker)
- User file processing (from `user_file_processing` worker)
- Lower resource footprint (fewer worker processes)
- Suitable for smaller deployments or development environments
- Default concurrency: 20 threads (increased to handle combined workload)
**Standard Mode** (`USE_LIGHTWEIGHT_BACKGROUND_WORKER=false`):
- Runs separate specialized workers as documented above (light, docprocessing, docfetching, heavy, kg_processing, monitoring, user_file_processing)
- Better isolation and scalability
- Can scale individual workers independently based on workload
- Suitable for production deployments with higher load
The deployment mode affects:
- **Backend**: Worker processes spawned by supervisord or dev scripts
- **Helm**: Which Kubernetes deployments are created
- **Dev Environment**: Which workers `dev_run_background_jobs.py` spawns
#### Key Features
- **Thread-based Workers**: All workers use thread pools (not processes) for stability
@@ -104,10 +135,6 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
- Always use `@shared_task` rather than `@celery_app`
- Put tasks under `background/celery/tasks/` or `ee/background/celery/tasks`
- Never enqueue a task without an expiration. Always supply `expires=` when
sending tasks, either from the beat schedule or directly from another task. It
should never be acceptable to submit code which enqueues tasks without an
expiration, as doing so can lead to unbounded task queue growth.
**Defining APIs**:
When creating new FastAPI APIs, do NOT use the `response_model` field. Instead, just type the
@@ -544,8 +571,6 @@ To run them:
npx playwright test <TEST_NAME>
```
For shared fixtures, best practices, and detailed guidance, see `backend/tests/README.md`.
## Logs
When (1) writing integration tests or (2) doing live tests (e.g. curl / playwright) you can get access
@@ -598,7 +623,7 @@ Before writing your plan, make sure to do research. Explore the relevant section
Never hardcode status codes or use `starlette.status` / `fastapi.status` constants directly.**
A global FastAPI exception handler converts `OnyxError` into a JSON response with the standard
`{"error_code": "...", "detail": "..."}` shape. This eliminates boilerplate and keeps error
`{"error_code": "...", "message": "..."}` shape. This eliminates boilerplate and keeps error
handling consistent across the entire backend.
```python

View File

@@ -46,9 +46,7 @@ RUN apt-get update && \
pkg-config \
gcc \
nano \
vim \
libjemalloc2 \
&& \
vim && \
rm -rf /var/lib/apt/lists/* && \
apt-get clean
@@ -166,13 +164,6 @@ ENV PYTHONPATH=/app
ARG ONYX_VERSION=0.0.0-dev
ENV ONYX_VERSION=${ONYX_VERSION}
# Use jemalloc instead of glibc malloc to reduce memory fragmentation
# in long-running Python processes (API server, Celery workers).
# The soname is architecture-independent; the dynamic linker resolves
# the correct path from standard library directories.
# Placed after all RUN steps so build-time processes are unaffected.
ENV LD_PRELOAD=libjemalloc.so.2
# Default command which does nothing
# This container is used by api server and background which specify their own CMD
CMD ["tail", "-f", "/dev/null"]

View File

@@ -0,0 +1,15 @@
from onyx.background.celery.apps import app_base
from onyx.background.celery.apps.background import celery_app
celery_app.autodiscover_tasks(
app_base.filter_task_modules(
[
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
"ee.onyx.background.celery.tasks.cleanup",
"ee.onyx.background.celery.tasks.tenant_provisioning",
"ee.onyx.background.celery.tasks.query_history",
]
)
)

View File

@@ -15,7 +15,6 @@ from sqlalchemy.orm import Session
from ee.onyx.server.user_group.models import SetCuratorRequest
from ee.onyx.server.user_group.models import UserGroupCreate
from ee.onyx.server.user_group.models import UserGroupUpdate
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
@@ -472,9 +471,7 @@ def _add_user_group__cc_pair_relationships__no_commit(
def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserGroup:
db_user_group = UserGroup(
name=user_group.name,
time_last_modified_by_user=func.now(),
is_up_to_date=DISABLE_VECTOR_DB,
name=user_group.name, time_last_modified_by_user=func.now()
)
db_session.add(db_user_group)
db_session.flush() # give the group an ID
@@ -777,7 +774,8 @@ def update_user_group(
cc_pair_ids=user_group_update.cc_pair_ids,
)
if cc_pairs_updated and not DISABLE_VECTOR_DB:
# only needs to sync with Vespa if the cc_pairs have been updated
if cc_pairs_updated:
db_user_group.is_up_to_date = False
removed_users = db_session.scalars(

View File

@@ -68,7 +68,6 @@ def get_external_access_for_raw_gdrive_file(
company_domain: str,
retriever_drive_service: GoogleDriveService | None,
admin_drive_service: GoogleDriveService,
fallback_user_email: str,
add_prefix: bool = False,
) -> ExternalAccess:
"""
@@ -80,11 +79,6 @@ def get_external_access_for_raw_gdrive_file(
set add_prefix to True so group IDs are prefixed with the source type.
When invoked from doc_sync (permission sync), use the default (False)
since upsert_document_external_perms handles prefixing.
fallback_user_email: When we cannot retrieve any permission info for a file
(e.g. externally-owned files where the API returns no permissions
and permissions.list returns 403), fall back to granting access
to this user. This is typically the impersonated org user whose
drive contained the file.
"""
doc_id = file.get("id")
if not doc_id:
@@ -123,26 +117,6 @@ def get_external_access_for_raw_gdrive_file(
[permissions_list, backup_permissions_list]
)
# For externally-owned files, the Drive API may return no permissions
# and permissions.list may return 403. In this case, fall back to
# granting access to the user who found the file in their drive.
# Note, even if other users also have access to this file,
# they will not be granted access in Onyx.
# We check permissions_list (the final result after all fetch attempts)
# rather than the raw fields, because permission_ids may be present
# but the actual fetch can still return empty due to a 403.
if not permissions_list:
logger.info(
f"No permission info available for file {doc_id} "
f"(likely owned by a user outside of your organization). "
f"Falling back to granting access to retriever user: {fallback_user_email}"
)
return ExternalAccess(
external_user_emails={fallback_user_email},
external_user_group_ids=set(),
is_public=False,
)
folder_ids_to_inherit_permissions_from: set[str] = set()
user_emails: set[str] = set()
group_emails: set[str] = set()

View File

@@ -1,8 +1,6 @@
from collections.abc import Generator
from typing import Any
from jira import JIRA
from jira.exceptions import JIRAError
from ee.onyx.db.external_perm import ExternalUserGroup
from onyx.connectors.jira.utils import build_jira_client
@@ -11,102 +9,107 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_ATLASSIAN_ACCOUNT_TYPE = "atlassian"
_GROUP_MEMBER_PAGE_SIZE = 50
# The GET /group/member endpoint was introduced in Jira 6.0.
# Jira versions older than 6.0 do not have group management REST APIs at all.
_MIN_JIRA_VERSION_FOR_GROUP_MEMBER = "6.0"
def _fetch_group_member_page(
def _get_jira_group_members_email(
jira_client: JIRA,
group_name: str,
start_at: int,
) -> dict[str, Any]:
"""Fetch a single page from the non-deprecated GET /group/member endpoint.
) -> list[str]:
"""Get all member emails for a Jira group.
The old GET /group endpoint (used by jira_client.group_members()) is deprecated
and decommissioned in Jira Server 10.3+. This uses the replacement endpoint
directly via the library's internal _get_json helper, following the same pattern
as enhanced_search_ids / bulk_fetch_issues in connector.py.
There is an open PR to the library to switch to this endpoint since last year:
https://github.com/pycontribs/jira/pull/2356
so once it is merged and released, we can switch to using the library function.
Filters out app accounts (bots, integrations) and only returns real user emails.
"""
emails: list[str] = []
try:
return jira_client._get_json(
"group/member",
params={
"groupname": group_name,
"includeInactiveUsers": "false",
"startAt": start_at,
"maxResults": _GROUP_MEMBER_PAGE_SIZE,
},
)
except JIRAError as e:
if e.status_code == 404:
raise RuntimeError(
f"GET /group/member returned 404 for group '{group_name}'. "
f"This endpoint requires Jira {_MIN_JIRA_VERSION_FOR_GROUP_MEMBER}+. "
f"If you are running a self-hosted Jira instance, please upgrade "
f"to at least Jira {_MIN_JIRA_VERSION_FOR_GROUP_MEMBER}."
) from e
raise
# group_members returns an OrderedDict of account_id -> member_info
members = jira_client.group_members(group=group_name)
if not members:
logger.warning(f"No members found for group {group_name}")
return emails
def _get_group_member_emails(
jira_client: JIRA,
group_name: str,
) -> set[str]:
"""Get all member emails for a single Jira group.
for account_id, member_info in members.items():
# member_info is a dict with keys like 'fullname', 'email', 'active'
email = member_info.get("email")
Uses the non-deprecated GET /group/member endpoint which returns full user
objects including accountType, so we can filter out app/customer accounts
without making separate user() calls.
"""
emails: set[str] = set()
start_at = 0
while True:
try:
page = _fetch_group_member_page(jira_client, group_name, start_at)
except Exception as e:
logger.error(f"Error fetching members for group {group_name}: {e}")
raise
members: list[dict[str, Any]] = page.get("values", [])
for member in members:
account_type = member.get("accountType")
# On Jira DC < 9.0, accountType is absent; include those users.
# On Cloud / DC 9.0+, filter to real user accounts only.
if account_type is not None and account_type != _ATLASSIAN_ACCOUNT_TYPE:
continue
email = member.get("emailAddress")
if email:
emails.add(email)
# Skip "hidden" emails - these are typically app accounts
if email and email != "hidden":
emails.append(email)
else:
logger.warning(
f"Atlassian user {member.get('accountId', 'unknown')} "
f"in group {group_name} has no visible email address"
)
# For cloud, we might need to fetch user details separately
try:
user = jira_client.user(id=account_id)
if page.get("isLast", True) or not members:
break
start_at += len(members)
# Skip app accounts (bots, integrations, etc.)
if hasattr(user, "accountType") and user.accountType == "app":
logger.info(
f"Skipping app account {account_id} for group {group_name}"
)
continue
if hasattr(user, "emailAddress") and user.emailAddress:
emails.append(user.emailAddress)
else:
logger.warning(f"User {account_id} has no email address")
except Exception as e:
logger.warning(
f"Could not fetch email for user {account_id} in group {group_name}: {e}"
)
except Exception as e:
logger.error(f"Error fetching members for group {group_name}: {e}")
return emails
def _build_group_member_email_map(
jira_client: JIRA,
) -> dict[str, set[str]]:
"""Build a map of group names to member emails."""
group_member_emails: dict[str, set[str]] = {}
try:
# Get all groups from Jira - returns a list of group name strings
group_names = jira_client.groups()
if not group_names:
logger.warning("No groups found in Jira")
return group_member_emails
logger.info(f"Found {len(group_names)} groups in Jira")
for group_name in group_names:
if not group_name:
continue
member_emails = _get_jira_group_members_email(
jira_client=jira_client,
group_name=group_name,
)
if member_emails:
group_member_emails[group_name] = set(member_emails)
logger.debug(
f"Found {len(member_emails)} members for group {group_name}"
)
else:
logger.debug(f"No members found for group {group_name}")
except Exception as e:
logger.error(f"Error building group member email map: {e}")
return group_member_emails
def jira_group_sync(
tenant_id: str, # noqa: ARG001
cc_pair: ConnectorCredentialPair,
) -> Generator[ExternalUserGroup, None, None]:
"""Sync Jira groups and their members, yielding one group at a time.
"""
Sync Jira groups and their members.
Streams group-by-group rather than accumulating all groups in memory.
This function fetches all groups from Jira and yields ExternalUserGroup
objects containing the group ID and member emails.
"""
jira_base_url = cc_pair.connector.connector_specific_config.get("jira_base_url", "")
scoped_token = cc_pair.connector.connector_specific_config.get(
@@ -127,26 +130,12 @@ def jira_group_sync(
scoped_token=scoped_token,
)
group_names = jira_client.groups()
if not group_names:
raise ValueError(f"No groups found for cc_pair_id={cc_pair.id}")
group_member_email_map = _build_group_member_email_map(jira_client=jira_client)
if not group_member_email_map:
raise ValueError(f"No groups with members found for cc_pair_id={cc_pair.id}")
logger.info(f"Found {len(group_names)} groups in Jira")
for group_name in group_names:
if not group_name:
continue
member_emails = _get_group_member_emails(
jira_client=jira_client,
group_name=group_name,
)
if not member_emails:
logger.debug(f"No members found for group {group_name}")
continue
logger.debug(f"Found {len(member_emails)} members for group {group_name}")
for group_id, group_member_emails in group_member_email_map.items():
yield ExternalUserGroup(
id=group_name,
user_emails=list(member_emails),
id=group_id,
user_emails=list(group_member_emails),
)

View File

@@ -4,6 +4,7 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI
from httpx_oauth.clients.google import GoogleOAuth2
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
from ee.onyx.server.analytics.api import router as analytics_router
from ee.onyx.server.auth_check import check_ee_router_auth
from ee.onyx.server.billing.api import router as billing_router
@@ -152,9 +153,12 @@ def get_application() -> FastAPI:
# License management
include_router_with_global_prefix_prepended(application, license_router)
# Unified billing API - always registered in EE.
# Each endpoint is protected by the `current_admin_user` dependency (admin auth).
include_router_with_global_prefix_prepended(application, billing_router)
# Unified billing API - available when license system is enabled
# Works for both self-hosted and cloud deployments
# TODO(ENG-3533): Once frontend migrates to /admin/billing/*, this becomes the
# primary billing API and /tenants/* billing endpoints can be removed
if LICENSE_ENFORCEMENT_ENABLED:
include_router_with_global_prefix_prepended(application, billing_router)
if MULTI_TENANT:
# Tenant management

View File

@@ -246,11 +246,7 @@ async def get_billing_information(
)
except OnyxError as e:
# Open circuit breaker on connection failures (self-hosted only)
if e.status_code in (
OnyxErrorCode.BAD_GATEWAY.status_code,
OnyxErrorCode.SERVICE_UNAVAILABLE.status_code,
OnyxErrorCode.GATEWAY_TIMEOUT.status_code,
):
if e.status_code in (502, 503, 504):
_open_billing_circuit()
raise

View File

@@ -223,15 +223,6 @@ def get_active_scim_token(
token = dal.get_active_token()
if not token:
raise HTTPException(status_code=404, detail="No active SCIM token")
# Derive the IdP domain from the first synced user as a heuristic.
idp_domain: str | None = None
mappings, _total = dal.list_user_mappings(start_index=1, count=1)
if mappings:
user = dal.get_user(mappings[0].user_id)
if user and "@" in user.email:
idp_domain = user.email.rsplit("@", 1)[1]
return ScimTokenResponse(
id=token.id,
name=token.name,
@@ -239,7 +230,6 @@ def get_active_scim_token(
is_active=token.is_active,
created_at=token.created_at,
last_used_at=token.last_used_at,
idp_domain=idp_domain,
)

View File

@@ -365,7 +365,6 @@ class ScimTokenResponse(BaseModel):
is_active: bool
created_at: datetime
last_used_at: datetime | None = None
idp_domain: str | None = None
class ScimTokenCreatedResponse(ScimTokenResponse):

View File

@@ -26,7 +26,6 @@ from onyx.db.models import Tool
from onyx.db.persona import upsert_persona
from onyx.server.features.persona.models import PersonaUpsertRequest
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.settings.models import Settings
from onyx.server.settings.store import store_settings as store_base_settings
from onyx.utils.logger import setup_logger
@@ -126,16 +125,10 @@ def _seed_llms(
existing = fetch_existing_llm_provider(name=request.name, db_session=db_session)
if existing:
request.id = existing.id
seeded_providers: list[LLMProviderView] = []
for llm_upsert_request in llm_upsert_requests:
try:
seeded_providers.append(upsert_llm_provider(llm_upsert_request, db_session))
except ValueError as e:
logger.warning(
"Failed to upsert LLM provider '%s' during seeding: %s",
llm_upsert_request.name,
e,
)
seeded_providers = [
upsert_llm_provider(llm_upsert_request, db_session)
for llm_upsert_request in llm_upsert_requests
]
default_provider = next(
(p for p in seeded_providers if p.model_configurations), None

View File

@@ -5,8 +5,6 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from ee.onyx.db.user_group import add_users_to_user_group
from ee.onyx.db.user_group import delete_user_group as db_delete_user_group
from ee.onyx.db.user_group import fetch_user_group
from ee.onyx.db.user_group import fetch_user_groups
from ee.onyx.db.user_group import fetch_user_groups_for_user
from ee.onyx.db.user_group import insert_user_group
@@ -22,7 +20,6 @@ from ee.onyx.server.user_group.models import UserGroupUpdate
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.constants import PUBLIC_API_TAGS
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
@@ -156,8 +153,3 @@ def delete_user_group(
prepare_user_group_for_deletion(db_session, user_group_id)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
if DISABLE_VECTOR_DB:
user_group = fetch_user_group(db_session, user_group_id)
if user_group:
db_delete_user_group(db_session, user_group)

View File

@@ -0,0 +1,142 @@
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_process_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.configs.app_configs import MANAGED_VESPA
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME
from onyx.db.engine.sql_engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.background")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
@signals.task_postrun.connect
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
@celeryd_init.connect
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
EXTRA_CONCURRENCY = 8 # small extra fudge factor for connection limits
logger.info("worker_init signal received for consolidated background worker.")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME)
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=EXTRA_CONCURRENCY)
# Initialize Vespa httpx pool (needed for light worker tasks)
if MANAGED_VESPA:
httpx_init_vespa_pool(
sender.concurrency + EXTRA_CONCURRENCY, # type: ignore
ssl_cert=VESPA_CLOUD_CERT_PATH,
ssl_key=VESPA_CLOUD_KEY_PATH,
)
else:
httpx_init_vespa_pool(sender.concurrency + EXTRA_CONCURRENCY) # type: ignore
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
app_base.on_secondary_worker_init(sender, **kwargs)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_ready(sender, **kwargs)
@worker_shutdown.connect
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@worker_process_init.connect
def init_worker(**kwargs: Any) -> None: # noqa: ARG001
SqlEngine.reset_engine()
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
base_bootsteps = app_base.get_bootsteps()
for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
app_base.filter_task_modules(
[
# Original background worker tasks
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.monitoring",
"onyx.background.celery.tasks.user_file_processing",
"onyx.background.celery.tasks.llm_model_update",
# Light worker tasks
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.opensearch_migration",
# Docprocessing worker tasks
"onyx.background.celery.tasks.docprocessing",
# Docfetching worker tasks
"onyx.background.celery.tasks.docfetching",
# Sandbox cleanup tasks (isolated in build feature)
"onyx.server.features.build.sandbox.tasks",
]
)
)

View File

@@ -39,13 +39,9 @@ CT = TypeVar("CT", bound=ConnectorCheckpoint)
class SlimConnectorExtractionResult(BaseModel):
"""Result of extracting document IDs and hierarchy nodes from a connector.
"""Result of extracting document IDs and hierarchy nodes from a connector."""
raw_id_to_parent maps document ID → parent_hierarchy_raw_node_id (or None).
Use raw_id_to_parent.keys() wherever the old set of IDs was needed.
"""
raw_id_to_parent: dict[str, str | None]
doc_ids: set[str]
hierarchy_nodes: list[HierarchyNode]
@@ -97,37 +93,30 @@ def _get_failure_id(failure: ConnectorFailure) -> str | None:
return None
class BatchResult(BaseModel):
raw_id_to_parent: dict[str, str | None]
hierarchy_nodes: list[HierarchyNode]
def _extract_from_batch(
doc_list: Sequence[Document | SlimDocument | HierarchyNode | ConnectorFailure],
) -> BatchResult:
"""Separate a batch into document IDs (with parent mapping) and hierarchy nodes.
) -> tuple[set[str], list[HierarchyNode]]:
"""Separate a batch into document IDs and hierarchy nodes.
ConnectorFailure items have their failed document/entity IDs added to the
ID dict so that failed-to-retrieve documents are not accidentally pruned.
ID set so that failed-to-retrieve documents are not accidentally pruned.
"""
ids: dict[str, str | None] = {}
ids: set[str] = set()
hierarchy_nodes: list[HierarchyNode] = []
for item in doc_list:
if isinstance(item, HierarchyNode):
hierarchy_nodes.append(item)
if item.raw_node_id not in ids:
ids[item.raw_node_id] = None
ids.add(item.raw_node_id)
elif isinstance(item, ConnectorFailure):
failed_id = _get_failure_id(item)
if failed_id:
ids[failed_id] = None
ids.add(failed_id)
logger.warning(
f"Failed to retrieve document {failed_id}: " f"{item.failure_message}"
)
else:
parent_raw = getattr(item, "parent_hierarchy_raw_node_id", None)
ids[item.id] = parent_raw
return BatchResult(raw_id_to_parent=ids, hierarchy_nodes=hierarchy_nodes)
ids.add(item.id)
return ids, hierarchy_nodes
def extract_ids_from_runnable_connector(
@@ -143,7 +132,7 @@ def extract_ids_from_runnable_connector(
Optionally, a callback can be passed to handle the length of each document batch.
"""
all_raw_id_to_parent: dict[str, str | None] = {}
all_connector_doc_ids: set[str] = set()
all_hierarchy_nodes: list[HierarchyNode] = []
# Sequence (covariant) lets all the specific list[...] iterator types unify here
@@ -188,20 +177,15 @@ def extract_ids_from_runnable_connector(
"extract_ids_from_runnable_connector: Stop signal detected"
)
batch_result = _extract_from_batch(doc_list)
batch_ids = batch_result.raw_id_to_parent
batch_nodes = batch_result.hierarchy_nodes
doc_batch_processing_func(batch_ids)
for k, v in batch_ids.items():
if v is not None or k not in all_raw_id_to_parent:
all_raw_id_to_parent[k] = v
batch_ids, batch_nodes = _extract_from_batch(doc_list)
all_connector_doc_ids.update(doc_batch_processing_func(batch_ids))
all_hierarchy_nodes.extend(batch_nodes)
if callback:
callback.progress("extract_ids_from_runnable_connector", len(batch_ids))
return SlimConnectorExtractionResult(
raw_id_to_parent=all_raw_id_to_parent,
doc_ids=all_connector_doc_ids,
hierarchy_nodes=all_hierarchy_nodes,
)

View File

@@ -0,0 +1,23 @@
import onyx.background.celery.configs.base as shared_config
from onyx.configs.app_configs import CELERY_WORKER_BACKGROUND_CONCURRENCY
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
broker_pool_limit = shared_config.broker_pool_limit
broker_transport_options = shared_config.broker_transport_options
redis_socket_keepalive = shared_config.redis_socket_keepalive
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
result_backend = shared_config.result_backend
result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
worker_concurrency = CELERY_WORKER_BACKGROUND_CONCURRENCY
worker_pool = "threads"
# Increased from 1 to 4 to handle fast light worker tasks more efficiently
# This allows the worker to prefetch multiple tasks per thread
worker_prefetch_multiplier = 4

View File

@@ -29,7 +29,6 @@ from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
@@ -48,8 +47,6 @@ from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.hierarchy import link_hierarchy_nodes_to_documents
from onyx.db.hierarchy import update_document_parent_hierarchy_nodes
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
from onyx.db.models import ConnectorCredentialPair
from onyx.db.sync_record import insert_sync_record
@@ -60,8 +57,6 @@ from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
from onyx.redis.redis_hierarchy import ensure_source_node_exists
from onyx.redis.redis_hierarchy import get_node_id_from_raw_id
from onyx.redis.redis_hierarchy import get_source_node_id_from_cache
from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
@@ -118,38 +113,6 @@ class PruneCallback(IndexingCallbackBase):
super().progress(tag, amount)
def _resolve_and_update_document_parents(
db_session: Session,
redis_client: Redis,
source: DocumentSource,
raw_id_to_parent: dict[str, str | None],
) -> None:
"""Resolve parent_hierarchy_raw_node_id → parent_hierarchy_node_id for
each document and bulk-update the DB. Mirrors the resolution logic in
run_docfetching.py."""
source_node_id = get_source_node_id_from_cache(redis_client, db_session, source)
resolved: dict[str, int | None] = {}
for doc_id, raw_parent_id in raw_id_to_parent.items():
if raw_parent_id is None:
continue
node_id, found = get_node_id_from_raw_id(redis_client, source, raw_parent_id)
resolved[doc_id] = node_id if found else source_node_id
if not resolved:
return
update_document_parent_hierarchy_nodes(
db_session=db_session,
doc_parent_map=resolved,
commit=True,
)
task_logger.info(
f"Pruning: resolved and updated parent hierarchy for "
f"{len(resolved)} documents (source={source.value})"
)
"""Jobs / utils for kicking off pruning tasks."""
@@ -572,22 +535,22 @@ def connector_pruning_generator_task(
extraction_result = extract_ids_from_runnable_connector(
runnable_connector, callback
)
all_connector_doc_ids = extraction_result.raw_id_to_parent
all_connector_doc_ids = extraction_result.doc_ids
# Process hierarchy nodes (same as docfetching):
# upsert to Postgres and cache in Redis
source = cc_pair.connector.source
redis_client = get_redis_client(tenant_id=tenant_id)
if extraction_result.hierarchy_nodes:
is_connector_public = cc_pair.access_type == AccessType.PUBLIC
ensure_source_node_exists(redis_client, db_session, source)
redis_client = get_redis_client(tenant_id=tenant_id)
ensure_source_node_exists(
redis_client, db_session, cc_pair.connector.source
)
upserted_nodes = upsert_hierarchy_nodes_batch(
db_session=db_session,
nodes=extraction_result.hierarchy_nodes,
source=source,
source=cc_pair.connector.source,
commit=True,
is_connector_public=is_connector_public,
)
@@ -598,7 +561,7 @@ def connector_pruning_generator_task(
]
cache_hierarchy_nodes_batch(
redis_client=redis_client,
source=source,
source=cc_pair.connector.source,
entries=cache_entries,
)
@@ -607,26 +570,6 @@ def connector_pruning_generator_task(
f"hierarchy nodes for cc_pair={cc_pair_id}"
)
ensure_source_node_exists(redis_client, db_session, source)
# Resolve parent_hierarchy_raw_node_id → parent_hierarchy_node_id
# and bulk-update documents, mirroring the docfetching resolution
_resolve_and_update_document_parents(
db_session=db_session,
redis_client=redis_client,
source=source,
raw_id_to_parent=all_connector_doc_ids,
)
# Link hierarchy nodes to documents for sources where pages can be
# both hierarchy nodes AND documents (e.g. Notion, Confluence)
all_doc_id_list = list(all_connector_doc_ids.keys())
link_hierarchy_nodes_to_documents(
db_session=db_session,
document_ids=all_doc_id_list,
source=source,
commit=True,
)
# a list of docs in our local index
all_indexed_document_ids = {
doc.id
@@ -638,9 +581,7 @@ def connector_pruning_generator_task(
}
# generate list of docs to remove (no longer in the source)
doc_ids_to_remove = list(
all_indexed_document_ids - all_connector_doc_ids.keys()
)
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
task_logger.info(
"Pruning set collected: "

View File

@@ -0,0 +1,10 @@
from celery import Celery
from onyx.utils.variable_functionality import fetch_versioned_implementation
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
app: Celery = fetch_versioned_implementation(
"onyx.background.celery.apps.background",
"celery_app",
)

View File

@@ -58,6 +58,8 @@ from onyx.file_store.document_batch_storage import DocumentBatchStorage
from onyx.file_store.document_batch_storage import get_document_batch_storage
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
from onyx.indexing.postgres_sanitization import sanitize_document_for_postgres
from onyx.indexing.postgres_sanitization import sanitize_hierarchy_nodes_for_postgres
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
from onyx.redis.redis_hierarchy import ensure_source_node_exists
from onyx.redis.redis_hierarchy import get_node_id_from_raw_id
@@ -69,8 +71,6 @@ from onyx.server.features.build.indexing.persistent_document_writer import (
)
from onyx.utils.logger import setup_logger
from onyx.utils.middleware import make_randomized_onyx_request_id
from onyx.utils.postgres_sanitization import sanitize_document_for_postgres
from onyx.utils.postgres_sanitization import sanitize_hierarchy_nodes_for_postgres
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import INDEX_ATTEMPT_INFO_CONTEXTVAR

View File

@@ -36,6 +36,7 @@ from onyx.db.memory import add_memory
from onyx.db.memory import update_memory_at_index
from onyx.db.memory import UserMemoryContext
from onyx.db.models import Persona
from onyx.llm.constants import LlmProviderNames
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMUserIdentity
from onyx.llm.interfaces import ToolChoiceOptions
@@ -50,7 +51,6 @@ from onyx.tools.built_in_tools import CITEABLE_TOOLS_NAMES
from onyx.tools.built_in_tools import STOPPING_TOOLS_NAMES
from onyx.tools.interface import Tool
from onyx.tools.models import ChatFile
from onyx.tools.models import CustomToolCallSummary
from onyx.tools.models import MemoryToolResponseSnapshot
from onyx.tools.models import PythonToolRichResponse
from onyx.tools.models import ToolCallInfo
@@ -84,6 +84,28 @@ def _looks_like_xml_tool_call_payload(text: str | None) -> bool:
)
def _should_keep_bedrock_tool_definitions(
llm: object, simple_chat_history: list[ChatMessageSimple]
) -> bool:
"""Bedrock requires tool config when history includes toolUse/toolResult blocks."""
model_provider = getattr(getattr(llm, "config", None), "model_provider", None)
if model_provider not in {
LlmProviderNames.BEDROCK,
LlmProviderNames.BEDROCK_CONVERSE,
}:
return False
return any(
(
msg.message_type == MessageType.ASSISTANT
and msg.tool_calls
and len(msg.tool_calls) > 0
)
or msg.message_type == MessageType.TOOL_CALL_RESPONSE
for msg in simple_chat_history
)
def _try_fallback_tool_extraction(
llm_step_result: LlmStepResult,
tool_choice: ToolChoiceOptions,
@@ -664,7 +686,12 @@ def run_llm_loop(
elif out_of_cycles or ran_image_gen:
# Last cycle, no tools allowed, just answer!
tool_choice = ToolChoiceOptions.NONE
final_tools = []
# Bedrock requires tool config in requests that include toolUse/toolResult history.
final_tools = (
tools
if _should_keep_bedrock_tool_definitions(llm, simple_chat_history)
else []
)
else:
tool_choice = ToolChoiceOptions.AUTO
final_tools = tools
@@ -981,10 +1008,6 @@ def run_llm_loop(
if memory_snapshot:
saved_response = json.dumps(memory_snapshot.model_dump())
elif isinstance(tool_response.rich_response, CustomToolCallSummary):
saved_response = json.dumps(
tool_response.rich_response.model_dump()
)
elif isinstance(tool_response.rich_response, str):
saved_response = tool_response.rich_response
else:

View File

@@ -55,7 +55,6 @@ from onyx.tools.models import ToolCallKickoff
from onyx.tracing.framework.create import generation_span
from onyx.utils.b64 import get_image_type_from_bytes
from onyx.utils.logger import setup_logger
from onyx.utils.postgres_sanitization import sanitize_string
from onyx.utils.text_processing import find_all_json_objects
logger = setup_logger()
@@ -167,6 +166,15 @@ def _find_function_calls_open_marker(text_lower: str) -> int:
search_from = idx + 1
def _sanitize_llm_output(value: str) -> str:
"""Remove characters that PostgreSQL's text/JSONB types cannot store.
- NULL bytes (\x00): Not allowed in PostgreSQL text types
- UTF-16 surrogates (\ud800-\udfff): Invalid in UTF-8 encoding
"""
return "".join(c for c in value if c != "\x00" and not ("\ud800" <= c <= "\udfff"))
def _try_parse_json_string(value: Any) -> Any:
"""Attempt to parse a JSON string value into its Python equivalent.
@@ -214,7 +222,9 @@ def _parse_tool_args_to_dict(raw_args: Any) -> dict[str, Any]:
if isinstance(raw_args, dict):
# Parse any string values that look like JSON arrays/objects
return {
k: _try_parse_json_string(sanitize_string(v) if isinstance(v, str) else v)
k: _try_parse_json_string(
_sanitize_llm_output(v) if isinstance(v, str) else v
)
for k, v in raw_args.items()
}
@@ -222,7 +232,7 @@ def _parse_tool_args_to_dict(raw_args: Any) -> dict[str, Any]:
return {}
# Sanitize before parsing to remove NULL bytes and surrogates
raw_args = sanitize_string(raw_args)
raw_args = _sanitize_llm_output(raw_args)
try:
parsed1: Any = json.loads(raw_args)
@@ -535,12 +545,12 @@ def _extract_xml_attribute(attrs: str, attr_name: str) -> str | None:
)
if not attr_match:
return None
return sanitize_string(unescape(attr_match.group(2).strip()))
return _sanitize_llm_output(unescape(attr_match.group(2).strip()))
def _parse_xml_parameter_value(raw_value: str, string_attr: str | None) -> Any:
"""Parse a parameter value from XML-style tool call payloads."""
value = sanitize_string(unescape(raw_value).strip())
value = _sanitize_llm_output(unescape(raw_value).strip())
if string_attr and string_attr.lower() == "true":
return value
@@ -559,7 +569,6 @@ def _resolve_tool_arguments(obj: dict[str, Any]) -> dict[str, Any] | None:
"""
arguments = obj.get("arguments", obj.get("parameters", {}))
if isinstance(arguments, str):
arguments = sanitize_string(arguments)
try:
arguments = json.loads(arguments)
except json.JSONDecodeError:

View File

@@ -19,7 +19,6 @@ from onyx.natural_language_processing.utils import get_tokenizer
from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type
from onyx.tools.models import ToolCallInfo
from onyx.utils.logger import setup_logger
from onyx.utils.postgres_sanitization import sanitize_string
logger = setup_logger()
@@ -202,13 +201,8 @@ def save_chat_turn(
pre_answer_processing_time: Duration of processing before answer starts (in seconds)
"""
# 1. Update ChatMessage with message content, reasoning tokens, and token count
sanitized_message_text = (
sanitize_string(message_text) if message_text else message_text
)
assistant_message.message = sanitized_message_text
assistant_message.reasoning_tokens = (
sanitize_string(reasoning_tokens) if reasoning_tokens else reasoning_tokens
)
assistant_message.message = message_text
assistant_message.reasoning_tokens = reasoning_tokens
assistant_message.is_clarification = is_clarification
# Use pre-answer processing time (captured when MESSAGE_START was emitted)
@@ -218,10 +212,8 @@ def save_chat_turn(
# Calculate token count using default tokenizer, when storing, this should not use the LLM
# specific one so we use a system default tokenizer here.
default_tokenizer = get_tokenizer(None, None)
if sanitized_message_text:
assistant_message.token_count = len(
default_tokenizer.encode(sanitized_message_text)
)
if message_text:
assistant_message.token_count = len(default_tokenizer.encode(message_text))
else:
assistant_message.token_count = 0
@@ -336,10 +328,8 @@ def save_chat_turn(
# 8. Attach code interpreter generated files that the assistant actually
# referenced in its response, so they are available via load_all_chat_files
# on subsequent turns. Files not mentioned are intermediate artifacts.
if sanitized_message_text:
referenced = _extract_referenced_file_descriptors(
tool_calls, sanitized_message_text
)
if message_text:
referenced = _extract_referenced_file_descriptors(tool_calls, message_text)
if referenced:
existing_files = assistant_message.files or []
assistant_message.files = existing_files + referenced

View File

@@ -288,9 +288,8 @@ OPENSEARCH_TEXT_ANALYZER = os.environ.get("OPENSEARCH_TEXT_ANALYZER") or "englis
# environments we always want to be dual indexing into both OpenSearch and Vespa
# to stress test the new codepaths. Only enable this if there is some instance
# of OpenSearch running for the relevant Onyx instance.
# NOTE: Now enabled on by default, unless the env indicates otherwise.
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX = (
os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "true").lower() == "true"
os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "").lower() == "true"
)
# NOTE: This effectively does nothing anymore, admins can now toggle whether
# retrieval is through OpenSearch. This value is only used as a final fallback
@@ -496,7 +495,14 @@ CELERY_WORKER_PRIMARY_POOL_OVERFLOW = int(
os.environ.get("CELERY_WORKER_PRIMARY_POOL_OVERFLOW") or 4
)
# Individual worker concurrency settings
# Consolidated background worker (light, docprocessing, docfetching, heavy, monitoring, user_file_processing)
# separate workers' defaults: light=24, docprocessing=6, docfetching=1, heavy=4, kg=2, monitoring=1, user_file=2
# Total would be 40, but we use a more conservative default of 20 for the consolidated worker
CELERY_WORKER_BACKGROUND_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_BACKGROUND_CONCURRENCY") or 20
)
# Individual worker concurrency settings (used when USE_LIGHTWEIGHT_BACKGROUND_WORKER is False or on Kuberenetes deployments)
CELERY_WORKER_HEAVY_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_HEAVY_CONCURRENCY") or 4
)

View File

@@ -84,6 +84,7 @@ POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME = "celery_worker_docprocessing"
POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME = "celery_worker_docfetching"
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME = "celery_worker_background"
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
POSTGRES_CELERY_WORKER_USER_FILE_PROCESSING_APP_NAME = (

View File

@@ -943,9 +943,6 @@ class ConfluenceConnector(
if include_permissions
else None
),
parent_hierarchy_raw_node_id=self._get_parent_hierarchy_raw_id(
page
),
)
)
@@ -995,7 +992,6 @@ class ConfluenceConnector(
if include_permissions
else None
),
parent_hierarchy_raw_node_id=page_id,
)
)

View File

@@ -1,5 +1,4 @@
import asyncio
from collections.abc import AsyncGenerator
from collections.abc import AsyncIterable
from collections.abc import Iterable
from datetime import datetime
@@ -205,7 +204,7 @@ def _manage_async_retrieval(
end_time: datetime | None = end
async def _async_fetch() -> AsyncGenerator[Document, None]:
async def _async_fetch() -> AsyncIterable[Document]:
intents = Intents.default()
intents.message_content = True
async with Client(intents=intents) as discord_client:
@@ -228,23 +227,22 @@ def _manage_async_retrieval(
def run_and_yield() -> Iterable[Document]:
loop = asyncio.new_event_loop()
async_gen = _async_fetch()
try:
# Get the async generator
async_gen = _async_fetch()
# Convert to AsyncIterator
async_iter = async_gen.__aiter__()
while True:
try:
doc = loop.run_until_complete(anext(async_gen))
# Create a coroutine by calling anext with the async iterator
next_coro = anext(async_iter)
# Run the coroutine to get the next document
doc = loop.run_until_complete(next_coro)
yield doc
except StopAsyncIteration:
break
finally:
# Must close the async generator before the loop so the Discord
# client's `async with` block can await its shutdown coroutine.
# The nested try/finally ensures the loop always closes even if
# aclose() raises (same pattern as cursor.close() before conn.close()).
try:
loop.run_until_complete(async_gen.aclose())
finally:
loop.close()
loop.close()
return run_and_yield()

View File

@@ -1722,7 +1722,6 @@ class GoogleDriveConnector(
primary_admin_email=self.primary_admin_email,
google_domain=self.google_domain,
),
retriever_email=file.user_email,
):
slim_batch.append(doc)

View File

@@ -476,7 +476,6 @@ def _get_external_access_for_raw_gdrive_file(
company_domain: str,
retriever_drive_service: GoogleDriveService | None,
admin_drive_service: GoogleDriveService,
fallback_user_email: str,
add_prefix: bool = False,
) -> ExternalAccess:
"""
@@ -485,8 +484,6 @@ def _get_external_access_for_raw_gdrive_file(
add_prefix: When True, prefix group IDs with source type (for indexing path).
When False (default), leave unprefixed (for permission sync path
where upsert_document_external_perms handles prefixing).
fallback_user_email: When permission info can't be retrieved (e.g. externally-owned
files), fall back to granting access to this user.
"""
external_access_fn = cast(
Callable[
@@ -495,7 +492,6 @@ def _get_external_access_for_raw_gdrive_file(
str,
GoogleDriveService | None,
GoogleDriveService,
str,
bool,
],
ExternalAccess,
@@ -511,7 +507,6 @@ def _get_external_access_for_raw_gdrive_file(
company_domain,
retriever_drive_service,
admin_drive_service,
fallback_user_email,
add_prefix,
)
@@ -677,7 +672,6 @@ def _convert_drive_item_to_document(
creds, user_email=permission_sync_context.primary_admin_email
),
add_prefix=True, # Indexing path - prefix here
fallback_user_email=retriever_email,
)
if permission_sync_context
else None
@@ -759,7 +753,6 @@ def build_slim_document(
# if not specified, we will not sync permissions
# will also be a no-op if EE is not enabled
permission_sync_context: PermissionSyncContext | None,
retriever_email: str,
) -> SlimDocument | None:
if file.get("mimeType") in [DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE]:
return None
@@ -781,7 +774,6 @@ def build_slim_document(
creds,
user_email=permission_sync_context.primary_admin_email,
),
fallback_user_email=retriever_email,
)
if permission_sync_context
else None
@@ -789,5 +781,4 @@ def build_slim_document(
return SlimDocument(
id=onyx_document_id_from_drive_file(file),
external_access=external_access,
parent_hierarchy_raw_node_id=(file.get("parents") or [None])[0],
)

View File

@@ -902,11 +902,6 @@ class JiraConnector(
external_access=self._get_project_permissions(
project_key, add_prefix=False
),
parent_hierarchy_raw_node_id=(
self._get_parent_hierarchy_raw_node_id(issue, project_key)
if project_key
else None
),
)
)
current_offset += 1

View File

@@ -385,7 +385,6 @@ class IndexingDocument(Document):
class SlimDocument(BaseModel):
id: str
external_access: ExternalAccess | None = None
parent_hierarchy_raw_node_id: str | None = None
class HierarchyNode(BaseModel):

View File

@@ -772,7 +772,6 @@ def _convert_driveitem_to_slim_document(
drive_name: str,
ctx: ClientContext,
graph_client: GraphClient,
parent_hierarchy_raw_node_id: str | None = None,
) -> SlimDocument:
if driveitem.id is None:
raise ValueError("DriveItem ID is required")
@@ -788,15 +787,11 @@ def _convert_driveitem_to_slim_document(
return SlimDocument(
id=driveitem.id,
external_access=external_access,
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
)
def _convert_sitepage_to_slim_document(
site_page: dict[str, Any],
ctx: ClientContext | None,
graph_client: GraphClient,
parent_hierarchy_raw_node_id: str | None = None,
site_page: dict[str, Any], ctx: ClientContext | None, graph_client: GraphClient
) -> SlimDocument:
"""Convert a SharePoint site page to a SlimDocument object."""
if site_page.get("id") is None:
@@ -813,7 +808,6 @@ def _convert_sitepage_to_slim_document(
return SlimDocument(
id=id,
external_access=external_access,
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
)
@@ -1600,22 +1594,12 @@ class SharepointConnector(
)
)
parent_hierarchy_url: str | None = None
if drive_web_url:
parent_hierarchy_url = self._get_parent_hierarchy_url(
site_url, drive_web_url, drive_name, driveitem
)
try:
logger.debug(f"Processing: {driveitem.web_url}")
ctx = self._create_rest_client_context(site_descriptor.url)
doc_batch.append(
_convert_driveitem_to_slim_document(
driveitem,
drive_name,
ctx,
self.graph_client,
parent_hierarchy_raw_node_id=parent_hierarchy_url,
driveitem, drive_name, ctx, self.graph_client
)
)
except Exception as e:
@@ -1635,10 +1619,7 @@ class SharepointConnector(
ctx = self._create_rest_client_context(site_descriptor.url)
doc_batch.append(
_convert_sitepage_to_slim_document(
site_page,
ctx,
self.graph_client,
parent_hierarchy_raw_node_id=site_descriptor.url,
site_page, ctx, self.graph_client
)
)
if len(doc_batch) >= SLIM_BATCH_SIZE:

View File

@@ -565,7 +565,6 @@ def _get_all_doc_ids(
channel_id=channel_id, thread_ts=message["ts"]
),
external_access=external_access,
parent_hierarchy_raw_node_id=channel_id,
)
)

View File

@@ -38,7 +38,6 @@ from onyx.llm.override_models import LLMOverride
from onyx.llm.override_models import PromptOverride
from onyx.server.query_and_chat.models import ChatMessageDetail
from onyx.utils.logger import setup_logger
from onyx.utils.postgres_sanitization import sanitize_string
logger = setup_logger()
@@ -676,43 +675,58 @@ def set_as_latest_chat_message(
db_session.commit()
def _sanitize_for_postgres(value: str) -> str:
"""Remove NUL (0x00) characters from strings as PostgreSQL doesn't allow them."""
sanitized = value.replace("\x00", "")
if value and not sanitized:
logger.warning("Sanitization removed all characters from string")
return sanitized
def _sanitize_list_for_postgres(values: list[str]) -> list[str]:
"""Remove NUL (0x00) characters from all strings in a list."""
return [_sanitize_for_postgres(v) for v in values]
def create_db_search_doc(
server_search_doc: ServerSearchDoc,
db_session: Session,
commit: bool = True,
) -> DBSearchDoc:
# Sanitize string fields to remove NUL characters (PostgreSQL doesn't allow them)
db_search_doc = DBSearchDoc(
document_id=sanitize_string(server_search_doc.document_id),
document_id=_sanitize_for_postgres(server_search_doc.document_id),
chunk_ind=server_search_doc.chunk_ind,
semantic_id=sanitize_string(server_search_doc.semantic_identifier),
semantic_id=_sanitize_for_postgres(server_search_doc.semantic_identifier),
link=(
sanitize_string(server_search_doc.link)
_sanitize_for_postgres(server_search_doc.link)
if server_search_doc.link is not None
else None
),
blurb=sanitize_string(server_search_doc.blurb),
blurb=_sanitize_for_postgres(server_search_doc.blurb),
source_type=server_search_doc.source_type,
boost=server_search_doc.boost,
hidden=server_search_doc.hidden,
doc_metadata=server_search_doc.metadata,
is_relevant=server_search_doc.is_relevant,
relevance_explanation=(
sanitize_string(server_search_doc.relevance_explanation)
_sanitize_for_postgres(server_search_doc.relevance_explanation)
if server_search_doc.relevance_explanation is not None
else None
),
# For docs further down that aren't reranked, we can't use the retrieval score
score=server_search_doc.score or 0.0,
match_highlights=[
sanitize_string(h) for h in server_search_doc.match_highlights
],
match_highlights=_sanitize_list_for_postgres(
server_search_doc.match_highlights
),
updated_at=server_search_doc.updated_at,
primary_owners=(
[sanitize_string(o) for o in server_search_doc.primary_owners]
_sanitize_list_for_postgres(server_search_doc.primary_owners)
if server_search_doc.primary_owners is not None
else None
),
secondary_owners=(
[sanitize_string(o) for o in server_search_doc.secondary_owners]
_sanitize_list_for_postgres(server_search_doc.secondary_owners)
if server_search_doc.secondary_owners is not None
else None
),

View File

@@ -13,7 +13,6 @@ from sqlalchemy.orm import aliased
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.enums import AccessType
@@ -247,7 +246,6 @@ def insert_document_set(
description=document_set_creation_request.description,
user_id=user_id,
is_public=document_set_creation_request.is_public,
is_up_to_date=DISABLE_VECTOR_DB,
time_last_modified_by_user=func.now(),
)
db_session.add(new_document_set_row)
@@ -338,8 +336,7 @@ def update_document_set(
)
document_set_row.description = document_set_update_request.description
if not DISABLE_VECTOR_DB:
document_set_row.is_up_to_date = False
document_set_row.is_up_to_date = False
document_set_row.is_public = document_set_update_request.is_public
document_set_row.time_last_modified_by_user = func.now()
versioned_private_doc_set_fn = fetch_versioned_implementation(

View File

@@ -1,7 +1,5 @@
"""CRUD operations for HierarchyNode."""
from collections import defaultdict
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -527,53 +525,6 @@ def get_document_parent_hierarchy_node_ids(
return {doc_id: parent_id for doc_id, parent_id in results}
def update_document_parent_hierarchy_nodes(
db_session: Session,
doc_parent_map: dict[str, int | None],
commit: bool = True,
) -> int:
"""Bulk-update Document.parent_hierarchy_node_id for multiple documents.
Only updates rows whose current value differs from the desired value to
avoid unnecessary writes.
Args:
db_session: SQLAlchemy session
doc_parent_map: Mapping of document_id → desired parent_hierarchy_node_id
commit: Whether to commit the transaction
Returns:
Number of documents actually updated
"""
if not doc_parent_map:
return 0
doc_ids = list(doc_parent_map.keys())
existing = get_document_parent_hierarchy_node_ids(db_session, doc_ids)
by_parent: dict[int | None, list[str]] = defaultdict(list)
for doc_id, desired_parent_id in doc_parent_map.items():
current = existing.get(doc_id)
if current == desired_parent_id or doc_id not in existing:
continue
by_parent[desired_parent_id].append(doc_id)
updated = 0
for desired_parent_id, ids in by_parent.items():
db_session.query(Document).filter(Document.id.in_(ids)).update(
{Document.parent_hierarchy_node_id: desired_parent_id},
synchronize_session=False,
)
updated += len(ids)
if commit:
db_session.commit()
elif updated:
db_session.flush()
return updated
def update_hierarchy_node_permissions(
db_session: Session,
raw_node_id: str,

View File

@@ -25,12 +25,8 @@ from onyx.server.manage.embedding.models import CloudEmbeddingProvider
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import SyncModelEntry
from onyx.utils.logger import setup_logger
from shared_configs.enums import EmbeddingProvider
logger = setup_logger()
def update_group_llm_provider_relationships__no_commit(
llm_provider_id: int,
@@ -271,35 +267,10 @@ def upsert_llm_provider(
mc.name for mc in llm_provider_upsert_request.model_configurations
}
# Build a lookup of requested visibility by model name
requested_visibility = {
mc.name: mc.is_visible
for mc in llm_provider_upsert_request.model_configurations
}
# Delete removed models
removed_ids = [
mc.id for name, mc in existing_by_name.items() if name not in models_to_exist
]
default_model = fetch_default_llm_model(db_session)
# Prevent removing and hiding the default model
if default_model:
for name, mc in existing_by_name.items():
if mc.id == default_model.id:
if default_model.id in removed_ids:
raise ValueError(
f"Cannot remove the default model '{name}'. "
"Please change the default model before removing."
)
if not requested_visibility.get(name, True):
raise ValueError(
f"Cannot hide the default model '{name}'. "
"Please change the default model before hiding."
)
break
if removed_ids:
db_session.query(ModelConfiguration).filter(
ModelConfiguration.id.in_(removed_ids)
@@ -370,9 +341,9 @@ def upsert_llm_provider(
def sync_model_configurations(
db_session: Session,
provider_name: str,
models: list[SyncModelEntry],
models: list[dict],
) -> int:
"""Sync model configurations for a dynamic provider (OpenRouter, Bedrock, Ollama, etc.).
"""Sync model configurations for a dynamic provider (OpenRouter, Bedrock, Ollama).
This inserts NEW models from the source API without overwriting existing ones.
User preferences (is_visible, max_input_tokens) are preserved for existing models.
@@ -380,7 +351,7 @@ def sync_model_configurations(
Args:
db_session: Database session
provider_name: Name of the LLM provider
models: List of SyncModelEntry objects describing the fetched models
models: List of model dicts with keys: name, display_name, max_input_tokens, supports_image_input
Returns:
Number of new models added
@@ -394,20 +365,21 @@ def sync_model_configurations(
new_count = 0
for model in models:
if model.name not in existing_names:
model_name = model["name"]
if model_name not in existing_names:
# Insert new model with is_visible=False (user must explicitly enable)
supported_flows = [LLMModelFlowType.CHAT]
if model.supports_image_input:
if model.get("supports_image_input", False):
supported_flows.append(LLMModelFlowType.VISION)
insert_new_model_configuration__no_commit(
db_session=db_session,
llm_provider_id=provider.id,
model_name=model.name,
model_name=model_name,
supported_flows=supported_flows,
is_visible=False,
max_input_tokens=model.max_input_tokens,
display_name=model.display_name,
max_input_tokens=model.get("max_input_tokens"),
display_name=model.get("display_name"),
)
new_count += 1
@@ -563,6 +535,7 @@ def fetch_default_model(
.options(selectinload(ModelConfiguration.llm_provider))
.join(LLMModelFlow)
.where(
ModelConfiguration.is_visible == True, # noqa: E712
LLMModelFlow.llm_model_flow_type == flow_type,
LLMModelFlow.is_default == True, # noqa: E712
)
@@ -838,29 +811,6 @@ def sync_auto_mode_models(
)
changes += 1
# Update the default if this provider currently holds the global CHAT default.
# We flush (but don't commit) so that _update_default_model can see the new
# model rows, then commit everything atomically to avoid a window where the
# old default is invisible but still pointed-to.
db_session.flush()
recommended_default = llm_recommendations.get_default_model(provider.provider)
if recommended_default:
current_default = fetch_default_llm_model(db_session)
if (
current_default
and current_default.llm_provider_id == provider.id
and current_default.name != recommended_default.name
):
_update_default_model__no_commit(
db_session=db_session,
provider_id=provider.id,
model=recommended_default.name,
flow_type=LLMModelFlowType.CHAT,
)
changes += 1
db_session.commit()
return changes
@@ -992,7 +942,7 @@ def update_model_configuration__no_commit(
db_session.flush()
def _update_default_model__no_commit(
def _update_default_model(
db_session: Session,
provider_id: int,
model: str,
@@ -1030,14 +980,6 @@ def _update_default_model__no_commit(
new_default.is_default = True
model_config.is_visible = True
def _update_default_model(
db_session: Session,
provider_id: int,
model: str,
flow_type: LLMModelFlowType,
) -> None:
_update_default_model__no_commit(db_session, provider_id, model, flow_type)
db_session.commit()

View File

@@ -129,7 +129,7 @@ def get_current_search_settings(db_session: Session) -> SearchSettings:
latest_settings = result.scalars().first()
if not latest_settings:
raise RuntimeError("No search settings specified; DB is not in a valid state.")
raise RuntimeError("No search settings specified, DB is not in a valid state")
return latest_settings

View File

@@ -13,15 +13,12 @@ from onyx.db.constants import UNSET
from onyx.db.constants import UnsetType
from onyx.db.enums import MCPServerStatus
from onyx.db.models import MCPServer
from onyx.db.models import OAuthConfig
from onyx.db.models import Tool
from onyx.db.models import ToolCall
from onyx.server.features.tool.models import Header
from onyx.tools.built_in_tools import BUILT_IN_TOOL_TYPES
from onyx.utils.headers import HeaderItemDict
from onyx.utils.logger import setup_logger
from onyx.utils.postgres_sanitization import sanitize_json_like
from onyx.utils.postgres_sanitization import sanitize_string
if TYPE_CHECKING:
pass
@@ -162,26 +159,10 @@ def update_tool(
]
if passthrough_auth is not None:
tool.passthrough_auth = passthrough_auth
old_oauth_config_id = tool.oauth_config_id
if not isinstance(oauth_config_id, UnsetType):
tool.oauth_config_id = oauth_config_id
db_session.flush()
# Clean up orphaned OAuthConfig if the oauth_config_id was changed
if (
old_oauth_config_id is not None
and not isinstance(oauth_config_id, UnsetType)
and old_oauth_config_id != oauth_config_id
):
other_tools = db_session.scalars(
select(Tool).where(Tool.oauth_config_id == old_oauth_config_id)
).all()
if not other_tools:
oauth_config = db_session.get(OAuthConfig, old_oauth_config_id)
if oauth_config:
db_session.delete(oauth_config)
db_session.commit()
return tool
@@ -190,21 +171,8 @@ def delete_tool__no_commit(tool_id: int, db_session: Session) -> None:
if tool is None:
raise ValueError(f"Tool with ID {tool_id} does not exist")
oauth_config_id = tool.oauth_config_id
db_session.delete(tool)
db_session.flush()
# Clean up orphaned OAuthConfig if no other tools reference it
if oauth_config_id is not None:
other_tools = db_session.scalars(
select(Tool).where(Tool.oauth_config_id == oauth_config_id)
).all()
if not other_tools:
oauth_config = db_session.get(OAuthConfig, oauth_config_id)
if oauth_config:
db_session.delete(oauth_config)
db_session.flush()
db_session.flush() # Don't commit yet, let caller decide when to commit
def get_builtin_tool(
@@ -288,13 +256,11 @@ def create_tool_call_no_commit(
tab_index=tab_index,
tool_id=tool_id,
tool_call_id=tool_call_id,
reasoning_tokens=(
sanitize_string(reasoning_tokens) if reasoning_tokens else reasoning_tokens
),
tool_call_arguments=sanitize_json_like(tool_call_arguments),
tool_call_response=sanitize_json_like(tool_call_response),
reasoning_tokens=reasoning_tokens,
tool_call_arguments=tool_call_arguments,
tool_call_response=tool_call_response,
tool_call_tokens=tool_call_tokens,
generated_images=sanitize_json_like(generated_images),
generated_images=generated_images,
)
db_session.add(tool_call)

View File

@@ -1,103 +0,0 @@
# Vector DB Filter Semantics
How `IndexFilters` fields combine into the final query filter. Applies to both Vespa and OpenSearch.
## Filter categories
| Category | Fields | Join logic |
|---|---|---|
| **Visibility** | `hidden` | Always applied (unless `include_hidden`) |
| **Tenant** | `tenant_id` | AND (multi-tenant only) |
| **ACL** | `access_control_list` | OR within, AND with rest |
| **Narrowing** | `source_type`, `tags`, `time_cutoff` | Each OR within, AND with rest |
| **Knowledge scope** | `document_set`, `user_file_ids`, `attached_document_ids`, `hierarchy_node_ids` | OR within group, AND with rest |
| **Additive scope** | `project_id`, `persona_id` | OR'd into knowledge scope **only when** a knowledge scope filter already exists |
## How filters combine
All categories are AND'd together. Within the knowledge scope category, individual filters are OR'd.
```
NOT hidden
AND tenant = T -- if multi-tenant
AND (acl contains A1 OR acl contains A2)
AND (source_type = S1 OR ...) -- if set
AND (tag = T1 OR ...) -- if set
AND <knowledge scope> -- see below
AND time >= cutoff -- if set
```
## Knowledge scope rules
The knowledge scope filter controls **what knowledge an assistant can access**.
### No explicit knowledge attached
When `document_set`, `user_file_ids`, `attached_document_ids`, and `hierarchy_node_ids` are all empty/None:
- **No knowledge scope filter is applied.** The assistant can see everything (subject to ACL).
- `project_id` and `persona_id` are ignored — they never restrict on their own.
### One explicit knowledge type
```
-- Only document sets
AND (document_sets contains "Engineering" OR document_sets contains "Legal")
-- Only user files
AND (document_id = "uuid-1" OR document_id = "uuid-2")
```
### Multiple explicit knowledge types (OR'd)
```
-- Document sets + user files
AND (
document_sets contains "Engineering"
OR document_id = "uuid-1"
)
```
### Explicit knowledge + overflowing user files
When an explicit knowledge restriction is in effect **and** `project_id` or `persona_id` is set (user files overflowed the LLM context window), the additive scopes widen the filter:
```
-- Document sets + persona user files overflowed
AND (
document_sets contains "Engineering"
OR personas contains 42
)
-- User files + project files overflowed
AND (
document_id = "uuid-1"
OR user_project contains 7
)
```
### Only project_id or persona_id (no explicit knowledge)
No knowledge scope filter. The assistant searches everything.
```
-- Just ACL, no restriction
NOT hidden
AND (acl contains ...)
```
## Field reference
| Filter field | Vespa field | Vespa type | Purpose |
|---|---|---|---|
| `document_set` | `document_sets` | `weightedset<string>` | Connector doc sets attached to assistant |
| `user_file_ids` | `document_id` | `string` | User files uploaded to assistant |
| `attached_document_ids` | `document_id` | `string` | Documents explicitly attached (OpenSearch only) |
| `hierarchy_node_ids` | `ancestor_hierarchy_node_ids` | `array<int>` | Folder/space nodes (OpenSearch only) |
| `project_id` | `user_project` | `array<int>` | Project tag for overflowing user files |
| `persona_id` | `personas` | `array<int>` | Persona tag for overflowing user files |
| `access_control_list` | `access_control_list` | `weightedset<string>` | ACL entries for the requesting user |
| `source_type` | `source_type` | `string` | Connector source type (e.g. `web`, `jira`) |
| `tags` | `metadata_list` | `array<string>` | Document metadata tags |
| `time_cutoff` | `doc_updated_at` | `long` | Minimum document update timestamp |
| `tenant_id` | `tenant_id` | `string` | Tenant isolation (multi-tenant) |

View File

@@ -32,6 +32,9 @@ def get_multipass_config(search_settings: SearchSettings) -> MultipassConfig:
Determines whether to enable multipass and large chunks by examining
the current search settings and the embedder configuration.
"""
if not search_settings:
return MultipassConfig(multipass_indexing=False, enable_large_chunks=False)
multipass = should_use_multipass(search_settings)
enable_large_chunks = SearchSettings.can_use_large_chunks(
multipass, search_settings.model_name, search_settings.provider_type

View File

@@ -26,10 +26,11 @@ def get_default_document_index(
To be used for retrieval only. Indexing should be done through both indices
until Vespa is deprecated.
Pre-existing docstring for this function, although secondary indices are not
currently supported:
Primary index is the index that is used for querying/updating etc. Secondary
index is for when both the currently used index and the upcoming index both
need to be updated. Updates are applied to both indices.
WARNING: In that case, get_all_document_indices should be used.
need to be updated, updates are applied to both indices.
"""
if DISABLE_VECTOR_DB:
return DisabledDocumentIndex(
@@ -50,26 +51,11 @@ def get_default_document_index(
opensearch_retrieval_enabled = get_opensearch_retrieval_state(db_session)
if opensearch_retrieval_enabled:
indexing_setting = IndexingSetting.from_db_model(search_settings)
secondary_indexing_setting = (
IndexingSetting.from_db_model(secondary_search_settings)
if secondary_search_settings
else None
)
return OpenSearchOldDocumentIndex(
index_name=search_settings.index_name,
embedding_dim=indexing_setting.final_embedding_dim,
embedding_precision=indexing_setting.embedding_precision,
secondary_index_name=secondary_index_name,
secondary_embedding_dim=(
secondary_indexing_setting.final_embedding_dim
if secondary_indexing_setting
else None
),
secondary_embedding_precision=(
secondary_indexing_setting.embedding_precision
if secondary_indexing_setting
else None
),
large_chunks_enabled=search_settings.large_chunks_enabled,
secondary_large_chunks_enabled=secondary_large_chunks_enabled,
multitenant=MULTI_TENANT,
@@ -100,7 +86,8 @@ def get_all_document_indices(
Used for indexing only. Until Vespa is deprecated we will index into both
document indices. Retrieval is done through only one index however.
Large chunks are not currently supported so we hardcode appropriate values.
Large chunks and secondary indices are not currently supported so we
hardcode appropriate values.
NOTE: Make sure the Vespa index object is returned first. In the rare event
that there is some conflict between indexing and the migration task, it is
@@ -136,36 +123,13 @@ def get_all_document_indices(
opensearch_document_index: OpenSearchOldDocumentIndex | None = None
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
indexing_setting = IndexingSetting.from_db_model(search_settings)
secondary_indexing_setting = (
IndexingSetting.from_db_model(secondary_search_settings)
if secondary_search_settings
else None
)
opensearch_document_index = OpenSearchOldDocumentIndex(
index_name=search_settings.index_name,
embedding_dim=indexing_setting.final_embedding_dim,
embedding_precision=indexing_setting.embedding_precision,
secondary_index_name=(
secondary_search_settings.index_name
if secondary_search_settings
else None
),
secondary_embedding_dim=(
secondary_indexing_setting.final_embedding_dim
if secondary_indexing_setting
else None
),
secondary_embedding_precision=(
secondary_indexing_setting.embedding_precision
if secondary_indexing_setting
else None
),
large_chunks_enabled=search_settings.large_chunks_enabled,
secondary_large_chunks_enabled=(
secondary_search_settings.large_chunks_enabled
if secondary_search_settings
else None
),
secondary_index_name=None,
large_chunks_enabled=False,
secondary_large_chunks_enabled=None,
multitenant=MULTI_TENANT,
httpx_client=httpx_client,
)

View File

@@ -61,25 +61,6 @@ class SearchHit(BaseModel, Generic[SchemaDocumentModel]):
explanation: dict[str, Any] | None = None
class IndexInfo(BaseModel):
"""
Represents information about an OpenSearch index.
"""
model_config = {"frozen": True}
name: str
health: str
status: str
num_primary_shards: str
num_replica_shards: str
docs_count: str
docs_deleted: str
created_at: str
total_size: str
primary_shards_size: str
def get_new_body_without_vectors(body: dict[str, Any]) -> dict[str, Any]:
"""Recursively replaces vectors in the body with their length.
@@ -178,8 +159,8 @@ class OpenSearchClient(AbstractContextManager):
Raises:
Exception: There was an error creating the search pipeline.
"""
response = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
if not response.get("acknowledged", False):
result = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
if not result.get("acknowledged", False):
raise RuntimeError(f"Failed to create search pipeline {pipeline_id}.")
@log_function_time(print_only=True, debug_only=True, include_args=True)
@@ -192,8 +173,8 @@ class OpenSearchClient(AbstractContextManager):
Raises:
Exception: There was an error deleting the search pipeline.
"""
response = self._client.search_pipeline.delete(id=pipeline_id)
if not response.get("acknowledged", False):
result = self._client.search_pipeline.delete(id=pipeline_id)
if not result.get("acknowledged", False):
raise RuntimeError(f"Failed to delete search pipeline {pipeline_id}.")
@log_function_time(print_only=True, debug_only=True, include_args=True)
@@ -217,34 +198,6 @@ class OpenSearchClient(AbstractContextManager):
logger.error(f"Failed to put cluster settings: {response}.")
return False
@log_function_time(print_only=True, debug_only=True)
def list_indices_with_info(self) -> list[IndexInfo]:
"""
Lists the indices in the OpenSearch cluster with information about each
index.
Returns:
A list of IndexInfo objects for each index.
"""
response = self._client.cat.indices(format="json")
indices: list[IndexInfo] = []
for raw_index_info in response:
indices.append(
IndexInfo(
name=raw_index_info.get("index", ""),
health=raw_index_info.get("health", ""),
status=raw_index_info.get("status", ""),
num_primary_shards=raw_index_info.get("pri", ""),
num_replica_shards=raw_index_info.get("rep", ""),
docs_count=raw_index_info.get("docs.count", ""),
docs_deleted=raw_index_info.get("docs.deleted", ""),
created_at=raw_index_info.get("creation.date.string", ""),
total_size=raw_index_info.get("store.size", ""),
primary_shards_size=raw_index_info.get("pri.store.size", ""),
)
)
return indices
@log_function_time(print_only=True, debug_only=True)
def ping(self) -> bool:
"""Pings the OpenSearch cluster.

View File

@@ -271,9 +271,6 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
embedding_dim: int,
embedding_precision: EmbeddingPrecision,
secondary_index_name: str | None,
secondary_embedding_dim: int | None,
secondary_embedding_precision: EmbeddingPrecision | None,
# NOTE: We do not support large chunks right now.
large_chunks_enabled: bool, # noqa: ARG002
secondary_large_chunks_enabled: bool | None, # noqa: ARG002
multitenant: bool = False,
@@ -289,25 +286,12 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
f"Expected {MULTI_TENANT}, got {multitenant}."
)
tenant_id = get_current_tenant_id()
tenant_state = TenantState(tenant_id=tenant_id, multitenant=multitenant)
self._real_index = OpenSearchDocumentIndex(
tenant_state=tenant_state,
tenant_state=TenantState(tenant_id=tenant_id, multitenant=multitenant),
index_name=index_name,
embedding_dim=embedding_dim,
embedding_precision=embedding_precision,
)
self._secondary_real_index: OpenSearchDocumentIndex | None = None
if self.secondary_index_name:
if secondary_embedding_dim is None or secondary_embedding_precision is None:
raise ValueError(
"Bug: Secondary index embedding dimension and precision are not set."
)
self._secondary_real_index = OpenSearchDocumentIndex(
tenant_state=tenant_state,
index_name=self.secondary_index_name,
embedding_dim=secondary_embedding_dim,
embedding_precision=secondary_embedding_precision,
)
@staticmethod
def register_multitenant_indices(
@@ -323,38 +307,19 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
self,
primary_embedding_dim: int,
primary_embedding_precision: EmbeddingPrecision,
secondary_index_embedding_dim: int | None,
secondary_index_embedding_precision: EmbeddingPrecision | None,
secondary_index_embedding_dim: int | None, # noqa: ARG002
secondary_index_embedding_precision: EmbeddingPrecision | None, # noqa: ARG002
) -> None:
self._real_index.verify_and_create_index_if_necessary(
# Only handle primary index for now, ignore secondary.
return self._real_index.verify_and_create_index_if_necessary(
primary_embedding_dim, primary_embedding_precision
)
if self.secondary_index_name:
if (
secondary_index_embedding_dim is None
or secondary_index_embedding_precision is None
):
raise ValueError(
"Bug: Secondary index embedding dimension and precision are not set."
)
assert (
self._secondary_real_index is not None
), "Bug: Secondary index is not initialized."
self._secondary_real_index.verify_and_create_index_if_necessary(
secondary_index_embedding_dim, secondary_index_embedding_precision
)
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[OldDocumentInsertionRecord]:
"""
NOTE: Do NOT consider the secondary index here. A separate indexing
pipeline will be responsible for indexing to the secondary index. This
design is not ideal and we should reconsider this when revamping index
swapping.
"""
# Convert IndexBatchParams to IndexingMetadata.
chunk_counts: dict[str, IndexingMetadata.ChunkCounts] = {}
for doc_id in index_batch_params.doc_id_to_new_chunk_cnt:
@@ -386,20 +351,7 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
tenant_id: str, # noqa: ARG002
chunk_count: int | None,
) -> int:
"""
NOTE: Remember to handle the secondary index here. There is no separate
pipeline for deleting chunks in the secondary index. This design is not
ideal and we should reconsider this when revamping index swapping.
"""
total_chunks_deleted = self._real_index.delete(doc_id, chunk_count)
if self.secondary_index_name:
assert (
self._secondary_real_index is not None
), "Bug: Secondary index is not initialized."
total_chunks_deleted += self._secondary_real_index.delete(
doc_id, chunk_count
)
return total_chunks_deleted
return self._real_index.delete(doc_id, chunk_count)
def update_single(
self,
@@ -410,11 +362,6 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
fields: VespaDocumentFields | None,
user_fields: VespaDocumentUserFields | None,
) -> None:
"""
NOTE: Remember to handle the secondary index here. There is no separate
pipeline for updating chunks in the secondary index. This design is not
ideal and we should reconsider this when revamping index swapping.
"""
if fields is None and user_fields is None:
logger.warning(
f"Tried to update document {doc_id} with no updated fields or user fields."
@@ -445,11 +392,6 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
try:
self._real_index.update([update_request])
if self.secondary_index_name:
assert (
self._secondary_real_index is not None
), "Bug: Secondary index is not initialized."
self._secondary_real_index.update([update_request])
except NotFoundError:
logger.exception(
f"Tried to update document {doc_id} but at least one of its chunks was not found in OpenSearch. "
@@ -739,8 +681,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
The number of chunks successfully deleted.
"""
logger.debug(
f"[OpenSearchDocumentIndex] Deleting document {document_id} from index "
f"{self._index_name}."
f"[OpenSearchDocumentIndex] Deleting document {document_id} from index {self._index_name}."
)
query_body = DocumentQuery.delete_from_document_id_query(
document_id=document_id,
@@ -776,8 +717,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
specified documents.
"""
logger.debug(
f"[OpenSearchDocumentIndex] Updating {len(update_requests)} chunks for index "
f"{self._index_name}."
f"[OpenSearchDocumentIndex] Updating {len(update_requests)} chunks for index {self._index_name}."
)
for update_request in update_requests:
properties_to_update: dict[str, Any] = dict()
@@ -833,11 +773,9 @@ class OpenSearchDocumentIndex(DocumentIndex):
# here.
# TODO(andrei): Fix the aforementioned race condition.
raise ChunkCountNotFoundError(
f"Tried to update document {doc_id} but its chunk count is not known. "
"Older versions of the application used to permit this but is not a "
"supported state for a document when using OpenSearch. The document was "
"likely just added to the indexing pipeline and the chunk count will be "
"updated shortly."
f"Tried to update document {doc_id} but its chunk count is not known. Older versions of the "
"application used to permit this but is not a supported state for a document when using OpenSearch. "
"The document was likely just added to the indexing pipeline and the chunk count will be updated shortly."
)
if doc_chunk_count == 0:
raise ValueError(
@@ -869,8 +807,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
chunk IDs vs querying for matching document chunks.
"""
logger.debug(
f"[OpenSearchDocumentIndex] Retrieving {len(chunk_requests)} chunks for index "
f"{self._index_name}."
f"[OpenSearchDocumentIndex] Retrieving {len(chunk_requests)} chunks for index {self._index_name}."
)
results: list[InferenceChunk] = []
for chunk_request in chunk_requests:
@@ -917,8 +854,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
num_to_retrieve: int,
) -> list[InferenceChunk]:
logger.debug(
f"[OpenSearchDocumentIndex] Hybrid retrieving {num_to_retrieve} chunks for index "
f"{self._index_name}."
f"[OpenSearchDocumentIndex] Hybrid retrieving {num_to_retrieve} chunks for index {self._index_name}."
)
# TODO(andrei): This could be better, the caller should just make this
# decision when passing in the query param. See the above comment in the
@@ -938,10 +874,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
index_filters=filters,
include_hidden=False,
)
# NOTE: Using z-score normalization here because it's better for hybrid
# search from a theoretical standpoint. Empirically on a small dataset
# of up to 10K docs, it's not very different. Likely more impactful at
# scale.
# NOTE: Using z-score normalization here because it's better for hybrid search from a theoretical standpoint.
# Empirically on a small dataset of up to 10K docs, it's not very different. Likely more impactful at scale.
# https://opensearch.org/blog/introducing-the-z-score-normalization-technique-for-hybrid-search/
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
body=query_body,
@@ -968,8 +902,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
dirty: bool | None = None, # noqa: ARG002
) -> list[InferenceChunk]:
logger.debug(
f"[OpenSearchDocumentIndex] Randomly retrieving {num_to_retrieve} chunks for index "
f"{self._index_name}."
f"[OpenSearchDocumentIndex] Randomly retrieving {num_to_retrieve} chunks for index {self._index_name}."
)
query_body = DocumentQuery.get_random_search_query(
tenant_state=self._tenant_state,
@@ -999,8 +932,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
complete.
"""
logger.debug(
f"[OpenSearchDocumentIndex] Indexing {len(chunks)} raw chunks for index "
f"{self._index_name}."
f"[OpenSearchDocumentIndex] Indexing {len(chunks)} raw chunks for index {self._index_name}."
)
# Do not raise if the document already exists, just update. This is
# because the document may already have been indexed during the

View File

@@ -243,8 +243,7 @@ class DocumentChunk(BaseModel):
return value
if not isinstance(value, int):
raise ValueError(
f"Bug: Expected an int for the last_updated property from OpenSearch, got "
f"{type(value)} instead."
f"Bug: Expected an int for the last_updated property from OpenSearch, got {type(value)} instead."
)
return datetime.fromtimestamp(value, tz=timezone.utc)
@@ -285,22 +284,19 @@ class DocumentChunk(BaseModel):
elif isinstance(value, TenantState):
if MULTI_TENANT != value.multitenant:
raise ValueError(
f"Bug: An existing TenantState object was supplied to the DocumentChunk model "
f"but its multi-tenant mode ({value.multitenant}) does not match the program's "
"current global tenancy state."
f"Bug: An existing TenantState object was supplied to the DocumentChunk model but its multi-tenant mode "
f"({value.multitenant}) does not match the program's current global tenancy state."
)
return value
elif not isinstance(value, str):
raise ValueError(
f"Bug: Expected a str for the tenant_id property from OpenSearch, got "
f"{type(value)} instead."
f"Bug: Expected a str for the tenant_id property from OpenSearch, got {type(value)} instead."
)
else:
if not MULTI_TENANT:
raise ValueError(
"Bug: Got a non-null str for the tenant_id property from OpenSearch but "
"multi-tenant mode is not enabled. This is unexpected because in single-tenant "
"mode we don't expect to see a tenant_id."
"Bug: Got a non-null str for the tenant_id property from OpenSearch but multi-tenant mode is not enabled. "
"This is unexpected because in single-tenant mode we don't expect to see a tenant_id."
)
return TenantState(tenant_id=value, multitenant=MULTI_TENANT)
@@ -356,10 +352,8 @@ class DocumentSchema:
"properties": {
TITLE_FIELD_NAME: {
"type": "text",
# Language analyzer (e.g. english) stems at index and search
# time for variant matching. Configure via
# OPENSEARCH_TEXT_ANALYZER. Existing indices need reindexing
# after a change.
# Language analyzer (e.g. english) stems at index and search time for variant matching.
# Configure via OPENSEARCH_TEXT_ANALYZER. Existing indices need reindexing after a change.
"analyzer": OPENSEARCH_TEXT_ANALYZER,
"fields": {
# Subfield accessed as title.keyword. Not indexed for

View File

@@ -698,6 +698,41 @@ class DocumentQuery:
"""
return {"terms": {ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME: node_ids}}
def _get_assistant_knowledge_filter(
attached_doc_ids: list[str] | None,
node_ids: list[int] | None,
file_ids: list[UUID] | None,
document_sets: list[str] | None,
) -> dict[str, Any]:
"""Combined filter for assistant knowledge.
When an assistant has attached knowledge, search should be scoped to:
- Documents explicitly attached (by document ID), OR
- Documents under attached hierarchy nodes (by ancestor node IDs), OR
- User-uploaded files attached to the assistant, OR
- Documents in the assistant's document sets (if any)
"""
knowledge_filter: dict[str, Any] = {
"bool": {"should": [], "minimum_should_match": 1}
}
if attached_doc_ids:
knowledge_filter["bool"]["should"].append(
_get_attached_document_id_filter(attached_doc_ids)
)
if node_ids:
knowledge_filter["bool"]["should"].append(
_get_hierarchy_node_filter(node_ids)
)
if file_ids:
knowledge_filter["bool"]["should"].append(
_get_user_file_id_filter(file_ids)
)
if document_sets:
knowledge_filter["bool"]["should"].append(
_get_document_set_filter(document_sets)
)
return knowledge_filter
filter_clauses: list[dict[str, Any]] = []
if not include_hidden:
@@ -723,53 +758,41 @@ class DocumentQuery:
# document's metadata list.
filter_clauses.append(_get_tag_filter(tags))
# Knowledge scope: explicit knowledge attachments restrict what
# an assistant can see. When none are set the assistant
# searches everything.
#
# project_id / persona_id are additive: they make overflowing
# user files findable but must NOT trigger the restriction on
# their own (an agent with no explicit knowledge should search
# everything).
has_knowledge_scope = (
# Check if this is an assistant knowledge search (has any assistant-scoped knowledge)
has_assistant_knowledge = (
attached_document_ids
or hierarchy_node_ids
or user_file_ids
or document_sets
)
if has_knowledge_scope:
knowledge_filter: dict[str, Any] = {
"bool": {"should": [], "minimum_should_match": 1}
}
if attached_document_ids:
knowledge_filter["bool"]["should"].append(
_get_attached_document_id_filter(attached_document_ids)
if has_assistant_knowledge:
# If assistant has attached knowledge, scope search to that knowledge.
# Document sets are included in the OR filter so directly attached
# docs are always findable even if not in the document sets.
filter_clauses.append(
_get_assistant_knowledge_filter(
attached_document_ids,
hierarchy_node_ids,
user_file_ids,
document_sets,
)
if hierarchy_node_ids:
knowledge_filter["bool"]["should"].append(
_get_hierarchy_node_filter(hierarchy_node_ids)
)
if user_file_ids:
knowledge_filter["bool"]["should"].append(
_get_user_file_id_filter(user_file_ids)
)
if document_sets:
knowledge_filter["bool"]["should"].append(
_get_document_set_filter(document_sets)
)
# Additive: widen scope to also cover overflowing user
# files, but only when an explicit restriction is already
# in effect.
if project_id is not None:
knowledge_filter["bool"]["should"].append(
_get_user_project_filter(project_id)
)
if persona_id is not None:
knowledge_filter["bool"]["should"].append(
_get_persona_filter(persona_id)
)
filter_clauses.append(knowledge_filter)
)
elif user_file_ids:
# Fallback for non-assistant user file searches (e.g., project searches)
# If at least one user file ID is provided, the caller will only
# retrieve documents where the document ID is in this input list of
# file IDs.
filter_clauses.append(_get_user_file_id_filter(user_file_ids))
if project_id is not None:
# If a project ID is provided, the caller will only retrieve
# documents where the project ID provided here is present in the
# document's user projects list.
filter_clauses.append(_get_user_project_filter(project_id))
if persona_id is not None:
filter_clauses.append(_get_persona_filter(persona_id))
if time_cutoff is not None:
# If a time cutoff is provided, the caller will only retrieve

View File

@@ -465,12 +465,6 @@ class VespaIndex(DocumentIndex):
chunks: list[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[OldDocumentInsertionRecord]:
"""
NOTE: Do NOT consider the secondary index here. A separate indexing
pipeline will be responsible for indexing to the secondary index. This
design is not ideal and we should reconsider this when revamping index
swapping.
"""
if len(index_batch_params.doc_id_to_previous_chunk_cnt) != len(
index_batch_params.doc_id_to_new_chunk_cnt
):
@@ -665,10 +659,6 @@ class VespaIndex(DocumentIndex):
"""Note: if the document id does not exist, the update will be a no-op and the
function will complete with no errors or exceptions.
Handle other exceptions if you wish to implement retry behavior
NOTE: Remember to handle the secondary index here. There is no separate
pipeline for updating chunks in the secondary index. This design is not
ideal and we should reconsider this when revamping index swapping.
"""
if fields is None and user_fields is None:
logger.warning(
@@ -689,6 +679,13 @@ class VespaIndex(DocumentIndex):
f"Bug: Tenant ID mismatch. Expected {tenant_state.tenant_id}, got {tenant_id}."
)
vespa_document_index = VespaDocumentIndex(
index_name=self.index_name,
tenant_state=tenant_state,
large_chunks_enabled=self.large_chunks_enabled,
httpx_client=self.httpx_client,
)
project_ids: set[int] | None = None
if user_fields is not None and user_fields.user_projects is not None:
project_ids = set(user_fields.user_projects)
@@ -708,20 +705,7 @@ class VespaIndex(DocumentIndex):
persona_ids=persona_ids,
)
indices = [self.index_name]
if self.secondary_index_name:
indices.append(self.secondary_index_name)
for index_name in indices:
vespa_document_index = VespaDocumentIndex(
index_name=index_name,
tenant_state=tenant_state,
large_chunks_enabled=self.index_to_large_chunks_enabled.get(
index_name, False
),
httpx_client=self.httpx_client,
)
vespa_document_index.update([update_request])
vespa_document_index.update([update_request])
def delete_single(
self,
@@ -730,11 +714,6 @@ class VespaIndex(DocumentIndex):
tenant_id: str,
chunk_count: int | None,
) -> int:
"""
NOTE: Remember to handle the secondary index here. There is no separate
pipeline for deleting chunks in the secondary index. This design is not
ideal and we should reconsider this when revamping index swapping.
"""
tenant_state = TenantState(
tenant_id=get_current_tenant_id(),
multitenant=MULTI_TENANT,
@@ -747,25 +726,13 @@ class VespaIndex(DocumentIndex):
raise ValueError(
f"Bug: Tenant ID mismatch. Expected {tenant_state.tenant_id}, got {tenant_id}."
)
indices = [self.index_name]
if self.secondary_index_name:
indices.append(self.secondary_index_name)
total_chunks_deleted = 0
for index_name in indices:
vespa_document_index = VespaDocumentIndex(
index_name=index_name,
tenant_state=tenant_state,
large_chunks_enabled=self.index_to_large_chunks_enabled.get(
index_name, False
),
httpx_client=self.httpx_client,
)
total_chunks_deleted += vespa_document_index.delete(
document_id=doc_id, chunk_count=chunk_count
)
return total_chunks_deleted
vespa_document_index = VespaDocumentIndex(
index_name=self.index_name,
tenant_state=tenant_state,
large_chunks_enabled=self.large_chunks_enabled,
httpx_client=self.httpx_client,
)
return vespa_document_index.delete(document_id=doc_id, chunk_count=chunk_count)
def id_based_retrieval(
self,

View File

@@ -23,8 +23,11 @@ from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
def build_tenant_id_filter(tenant_id: str) -> str:
return f'({TENANT_ID} contains "{tenant_id}")'
def build_tenant_id_filter(tenant_id: str, include_trailing_and: bool = False) -> str:
filter_str = f'({TENANT_ID} contains "{tenant_id}")'
if include_trailing_and:
filter_str += " and "
return filter_str
def build_vespa_filters(
@@ -34,22 +37,30 @@ def build_vespa_filters(
remove_trailing_and: bool = False, # Set to True when using as a complete Vespa query
) -> str:
def _build_or_filters(key: str, vals: list[str] | None) -> str:
"""For string-based 'contains' filters, e.g. WSET fields or array<string> fields.
Returns a bare clause like '(key contains "v1" or key contains "v2")' or ""."""
"""For string-based 'contains' filters, e.g. WSET fields or array<string> fields."""
if not key or not vals:
return ""
eq_elems = [f'{key} contains "{val}"' for val in vals if val]
if not eq_elems:
return ""
return f"({' or '.join(eq_elems)})"
or_clause = " or ".join(eq_elems)
return f"({or_clause}) and "
def _build_int_or_filters(key: str, vals: list[int] | None) -> str:
"""For an integer field filter.
Returns a bare clause or ""."""
"""
For an integer field filter.
If vals is not None, we want *only* docs whose key matches one of vals.
"""
# If `vals` is None => skip the filter entirely
if vals is None or not vals:
return ""
# Otherwise build the OR filter
eq_elems = [f"{key} = {val}" for val in vals]
return f"({' or '.join(eq_elems)})"
or_clause = " or ".join(eq_elems)
result = f"({or_clause}) and "
return result
def _build_kg_filter(
kg_entities: list[str] | None,
@@ -62,12 +73,16 @@ def build_vespa_filters(
combined_filter_parts = []
def _build_kge(entity: str) -> str:
# TYPE-SUBTYPE::ID -> "TYPE-SUBTYPE::ID"
# TYPE-SUBTYPE::* -> ({prefix: true}"TYPE-SUBTYPE")
# TYPE::* -> ({prefix: true}"TYPE")
GENERAL = "::*"
if entity.endswith(GENERAL):
return f'({{prefix: true}}"{entity.split(GENERAL, 1)[0]}")'
else:
return f'"{entity}"'
# OR the entities (give new design)
if kg_entities:
filter_parts = []
for kg_entity in kg_entities:
@@ -89,7 +104,8 @@ def build_vespa_filters(
# TODO: remove kg terms entirely from prompts and codebase
return f"({' and '.join(combined_filter_parts)})"
# AND the combined filter parts
return f"({' and '.join(combined_filter_parts)}) and "
def _build_kg_source_filters(
kg_sources: list[str] | None,
@@ -98,14 +114,16 @@ def build_vespa_filters(
return ""
source_phrases = [f'{DOCUMENT_ID} contains "{source}"' for source in kg_sources]
return f"({' or '.join(source_phrases)})"
return f"({' or '.join(source_phrases)}) and "
def _build_kg_chunk_id_zero_only_filter(
kg_chunk_id_zero_only: bool,
) -> str:
if not kg_chunk_id_zero_only:
return ""
return "(chunk_id = 0)"
return "(chunk_id = 0 ) and "
def _build_time_filter(
cutoff: datetime | None,
@@ -117,8 +135,8 @@ def build_vespa_filters(
cutoff_secs = int(cutoff.timestamp())
if include_untimed:
return f"!({DOC_UPDATED_AT} < {cutoff_secs})"
return f"({DOC_UPDATED_AT} >= {cutoff_secs})"
return f"!({DOC_UPDATED_AT} < {cutoff_secs}) and "
return f"({DOC_UPDATED_AT} >= {cutoff_secs}) and "
def _build_user_project_filter(
project_id: int | None,
@@ -129,7 +147,8 @@ def build_vespa_filters(
pid = int(project_id)
except Exception:
return ""
return f'({USER_PROJECT} contains "{pid}")'
# Vespa YQL 'contains' expects a string literal; quote the integer
return f'({USER_PROJECT} contains "{pid}") and '
def _build_persona_filter(
persona_id: int | None,
@@ -141,94 +160,73 @@ def build_vespa_filters(
except Exception:
logger.warning(f"Invalid persona ID: {persona_id}")
return ""
return f'({PERSONAS} contains "{pid}")'
return f'({PERSONAS} contains "{pid}") and '
def _append(parts: list[str], clause: str) -> None:
if clause:
parts.append(clause)
# Collect all top-level filter clauses, then join with " and " at the end.
filter_parts: list[str] = []
if not include_hidden:
filter_parts.append(f"!({HIDDEN}=true)")
# Start building the filter string
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
# TODO: add error condition if MULTI_TENANT and no tenant_id filter is set
# If running in multi-tenant mode
if filters.tenant_id and MULTI_TENANT:
filter_parts.append(build_tenant_id_filter(filters.tenant_id))
filter_str += build_tenant_id_filter(
filters.tenant_id, include_trailing_and=True
)
# ACL filters
if filters.access_control_list is not None:
_append(
filter_parts,
_build_or_filters(ACCESS_CONTROL_LIST, filters.access_control_list),
filter_str += _build_or_filters(
ACCESS_CONTROL_LIST, filters.access_control_list
)
# Source type filters
source_strs = (
[s.value for s in filters.source_type] if filters.source_type else None
)
_append(filter_parts, _build_or_filters(SOURCE_TYPE, source_strs))
filter_str += _build_or_filters(SOURCE_TYPE, source_strs)
# Tag filters
tag_attributes = None
if filters.tags:
# build e.g. "tag_key|tag_value"
tag_attributes = [
f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}" for tag in filters.tags
]
_append(filter_parts, _build_or_filters(METADATA_LIST, tag_attributes))
filter_str += _build_or_filters(METADATA_LIST, tag_attributes)
# Knowledge scope: explicit knowledge attachments (document_sets,
# user_file_ids) restrict what an assistant can see. When none are
# set, the assistant can see everything.
#
# project_id / persona_id are additive: they make overflowing user
# files findable in Vespa but must NOT trigger the restriction on
# their own (an agent with no explicit knowledge should search
# everything).
knowledge_scope_parts: list[str] = []
_append(
knowledge_scope_parts, _build_or_filters(DOCUMENT_SETS, filters.document_set)
)
# Document sets
filter_str += _build_or_filters(DOCUMENT_SETS, filters.document_set)
# Convert UUIDs to strings for user_file_ids
user_file_ids_str = (
[str(uuid) for uuid in filters.user_file_ids] if filters.user_file_ids else None
)
_append(knowledge_scope_parts, _build_or_filters(DOCUMENT_ID, user_file_ids_str))
filter_str += _build_or_filters(DOCUMENT_ID, user_file_ids_str)
# Only include project/persona scopes when an explicit knowledge
# restriction is already in effect — they widen the scope to also
# cover overflowing user files but never restrict on their own.
if knowledge_scope_parts:
_append(knowledge_scope_parts, _build_user_project_filter(filters.project_id))
_append(knowledge_scope_parts, _build_persona_filter(filters.persona_id))
# User project filter (array<int> attribute membership)
filter_str += _build_user_project_filter(filters.project_id)
if len(knowledge_scope_parts) > 1:
filter_parts.append("(" + " or ".join(knowledge_scope_parts) + ")")
elif len(knowledge_scope_parts) == 1:
filter_parts.append(knowledge_scope_parts[0])
# Persona filter (array<int> attribute membership)
filter_str += _build_persona_filter(filters.persona_id)
# Time filter
_append(filter_parts, _build_time_filter(filters.time_cutoff))
filter_str += _build_time_filter(filters.time_cutoff)
# # Knowledge Graph Filters
# _append(filter_parts, _build_kg_filter(
# filter_str += _build_kg_filter(
# kg_entities=filters.kg_entities,
# kg_relationships=filters.kg_relationships,
# kg_terms=filters.kg_terms,
# ))
# )
# _append(filter_parts, _build_kg_source_filters(filters.kg_sources))
# filter_str += _build_kg_source_filters(filters.kg_sources)
# _append(filter_parts, _build_kg_chunk_id_zero_only_filter(
# filter_str += _build_kg_chunk_id_zero_only_filter(
# filters.kg_chunk_id_zero_only or False
# ))
# )
filter_str = " and ".join(filter_parts)
if filter_str and not remove_trailing_and:
filter_str += " and "
# Trim trailing " and "
if remove_trailing_and and filter_str.endswith(" and "):
filter_str = filter_str[:-5]
return filter_str

View File

@@ -91,11 +91,11 @@ class OnyxErrorCode(Enum):
"""Build a structured error detail dict.
Returns a dict like:
{"error_code": "UNAUTHENTICATED", "detail": "Token expired"}
{"error_code": "UNAUTHENTICATED", "message": "Token expired"}
If no message is supplied, the error code itself is used as the detail.
If no message is supplied, the error code itself is used as the message.
"""
return {
"error_code": self.code,
"detail": message or self.code,
"message": message or self.code,
}

View File

@@ -3,7 +3,7 @@
Raise ``OnyxError`` instead of ``HTTPException`` in business code. A global
FastAPI exception handler (registered via ``register_onyx_exception_handlers``)
converts it into a JSON response with the standard
``{"error_code": "...", "detail": "..."}`` shape.
``{"error_code": "...", "message": "..."}`` shape.
Usage::
@@ -37,22 +37,21 @@ class OnyxError(Exception):
Attributes:
error_code: The ``OnyxErrorCode`` enum member.
detail: Human-readable detail (defaults to the error code string).
message: Human-readable message (defaults to the error code string).
status_code: HTTP status — either overridden or from the error code.
"""
def __init__(
self,
error_code: OnyxErrorCode,
detail: str | None = None,
message: str | None = None,
*,
status_code_override: int | None = None,
) -> None:
resolved_detail = detail or error_code.code
super().__init__(resolved_detail)
self.error_code = error_code
self.detail = resolved_detail
self.message = message or error_code.code
self._status_code_override = status_code_override
super().__init__(self.message)
@property
def status_code(self) -> int:
@@ -73,11 +72,11 @@ def register_onyx_exception_handlers(app: FastAPI) -> None:
) -> JSONResponse:
status_code = exc.status_code
if status_code >= 500:
logger.error(f"OnyxError {exc.error_code.code}: {exc.detail}")
logger.error(f"OnyxError {exc.error_code.code}: {exc.message}")
elif status_code >= 400:
logger.warning(f"OnyxError {exc.error_code.code}: {exc.detail}")
logger.warning(f"OnyxError {exc.error_code.code}: {exc.message}")
return JSONResponse(
status_code=status_code,
content=exc.error_code.detail(exc.detail),
content=exc.error_code.detail(exc.message),
)

View File

@@ -19,16 +19,12 @@ class OnyxMimeTypes:
PLAIN_TEXT_MIME_TYPE,
"text/markdown",
"text/x-markdown",
"text/x-log",
"text/x-config",
"text/tab-separated-values",
"application/json",
"application/xml",
"text/xml",
"application/x-yaml",
"application/yaml",
"text/yaml",
"text/x-yaml",
}
DOCUMENT_MIME_TYPES = {
PDF_MIME_TYPE,

View File

@@ -49,6 +49,7 @@ from onyx.indexing.embedder import IndexingEmbedder
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import IndexingBatchAdapter
from onyx.indexing.models import UpdatableChunkData
from onyx.indexing.postgres_sanitization import sanitize_documents_for_postgres
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
from onyx.llm.factory import get_default_llm_with_vision
from onyx.llm.factory import get_llm_for_contextual_rag
@@ -64,7 +65,6 @@ from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT1
from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT2
from onyx.prompts.contextual_retrieval import DOCUMENT_SUMMARY_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.postgres_sanitization import sanitize_documents_for_postgres
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.timing import log_function_time

View File

@@ -1,49 +1,30 @@
import re
from typing import Any
from onyx.access.models import ExternalAccess
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
from onyx.utils.logger import setup_logger
logger = setup_logger()
_SURROGATE_RE = re.compile(r"[\ud800-\udfff]")
def sanitize_string(value: str) -> str:
"""Strip characters that PostgreSQL text/JSONB columns cannot store.
Removes:
- NUL bytes (\\x00)
- UTF-16 surrogates (\\ud800-\\udfff), which are invalid in UTF-8
"""
sanitized = value.replace("\x00", "")
sanitized = _SURROGATE_RE.sub("", sanitized)
if value and not sanitized:
logger.warning(
"sanitize_string: all characters were removed from a non-empty string"
)
return sanitized
def _sanitize_string(value: str) -> str:
return value.replace("\x00", "")
def sanitize_json_like(value: Any) -> Any:
"""Recursively sanitize all strings in a JSON-like structure (dict/list/tuple)."""
def _sanitize_json_like(value: Any) -> Any:
if isinstance(value, str):
return sanitize_string(value)
return _sanitize_string(value)
if isinstance(value, list):
return [sanitize_json_like(item) for item in value]
return [_sanitize_json_like(item) for item in value]
if isinstance(value, tuple):
return tuple(sanitize_json_like(item) for item in value)
return tuple(_sanitize_json_like(item) for item in value)
if isinstance(value, dict):
sanitized: dict[Any, Any] = {}
for key, nested_value in value.items():
cleaned_key = sanitize_string(key) if isinstance(key, str) else key
sanitized[cleaned_key] = sanitize_json_like(nested_value)
cleaned_key = _sanitize_string(key) if isinstance(key, str) else key
sanitized[cleaned_key] = _sanitize_json_like(nested_value)
return sanitized
return value
@@ -53,27 +34,27 @@ def _sanitize_expert_info(expert: BasicExpertInfo) -> BasicExpertInfo:
return expert.model_copy(
update={
"display_name": (
sanitize_string(expert.display_name)
_sanitize_string(expert.display_name)
if expert.display_name is not None
else None
),
"first_name": (
sanitize_string(expert.first_name)
_sanitize_string(expert.first_name)
if expert.first_name is not None
else None
),
"middle_initial": (
sanitize_string(expert.middle_initial)
_sanitize_string(expert.middle_initial)
if expert.middle_initial is not None
else None
),
"last_name": (
sanitize_string(expert.last_name)
_sanitize_string(expert.last_name)
if expert.last_name is not None
else None
),
"email": (
sanitize_string(expert.email) if expert.email is not None else None
_sanitize_string(expert.email) if expert.email is not None else None
),
}
)
@@ -82,10 +63,10 @@ def _sanitize_expert_info(expert: BasicExpertInfo) -> BasicExpertInfo:
def _sanitize_external_access(external_access: ExternalAccess) -> ExternalAccess:
return ExternalAccess(
external_user_emails={
sanitize_string(email) for email in external_access.external_user_emails
_sanitize_string(email) for email in external_access.external_user_emails
},
external_user_group_ids={
sanitize_string(group_id)
_sanitize_string(group_id)
for group_id in external_access.external_user_group_ids
},
is_public=external_access.is_public,
@@ -95,26 +76,26 @@ def _sanitize_external_access(external_access: ExternalAccess) -> ExternalAccess
def sanitize_document_for_postgres(document: Document) -> Document:
cleaned_doc = document.model_copy(deep=True)
cleaned_doc.id = sanitize_string(cleaned_doc.id)
cleaned_doc.semantic_identifier = sanitize_string(cleaned_doc.semantic_identifier)
cleaned_doc.id = _sanitize_string(cleaned_doc.id)
cleaned_doc.semantic_identifier = _sanitize_string(cleaned_doc.semantic_identifier)
if cleaned_doc.title is not None:
cleaned_doc.title = sanitize_string(cleaned_doc.title)
cleaned_doc.title = _sanitize_string(cleaned_doc.title)
if cleaned_doc.parent_hierarchy_raw_node_id is not None:
cleaned_doc.parent_hierarchy_raw_node_id = sanitize_string(
cleaned_doc.parent_hierarchy_raw_node_id = _sanitize_string(
cleaned_doc.parent_hierarchy_raw_node_id
)
cleaned_doc.metadata = {
sanitize_string(key): (
[sanitize_string(item) for item in value]
_sanitize_string(key): (
[_sanitize_string(item) for item in value]
if isinstance(value, list)
else sanitize_string(value)
else _sanitize_string(value)
)
for key, value in cleaned_doc.metadata.items()
}
if cleaned_doc.doc_metadata is not None:
cleaned_doc.doc_metadata = sanitize_json_like(cleaned_doc.doc_metadata)
cleaned_doc.doc_metadata = _sanitize_json_like(cleaned_doc.doc_metadata)
if cleaned_doc.primary_owners is not None:
cleaned_doc.primary_owners = [
@@ -132,11 +113,11 @@ def sanitize_document_for_postgres(document: Document) -> Document:
for section in cleaned_doc.sections:
if section.link is not None:
section.link = sanitize_string(section.link)
section.link = _sanitize_string(section.link)
if section.text is not None:
section.text = sanitize_string(section.text)
section.text = _sanitize_string(section.text)
if section.image_file_id is not None:
section.image_file_id = sanitize_string(section.image_file_id)
section.image_file_id = _sanitize_string(section.image_file_id)
return cleaned_doc
@@ -148,12 +129,12 @@ def sanitize_documents_for_postgres(documents: list[Document]) -> list[Document]
def sanitize_hierarchy_node_for_postgres(node: HierarchyNode) -> HierarchyNode:
cleaned_node = node.model_copy(deep=True)
cleaned_node.raw_node_id = sanitize_string(cleaned_node.raw_node_id)
cleaned_node.display_name = sanitize_string(cleaned_node.display_name)
cleaned_node.raw_node_id = _sanitize_string(cleaned_node.raw_node_id)
cleaned_node.display_name = _sanitize_string(cleaned_node.display_name)
if cleaned_node.raw_parent_id is not None:
cleaned_node.raw_parent_id = sanitize_string(cleaned_node.raw_parent_id)
cleaned_node.raw_parent_id = _sanitize_string(cleaned_node.raw_parent_id)
if cleaned_node.link is not None:
cleaned_node.link = sanitize_string(cleaned_node.link)
cleaned_node.link = _sanitize_string(cleaned_node.link)
if cleaned_node.external_access is not None:
cleaned_node.external_access = _sanitize_external_access(

View File

@@ -22,7 +22,6 @@ class LlmProviderNames(str, Enum):
OPENROUTER = "openrouter"
AZURE = "azure"
OLLAMA_CHAT = "ollama_chat"
LM_STUDIO = "lm_studio"
MISTRAL = "mistral"
LITELLM_PROXY = "litellm_proxy"
@@ -42,8 +41,6 @@ WELL_KNOWN_PROVIDER_NAMES = [
LlmProviderNames.OPENROUTER,
LlmProviderNames.AZURE,
LlmProviderNames.OLLAMA_CHAT,
LlmProviderNames.LM_STUDIO,
LlmProviderNames.LITELLM_PROXY,
]
@@ -59,8 +56,6 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
LlmProviderNames.AZURE: "Azure",
"ollama": "Ollama",
LlmProviderNames.OLLAMA_CHAT: "Ollama",
LlmProviderNames.LM_STUDIO: "LM Studio",
LlmProviderNames.LITELLM_PROXY: "LiteLLM Proxy",
"groq": "Groq",
"anyscale": "Anyscale",
"deepseek": "DeepSeek",
@@ -108,10 +103,8 @@ AGGREGATOR_PROVIDERS: set[str] = {
LlmProviderNames.BEDROCK_CONVERSE,
LlmProviderNames.OPENROUTER,
LlmProviderNames.OLLAMA_CHAT,
LlmProviderNames.LM_STUDIO,
LlmProviderNames.VERTEX_AI,
LlmProviderNames.AZURE,
LlmProviderNames.LITELLM_PROXY,
}
# Model family name mappings for display name generation

View File

@@ -20,9 +20,7 @@ from onyx.llm.multi_llm import LitellmLLM
from onyx.llm.override_models import LLMOverride
from onyx.llm.utils import get_max_input_tokens_from_llm_provider
from onyx.llm.utils import model_supports_image_input
from onyx.llm.well_known_providers.constants import (
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING,
)
from onyx.llm.well_known_providers.constants import OLLAMA_API_KEY_CONFIG_KEY
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.server.manage.llm.models import LLMProviderView
from onyx.utils.headers import build_llm_extra_headers
@@ -34,18 +32,14 @@ logger = setup_logger()
def _build_provider_extra_headers(
provider: str, custom_config: dict[str, str] | None
) -> dict[str, str]:
if provider in PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING and custom_config:
raw = custom_config.get(PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING[provider])
api_key = raw.strip() if raw else None
if provider == LlmProviderNames.OLLAMA_CHAT and custom_config:
raw_api_key = custom_config.get(OLLAMA_API_KEY_CONFIG_KEY)
api_key = raw_api_key.strip() if raw_api_key else None
if not api_key:
return {}
return {
"Authorization": (
api_key
if api_key.lower().startswith("bearer ")
else f"Bearer {api_key}"
)
}
if not api_key.lower().startswith("bearer "):
api_key = f"Bearer {api_key}"
return {"Authorization": api_key}
# Passing these will put Onyx on the OpenRouter leaderboard
elif provider == LlmProviderNames.OPENROUTER:

View File

@@ -1512,10 +1512,6 @@
"display_name": "Claude Opus 4.5",
"model_vendor": "anthropic"
},
"claude-opus-4-6": {
"display_name": "Claude Opus 4.6",
"model_vendor": "anthropic"
},
"claude-opus-4-5-20251101": {
"display_name": "Claude Opus 4.5",
"model_vendor": "anthropic",
@@ -1530,10 +1526,6 @@
"display_name": "Claude Sonnet 4.5",
"model_vendor": "anthropic"
},
"claude-sonnet-4-6": {
"display_name": "Claude Sonnet 4.6",
"model_vendor": "anthropic"
},
"claude-sonnet-4-5-20250929": {
"display_name": "Claude Sonnet 4.5",
"model_vendor": "anthropic",
@@ -2524,10 +2516,6 @@
"model_vendor": "openai",
"model_version": "2025-10-06"
},
"gpt-5.4": {
"display_name": "GPT-5.4",
"model_vendor": "openai"
},
"gpt-5.2-pro-2025-12-11": {
"display_name": "GPT-5.2 Pro",
"model_vendor": "openai",

View File

@@ -42,7 +42,6 @@ from onyx.llm.well_known_providers.constants import AWS_SECRET_ACCESS_KEY_KWARG
from onyx.llm.well_known_providers.constants import (
AWS_SECRET_ACCESS_KEY_KWARG_ENV_VAR_FORMAT,
)
from onyx.llm.well_known_providers.constants import LM_STUDIO_API_KEY_CONFIG_KEY
from onyx.llm.well_known_providers.constants import OLLAMA_API_KEY_CONFIG_KEY
from onyx.llm.well_known_providers.constants import VERTEX_CREDENTIALS_FILE_KWARG
from onyx.llm.well_known_providers.constants import (
@@ -93,98 +92,6 @@ def _prompt_to_dicts(prompt: LanguageModelInput) -> list[dict[str, Any]]:
return [prompt.model_dump(exclude_none=True)]
def _normalize_content(raw: Any) -> str:
"""Normalize a message content field to a plain string.
Content can be a string, None, or a list of content-block dicts
(e.g. [{"type": "text", "text": "..."}]).
"""
if raw is None:
return ""
if isinstance(raw, str):
return raw
if isinstance(raw, list):
return "\n".join(
block.get("text", "") if isinstance(block, dict) else str(block)
for block in raw
)
return str(raw)
def _strip_tool_content_from_messages(
messages: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Convert tool-related messages to plain text.
Bedrock's Converse API requires toolConfig when messages contain
toolUse/toolResult content blocks. When no tools are provided for the
current request, we must convert any tool-related history into plain text
to avoid the "toolConfig field must be defined" error.
This is the same approach used by _OllamaHistoryMessageFormatter.
"""
result: list[dict[str, Any]] = []
for msg in messages:
role = msg.get("role")
tool_calls = msg.get("tool_calls")
if role == "assistant" and tool_calls:
# Convert structured tool calls to text representation
tool_call_lines = []
for tc in tool_calls:
func = tc.get("function", {})
name = func.get("name", "unknown")
args = func.get("arguments", "{}")
tc_id = tc.get("id", "")
tool_call_lines.append(
f"[Tool Call] name={name} id={tc_id} args={args}"
)
existing_content = _normalize_content(msg.get("content"))
parts = (
[existing_content] + tool_call_lines
if existing_content
else tool_call_lines
)
new_msg = {
"role": "assistant",
"content": "\n".join(parts),
}
result.append(new_msg)
elif role == "tool":
# Convert tool response to user message with text content
tool_call_id = msg.get("tool_call_id", "")
content = _normalize_content(msg.get("content"))
tool_result_text = f"[Tool Result] id={tool_call_id}\n{content}"
# Merge into previous user message if it is also a converted
# tool result to avoid consecutive user messages (Bedrock requires
# strict user/assistant alternation).
if (
result
and result[-1]["role"] == "user"
and "[Tool Result]" in result[-1].get("content", "")
):
result[-1]["content"] += "\n\n" + tool_result_text
else:
result.append({"role": "user", "content": tool_result_text})
else:
result.append(msg)
return result
def _messages_contain_tool_content(messages: list[dict[str, Any]]) -> bool:
"""Check if any messages contain tool-related content blocks."""
for msg in messages:
if msg.get("role") == "tool":
return True
if msg.get("role") == "assistant" and msg.get("tool_calls"):
return True
return False
def _is_vertex_model_rejecting_output_config(model_name: str) -> bool:
normalized_model_name = model_name.lower()
return any(
@@ -250,9 +157,6 @@ class LitellmLLM(LLM):
elif model_provider == LlmProviderNames.OLLAMA_CHAT:
if k == OLLAMA_API_KEY_CONFIG_KEY:
model_kwargs["api_key"] = v
elif model_provider == LlmProviderNames.LM_STUDIO:
if k == LM_STUDIO_API_KEY_CONFIG_KEY:
model_kwargs["api_key"] = v
elif model_provider == LlmProviderNames.BEDROCK:
if k == AWS_REGION_NAME_KWARG:
model_kwargs[k] = v
@@ -269,19 +173,6 @@ class LitellmLLM(LLM):
elif k == AWS_SECRET_ACCESS_KEY_KWARG_ENV_VAR_FORMAT:
model_kwargs[AWS_SECRET_ACCESS_KEY_KWARG] = v
# LM Studio: LiteLLM defaults to "fake-api-key" when no key is provided,
# which LM Studio rejects. Ensure we always pass an explicit key (or empty
# string) to prevent LiteLLM from injecting its fake default.
if model_provider == LlmProviderNames.LM_STUDIO:
model_kwargs.setdefault("api_key", "")
# Users provide the server root (e.g. http://localhost:1234) but LiteLLM
# needs /v1 for OpenAI-compatible calls.
if self._api_base is not None:
base = self._api_base.rstrip("/")
self._api_base = base if base.endswith("/v1") else f"{base}/v1"
model_kwargs["api_base"] = self._api_base
# Default vertex_location to "global" if not provided for Vertex AI
# Latest gemini models are only available through the global region
if (
@@ -513,30 +404,13 @@ class LitellmLLM(LLM):
else nullcontext()
)
with env_ctx:
messages = _prompt_to_dicts(prompt)
# Bedrock's Converse API requires toolConfig when messages
# contain toolUse/toolResult content blocks. When no tools are
# provided for this request but the history contains tool
# content from previous turns, strip it to plain text.
is_bedrock = self._model_provider in {
LlmProviderNames.BEDROCK,
LlmProviderNames.BEDROCK_CONVERSE,
}
if (
is_bedrock
and not tools
and _messages_contain_tool_content(messages)
):
messages = _strip_tool_content_from_messages(messages)
response = litellm.completion(
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
model=model,
base_url=self._api_base or None,
api_version=self._api_version or None,
custom_llm_provider=self._custom_llm_provider or None,
messages=messages,
messages=_prompt_to_dicts(prompt),
tools=tools,
tool_choice=tool_choice,
stream=stream,

View File

@@ -322,7 +322,7 @@ def test_llm(llm: LLM) -> str | None:
error_msg = None
for _ in range(2):
try:
llm.invoke(UserMessage(content="Do not respond"), max_tokens=50)
llm.invoke(UserMessage(content="Do not respond"))
return None
except Exception as e:
error_msg = str(e)

View File

@@ -1,29 +1,52 @@
from onyx.llm.constants import LlmProviderNames
OPENAI_PROVIDER_NAME = "openai"
# Curated list of OpenAI models to show by default in the UI
OPENAI_VISIBLE_MODEL_NAMES = {
"gpt-5",
"gpt-5-mini",
"o1",
"o3-mini",
"gpt-4o",
"gpt-4o-mini",
}
BEDROCK_PROVIDER_NAME = "bedrock"
BEDROCK_DEFAULT_MODEL = "anthropic.claude-3-5-sonnet-20241022-v2:0"
def _fallback_bedrock_regions() -> list[str]:
# Fall back to a conservative set of well-known Bedrock regions if boto3 data isn't available.
return [
"us-east-1",
"us-east-2",
"us-gov-east-1",
"us-gov-west-1",
"us-west-2",
"ap-northeast-1",
"ap-south-1",
"ap-southeast-1",
"ap-southeast-2",
"ap-east-1",
"ca-central-1",
"eu-central-1",
"eu-west-2",
]
OLLAMA_PROVIDER_NAME = "ollama_chat"
OLLAMA_API_KEY_CONFIG_KEY = "OLLAMA_API_KEY"
LM_STUDIO_PROVIDER_NAME = "lm_studio"
LM_STUDIO_API_KEY_CONFIG_KEY = "LM_STUDIO_API_KEY"
LITELLM_PROXY_PROVIDER_NAME = "litellm_proxy"
# Providers that use optional Bearer auth from custom_config
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING: dict[str, str] = {
LlmProviderNames.OLLAMA_CHAT: OLLAMA_API_KEY_CONFIG_KEY,
LlmProviderNames.LM_STUDIO: LM_STUDIO_API_KEY_CONFIG_KEY,
}
# OpenRouter
OPENROUTER_PROVIDER_NAME = "openrouter"
ANTHROPIC_PROVIDER_NAME = "anthropic"
# Curated list of Anthropic models to show by default in the UI
ANTHROPIC_VISIBLE_MODEL_NAMES = {
"claude-opus-4-5",
"claude-sonnet-4-5",
"claude-haiku-4-5",
}
AZURE_PROVIDER_NAME = "azure"
@@ -31,6 +54,13 @@ VERTEXAI_PROVIDER_NAME = "vertex_ai"
VERTEX_CREDENTIALS_FILE_KWARG = "vertex_credentials"
VERTEX_CREDENTIALS_FILE_KWARG_ENV_VAR_FORMAT = "CREDENTIALS_FILE"
VERTEX_LOCATION_KWARG = "vertex_location"
VERTEXAI_DEFAULT_MODEL = "gemini-2.5-flash"
# Curated list of Vertex AI models to show by default in the UI
VERTEXAI_VISIBLE_MODEL_NAMES = {
"gemini-2.5-flash",
"gemini-2.5-flash-lite",
"gemini-2.5-pro",
}
AWS_REGION_NAME_KWARG = "aws_region_name"
AWS_REGION_NAME_KWARG_ENV_VAR_FORMAT = "AWS_REGION_NAME"

View File

@@ -15,8 +15,6 @@ from onyx.llm.well_known_providers.auto_update_service import (
from onyx.llm.well_known_providers.constants import ANTHROPIC_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import AZURE_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import BEDROCK_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import LITELLM_PROXY_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import LM_STUDIO_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OLLAMA_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OPENAI_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OPENROUTER_PROVIDER_NAME
@@ -46,9 +44,7 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
ANTHROPIC_PROVIDER_NAME: get_anthropic_model_names(),
VERTEXAI_PROVIDER_NAME: get_vertexai_model_names(),
OLLAMA_PROVIDER_NAME: [], # Dynamic - fetched from Ollama API
LM_STUDIO_PROVIDER_NAME: [], # Dynamic - fetched from LM Studio API
OPENROUTER_PROVIDER_NAME: [], # Dynamic - fetched from OpenRouter API
LITELLM_PROXY_PROVIDER_NAME: [], # Dynamic - fetched from LiteLLM proxy API
}
@@ -327,13 +323,11 @@ def get_provider_display_name(provider_name: str) -> str:
_ONYX_PROVIDER_DISPLAY_NAMES: dict[str, str] = {
OPENAI_PROVIDER_NAME: "ChatGPT (OpenAI)",
OLLAMA_PROVIDER_NAME: "Ollama",
LM_STUDIO_PROVIDER_NAME: "LM Studio",
ANTHROPIC_PROVIDER_NAME: "Claude (Anthropic)",
AZURE_PROVIDER_NAME: "Azure OpenAI",
BEDROCK_PROVIDER_NAME: "Amazon Bedrock",
VERTEXAI_PROVIDER_NAME: "Google Vertex AI",
OPENROUTER_PROVIDER_NAME: "OpenRouter",
LITELLM_PROXY_PROVIDER_NAME: "LiteLLM Proxy",
}
if provider_name in _ONYX_PROVIDER_DISPLAY_NAMES:

View File

@@ -1,12 +1,12 @@
{
"version": "1.1",
"updated_at": "2026-03-05T00:00:00Z",
"updated_at": "2026-02-05T00:00:00Z",
"providers": {
"openai": {
"default_model": { "name": "gpt-5.4" },
"default_model": { "name": "gpt-5.2" },
"additional_visible_models": [
{ "name": "gpt-5.4" },
{ "name": "gpt-5.2" }
{ "name": "gpt-5-mini" },
{ "name": "gpt-4.1" }
]
},
"anthropic": {
@@ -16,10 +16,6 @@
"name": "claude-opus-4-6",
"display_name": "Claude Opus 4.6"
},
{
"name": "claude-sonnet-4-6",
"display_name": "Claude Sonnet 4.6"
},
{
"name": "claude-opus-4-5",
"display_name": "Claude Opus 4.5"

View File

@@ -10,7 +10,6 @@ from onyx.mcp_server.utils import get_indexed_sources
from onyx.mcp_server.utils import require_access_token
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import build_api_server_url_for_http_requests
from onyx.utils.variable_functionality import global_version
logger = setup_logger()
@@ -27,14 +26,6 @@ async def search_indexed_documents(
Use this tool for information that is not public knowledge and specific to the user,
their team, their work, or their organization/company.
Note: In CE mode, this tool uses the chat endpoint internally which invokes an LLM
on every call, consuming tokens and adding latency.
Additionally, CE callers receive a truncated snippet (blurb) instead of a full document chunk,
but this should still be sufficient for most use cases. CE mode functionality should be swapped
when a dedicated CE search endpoint is implemented.
In EE mode, the dedicated search endpoint is used instead.
To find a list of available sources, use the `indexed_sources` resource.
Returns chunks of text as search results with snippets, scores, and metadata.
@@ -120,72 +111,47 @@ async def search_indexed_documents(
if time_cutoff_dt:
filters["time_cutoff"] = time_cutoff_dt.isoformat()
is_ee = global_version.is_ee_version()
base_url = build_api_server_url_for_http_requests(respect_env_override_if_set=True)
auth_headers = {"Authorization": f"Bearer {access_token.token}"}
search_request: dict[str, Any]
if is_ee:
# EE: use the dedicated search endpoint (no LLM invocation)
search_request = {
"search_query": query,
"filters": filters,
"num_docs_fed_to_llm_selection": limit,
"run_query_expansion": False,
"include_content": True,
"stream": False,
}
endpoint = f"{base_url}/search/send-search-message"
error_key = "error"
docs_key = "search_docs"
content_field = "content"
else:
# CE: fall back to the chat endpoint (invokes LLM, consumes tokens)
search_request = {
"message": query,
"stream": False,
"chat_session_info": {},
}
if filters:
search_request["internal_search_filters"] = filters
endpoint = f"{base_url}/chat/send-chat-message"
error_key = "error_msg"
docs_key = "top_documents"
content_field = "blurb"
# Build the search request using the new SendSearchQueryRequest format
search_request = {
"search_query": query,
"filters": filters,
"num_docs_fed_to_llm_selection": limit,
"run_query_expansion": False,
"include_content": True,
"stream": False,
}
# Call the API server using the new send-search-message route
try:
response = await get_http_client().post(
endpoint,
f"{build_api_server_url_for_http_requests(respect_env_override_if_set=True)}/search/send-search-message",
json=search_request,
headers=auth_headers,
headers={"Authorization": f"Bearer {access_token.token}"},
)
response.raise_for_status()
result = response.json()
# Check for error in response
if result.get(error_key):
if result.get("error"):
return {
"documents": [],
"total_results": 0,
"query": query,
"error": result.get(error_key),
"error": result.get("error"),
}
documents = [
{
"semantic_identifier": doc.get("semantic_identifier"),
"content": doc.get(content_field),
"source_type": doc.get("source_type"),
"link": doc.get("link"),
"score": doc.get("score"),
}
for doc in result.get(docs_key, [])
# Return simplified format for MCP clients
fields_to_return = [
"semantic_identifier",
"content",
"source_type",
"link",
"score",
]
documents = [
{key: doc.get(key) for key in fields_to_return}
for doc in result.get("search_docs", [])
]
# NOTE: search depth is controlled by the backend persona defaults, not `limit`.
# `limit` only caps the returned list; fewer results may be returned if the
# backend retrieves fewer documents than requested.
documents = documents[:limit]
logger.info(
f"Onyx MCP Server: Internal search returned {len(documents)} results"
@@ -194,6 +160,7 @@ async def search_indexed_documents(
"documents": documents,
"total_results": len(documents),
"query": query,
"executed_queries": result.get("all_executed_queries", [query]),
}
except Exception as e:
logger.error(f"Onyx MCP Server: Document search error: {e}", exc_info=True)

View File

@@ -130,7 +130,7 @@ def format_slack_message(message: str | None) -> str:
message = _transform_outside_code_blocks(message, _sanitize_html)
message = _convert_slack_links_to_markdown(message)
normalized_message = _normalize_link_destinations(message)
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough", "table"])
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough"])
result = md(normalized_message)
# With HTMLRenderer, result is always str (not AST list)
assert isinstance(result, str)
@@ -146,11 +146,6 @@ class SlackRenderer(HTMLRenderer):
SPECIALS: dict[str, str] = {"&": "&amp;", "<": "&lt;", ">": "&gt;"}
def __init__(self) -> None:
super().__init__()
self._table_headers: list[str] = []
self._current_row_cells: list[str] = []
def escape_special(self, text: str) -> str:
for special, replacement in self.SPECIALS.items():
text = text.replace(special, replacement)
@@ -223,48 +218,5 @@ class SlackRenderer(HTMLRenderer):
# as literal &quot; text since Slack doesn't recognize that entity.
return self.escape_special(text)
# -- Table rendering (converts markdown tables to vertical cards) --
def table_cell(
self, text: str, align: str | None = None, head: bool = False # noqa: ARG002
) -> str:
if head:
self._table_headers.append(text.strip())
else:
self._current_row_cells.append(text.strip())
return ""
def table_head(self, text: str) -> str: # noqa: ARG002
self._current_row_cells = []
return ""
def table_row(self, text: str) -> str: # noqa: ARG002
cells = self._current_row_cells
self._current_row_cells = []
# First column becomes the bold title, remaining columns are bulleted fields
lines: list[str] = []
if cells:
title = cells[0]
if title:
# Avoid double-wrapping if cell already contains bold markup
if title.startswith("*") and title.endswith("*") and len(title) > 1:
lines.append(title)
else:
lines.append(f"*{title}*")
for i, cell in enumerate(cells[1:], start=1):
if i < len(self._table_headers):
lines.append(f"{self._table_headers[i]}: {cell}")
else:
lines.append(f"{cell}")
return "\n".join(lines) + "\n\n"
def table_body(self, text: str) -> str:
return text
def table(self, text: str) -> str:
self._table_headers = []
self._current_row_cells = []
return text + "\n"
def paragraph(self, text: str) -> str:
return f"{text}\n\n"

View File

@@ -961,9 +961,9 @@
"license": "MIT"
},
"node_modules/@hono/node-server": {
"version": "1.19.10",
"resolved": "https://registry.npmjs.org/@hono/node-server/-/node-server-1.19.10.tgz",
"integrity": "sha512-hZ7nOssGqRgyV3FVVQdfi+U4q02uB23bpnYpdvNXkYTRRyWx84b7yf1ans+dnJ/7h41sGL3CeQTfO+ZGxuO+Iw==",
"version": "1.19.9",
"resolved": "https://registry.npmjs.org/@hono/node-server/-/node-server-1.19.9.tgz",
"integrity": "sha512-vHL6w3ecZsky+8P5MD+eFfaGTyCeOHUIFYMGpQGbrBTSmNNoxv0if69rEZ5giu36weC5saFuznL411gRX7bJDw==",
"license": "MIT",
"engines": {
"node": ">=18.14.1"
@@ -1573,6 +1573,27 @@
}
}
},
"node_modules/@isaacs/balanced-match": {
"version": "4.0.1",
"resolved": "https://registry.npmjs.org/@isaacs/balanced-match/-/balanced-match-4.0.1.tgz",
"integrity": "sha512-yzMTt9lEb8Gv7zRioUilSglI0c0smZ9k5D65677DLWLtWJaXIS3CqcGyUFByYKlnUj6TkjLVs54fBl6+TiGQDQ==",
"license": "MIT",
"engines": {
"node": "20 || >=22"
}
},
"node_modules/@isaacs/brace-expansion": {
"version": "5.0.1",
"resolved": "https://registry.npmjs.org/@isaacs/brace-expansion/-/brace-expansion-5.0.1.tgz",
"integrity": "sha512-WMz71T1JS624nWj2n2fnYAuPovhv7EUhk69R6i9dsVyzxt5eM3bjwvgk9L+APE1TRscGysAVMANkB0jh0LQZrQ==",
"license": "MIT",
"dependencies": {
"@isaacs/balanced-match": "^4.0.1"
},
"engines": {
"node": "20 || >=22"
}
},
"node_modules/@jridgewell/gen-mapping": {
"version": "0.3.13",
"resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.13.tgz",
@@ -1659,9 +1680,9 @@
}
},
"node_modules/@modelcontextprotocol/sdk/node_modules/ajv": {
"version": "8.18.0",
"resolved": "https://registry.npmjs.org/ajv/-/ajv-8.18.0.tgz",
"integrity": "sha512-PlXPeEWMXMZ7sPYOHqmDyCJzcfNrUr3fGNKtezX14ykXOEIvyK81d+qydx89KY5O71FKMPaQ2vBfBFI5NHR63A==",
"version": "8.17.1",
"resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz",
"integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==",
"license": "MIT",
"dependencies": {
"fast-deep-equal": "^3.1.3",
@@ -3834,27 +3855,6 @@
"path-browserify": "^1.0.1"
}
},
"node_modules/@ts-morph/common/node_modules/balanced-match": {
"version": "4.0.4",
"resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-4.0.4.tgz",
"integrity": "sha512-BLrgEcRTwX2o6gGxGOCNyMvGSp35YofuYzw9h1IMTRmKqttAZZVU67bdb9Pr2vUHA8+j3i2tJfjO6C6+4myGTA==",
"license": "MIT",
"engines": {
"node": "18 || 20 || >=22"
}
},
"node_modules/@ts-morph/common/node_modules/brace-expansion": {
"version": "5.0.3",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-5.0.3.tgz",
"integrity": "sha512-fy6KJm2RawA5RcHkLa1z/ScpBeA762UF9KmZQxwIbDtRJrgLzM10depAiEQ+CXYcoiqW1/m96OAAoke2nE9EeA==",
"license": "MIT",
"dependencies": {
"balanced-match": "^4.0.2"
},
"engines": {
"node": "18 || 20 || >=22"
}
},
"node_modules/@ts-morph/common/node_modules/fast-glob": {
"version": "3.3.3",
"resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.3.3.tgz",
@@ -3884,15 +3884,15 @@
}
},
"node_modules/@ts-morph/common/node_modules/minimatch": {
"version": "10.2.4",
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-10.2.4.tgz",
"integrity": "sha512-oRjTw/97aTBN0RHbYCdtF1MQfvusSIBQM0IZEgzl6426+8jSC0nF1a/GmnVLpfB9yyr6g6FTqWqiZVbxrtaCIg==",
"version": "10.1.1",
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-10.1.1.tgz",
"integrity": "sha512-enIvLvRAFZYXJzkCYG5RKmPfrFArdLv+R+lbQ53BmIMLIry74bjKzX6iHAm8WYamJkhSSEabrWN5D97XnKObjQ==",
"license": "BlueOak-1.0.0",
"dependencies": {
"brace-expansion": "^5.0.2"
"@isaacs/brace-expansion": "^5.0.0"
},
"engines": {
"node": "18 || 20 || >=22"
"node": "20 || >=22"
},
"funding": {
"url": "https://github.com/sponsors/isaacs"
@@ -4234,13 +4234,13 @@
}
},
"node_modules/@typescript-eslint/typescript-estree/node_modules/minimatch": {
"version": "9.0.9",
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.9.tgz",
"integrity": "sha512-OBwBN9AL4dqmETlpS2zasx+vTeWclWzkblfZk7KTA5j3jeOONz/tRCnZomUyvNg83wL5Zv9Ss6HMJXAgL8R2Yg==",
"version": "9.0.5",
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz",
"integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==",
"dev": true,
"license": "ISC",
"dependencies": {
"brace-expansion": "^2.0.2"
"brace-expansion": "^2.0.1"
},
"engines": {
"node": ">=16 || 14 >=14.17"
@@ -4619,9 +4619,9 @@
}
},
"node_modules/ajv": {
"version": "6.14.0",
"resolved": "https://registry.npmjs.org/ajv/-/ajv-6.14.0.tgz",
"integrity": "sha512-IWrosm/yrn43eiKqkfkHis7QioDleaXQHdDVPKg0FSwwd/DuvyX79TZnFOnYpB7dcsFAMmtFztZuXPDvSePkFw==",
"version": "6.12.6",
"resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz",
"integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==",
"dev": true,
"license": "MIT",
"dependencies": {
@@ -4653,9 +4653,9 @@
}
},
"node_modules/ajv-formats/node_modules/ajv": {
"version": "8.18.0",
"resolved": "https://registry.npmjs.org/ajv/-/ajv-8.18.0.tgz",
"integrity": "sha512-PlXPeEWMXMZ7sPYOHqmDyCJzcfNrUr3fGNKtezX14ykXOEIvyK81d+qydx89KY5O71FKMPaQ2vBfBFI5NHR63A==",
"version": "8.17.1",
"resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz",
"integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==",
"license": "MIT",
"dependencies": {
"fast-deep-equal": "^3.1.3",
@@ -6758,12 +6758,12 @@
}
},
"node_modules/express-rate-limit": {
"version": "8.3.0",
"resolved": "https://registry.npmjs.org/express-rate-limit/-/express-rate-limit-8.3.0.tgz",
"integrity": "sha512-KJzBawY6fB9FiZGdE/0aftepZ91YlaGIrV8vgblRM3J8X+dHx/aiowJWwkx6LIGyuqGiANsjSwwrbb8mifOJ4Q==",
"version": "8.2.1",
"resolved": "https://registry.npmjs.org/express-rate-limit/-/express-rate-limit-8.2.1.tgz",
"integrity": "sha512-PCZEIEIxqwhzw4KF0n7QF4QqruVTcF73O5kFKUnGOyjbCCgizBBiFaYpd/fnBLUMPw/BWw9OsiN7GgrNYr7j6g==",
"license": "MIT",
"dependencies": {
"ip-address": "10.1.0"
"ip-address": "10.0.1"
},
"engines": {
"node": ">= 16"
@@ -7424,9 +7424,9 @@
}
},
"node_modules/hono": {
"version": "4.12.5",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.5.tgz",
"integrity": "sha512-3qq+FUBtlTHhtYxbxheZgY8NIFnkkC/MR8u5TTsr7YZ3wixryQ3cCwn3iZbg8p8B88iDBBAYSfZDS75t8MN7Vg==",
"version": "4.11.7",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.11.7.tgz",
"integrity": "sha512-l7qMiNee7t82bH3SeyUCt9UF15EVmaBvsppY2zQtrbIhl/yzBTny+YUxsVjSjQ6gaqaeVtZmGocom8TzBlA4Yw==",
"license": "MIT",
"engines": {
"node": ">=16.9.0"
@@ -7556,9 +7556,9 @@
}
},
"node_modules/ip-address": {
"version": "10.1.0",
"resolved": "https://registry.npmjs.org/ip-address/-/ip-address-10.1.0.tgz",
"integrity": "sha512-XXADHxXmvT9+CRxhXg56LJovE+bmWnEWB78LB83VZTprKTmaC5QfruXocxzTZ2Kl0DNwKuBdlIhjL8LeY8Sf8Q==",
"version": "10.0.1",
"resolved": "https://registry.npmjs.org/ip-address/-/ip-address-10.0.1.tgz",
"integrity": "sha512-NWv9YLW4PoW2B7xtzaS3NCot75m6nK7Icdv0o3lfMceJVRfSoQwqD4wEH5rLwoKJwUiZ/rfpiVBhnaF0FK4HoA==",
"license": "MIT",
"engines": {
"node": ">= 12"
@@ -8831,9 +8831,9 @@
}
},
"node_modules/minimatch": {
"version": "3.1.5",
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.5.tgz",
"integrity": "sha512-VgjWUsnnT6n+NUk6eZq77zeFdpW2LWDzP6zFGrCbHXiYNul5Dzqk2HHQ5uFH2DNW5Xbp8+jVzaeNt94ssEEl4w==",
"version": "3.1.2",
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz",
"integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==",
"dev": true,
"license": "ISC",
"dependencies": {
@@ -9699,9 +9699,9 @@
}
},
"node_modules/qs": {
"version": "6.14.2",
"resolved": "https://registry.npmjs.org/qs/-/qs-6.14.2.tgz",
"integrity": "sha512-V/yCWTTF7VJ9hIh18Ugr2zhJMP01MY7c5kh4J870L7imm6/DIzBsNLTXzMwUA3yZ5b/KBqLx8Kp3uRvd7xSe3Q==",
"version": "6.14.1",
"resolved": "https://registry.npmjs.org/qs/-/qs-6.14.1.tgz",
"integrity": "sha512-4EK3+xJl8Ts67nLYNwqw/dsFVnCf+qR7RgXSK9jEEm9unao3njwMDdmsdvoKBKHzxd7tCYz5e5M+SnMjdtXGQQ==",
"license": "BSD-3-Clause",
"dependencies": {
"side-channel": "^1.1.0"

View File

@@ -11,7 +11,6 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.document_set import check_document_sets_are_public
from onyx.db.document_set import delete_document_set as db_delete_document_set
from onyx.db.document_set import fetch_all_document_sets_for_user
from onyx.db.document_set import get_document_set_by_id
from onyx.db.document_set import insert_document_set
@@ -143,10 +142,7 @@ def delete_document_set(
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
if DISABLE_VECTOR_DB:
db_session.refresh(document_set)
db_delete_document_set(document_set, db_session)
else:
if not DISABLE_VECTOR_DB:
client_app.send_task(
OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
kwargs={"tenant_id": tenant_id},

View File

@@ -1,6 +1,5 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
@@ -15,6 +14,8 @@ from onyx.db.llm import remove_llm_provider__no_commit
from onyx.db.models import LLMProvider as LLMProviderModel
from onyx.db.models import ModelConfiguration
from onyx.db.models import User
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.image_gen.exceptions import ImageProviderCredentialsError
from onyx.image_gen.factory import get_image_generation_provider
from onyx.image_gen.factory import validate_credentials
@@ -74,9 +75,9 @@ def _build_llm_provider_request(
# Clone mode: Only use API key from source provider
source_provider = db_session.get(LLMProviderModel, source_llm_provider_id)
if not source_provider:
raise HTTPException(
status_code=404,
detail=f"Source LLM provider with id {source_llm_provider_id} not found",
raise OnyxError(
OnyxErrorCode.NOT_FOUND,
f"Source LLM provider with id {source_llm_provider_id} not found",
)
_validate_llm_provider_change(
@@ -110,9 +111,9 @@ def _build_llm_provider_request(
)
if not provider:
raise HTTPException(
status_code=400,
detail="No provider or source llm provided",
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No provider or source llm provided",
)
credentials = ImageGenerationProviderCredentials(
@@ -124,9 +125,9 @@ def _build_llm_provider_request(
)
if not validate_credentials(provider, credentials):
raise HTTPException(
status_code=400,
detail=f"Incorrect credentials for {provider}",
raise OnyxError(
OnyxErrorCode.CREDENTIAL_INVALID,
f"Incorrect credentials for {provider}",
)
return LLMProviderUpsertRequest(
@@ -215,9 +216,9 @@ def test_image_generation(
LLMProviderModel, test_request.source_llm_provider_id
)
if not source_provider:
raise HTTPException(
status_code=404,
detail=f"Source LLM provider with id {test_request.source_llm_provider_id} not found",
raise OnyxError(
OnyxErrorCode.NOT_FOUND,
f"Source LLM provider with id {test_request.source_llm_provider_id} not found",
)
_validate_llm_provider_change(
@@ -236,9 +237,9 @@ def test_image_generation(
provider = source_provider.provider
if provider is None:
raise HTTPException(
status_code=400,
detail="No provider or source llm provided",
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No provider or source llm provided",
)
try:
@@ -257,14 +258,14 @@ def test_image_generation(
),
)
except ValueError:
raise HTTPException(
status_code=404,
detail=f"Invalid image generation provider: {provider}",
raise OnyxError(
OnyxErrorCode.NOT_FOUND,
f"Invalid image generation provider: {provider}",
)
except ImageProviderCredentialsError:
raise HTTPException(
status_code=401,
detail="Invalid image generation credentials",
raise OnyxError(
OnyxErrorCode.CREDENTIAL_INVALID,
"Invalid image generation credentials",
)
quality = _get_test_quality_for_model(test_request.model_name)
@@ -276,15 +277,15 @@ def test_image_generation(
n=1,
quality=quality,
)
except HTTPException:
except OnyxError:
raise
except Exception as e:
# Log only exception type to avoid exposing sensitive data
# (LiteLLM errors may contain URLs with API keys or auth tokens)
logger.warning(f"Image generation test failed: {type(e).__name__}")
raise HTTPException(
status_code=400,
detail=f"Image generation test failed: {type(e).__name__}",
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
f"Image generation test failed: {type(e).__name__}",
)
@@ -309,9 +310,9 @@ def create_config(
db_session, config_create.image_provider_id
)
if existing_config:
raise HTTPException(
status_code=400,
detail=f"ImageGenerationConfig with image_provider_id '{config_create.image_provider_id}' already exists",
raise OnyxError(
OnyxErrorCode.DUPLICATE_RESOURCE,
f"ImageGenerationConfig with image_provider_id '{config_create.image_provider_id}' already exists",
)
try:
@@ -345,10 +346,10 @@ def create_config(
db_session.commit()
db_session.refresh(config)
return ImageGenerationConfigView.from_model(config)
except HTTPException:
except OnyxError:
raise
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
@admin_router.get("/config")
@@ -373,9 +374,9 @@ def get_config_credentials(
"""
config = get_image_generation_config(db_session, image_provider_id)
if not config:
raise HTTPException(
status_code=404,
detail=f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
raise OnyxError(
OnyxErrorCode.NOT_FOUND,
f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
)
return ImageGenerationCredentials.from_model(config)
@@ -401,9 +402,9 @@ def update_config(
# 1. Get existing config
existing_config = get_image_generation_config(db_session, image_provider_id)
if not existing_config:
raise HTTPException(
status_code=404,
detail=f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
raise OnyxError(
OnyxErrorCode.NOT_FOUND,
f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
)
old_llm_provider_id = existing_config.model_configuration.llm_provider_id
@@ -472,10 +473,10 @@ def update_config(
db_session.refresh(existing_config)
return ImageGenerationConfigView.from_model(existing_config)
except HTTPException:
except OnyxError:
raise
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
@admin_router.delete("/config/{image_provider_id}")
@@ -489,9 +490,9 @@ def delete_config(
# Get the config first to find the associated LLM provider
existing_config = get_image_generation_config(db_session, image_provider_id)
if not existing_config:
raise HTTPException(
status_code=404,
detail=f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
raise OnyxError(
OnyxErrorCode.NOT_FOUND,
f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
)
llm_provider_id = existing_config.model_configuration.llm_provider_id
@@ -503,10 +504,10 @@ def delete_config(
remove_llm_provider__no_commit(db_session, llm_provider_id)
db_session.commit()
except HTTPException:
except OnyxError:
raise
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
raise OnyxError(OnyxErrorCode.NOT_FOUND, str(e))
@admin_router.post("/config/{image_provider_id}/default")
@@ -519,7 +520,7 @@ def set_config_as_default(
try:
set_default_image_generation_config(db_session, image_provider_id)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
raise OnyxError(OnyxErrorCode.NOT_FOUND, str(e))
@admin_router.delete("/config/{image_provider_id}/default")
@@ -532,4 +533,4 @@ def unset_config_as_default(
try:
unset_default_image_generation_config(db_session, image_provider_id)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
raise OnyxError(OnyxErrorCode.NOT_FOUND, str(e))

View File

@@ -48,7 +48,6 @@ from onyx.llm.utils import test_llm
from onyx.llm.well_known_providers.auto_update_service import (
fetch_llm_recommendations_from_github,
)
from onyx.llm.well_known_providers.constants import LM_STUDIO_API_KEY_CONFIG_KEY
from onyx.llm.well_known_providers.llm_provider_options import (
fetch_available_well_known_llms,
)
@@ -58,30 +57,22 @@ from onyx.llm.well_known_providers.llm_provider_options import (
from onyx.server.manage.llm.models import BedrockFinalModelResponse
from onyx.server.manage.llm.models import BedrockModelsRequest
from onyx.server.manage.llm.models import DefaultModel
from onyx.server.manage.llm.models import LitellmFinalModelResponse
from onyx.server.manage.llm.models import LitellmModelDetails
from onyx.server.manage.llm.models import LitellmModelsRequest
from onyx.server.manage.llm.models import LLMCost
from onyx.server.manage.llm.models import LLMProviderDescriptor
from onyx.server.manage.llm.models import LLMProviderResponse
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import LMStudioFinalModelResponse
from onyx.server.manage.llm.models import LMStudioModelsRequest
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from onyx.server.manage.llm.models import OllamaFinalModelResponse
from onyx.server.manage.llm.models import OllamaModelDetails
from onyx.server.manage.llm.models import OllamaModelsRequest
from onyx.server.manage.llm.models import OpenRouterFinalModelResponse
from onyx.server.manage.llm.models import OpenRouterModelDetails
from onyx.server.manage.llm.models import OpenRouterModelsRequest
from onyx.server.manage.llm.models import SyncModelEntry
from onyx.server.manage.llm.models import TestLLMRequest
from onyx.server.manage.llm.models import VisionProviderResponse
from onyx.server.manage.llm.utils import generate_bedrock_display_name
from onyx.server.manage.llm.utils import generate_ollama_display_name
from onyx.server.manage.llm.utils import infer_vision_support
from onyx.server.manage.llm.utils import is_reasoning_model
from onyx.server.manage.llm.utils import is_valid_bedrock_model
from onyx.server.manage.llm.utils import ModelMetadata
from onyx.server.manage.llm.utils import strip_openrouter_vendor_prefix
@@ -102,34 +93,6 @@ def _mask_string(value: str) -> str:
return value[:4] + "****" + value[-4:]
def _sync_fetched_models(
db_session: Session,
provider_name: str,
models: list[SyncModelEntry],
source_label: str,
) -> None:
"""Sync fetched models to DB for the given provider.
Args:
db_session: Database session
provider_name: Name of the LLM provider
models: List of SyncModelEntry objects describing the fetched models
source_label: Human-readable label for log messages (e.g. "Bedrock", "LiteLLM")
"""
try:
new_count = sync_model_configurations(
db_session=db_session,
provider_name=provider_name,
models=models,
)
if new_count > 0:
logger.info(
f"Added {new_count} new {source_label} models to provider '{provider_name}'"
)
except ValueError as e:
logger.warning(f"Failed to sync {source_label} models to DB: {e}")
# Keys in custom_config that contain sensitive credentials
_SENSITIVE_CONFIG_KEYS = {
"vertex_credentials",
@@ -478,18 +441,6 @@ def put_llm_provider(
not existing_provider or not existing_provider.is_auto_mode
)
# When transitioning to auto mode, preserve existing model configurations
# so the upsert doesn't try to delete them (which would trip the default
# model protection guard). sync_auto_mode_models will handle the model
# lifecycle afterward — adding new models, hiding removed ones, and
# updating the default. This is safe even if sync fails: the provider
# keeps its old models and default rather than losing them.
if transitioning_to_auto_mode and existing_provider:
llm_provider_upsert_request.model_configurations = [
ModelConfigurationUpsertRequest.from_model(mc)
for mc in existing_provider.model_configurations
]
try:
result = upsert_llm_provider(
llm_provider_upsert_request=llm_provider_upsert_request,
@@ -502,6 +453,7 @@ def put_llm_provider(
config = fetch_llm_recommendations_from_github()
if config and llm_provider_upsert_request.provider in config.providers:
# Refetch the provider to get the updated model
updated_provider = fetch_existing_llm_provider_by_id(
id=result.id, db_session=db_session
)
@@ -995,20 +947,27 @@ def get_bedrock_available_models(
# Sync new models to DB if provider_name is specified
if request.provider_name:
_sync_fetched_models(
db_session=db_session,
provider_name=request.provider_name,
models=[
SyncModelEntry(
name=r.name,
display_name=r.display_name,
max_input_tokens=r.max_input_tokens,
supports_image_input=r.supports_image_input,
)
try:
models_to_sync = [
{
"name": r.name,
"display_name": r.display_name,
"max_input_tokens": r.max_input_tokens,
"supports_image_input": r.supports_image_input,
}
for r in results
],
source_label="Bedrock",
)
]
new_count = sync_model_configurations(
db_session=db_session,
provider_name=request.provider_name,
models=models_to_sync,
)
if new_count > 0:
logger.info(
f"Added {new_count} new Bedrock models to provider '{request.provider_name}'"
)
except ValueError as e:
logger.warning(f"Failed to sync Bedrock models to DB: {e}")
return results
@@ -1126,20 +1085,27 @@ def get_ollama_available_models(
# Sync new models to DB if provider_name is specified
if request.provider_name:
_sync_fetched_models(
db_session=db_session,
provider_name=request.provider_name,
models=[
SyncModelEntry(
name=r.name,
display_name=r.display_name,
max_input_tokens=r.max_input_tokens,
supports_image_input=r.supports_image_input,
)
try:
models_to_sync = [
{
"name": r.name,
"display_name": r.display_name,
"max_input_tokens": r.max_input_tokens,
"supports_image_input": r.supports_image_input,
}
for r in sorted_results
],
source_label="Ollama",
)
]
new_count = sync_model_configurations(
db_session=db_session,
provider_name=request.provider_name,
models=models_to_sync,
)
if new_count > 0:
logger.info(
f"Added {new_count} new Ollama models to provider '{request.provider_name}'"
)
except ValueError as e:
logger.warning(f"Failed to sync Ollama models to DB: {e}")
return sorted_results
@@ -1228,226 +1194,26 @@ def get_openrouter_available_models(
# Sync new models to DB if provider_name is specified
if request.provider_name:
_sync_fetched_models(
db_session=db_session,
provider_name=request.provider_name,
models=[
SyncModelEntry(
name=r.name,
display_name=r.display_name,
max_input_tokens=r.max_input_tokens,
supports_image_input=r.supports_image_input,
)
for r in sorted_results
],
source_label="OpenRouter",
)
return sorted_results
@admin_router.post("/lm-studio/available-models")
def get_lm_studio_available_models(
request: LMStudioModelsRequest,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[LMStudioFinalModelResponse]:
"""Fetch available models from an LM Studio server.
Uses the LM Studio-native /api/v1/models endpoint which exposes
rich metadata including capabilities (vision, reasoning),
display names, and context lengths.
"""
cleaned_api_base = request.api_base.strip().rstrip("/")
# Strip /v1 suffix that users may copy from OpenAI-compatible tool configs;
# the native metadata endpoint lives at /api/v1/models, not /v1/api/v1/models.
cleaned_api_base = cleaned_api_base.removesuffix("/v1")
if not cleaned_api_base:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"API base URL is required to fetch LM Studio models.",
)
# If provider_name is given and the api_key hasn't been changed by the user,
# fall back to the stored API key from the database (the form value is masked).
api_key = request.api_key
if request.provider_name and not request.api_key_changed:
existing_provider = fetch_existing_llm_provider(
name=request.provider_name, db_session=db_session
)
if existing_provider and existing_provider.custom_config:
api_key = existing_provider.custom_config.get(LM_STUDIO_API_KEY_CONFIG_KEY)
url = f"{cleaned_api_base}/api/v1/models"
headers: dict[str, str] = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
try:
response = httpx.get(url, headers=headers, timeout=10.0)
response.raise_for_status()
response_json = response.json()
except Exception as e:
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
f"Failed to fetch LM Studio models: {e}",
)
models = response_json.get("models", [])
if not isinstance(models, list) or len(models) == 0:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No models found from your LM Studio server.",
)
results: list[LMStudioFinalModelResponse] = []
for item in models:
# Filter to LLM-type models only (skip embeddings, etc.)
if item.get("type") != "llm":
continue
model_key = item.get("key")
if not model_key:
continue
display_name = item.get("display_name") or model_key
max_context_length = item.get("max_context_length")
capabilities = item.get("capabilities") or {}
results.append(
LMStudioFinalModelResponse(
name=model_key,
display_name=display_name,
max_input_tokens=max_context_length,
supports_image_input=capabilities.get("vision", False),
supports_reasoning=capabilities.get("reasoning", False)
or is_reasoning_model(model_key, display_name),
)
)
if not results:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No compatible models found from LM Studio server.",
)
sorted_results = sorted(results, key=lambda m: m.name.lower())
# Sync new models to DB if provider_name is specified
if request.provider_name:
_sync_fetched_models(
db_session=db_session,
provider_name=request.provider_name,
models=[
SyncModelEntry(
name=r.name,
display_name=r.display_name,
max_input_tokens=r.max_input_tokens,
supports_image_input=r.supports_image_input,
)
for r in sorted_results
],
source_label="LM Studio",
)
return sorted_results
@admin_router.post("/litellm/available-models")
def get_litellm_available_models(
request: LitellmModelsRequest,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[LitellmFinalModelResponse]:
"""Fetch available models from Litellm proxy /v1/models endpoint."""
response_json = _get_litellm_models_response(
api_key=request.api_key, api_base=request.api_base
)
models = response_json.get("data", [])
if not isinstance(models, list) or len(models) == 0:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No models found from your Litellm endpoint",
)
results: list[LitellmFinalModelResponse] = []
for model in models:
try:
model_details = LitellmModelDetails.model_validate(model)
results.append(
LitellmFinalModelResponse(
provider_name=model_details.owned_by,
model_name=model_details.id,
)
)
except Exception as e:
logger.warning(
"Failed to parse Litellm model entry",
extra={"error": str(e), "item": str(model)[:1000]},
)
if not results:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No compatible models found from Litellm",
)
sorted_results = sorted(results, key=lambda m: m.model_name.lower())
# Sync new models to DB if provider_name is specified
if request.provider_name:
_sync_fetched_models(
db_session=db_session,
provider_name=request.provider_name,
models=[
SyncModelEntry(
name=r.model_name,
display_name=r.model_name,
)
models_to_sync = [
{
"name": r.name,
"display_name": r.display_name,
"max_input_tokens": r.max_input_tokens,
"supports_image_input": r.supports_image_input,
}
for r in sorted_results
],
source_label="LiteLLM",
)
]
new_count = sync_model_configurations(
db_session=db_session,
provider_name=request.provider_name,
models=models_to_sync,
)
if new_count > 0:
logger.info(
f"Added {new_count} new OpenRouter models to provider '{request.provider_name}'"
)
except ValueError as e:
logger.warning(f"Failed to sync OpenRouter models to DB: {e}")
return sorted_results
def _get_litellm_models_response(api_key: str, api_base: str) -> dict:
"""Perform GET to Litellm proxy /api/v1/models and return parsed JSON."""
cleaned_api_base = api_base.strip().rstrip("/")
url = f"{cleaned_api_base}/v1/models"
headers = {
"Authorization": f"Bearer {api_key}",
"HTTP-Referer": "https://onyx.app",
"X-Title": "Onyx",
}
try:
response = httpx.get(url, headers=headers, timeout=10.0)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Authentication failed: invalid or missing API key for LiteLLM proxy.",
)
elif e.response.status_code == 404:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
f"LiteLLM models endpoint not found at {url}. "
"Please verify the API base URL.",
)
else:
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
f"Failed to fetch LiteLLM models: {e}",
)
except Exception as e:
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
f"Failed to fetch LiteLLM models: {e}",
)

View File

@@ -371,22 +371,6 @@ class OpenRouterFinalModelResponse(BaseModel):
supports_image_input: bool
# LM Studio dynamic models fetch
class LMStudioModelsRequest(BaseModel):
api_base: str
api_key: str | None = None
api_key_changed: bool = False
provider_name: str | None = None # Optional: to save models to existing provider
class LMStudioFinalModelResponse(BaseModel):
name: str # Model ID from LM Studio (e.g., "lmstudio-community/Meta-Llama-3-8B")
display_name: str # Human-readable name
max_input_tokens: int | None # From LM Studio API or None if unavailable
supports_image_input: bool
supports_reasoning: bool
class DefaultModel(BaseModel):
provider_id: int
model_name: str
@@ -420,32 +404,3 @@ class LLMProviderResponse(BaseModel, Generic[T]):
default_text=default_text,
default_vision=default_vision,
)
class SyncModelEntry(BaseModel):
"""Typed model for syncing fetched models to the DB."""
name: str
display_name: str
max_input_tokens: int | None = None
supports_image_input: bool = False
class LitellmModelsRequest(BaseModel):
api_key: str
api_base: str
provider_name: str | None = None # Optional: to save models to existing provider
class LitellmModelDetails(BaseModel):
"""Response model for Litellm proxy /api/v1/models endpoint"""
id: str # Model ID (e.g. "gpt-4o")
object: str # "model"
created: int # Unix timestamp in seconds
owned_by: str # Provider name (e.g. "openai")
class LitellmFinalModelResponse(BaseModel):
provider_name: str # Provider name (e.g. "openai")
model_name: str # Model ID (e.g. "gpt-4o")

View File

@@ -12,7 +12,6 @@ from typing import TypedDict
from onyx.llm.constants import BEDROCK_MODEL_NAME_MAPPINGS
from onyx.llm.constants import LlmProviderNames
from onyx.llm.constants import MODEL_PREFIX_TO_VENDOR
from onyx.llm.constants import OLLAMA_MODEL_NAME_MAPPINGS
from onyx.llm.constants import OLLAMA_MODEL_TO_VENDOR
from onyx.llm.constants import PROVIDER_DISPLAY_NAMES
@@ -24,7 +23,6 @@ DYNAMIC_LLM_PROVIDERS = frozenset(
LlmProviderNames.OPENROUTER,
LlmProviderNames.BEDROCK,
LlmProviderNames.OLLAMA_CHAT,
LlmProviderNames.LM_STUDIO,
}
)
@@ -350,19 +348,4 @@ def extract_vendor_from_model_name(model_name: str, provider: str) -> str | None
# Fallback: capitalize the base name as vendor
return base_name.split("-")[0].title()
elif provider == LlmProviderNames.LM_STUDIO:
# LM Studio model IDs can be paths like "publisher/model-name"
# or simple names. Use MODEL_PREFIX_TO_VENDOR for matching.
model_lower = model_name.lower()
# Check for slash-separated vendor prefix first
if "/" in model_lower:
vendor_key = model_lower.split("/")[0]
return PROVIDER_DISPLAY_NAMES.get(vendor_key, vendor_key.title())
# Fallback to model prefix matching
for prefix, vendor in MODEL_PREFIX_TO_VENDOR.items():
if model_lower.startswith(prefix):
return PROVIDER_DISPLAY_NAMES.get(vendor, vendor.title())
return None
return None

View File

@@ -1,16 +1,11 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import status
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from onyx.context.search.models import SavedSearchSettings
from onyx.context.search.models import SearchSettingsCreationRequest
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.connector_credential_pair import resync_cc_pair
from onyx.db.engine.sql_engine import get_session
from onyx.db.index_attempt import expire_index_attempts
from onyx.db.llm import fetch_existing_llm_provider
@@ -18,25 +13,22 @@ from onyx.db.llm import update_default_contextual_model
from onyx.db.llm import update_no_default_contextual_rag_provider
from onyx.db.models import IndexModelStatus
from onyx.db.models import User
from onyx.db.search_settings import create_search_settings
from onyx.db.search_settings import delete_search_settings
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_embedding_provider_from_provider_type
from onyx.db.search_settings import get_secondary_search_settings
from onyx.db.search_settings import update_current_search_settings
from onyx.db.search_settings import update_search_settings_status
from onyx.document_index.factory import get_all_document_indices
from onyx.document_index.factory import get_default_document_index
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.file_processing.unstructured import delete_unstructured_api_key
from onyx.file_processing.unstructured import get_unstructured_api_key
from onyx.file_processing.unstructured import update_unstructured_api_key
from onyx.natural_language_processing.search_nlp_models import clean_model_name
from onyx.server.manage.embedding.models import SearchSettingsDeleteRequest
from onyx.server.manage.models import FullModelVersionResponse
from onyx.server.models import IdReturn
from onyx.server.utils_vector_db import require_vector_db
from onyx.utils.logger import setup_logger
from shared_configs.configs import ALT_INDEX_SUFFIX
from shared_configs.configs import MULTI_TENANT
router = APIRouter(prefix="/search-settings")
@@ -49,99 +41,110 @@ def set_new_search_settings(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session), # noqa: ARG001
) -> IdReturn:
"""Creates a new EmbeddingModel row and cancels the previous secondary indexing if any
Gives an error if the same model name is used as the current or secondary index
"""
Creates a new SearchSettings row and cancels the previous secondary indexing
if any exists.
"""
if search_settings_new.index_name:
logger.warning("Index name was specified by request, this is not suggested")
# Disallow contextual RAG for cloud deployments.
if MULTI_TENANT and search_settings_new.enable_contextual_rag:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Contextual RAG disabled in Onyx Cloud",
)
# Validate cloud provider exists or create new LiteLLM provider.
if search_settings_new.provider_type is not None:
cloud_provider = get_embedding_provider_from_provider_type(
db_session, provider_type=search_settings_new.provider_type
)
if cloud_provider is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"No embedding provider exists for cloud embedding type {search_settings_new.provider_type}",
)
validate_contextual_rag_model(
provider_name=search_settings_new.contextual_rag_llm_provider,
model_name=search_settings_new.contextual_rag_llm_name,
db_session=db_session,
# TODO(andrei): Re-enable.
# NOTE Enable integration external dependency tests in test_search_settings.py
# when this is reenabled. They are currently skipped
logger.error("Setting new search settings is temporarily disabled.")
raise OnyxError(
OnyxErrorCode.NOT_IMPLEMENTED,
"Setting new search settings is temporarily disabled.",
)
# if search_settings_new.index_name:
# logger.warning("Index name was specified by request, this is not suggested")
search_settings = get_current_search_settings(db_session)
# # Disallow contextual RAG for cloud deployments
# if MULTI_TENANT and search_settings_new.enable_contextual_rag:
# raise HTTPException(
# status_code=status.HTTP_400_BAD_REQUEST,
# detail="Contextual RAG disabled in Onyx Cloud",
# )
if search_settings_new.index_name is None:
# We define index name here.
index_name = f"danswer_chunk_{clean_model_name(search_settings_new.model_name)}"
if (
search_settings_new.model_name == search_settings.model_name
and not search_settings.index_name.endswith(ALT_INDEX_SUFFIX)
):
index_name += ALT_INDEX_SUFFIX
search_values = search_settings_new.model_dump()
search_values["index_name"] = index_name
new_search_settings_request = SavedSearchSettings(**search_values)
else:
new_search_settings_request = SavedSearchSettings(
**search_settings_new.model_dump()
)
# # Validate cloud provider exists or create new LiteLLM provider
# if search_settings_new.provider_type is not None:
# cloud_provider = get_embedding_provider_from_provider_type(
# db_session, provider_type=search_settings_new.provider_type
# )
secondary_search_settings = get_secondary_search_settings(db_session)
# if cloud_provider is None:
# raise HTTPException(
# status_code=status.HTTP_400_BAD_REQUEST,
# detail=f"No embedding provider exists for cloud embedding type {search_settings_new.provider_type}",
# )
if secondary_search_settings:
# Cancel any background indexing jobs.
expire_index_attempts(
search_settings_id=secondary_search_settings.id, db_session=db_session
)
# validate_contextual_rag_model(
# provider_name=search_settings_new.contextual_rag_llm_provider,
# model_name=search_settings_new.contextual_rag_llm_name,
# db_session=db_session,
# )
# Mark previous model as a past model directly.
update_search_settings_status(
search_settings=secondary_search_settings,
new_status=IndexModelStatus.PAST,
db_session=db_session,
)
# search_settings = get_current_search_settings(db_session)
new_search_settings = create_search_settings(
search_settings=new_search_settings_request, db_session=db_session
)
# if search_settings_new.index_name is None:
# # We define index name here
# index_name = f"danswer_chunk_{clean_model_name(search_settings_new.model_name)}"
# if (
# search_settings_new.model_name == search_settings.model_name
# and not search_settings.index_name.endswith(ALT_INDEX_SUFFIX)
# ):
# index_name += ALT_INDEX_SUFFIX
# search_values = search_settings_new.model_dump()
# search_values["index_name"] = index_name
# new_search_settings_request = SavedSearchSettings(**search_values)
# else:
# new_search_settings_request = SavedSearchSettings(
# **search_settings_new.model_dump()
# )
# Ensure the document indices have the new index immediately.
document_indices = get_all_document_indices(search_settings, new_search_settings)
for document_index in document_indices:
document_index.ensure_indices_exist(
primary_embedding_dim=search_settings.final_embedding_dim,
primary_embedding_precision=search_settings.embedding_precision,
secondary_index_embedding_dim=new_search_settings.final_embedding_dim,
secondary_index_embedding_precision=new_search_settings.embedding_precision,
)
# secondary_search_settings = get_secondary_search_settings(db_session)
# Pause index attempts for the currently in-use index to preserve resources.
if DISABLE_INDEX_UPDATE_ON_SWAP:
expire_index_attempts(
search_settings_id=search_settings.id, db_session=db_session
)
for cc_pair in get_connector_credential_pairs(db_session):
resync_cc_pair(
cc_pair=cc_pair,
search_settings_id=new_search_settings.id,
db_session=db_session,
)
# if secondary_search_settings:
# # Cancel any background indexing jobs
# expire_index_attempts(
# search_settings_id=secondary_search_settings.id, db_session=db_session
# )
db_session.commit()
return IdReturn(id=new_search_settings.id)
# # Mark previous model as a past model directly
# update_search_settings_status(
# search_settings=secondary_search_settings,
# new_status=IndexModelStatus.PAST,
# db_session=db_session,
# )
# new_search_settings = create_search_settings(
# search_settings=new_search_settings_request, db_session=db_session
# )
# # Ensure Vespa has the new index immediately
# get_multipass_config(search_settings)
# get_multipass_config(new_search_settings)
# document_index = get_default_document_index(
# search_settings, new_search_settings, db_session
# )
# document_index.ensure_indices_exist(
# primary_embedding_dim=search_settings.final_embedding_dim,
# primary_embedding_precision=search_settings.embedding_precision,
# secondary_index_embedding_dim=new_search_settings.final_embedding_dim,
# secondary_index_embedding_precision=new_search_settings.embedding_precision,
# )
# # Pause index attempts for the currently in use index to preserve resources
# if DISABLE_INDEX_UPDATE_ON_SWAP:
# expire_index_attempts(
# search_settings_id=search_settings.id, db_session=db_session
# )
# for cc_pair in get_connector_credential_pairs(db_session):
# resync_cc_pair(
# cc_pair=cc_pair,
# search_settings_id=new_search_settings.id,
# db_session=db_session,
# )
# db_session.commit()
# return IdReturn(id=new_search_settings.id)
@router.post("/cancel-new-embedding", dependencies=[Depends(require_vector_db)])
@@ -188,7 +191,7 @@ def delete_search_settings_endpoint(
search_settings_id=deletion_request.search_settings_id,
)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
@router.get("/get-current-search-settings")
@@ -238,9 +241,9 @@ def update_saved_search_settings(
) -> None:
# Disallow contextual RAG for cloud deployments
if MULTI_TENANT and search_settings.enable_contextual_rag:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Contextual RAG disabled in Onyx Cloud",
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Contextual RAG disabled in Onyx Cloud",
)
validate_contextual_rag_model(
@@ -294,7 +297,7 @@ def validate_contextual_rag_model(
model_name=model_name,
db_session=db_session,
):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=error_msg)
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, error_msg)
def _validate_contextual_rag_model(

View File

@@ -1,5 +1,6 @@
import datetime
import json
import os
from collections.abc import Generator
from datetime import timedelta
from uuid import UUID
@@ -60,6 +61,7 @@ from onyx.db.persona import get_persona_by_id
from onyx.db.usage import increment_usage
from onyx.db.usage import UsageType
from onyx.db.user_file import get_file_id_by_user_file_id
from onyx.file_processing.extract_file_text import docx_to_txt_filename
from onyx.file_store.file_store import get_default_file_store
from onyx.llm.constants import LlmProviderNames
from onyx.llm.factory import get_default_llm
@@ -810,6 +812,18 @@ def fetch_chat_file(
if not file_record:
raise HTTPException(status_code=404, detail="File not found")
original_file_name = file_record.display_name
if file_record.file_type.startswith(
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
# Check if a converted text file exists for .docx files
txt_file_name = docx_to_txt_filename(original_file_name)
txt_file_id = os.path.join(os.path.dirname(file_id), txt_file_name)
txt_file_record = file_store.read_file_record(txt_file_id)
if txt_file_record:
file_record = txt_file_record
file_id = txt_file_id
media_type = file_record.file_type
file_io = file_store.read_file(file_id, mode="b")

View File

@@ -1,11 +1,9 @@
from __future__ import annotations
import json
from typing import Any
from typing import cast
from typing import Literal
from pydantic import ValidationError
from sqlalchemy.orm import Session
from onyx.chat.citation_utils import extract_citation_order_from_text
@@ -22,9 +20,7 @@ from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import CustomToolArgs
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
from onyx.server.query_and_chat.streaming_models import CustomToolErrorInfo
from onyx.server.query_and_chat.streaming_models import CustomToolStart
from onyx.server.query_and_chat.streaming_models import FileReaderResult
from onyx.server.query_and_chat.streaming_models import FileReaderStart
@@ -184,37 +180,24 @@ def create_custom_tool_packets(
tab_index: int = 0,
data: dict | list | str | int | float | bool | None = None,
file_ids: list[str] | None = None,
error: CustomToolErrorInfo | None = None,
tool_args: dict[str, Any] | None = None,
tool_id: int | None = None,
) -> list[Packet]:
packets: list[Packet] = []
packets.append(
Packet(
placement=Placement(turn_index=turn_index, tab_index=tab_index),
obj=CustomToolStart(tool_name=tool_name, tool_id=tool_id),
obj=CustomToolStart(tool_name=tool_name),
)
)
if tool_args:
packets.append(
Packet(
placement=Placement(turn_index=turn_index, tab_index=tab_index),
obj=CustomToolArgs(tool_name=tool_name, tool_args=tool_args),
)
)
packets.append(
Packet(
placement=Placement(turn_index=turn_index, tab_index=tab_index),
obj=CustomToolDelta(
tool_name=tool_name,
tool_id=tool_id,
response_type=response_type,
data=data,
file_ids=file_ids,
error=error,
),
),
)
@@ -674,55 +657,13 @@ def translate_assistant_message_to_packets(
else:
# Custom tool or unknown tool
# Try to parse as structured CustomToolCallSummary JSON
custom_data: dict | list | str | int | float | bool | None = (
tool_call.tool_call_response
)
custom_error: CustomToolErrorInfo | None = None
custom_response_type = "text"
try:
parsed = json.loads(tool_call.tool_call_response)
if isinstance(parsed, dict) and "tool_name" in parsed:
custom_data = parsed.get("tool_result")
custom_response_type = parsed.get(
"response_type", "text"
)
if parsed.get("error"):
custom_error = CustomToolErrorInfo(
**parsed["error"]
)
except (
json.JSONDecodeError,
KeyError,
TypeError,
ValidationError,
):
pass
custom_file_ids: list[str] | None = None
if custom_response_type in ("image", "csv") and isinstance(
custom_data, dict
):
custom_file_ids = custom_data.get("file_ids")
custom_data = None
custom_args = {
k: v
for k, v in (tool_call.tool_call_arguments or {}).items()
if k != "requestBody"
}
turn_tool_packets.extend(
create_custom_tool_packets(
tool_name=tool.display_name or tool.name,
response_type=custom_response_type,
response_type="text",
turn_index=turn_num,
tab_index=tool_call.tab_index,
data=custom_data,
file_ids=custom_file_ids,
error=custom_error,
tool_args=custom_args if custom_args else None,
tool_id=tool_call.tool_id,
data=tool_call.tool_call_response,
)
)

View File

@@ -33,7 +33,6 @@ class StreamingType(Enum):
PYTHON_TOOL_START = "python_tool_start"
PYTHON_TOOL_DELTA = "python_tool_delta"
CUSTOM_TOOL_START = "custom_tool_start"
CUSTOM_TOOL_ARGS = "custom_tool_args"
CUSTOM_TOOL_DELTA = "custom_tool_delta"
FILE_READER_START = "file_reader_start"
FILE_READER_RESULT = "file_reader_result"
@@ -246,20 +245,6 @@ class CustomToolStart(BaseObj):
type: Literal["custom_tool_start"] = StreamingType.CUSTOM_TOOL_START.value
tool_name: str
tool_id: int | None = None
class CustomToolArgs(BaseObj):
type: Literal["custom_tool_args"] = StreamingType.CUSTOM_TOOL_ARGS.value
tool_name: str
tool_args: dict[str, Any]
class CustomToolErrorInfo(BaseModel):
is_auth_error: bool = False
status_code: int
message: str
# The allowed streamed packets for a custom tool
@@ -267,13 +252,11 @@ class CustomToolDelta(BaseObj):
type: Literal["custom_tool_delta"] = StreamingType.CUSTOM_TOOL_DELTA.value
tool_name: str
tool_id: int | None = None
response_type: str
# For non-file responses
data: dict | list | str | int | float | bool | None = None
# For file-based responses like image/csv
file_ids: list[str] | None = None
error: CustomToolErrorInfo | None = None
################################################
@@ -383,7 +366,6 @@ PacketObj = Union[
PythonToolStart,
PythonToolDelta,
CustomToolStart,
CustomToolArgs,
CustomToolDelta,
FileReaderStart,
FileReaderResult,

View File

@@ -8,6 +8,8 @@ from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
from onyx.server.query_and_chat.streaming_models import CustomToolStart
from onyx.server.query_and_chat.streaming_models import GeneratedImage
from onyx.server.query_and_chat.streaming_models import ImageGenerationFinal
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
@@ -163,6 +165,39 @@ def create_image_generation_packets(
return packets
def create_custom_tool_packets(
tool_name: str,
response_type: str,
turn_index: int,
data: dict | list | str | int | float | bool | None = None,
file_ids: list[str] | None = None,
) -> list[Packet]:
packets: list[Packet] = []
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
obj=CustomToolStart(tool_name=tool_name),
)
)
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
obj=CustomToolDelta(
tool_name=tool_name,
response_type=response_type,
data=data,
file_ids=file_ids,
),
),
)
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
return packets
def create_fetch_packets(
fetch_docs: list[SavedSearchDoc],
urls: list[str],

View File

@@ -60,11 +60,9 @@ class Settings(BaseModel):
deep_research_enabled: bool | None = None
search_ui_enabled: bool | None = None
# Whether EE features are unlocked for use.
# Depends on license status: True when the user has a valid license
# (ACTIVE, GRACE_PERIOD, PAYMENT_REMINDER), False when there's no license
# or the license is expired (GATED_ACCESS).
# This controls UI visibility of EE features (user groups, analytics, RBAC, etc.).
# Enterprise features flag - set by license enforcement at runtime
# When LICENSE_ENFORCEMENT_ENABLED=true, this reflects license status
# When LICENSE_ENFORCEMENT_ENABLED=false, defaults to False
ee_features_enabled: bool = False
temperature_override_enabled: bool | None = False

View File

@@ -275,13 +275,9 @@ def setup_postgres(db_session: Session) -> None:
],
api_key_changed=True,
)
try:
new_llm_provider = upsert_llm_provider(
llm_provider_upsert_request=model_req, db_session=db_session
)
except ValueError as e:
logger.warning("Failed to upsert LLM provider during setup: %s", e)
return
new_llm_provider = upsert_llm_provider(
llm_provider_upsert_request=model_req, db_session=db_session
)
update_default_provider(
provider_id=new_llm_provider.id, model_name=llm_model, db_session=db_session
)

View File

@@ -18,7 +18,6 @@ from onyx.context.search.models import SearchDoc
from onyx.context.search.models import SearchDocsResponse
from onyx.db.memory import UserMemoryContext
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import CustomToolErrorInfo
from onyx.server.query_and_chat.streaming_models import GeneratedImage
from onyx.tools.tool_implementations.images.models import FinalImageGenerationResponse
from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
@@ -62,7 +61,6 @@ class CustomToolCallSummary(BaseModel):
tool_name: str
response_type: str # e.g., 'json', 'image', 'csv', 'graph'
tool_result: Any # The response data
error: CustomToolErrorInfo | None = None
class ToolCallKickoff(BaseModel):

View File

@@ -15,9 +15,7 @@ from onyx.chat.emitter import get_default_emitter
from onyx.configs.constants import FileOrigin
from onyx.file_store.file_store import get_default_file_store
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import CustomToolArgs
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
from onyx.server.query_and_chat.streaming_models import CustomToolErrorInfo
from onyx.server.query_and_chat.streaming_models import CustomToolStart
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.tools.interface import Tool
@@ -141,7 +139,7 @@ class CustomTool(Tool[None]):
self.emitter.emit(
Packet(
placement=placement,
obj=CustomToolStart(tool_name=self._name, tool_id=self._id),
obj=CustomToolStart(tool_name=self._name),
)
)
@@ -151,8 +149,10 @@ class CustomTool(Tool[None]):
override_kwargs: None = None, # noqa: ARG002
**llm_kwargs: Any,
) -> ToolResponse:
# Build path params
request_body = llm_kwargs.get(REQUEST_BODY)
path_params = {}
for path_param_schema in self._method_spec.get_path_param_schemas():
param_name = path_param_schema["name"]
if param_name not in llm_kwargs:
@@ -165,7 +165,6 @@ class CustomTool(Tool[None]):
)
path_params[param_name] = llm_kwargs[param_name]
# Build query params
query_params = {}
for query_param_schema in self._method_spec.get_query_param_schemas():
if query_param_schema["name"] in llm_kwargs:
@@ -173,20 +172,6 @@ class CustomTool(Tool[None]):
query_param_schema["name"]
]
# Emit args packet (path + query params only, no request body)
tool_args = {**path_params, **query_params}
if tool_args:
self.emitter.emit(
Packet(
placement=placement,
obj=CustomToolArgs(
tool_name=self._name,
tool_args=tool_args,
),
)
)
request_body = llm_kwargs.get(REQUEST_BODY)
url = self._method_spec.build_url(self._base_url, path_params, query_params)
method = self._method_spec.method
@@ -195,18 +180,6 @@ class CustomTool(Tool[None]):
)
content_type = response.headers.get("Content-Type", "")
# Detect HTTP errors — only 401/403 are flagged as auth errors
error_info: CustomToolErrorInfo | None = None
if response.status_code in (401, 403):
error_info = CustomToolErrorInfo(
is_auth_error=True,
status_code=response.status_code,
message=f"{self._name} action failed because of authentication error",
)
logger.warning(
f"Auth error from custom tool '{self._name}': HTTP {response.status_code}"
)
tool_result: Any
response_type: str
file_ids: List[str] | None = None
@@ -249,11 +222,9 @@ class CustomTool(Tool[None]):
placement=placement,
obj=CustomToolDelta(
tool_name=self._name,
tool_id=self._id,
response_type=response_type,
data=data,
file_ids=file_ids,
error=error_info,
),
)
)
@@ -265,7 +236,6 @@ class CustomTool(Tool[None]):
tool_name=self._name,
response_type=response_type,
tool_result=tool_result,
error=error_info,
),
llm_facing_response=llm_facing_response,
)

View File

@@ -111,26 +111,19 @@ def _normalize_text_with_mapping(text: str) -> tuple[str, list[int]]:
# Step 1: NFC normalization with position mapping
nfc_text = unicodedata.normalize("NFC", text)
# Map NFD positions original positions.
# NFD only decomposes, so each original char produces 1+ NFD chars.
nfd_to_orig: list[int] = []
for orig_idx, orig_char in enumerate(original_text):
nfd_of_char = unicodedata.normalize("NFD", orig_char)
for _ in nfd_of_char:
nfd_to_orig.append(orig_idx)
# Map NFC positions → NFD positions.
# Each NFC char, when decomposed, tells us exactly how many NFD
# chars it was composed from.
# Build mapping from NFC positions to original start positions
nfc_to_orig: list[int] = []
nfd_idx = 0
orig_idx = 0
for nfc_char in nfc_text:
if nfd_idx < len(nfd_to_orig):
nfc_to_orig.append(nfd_to_orig[nfd_idx])
nfc_to_orig.append(orig_idx)
# Find how many original chars contributed to this NFC char
for length in range(1, len(original_text) - orig_idx + 1):
substr = original_text[orig_idx : orig_idx + length]
if unicodedata.normalize("NFC", substr) == nfc_char:
orig_idx += length
break
else:
nfc_to_orig.append(len(original_text) - 1)
nfd_of_nfc = unicodedata.normalize("NFD", nfc_char)
nfd_idx += len(nfd_of_nfc)
orig_idx += 1 # Fallback
# Work with NFC text from here
text = nfc_text

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
import json
import time
from collections.abc import Generator
@@ -86,19 +84,6 @@ class CodeInterpreterClient:
raise ValueError("CODE_INTERPRETER_BASE_URL not configured")
self.base_url = base_url.rstrip("/")
self.session = requests.Session()
self._closed = False
def __enter__(self) -> CodeInterpreterClient:
return self
def __exit__(self, *args: object) -> None:
self.close()
def close(self) -> None:
if self._closed:
return
self.session.close()
self._closed = True
def _build_payload(
self,
@@ -192,11 +177,8 @@ class CodeInterpreterClient:
yield from self._batch_as_stream(code, stdin, timeout_ms, files)
return
try:
response.raise_for_status()
yield from self._parse_sse(response)
finally:
response.close()
response.raise_for_status()
yield from self._parse_sse(response)
def _parse_sse(
self, response: requests.Response

View File

@@ -111,8 +111,8 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
if not server.server_enabled:
return False
with CodeInterpreterClient() as client:
return client.health(use_cache=True)
client = CodeInterpreterClient()
return client.health(use_cache=True)
def tool_definition(self) -> dict:
return {
@@ -176,203 +176,196 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
)
)
# Create Code Interpreter client — context manager ensures
# session.close() is called on every exit path.
with CodeInterpreterClient() as client:
# Stage chat files for execution
files_to_stage: list[FileInput] = []
for ind, chat_file in enumerate(chat_files):
file_name = chat_file.filename or f"file_{ind}"
try:
# Upload to Code Interpreter
ci_file_id = client.upload_file(chat_file.content, file_name)
# Stage for execution
files_to_stage.append({"path": file_name, "file_id": ci_file_id})
logger.info(f"Staged file for Python execution: {file_name}")
except Exception as e:
logger.warning(f"Failed to stage file {file_name}: {e}")
# Create Code Interpreter client
client = CodeInterpreterClient()
# Stage chat files for execution
files_to_stage: list[FileInput] = []
for ind, chat_file in enumerate(chat_files):
file_name = chat_file.filename or f"file_{ind}"
try:
logger.debug(f"Executing code: {code}")
# Upload to Code Interpreter
ci_file_id = client.upload_file(chat_file.content, file_name)
# Execute code with streaming (falls back to batch if unavailable)
stdout_parts: list[str] = []
stderr_parts: list[str] = []
result_event: StreamResultEvent | None = None
# Stage for execution
files_to_stage.append({"path": file_name, "file_id": ci_file_id})
for event in client.execute_streaming(
code=code,
timeout_ms=CODE_INTERPRETER_DEFAULT_TIMEOUT_MS,
files=files_to_stage or None,
):
if isinstance(event, StreamOutputEvent):
if event.stream == "stdout":
stdout_parts.append(event.data)
else:
stderr_parts.append(event.data)
# Emit incremental delta to frontend
self.emitter.emit(
Packet(
placement=placement,
obj=PythonToolDelta(
stdout=(
event.data if event.stream == "stdout" else ""
),
stderr=(
event.data if event.stream == "stderr" else ""
),
),
)
)
elif isinstance(event, StreamResultEvent):
result_event = event
elif isinstance(event, StreamErrorEvent):
raise RuntimeError(f"Code interpreter error: {event.message}")
logger.info(f"Staged file for Python execution: {file_name}")
if result_event is None:
raise RuntimeError(
"Code interpreter stream ended without a result event"
)
except Exception as e:
logger.warning(f"Failed to stage file {file_name}: {e}")
full_stdout = "".join(stdout_parts)
full_stderr = "".join(stderr_parts)
try:
logger.debug(f"Executing code: {code}")
# Truncate output for LLM consumption
truncated_stdout = _truncate_output(
full_stdout, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stdout"
)
truncated_stderr = _truncate_output(
full_stderr, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stderr"
)
# Execute code with streaming (falls back to batch if unavailable)
stdout_parts: list[str] = []
stderr_parts: list[str] = []
result_event: StreamResultEvent | None = None
# Handle generated files
generated_files: list[PythonExecutionFile] = []
generated_file_ids: list[str] = []
file_ids_to_cleanup: list[str] = []
file_store = get_default_file_store()
for workspace_file in result_event.files:
if workspace_file.kind != "file" or not workspace_file.file_id:
continue
try:
# Download file from Code Interpreter
file_content = client.download_file(workspace_file.file_id)
# Determine MIME type from file extension
filename = workspace_file.path.split("/")[-1]
mime_type, _ = mimetypes.guess_type(filename)
# Default to binary if we can't determine the type
mime_type = mime_type or "application/octet-stream"
# Save to Onyx file store
onyx_file_id = file_store.save_file(
content=BytesIO(file_content),
display_name=filename,
file_origin=FileOrigin.CHAT_UPLOAD,
file_type=mime_type,
)
generated_files.append(
PythonExecutionFile(
filename=filename,
file_link=build_full_frontend_file_url(onyx_file_id),
)
)
generated_file_ids.append(onyx_file_id)
# Mark for cleanup
file_ids_to_cleanup.append(workspace_file.file_id)
except Exception as e:
logger.error(
f"Failed to handle generated file "
f"{workspace_file.path}: {e}"
)
# Cleanup Code Interpreter files (generated files)
for ci_file_id in file_ids_to_cleanup:
try:
client.delete_file(ci_file_id)
except Exception as e:
logger.error(
f"Failed to delete Code Interpreter generated "
f"file {ci_file_id}: {e}"
)
# Cleanup staged input files
for file_mapping in files_to_stage:
try:
client.delete_file(file_mapping["file_id"])
except Exception as e:
logger.error(
f"Failed to delete Code Interpreter staged "
f"file {file_mapping['file_id']}: {e}"
)
# Emit file_ids once files are processed
if generated_file_ids:
for event in client.execute_streaming(
code=code,
timeout_ms=CODE_INTERPRETER_DEFAULT_TIMEOUT_MS,
files=files_to_stage or None,
):
if isinstance(event, StreamOutputEvent):
if event.stream == "stdout":
stdout_parts.append(event.data)
else:
stderr_parts.append(event.data)
# Emit incremental delta to frontend
self.emitter.emit(
Packet(
placement=placement,
obj=PythonToolDelta(file_ids=generated_file_ids),
obj=PythonToolDelta(
stdout=event.data if event.stream == "stdout" else "",
stderr=event.data if event.stream == "stderr" else "",
),
)
)
elif isinstance(event, StreamResultEvent):
result_event = event
elif isinstance(event, StreamErrorEvent):
raise RuntimeError(f"Code interpreter error: {event.message}")
# Build result
result = LlmPythonExecutionResult(
stdout=truncated_stdout,
stderr=truncated_stderr,
exit_code=result_event.exit_code,
timed_out=result_event.timed_out,
generated_files=generated_files,
error=(None if result_event.exit_code == 0 else truncated_stderr),
if result_event is None:
raise RuntimeError(
"Code interpreter stream ended without a result event"
)
# Serialize result for LLM
adapter = TypeAdapter(LlmPythonExecutionResult)
llm_response = adapter.dump_json(result).decode()
full_stdout = "".join(stdout_parts)
full_stderr = "".join(stderr_parts)
return ToolResponse(
rich_response=PythonToolRichResponse(
generated_files=generated_files,
),
llm_facing_response=llm_response,
)
# Truncate output for LLM consumption
truncated_stdout = _truncate_output(
full_stdout, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stdout"
)
truncated_stderr = _truncate_output(
full_stderr, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stderr"
)
except Exception as e:
logger.error(f"Python execution failed: {e}")
error_msg = str(e)
# Handle generated files
generated_files: list[PythonExecutionFile] = []
generated_file_ids: list[str] = []
file_ids_to_cleanup: list[str] = []
file_store = get_default_file_store()
# Emit error delta
for workspace_file in result_event.files:
if workspace_file.kind != "file" or not workspace_file.file_id:
continue
try:
# Download file from Code Interpreter
file_content = client.download_file(workspace_file.file_id)
# Determine MIME type from file extension
filename = workspace_file.path.split("/")[-1]
mime_type, _ = mimetypes.guess_type(filename)
# Default to binary if we can't determine the type
mime_type = mime_type or "application/octet-stream"
# Save to Onyx file store
onyx_file_id = file_store.save_file(
content=BytesIO(file_content),
display_name=filename,
file_origin=FileOrigin.CHAT_UPLOAD,
file_type=mime_type,
)
generated_files.append(
PythonExecutionFile(
filename=filename,
file_link=build_full_frontend_file_url(onyx_file_id),
)
)
generated_file_ids.append(onyx_file_id)
# Mark for cleanup
file_ids_to_cleanup.append(workspace_file.file_id)
except Exception as e:
logger.error(
f"Failed to handle generated file {workspace_file.path}: {e}"
)
# Cleanup Code Interpreter files (generated files)
for ci_file_id in file_ids_to_cleanup:
try:
client.delete_file(ci_file_id)
except Exception as e:
logger.error(
f"Failed to delete Code Interpreter generated file {ci_file_id}: {e}"
)
# Cleanup staged input files
for file_mapping in files_to_stage:
try:
client.delete_file(file_mapping["file_id"])
except Exception as e:
logger.error(
f"Failed to delete Code Interpreter staged file {file_mapping['file_id']}: {e}"
)
# Emit file_ids once files are processed
if generated_file_ids:
self.emitter.emit(
Packet(
placement=placement,
obj=PythonToolDelta(
stdout="",
stderr=error_msg,
file_ids=[],
),
obj=PythonToolDelta(file_ids=generated_file_ids),
)
)
# Return error result
result = LlmPythonExecutionResult(
stdout="",
stderr=error_msg,
exit_code=-1,
timed_out=False,
generated_files=[],
error=error_msg,
)
# Build result
result = LlmPythonExecutionResult(
stdout=truncated_stdout,
stderr=truncated_stderr,
exit_code=result_event.exit_code,
timed_out=result_event.timed_out,
generated_files=generated_files,
error=None if result_event.exit_code == 0 else truncated_stderr,
)
adapter = TypeAdapter(LlmPythonExecutionResult)
llm_response = adapter.dump_json(result).decode()
# Serialize result for LLM
adapter = TypeAdapter(LlmPythonExecutionResult)
llm_response = adapter.dump_json(result).decode()
return ToolResponse(
rich_response=None,
llm_facing_response=llm_response,
return ToolResponse(
rich_response=PythonToolRichResponse(
generated_files=generated_files,
),
llm_facing_response=llm_response,
)
except Exception as e:
logger.error(f"Python execution failed: {e}")
error_msg = str(e)
# Emit error delta
self.emitter.emit(
Packet(
placement=placement,
obj=PythonToolDelta(
stdout="",
stderr=error_msg,
file_ids=[],
),
)
)
# Return error result
result = LlmPythonExecutionResult(
stdout="",
stderr=error_msg,
exit_code=-1,
timed_out=False,
generated_files=[],
error=error_msg,
)
adapter = TypeAdapter(LlmPythonExecutionResult)
llm_response = adapter.dump_json(result).decode()
return ToolResponse(
rich_response=None,
llm_facing_response=llm_response,
)

View File

@@ -1,17 +0,0 @@
"""
jsonriver - A streaming JSON parser for Python
Parse JSON incrementally as it streams in, e.g. from a network request or a language model.
Gives you a sequence of increasingly complete values.
Copyright (c) 2023 Google LLC (original TypeScript implementation)
Copyright (c) 2024 jsonriver-python contributors (Python port)
SPDX-License-Identifier: BSD-3-Clause
"""
from .parse import _Parser as Parser
from .parse import JsonObject
from .parse import JsonValue
__all__ = ["Parser", "JsonValue", "JsonObject"]
__version__ = "0.0.1"

View File

@@ -1,427 +0,0 @@
"""
JSON parser for streaming incremental parsing
Copyright (c) 2023 Google LLC (original TypeScript implementation)
Copyright (c) 2024 jsonriver-python contributors (Python port)
SPDX-License-Identifier: BSD-3-Clause
"""
from __future__ import annotations
import copy
from enum import IntEnum
from typing import cast
from typing import Union
from .tokenize import _Input
from .tokenize import json_token_type_to_string
from .tokenize import JsonTokenType
from .tokenize import Tokenizer
# Type definitions for JSON values
JsonValue = Union[None, bool, float, str, list["JsonValue"], dict[str, "JsonValue"]]
JsonObject = dict[str, JsonValue]
class _StateEnum(IntEnum):
"""Parser state machine states"""
Initial = 0
InString = 1
InArray = 2
InObjectExpectingKey = 3
InObjectExpectingValue = 4
class _State:
"""Base class for parser states"""
type: _StateEnum
value: JsonValue | tuple[str, JsonObject] | None
class _InitialState(_State):
"""Initial state before any parsing"""
def __init__(self) -> None:
self.type = _StateEnum.Initial
self.value = None
class _InStringState(_State):
"""State while parsing a string"""
def __init__(self) -> None:
self.type = _StateEnum.InString
self.value = ""
class _InArrayState(_State):
"""State while parsing an array"""
def __init__(self) -> None:
self.type = _StateEnum.InArray
self.value: list[JsonValue] = []
class _InObjectExpectingKeyState(_State):
"""State while parsing an object, expecting a key"""
def __init__(self) -> None:
self.type = _StateEnum.InObjectExpectingKey
self.value: JsonObject = {}
class _InObjectExpectingValueState(_State):
"""State while parsing an object, expecting a value"""
def __init__(self, key: str, obj: JsonObject) -> None:
self.type = _StateEnum.InObjectExpectingValue
self.value = (key, obj)
# Sentinel value to distinguish "not set" from "set to None/null"
class _Unset:
pass
_UNSET = _Unset()
class _Parser:
"""
Incremental JSON parser
Feed chunks of JSON text via feed() and get back progressively
more complete JSON values.
"""
def __init__(self) -> None:
self._state_stack: list[_State] = [_InitialState()]
self._toplevel_value: JsonValue | _Unset = _UNSET
self._input = _Input()
self.tokenizer = Tokenizer(self._input, self)
self._finished = False
self._progressed = False
self._prev_snapshot: JsonValue | _Unset = _UNSET
def feed(self, chunk: str) -> list[JsonValue]:
"""
Feed a chunk of JSON text and return deltas from the previous state.
Each element in the returned list represents what changed since the
last yielded value. For dicts, only changed/new keys are included,
with string values containing only the newly appended characters.
"""
if self._finished:
return []
self._input.feed(chunk)
return self._collect_deltas()
@staticmethod
def _compute_delta(prev: JsonValue | None, current: JsonValue) -> JsonValue | None:
if prev is None:
return current
if isinstance(current, dict) and isinstance(prev, dict):
result: JsonObject = {}
for key in current:
cur_val = current[key]
prev_val = prev.get(key)
if key not in prev:
result[key] = cur_val
elif isinstance(cur_val, str) and isinstance(prev_val, str):
if cur_val != prev_val:
result[key] = cur_val[len(prev_val) :]
elif isinstance(cur_val, list) and isinstance(prev_val, list):
if cur_val != prev_val:
new_items = cur_val[len(prev_val) :]
# check if the last existing element was updated
if (
prev_val
and len(cur_val) >= len(prev_val)
and cur_val[len(prev_val) - 1] != prev_val[-1]
):
result[key] = [cur_val[len(prev_val) - 1]] + new_items
elif new_items:
result[key] = new_items
elif cur_val != prev_val:
result[key] = cur_val
return result if result else None
if isinstance(current, str) and isinstance(prev, str):
delta = current[len(prev) :]
return delta if delta else None
if isinstance(current, list) and isinstance(prev, list):
if current != prev:
new_items = current[len(prev) :]
if (
prev
and len(current) >= len(prev)
and current[len(prev) - 1] != prev[-1]
):
return [current[len(prev) - 1]] + new_items
return new_items if new_items else None
return None
if current != prev:
return current
return None
def finish(self) -> list[JsonValue]:
"""Signal that no more chunks will be fed. Validates trailing content.
Returns any final deltas produced by flushing pending tokens (e.g.
numbers, which have no terminator and wait for more input).
"""
self._input.mark_complete()
# Pump once more so the tokenizer can emit tokens that were waiting
# for more input (e.g. numbers need buffer_complete to finalize).
results = self._collect_deltas()
self._input.expect_end_of_content()
return results
def _collect_deltas(self) -> list[JsonValue]:
"""Run one pump cycle and return any deltas produced."""
results: list[JsonValue] = []
while True:
self._progressed = False
self.tokenizer.pump()
if self._progressed:
if self._toplevel_value is _UNSET:
raise RuntimeError(
"Internal error: toplevel_value should not be unset "
"after progressing"
)
current = copy.deepcopy(cast(JsonValue, self._toplevel_value))
if isinstance(self._prev_snapshot, _Unset):
results.append(current)
else:
delta = self._compute_delta(self._prev_snapshot, current)
if delta is not None:
results.append(delta)
self._prev_snapshot = current
else:
if not self._state_stack:
self._finished = True
break
return results
# TokenHandler protocol implementation
def handle_null(self) -> None:
"""Handle null token"""
self._handle_value_token(JsonTokenType.Null, None)
def handle_boolean(self, value: bool) -> None:
"""Handle boolean token"""
self._handle_value_token(JsonTokenType.Boolean, value)
def handle_number(self, value: float) -> None:
"""Handle number token"""
self._handle_value_token(JsonTokenType.Number, value)
def handle_string_start(self) -> None:
"""Handle string start token"""
state = self._current_state()
if not self._progressed and state.type != _StateEnum.InObjectExpectingKey:
self._progressed = True
if state.type == _StateEnum.Initial:
self._state_stack.pop()
self._toplevel_value = self._progress_value(JsonTokenType.StringStart, None)
elif state.type == _StateEnum.InArray:
v = self._progress_value(JsonTokenType.StringStart, None)
arr = cast(list[JsonValue], state.value)
arr.append(v)
elif state.type == _StateEnum.InObjectExpectingKey:
self._state_stack.append(_InStringState())
elif state.type == _StateEnum.InObjectExpectingValue:
key, obj = cast(tuple[str, JsonObject], state.value)
sv = self._progress_value(JsonTokenType.StringStart, None)
obj[key] = sv
elif state.type == _StateEnum.InString:
raise ValueError(
f"Unexpected {json_token_type_to_string(JsonTokenType.StringStart)} "
f"token in the middle of string"
)
def handle_string_middle(self, value: str) -> None:
"""Handle string middle token"""
state = self._current_state()
if not self._progressed:
if len(self._state_stack) >= 2:
prev = self._state_stack[-2]
if prev.type != _StateEnum.InObjectExpectingKey:
self._progressed = True
else:
self._progressed = True
if state.type != _StateEnum.InString:
raise ValueError(
f"Unexpected {json_token_type_to_string(JsonTokenType.StringMiddle)} "
f"token when not in string"
)
assert isinstance(state.value, str)
state.value += value
parent_state = self._state_stack[-2] if len(self._state_stack) >= 2 else None
self._update_string_parent(state.value, parent_state)
def handle_string_end(self) -> None:
"""Handle string end token"""
state = self._current_state()
if state.type != _StateEnum.InString:
raise ValueError(
f"Unexpected {json_token_type_to_string(JsonTokenType.StringEnd)} "
f"token when not in string"
)
self._state_stack.pop()
parent_state = self._state_stack[-1] if self._state_stack else None
assert isinstance(state.value, str)
self._update_string_parent(state.value, parent_state)
def handle_array_start(self) -> None:
"""Handle array start token"""
self._handle_value_token(JsonTokenType.ArrayStart, None)
def handle_array_end(self) -> None:
"""Handle array end token"""
state = self._current_state()
if state.type != _StateEnum.InArray:
raise ValueError(
f"Unexpected {json_token_type_to_string(JsonTokenType.ArrayEnd)} token"
)
self._state_stack.pop()
def handle_object_start(self) -> None:
"""Handle object start token"""
self._handle_value_token(JsonTokenType.ObjectStart, None)
def handle_object_end(self) -> None:
"""Handle object end token"""
state = self._current_state()
if state.type in (
_StateEnum.InObjectExpectingKey,
_StateEnum.InObjectExpectingValue,
):
self._state_stack.pop()
else:
raise ValueError(
f"Unexpected {json_token_type_to_string(JsonTokenType.ObjectEnd)} token"
)
# Private helper methods
def _current_state(self) -> _State:
"""Get current parser state"""
if not self._state_stack:
raise ValueError("Unexpected trailing input")
return self._state_stack[-1]
def _handle_value_token(self, token_type: JsonTokenType, value: JsonValue) -> None:
"""Handle a complete value token"""
state = self._current_state()
if not self._progressed:
self._progressed = True
if state.type == _StateEnum.Initial:
self._state_stack.pop()
self._toplevel_value = self._progress_value(token_type, value)
elif state.type == _StateEnum.InArray:
v = self._progress_value(token_type, value)
arr = cast(list[JsonValue], state.value)
arr.append(v)
elif state.type == _StateEnum.InObjectExpectingValue:
key, obj = cast(tuple[str, JsonObject], state.value)
if token_type != JsonTokenType.StringStart:
self._state_stack.pop()
new_state = _InObjectExpectingKeyState()
new_state.value = obj
self._state_stack.append(new_state)
v = self._progress_value(token_type, value)
obj[key] = v
elif state.type == _StateEnum.InString:
raise ValueError(
f"Unexpected {json_token_type_to_string(token_type)} "
f"token in the middle of string"
)
elif state.type == _StateEnum.InObjectExpectingKey:
raise ValueError(
f"Unexpected {json_token_type_to_string(token_type)} "
f"token in the middle of object expecting key"
)
def _update_string_parent(self, updated: str, parent_state: _State | None) -> None:
"""Update parent container with updated string value"""
if parent_state is None:
self._toplevel_value = updated
elif parent_state.type == _StateEnum.InArray:
arr = cast(list[JsonValue], parent_state.value)
arr[-1] = updated
elif parent_state.type == _StateEnum.InObjectExpectingValue:
key, obj = cast(tuple[str, JsonObject], parent_state.value)
obj[key] = updated
if self._state_stack and self._state_stack[-1] == parent_state:
self._state_stack.pop()
new_state = _InObjectExpectingKeyState()
new_state.value = obj
self._state_stack.append(new_state)
elif parent_state.type == _StateEnum.InObjectExpectingKey:
if self._state_stack and self._state_stack[-1] == parent_state:
self._state_stack.pop()
obj = cast(JsonObject, parent_state.value)
self._state_stack.append(_InObjectExpectingValueState(updated, obj))
def _progress_value(self, token_type: JsonTokenType, value: JsonValue) -> JsonValue:
"""Create initial value for a token and push appropriate state"""
if token_type == JsonTokenType.Null:
return None
elif token_type == JsonTokenType.Boolean:
return value
elif token_type == JsonTokenType.Number:
return value
elif token_type == JsonTokenType.StringStart:
string_state = _InStringState()
self._state_stack.append(string_state)
return ""
elif token_type == JsonTokenType.ArrayStart:
array_state = _InArrayState()
self._state_stack.append(array_state)
return array_state.value
elif token_type == JsonTokenType.ObjectStart:
object_state = _InObjectExpectingKeyState()
self._state_stack.append(object_state)
return object_state.value
else:
raise ValueError(
f"Unexpected token type: {json_token_type_to_string(token_type)}"
)

View File

@@ -1,514 +0,0 @@
"""
JSON tokenizer for streaming incremental parsing
Copyright (c) 2023 Google LLC (original TypeScript implementation)
Copyright (c) 2024 jsonriver-python contributors (Python port)
SPDX-License-Identifier: BSD-3-Clause
"""
from __future__ import annotations
import re
from enum import IntEnum
from typing import Protocol
class TokenHandler(Protocol):
"""Protocol for handling JSON tokens"""
def handle_null(self) -> None: ...
def handle_boolean(self, value: bool) -> None: ...
def handle_number(self, value: float) -> None: ...
def handle_string_start(self) -> None: ...
def handle_string_middle(self, value: str) -> None: ...
def handle_string_end(self) -> None: ...
def handle_array_start(self) -> None: ...
def handle_array_end(self) -> None: ...
def handle_object_start(self) -> None: ...
def handle_object_end(self) -> None: ...
class JsonTokenType(IntEnum):
"""Types of JSON tokens"""
Null = 0
Boolean = 1
Number = 2
StringStart = 3
StringMiddle = 4
StringEnd = 5
ArrayStart = 6
ArrayEnd = 7
ObjectStart = 8
ObjectEnd = 9
def json_token_type_to_string(token_type: JsonTokenType) -> str:
"""Convert token type to readable string"""
names = {
JsonTokenType.Null: "null",
JsonTokenType.Boolean: "boolean",
JsonTokenType.Number: "number",
JsonTokenType.StringStart: "string start",
JsonTokenType.StringMiddle: "string middle",
JsonTokenType.StringEnd: "string end",
JsonTokenType.ArrayStart: "array start",
JsonTokenType.ArrayEnd: "array end",
JsonTokenType.ObjectStart: "object start",
JsonTokenType.ObjectEnd: "object end",
}
return names[token_type]
class _State(IntEnum):
"""Internal tokenizer states"""
ExpectingValue = 0
InString = 1
StartArray = 2
AfterArrayValue = 3
StartObject = 4
AfterObjectKey = 5
AfterObjectValue = 6
BeforeObjectKey = 7
# Regex for validating JSON numbers
_JSON_NUMBER_PATTERN = re.compile(r"^-?(0|[1-9]\d*)(\.\d+)?([eE][+-]?\d+)?$")
def _parse_json_number(s: str) -> float:
"""Parse a JSON number string, validating format"""
if not _JSON_NUMBER_PATTERN.match(s):
raise ValueError("Invalid number")
return float(s)
class _Input:
"""
Input buffer for chunk-based JSON parsing
Manages buffering of input chunks and provides methods for
consuming and inspecting the buffer.
"""
def __init__(self) -> None:
self._buffer = ""
self._start_index = 0
self.buffer_complete = False
def feed(self, chunk: str) -> None:
"""Add a chunk of data to the buffer"""
self._buffer += chunk
def mark_complete(self) -> None:
"""Signal that no more chunks will be fed"""
self.buffer_complete = True
@property
def length(self) -> int:
"""Number of characters remaining in buffer"""
return len(self._buffer) - self._start_index
def advance(self, length: int) -> None:
"""Advance the start position by length characters"""
self._start_index += length
def peek(self, offset: int) -> str | None:
"""Peek at character at offset, or None if not available"""
idx = self._start_index + offset
if idx < len(self._buffer):
return self._buffer[idx]
return None
def peek_char_code(self, offset: int) -> int:
"""Get character code at offset"""
return ord(self._buffer[self._start_index + offset])
def slice(self, start: int, end: int) -> str:
"""Slice buffer from start to end (relative to current position)"""
return self._buffer[self._start_index + start : self._start_index + end]
def commit(self) -> None:
"""Commit consumed content, removing it from buffer"""
if self._start_index > 0:
self._buffer = self._buffer[self._start_index :]
self._start_index = 0
def remaining(self) -> str:
"""Get all remaining content in buffer"""
return self._buffer[self._start_index :]
def expect_end_of_content(self) -> None:
"""Verify no non-whitespace content remains"""
self.commit()
self.skip_past_whitespace()
if self.length != 0:
raise ValueError(f"Unexpected trailing content {self.remaining()!r}")
def skip_past_whitespace(self) -> None:
"""Skip whitespace characters"""
i = self._start_index
while i < len(self._buffer):
c = ord(self._buffer[i])
if c in (32, 9, 10, 13): # space, tab, \n, \r
i += 1
else:
break
self._start_index = i
def try_to_take_prefix(self, prefix: str) -> bool:
"""Try to consume prefix from buffer, return True if successful"""
if self._buffer.startswith(prefix, self._start_index):
self._start_index += len(prefix)
return True
return False
def try_to_take(self, length: int) -> str | None:
"""Try to take length characters, or None if not enough available"""
if self.length < length:
return None
result = self._buffer[self._start_index : self._start_index + length]
self._start_index += length
return result
def try_to_take_char_code(self) -> int | None:
"""Try to take a single character as char code, or None if buffer empty"""
if self.length == 0:
return None
code = ord(self._buffer[self._start_index])
self._start_index += 1
return code
def take_until_quote_or_backslash(self) -> tuple[str, bool]:
"""
Consume input up to first quote or backslash
Returns tuple of (consumed_content, pattern_found)
"""
buf = self._buffer
i = self._start_index
while i < len(buf):
c = ord(buf[i])
if c <= 0x1F:
raise ValueError("Unescaped control character in string")
if c == 34 or c == 92: # " or \
result = buf[self._start_index : i]
self._start_index = i
return (result, True)
i += 1
result = buf[self._start_index :]
self._start_index = len(buf)
return (result, False)
class Tokenizer:
"""
Tokenizer for chunk-based JSON parsing
Processes chunks fed into its input buffer and calls handler methods
as JSON tokens are recognized.
"""
def __init__(self, input: _Input, handler: TokenHandler) -> None:
self.input = input
self._handler = handler
self._stack: list[_State] = [_State.ExpectingValue]
self._emitted_tokens = 0
def is_done(self) -> bool:
"""Check if tokenization is complete"""
return len(self._stack) == 0 and self.input.length == 0
def pump(self) -> None:
"""Process all available tokens in the buffer"""
while True:
before = self._emitted_tokens
self._tokenize_more()
if self._emitted_tokens == before:
self.input.commit()
return
def _tokenize_more(self) -> None:
"""Process one step of tokenization based on current state"""
if not self._stack:
return
state = self._stack[-1]
if state == _State.ExpectingValue:
self._tokenize_value()
elif state == _State.InString:
self._tokenize_string()
elif state == _State.StartArray:
self._tokenize_array_start()
elif state == _State.AfterArrayValue:
self._tokenize_after_array_value()
elif state == _State.StartObject:
self._tokenize_object_start()
elif state == _State.AfterObjectKey:
self._tokenize_after_object_key()
elif state == _State.AfterObjectValue:
self._tokenize_after_object_value()
elif state == _State.BeforeObjectKey:
self._tokenize_before_object_key()
def _tokenize_value(self) -> None:
"""Tokenize a JSON value"""
self.input.skip_past_whitespace()
if self.input.try_to_take_prefix("null"):
self._handler.handle_null()
self._emitted_tokens += 1
self._stack.pop()
return
if self.input.try_to_take_prefix("true"):
self._handler.handle_boolean(True)
self._emitted_tokens += 1
self._stack.pop()
return
if self.input.try_to_take_prefix("false"):
self._handler.handle_boolean(False)
self._emitted_tokens += 1
self._stack.pop()
return
if self.input.length > 0:
ch = self.input.peek_char_code(0)
if (48 <= ch <= 57) or ch == 45: # 0-9 or -
# Scan for end of number
i = 0
while i < self.input.length:
c = self.input.peek_char_code(i)
if (48 <= c <= 57) or c in (45, 43, 46, 101, 69): # 0-9 - + . e E
i += 1
else:
break
if i == self.input.length and not self.input.buffer_complete:
# Need more input (numbers have no terminator)
return
number_chars = self.input.slice(0, i)
self.input.advance(i)
number = _parse_json_number(number_chars)
self._handler.handle_number(number)
self._emitted_tokens += 1
self._stack.pop()
return
if self.input.try_to_take_prefix('"'):
self._stack.pop()
self._stack.append(_State.InString)
self._handler.handle_string_start()
self._emitted_tokens += 1
self._tokenize_string()
return
if self.input.try_to_take_prefix("["):
self._stack.pop()
self._stack.append(_State.StartArray)
self._handler.handle_array_start()
self._emitted_tokens += 1
self._tokenize_array_start()
return
if self.input.try_to_take_prefix("{"):
self._stack.pop()
self._stack.append(_State.StartObject)
self._handler.handle_object_start()
self._emitted_tokens += 1
self._tokenize_object_start()
return
def _tokenize_string(self) -> None:
"""Tokenize string content"""
while True:
chunk, interrupted = self.input.take_until_quote_or_backslash()
if chunk:
self._handler.handle_string_middle(chunk)
self._emitted_tokens += 1
elif not interrupted:
return
if interrupted:
if self.input.length == 0:
return
next_char = self.input.peek(0)
if next_char == '"':
self.input.advance(1)
self._handler.handle_string_end()
self._emitted_tokens += 1
self._stack.pop()
return
# Handle escape sequences
next_char2 = self.input.peek(1)
if next_char2 is None:
return
value: str
if next_char2 == "u":
# Unicode escape: need 4 hex digits
if self.input.length < 6:
return
code = 0
for j in range(2, 6):
c = self.input.peek_char_code(j)
if 48 <= c <= 57: # 0-9
digit = c - 48
elif 65 <= c <= 70: # A-F
digit = c - 55
elif 97 <= c <= 102: # a-f
digit = c - 87
else:
raise ValueError("Bad Unicode escape in JSON")
code = (code << 4) | digit
self.input.advance(6)
self._handler.handle_string_middle(chr(code))
self._emitted_tokens += 1
continue
elif next_char2 == "n":
value = "\n"
elif next_char2 == "r":
value = "\r"
elif next_char2 == "t":
value = "\t"
elif next_char2 == "b":
value = "\b"
elif next_char2 == "f":
value = "\f"
elif next_char2 == "\\":
value = "\\"
elif next_char2 == "/":
value = "/"
elif next_char2 == '"':
value = '"'
else:
raise ValueError("Bad escape in string")
self.input.advance(2)
self._handler.handle_string_middle(value)
self._emitted_tokens += 1
def _tokenize_array_start(self) -> None:
"""Tokenize start of array (check for empty or first element)"""
self.input.skip_past_whitespace()
if self.input.length == 0:
return
if self.input.try_to_take_prefix("]"):
self._handler.handle_array_end()
self._emitted_tokens += 1
self._stack.pop()
return
self._stack.pop()
self._stack.append(_State.AfterArrayValue)
self._stack.append(_State.ExpectingValue)
self._tokenize_value()
def _tokenize_after_array_value(self) -> None:
"""Tokenize after an array value (expect , or ])"""
self.input.skip_past_whitespace()
next_char = self.input.try_to_take_char_code()
if next_char is None:
return
elif next_char == 0x5D: # ]
self._handler.handle_array_end()
self._emitted_tokens += 1
self._stack.pop()
return
elif next_char == 0x2C: # ,
self._stack.append(_State.ExpectingValue)
self._tokenize_value()
return
else:
raise ValueError(f"Expected , or ], got {chr(next_char)!r}")
def _tokenize_object_start(self) -> None:
"""Tokenize start of object (check for empty or first key)"""
self.input.skip_past_whitespace()
next_char = self.input.try_to_take_char_code()
if next_char is None:
return
elif next_char == 0x7D: # }
self._handler.handle_object_end()
self._emitted_tokens += 1
self._stack.pop()
return
elif next_char == 0x22: # "
self._stack.pop()
self._stack.append(_State.AfterObjectKey)
self._stack.append(_State.InString)
self._handler.handle_string_start()
self._emitted_tokens += 1
self._tokenize_string()
return
else:
raise ValueError(f"Expected start of object key, got {chr(next_char)!r}")
def _tokenize_after_object_key(self) -> None:
"""Tokenize after object key (expect :)"""
self.input.skip_past_whitespace()
next_char = self.input.try_to_take_char_code()
if next_char is None:
return
elif next_char == 0x3A: # :
self._stack.pop()
self._stack.append(_State.AfterObjectValue)
self._stack.append(_State.ExpectingValue)
self._tokenize_value()
return
else:
raise ValueError(f"Expected colon after object key, got {chr(next_char)!r}")
def _tokenize_after_object_value(self) -> None:
"""Tokenize after object value (expect , or })"""
self.input.skip_past_whitespace()
next_char = self.input.try_to_take_char_code()
if next_char is None:
return
elif next_char == 0x7D: # }
self._handler.handle_object_end()
self._emitted_tokens += 1
self._stack.pop()
return
elif next_char == 0x2C: # ,
self._stack.pop()
self._stack.append(_State.BeforeObjectKey)
self._tokenize_before_object_key()
return
else:
raise ValueError(
f"Expected , or }} after object value, got {chr(next_char)!r}"
)
def _tokenize_before_object_key(self) -> None:
"""Tokenize before object key (after comma)"""
self.input.skip_past_whitespace()
next_char = self.input.try_to_take_char_code()
if next_char is None:
return
elif next_char == 0x22: # "
self._stack.pop()
self._stack.append(_State.AfterObjectKey)
self._stack.append(_State.InString)
self._handler.handle_string_start()
self._emitted_tokens += 1
self._tokenize_string()
return
else:
raise ValueError(f"Expected start of object key, got {chr(next_char)!r}")

View File

@@ -24,9 +24,6 @@ class OnyxVersion:
def set_ee(self) -> None:
self._is_ee = True
def unset_ee(self) -> None:
self._is_ee = False
def is_ee_version(self) -> bool:
return self._is_ee

Some files were not shown because too many files have changed in this diff Show More