mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-09 17:52:40 +00:00
Compare commits
6 Commits
jamison/on
...
refactor/o
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4ca8050540 | ||
|
|
85186fc58a | ||
|
|
7e37c48fb7 | ||
|
|
7122e5a9a7 | ||
|
|
5ad3dd370f | ||
|
|
66c464c29a |
@@ -1,186 +0,0 @@
|
||||
---
|
||||
name: onyx-cli
|
||||
description: Query the Onyx knowledge base using the onyx-cli command. Use when the user wants to search company documents, ask questions about internal knowledge, query connected data sources, or look up information stored in Onyx.
|
||||
---
|
||||
|
||||
# Onyx CLI — Agent Tool
|
||||
|
||||
Onyx is an enterprise search and Gen-AI platform that connects to company documents, apps, and people. The `onyx-cli` CLI provides non-interactive commands to query the Onyx knowledge base and list available agents.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### 1. Check if installed
|
||||
|
||||
```bash
|
||||
which onyx-cli
|
||||
```
|
||||
|
||||
### 2. Install (if needed)
|
||||
|
||||
**Primary — pip:**
|
||||
|
||||
```bash
|
||||
pip install onyx-cli
|
||||
```
|
||||
|
||||
**From source (Go):**
|
||||
|
||||
```bash
|
||||
cd cli && go build -o onyx-cli . && sudo mv onyx-cli /usr/local/bin/
|
||||
```
|
||||
|
||||
### 3. Check if configured
|
||||
|
||||
```bash
|
||||
onyx-cli validate-config
|
||||
```
|
||||
|
||||
This checks the config file exists, API key is present, and tests the server connection via `/api/me`. Exit code 0 on success, non-zero with a descriptive error on failure.
|
||||
|
||||
If unconfigured, you have two options:
|
||||
|
||||
**Option A — Interactive setup (requires user input):**
|
||||
|
||||
```bash
|
||||
onyx-cli configure
|
||||
```
|
||||
|
||||
This prompts for the Onyx server URL and API key, tests the connection, and saves config.
|
||||
|
||||
**Option B — Environment variables (non-interactive, preferred for agents):**
|
||||
|
||||
```bash
|
||||
export ONYX_SERVER_URL="https://your-onyx-server.com" # default: https://cloud.onyx.app
|
||||
export ONYX_API_KEY="your-api-key"
|
||||
```
|
||||
|
||||
Environment variables override the config file. If these are set, no config file is needed.
|
||||
|
||||
| Variable | Required | Description |
|
||||
|----------|----------|-------------|
|
||||
| `ONYX_SERVER_URL` | No | Onyx server base URL (default: `https://cloud.onyx.app`) |
|
||||
| `ONYX_API_KEY` | Yes | API key for authentication |
|
||||
| `ONYX_PERSONA_ID` | No | Default agent/persona ID |
|
||||
|
||||
If neither the config file nor environment variables are set, tell the user that `onyx-cli` needs to be configured and ask them to either:
|
||||
- Run `onyx-cli configure` interactively, or
|
||||
- Set `ONYX_SERVER_URL` and `ONYX_API_KEY` environment variables
|
||||
|
||||
## Commands
|
||||
|
||||
### Validate configuration
|
||||
|
||||
```bash
|
||||
onyx-cli validate-config
|
||||
```
|
||||
|
||||
Checks config file exists, API key is present, and tests the server connection. Use this before `ask` or `agents` to confirm the CLI is properly set up.
|
||||
|
||||
### List available agents
|
||||
|
||||
```bash
|
||||
onyx-cli agents
|
||||
```
|
||||
|
||||
Prints a table of agent IDs, names, and descriptions. Use `--json` for structured output:
|
||||
|
||||
```bash
|
||||
onyx-cli agents --json
|
||||
```
|
||||
|
||||
Use agent IDs with `ask --agent-id` to query a specific agent.
|
||||
|
||||
### Basic query (plain text output)
|
||||
|
||||
```bash
|
||||
onyx-cli ask "What is our company's PTO policy?"
|
||||
```
|
||||
|
||||
Streams the answer as plain text to stdout. Exit code 0 on success, non-zero on error.
|
||||
|
||||
### JSON output (structured events)
|
||||
|
||||
```bash
|
||||
onyx-cli ask --json "What authentication methods do we support?"
|
||||
```
|
||||
|
||||
Outputs JSON-encoded parsed stream events (one object per line). Key event objects include message deltas, stop, errors, search-start, and citation payloads.
|
||||
|
||||
Each line is a JSON object with this envelope:
|
||||
|
||||
```json
|
||||
{"type": "<event_type>", "event": { ... }}
|
||||
```
|
||||
|
||||
| Event Type | Description |
|
||||
|------------|-------------|
|
||||
| `message_delta` | Content token — concatenate all `content` fields for the full answer |
|
||||
| `stop` | Stream complete |
|
||||
| `error` | Error with `error` message field |
|
||||
| `search_tool_start` | Onyx started searching documents |
|
||||
| `citation_info` | Source citation — see shape below |
|
||||
|
||||
`citation_info` event shape:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "citation_info",
|
||||
"event": {
|
||||
"citation_number": 1,
|
||||
"document_id": "abc123def456",
|
||||
"placement": {"turn_index": 0, "tab_index": 0, "sub_turn_index": null}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`placement` is metadata about where in the conversation the citation appeared and can be ignored for most use cases.
|
||||
|
||||
### Specify an agent
|
||||
|
||||
```bash
|
||||
onyx-cli ask --agent-id 5 "Summarize our Q4 roadmap"
|
||||
```
|
||||
|
||||
Uses a specific Onyx agent/persona instead of the default.
|
||||
|
||||
### All flags
|
||||
|
||||
| Flag | Type | Description |
|
||||
|------|------|-------------|
|
||||
| `--agent-id` | int | Agent ID to use (overrides default) |
|
||||
| `--json` | bool | Output raw NDJSON events instead of plain text |
|
||||
|
||||
## Statelessness
|
||||
|
||||
Each `onyx-cli ask` call creates an independent chat session. There is no built-in way to chain context across multiple `ask` invocations — every call starts fresh. If you need multi-turn conversation with memory, use the interactive TUI (`onyx-cli` or `onyx-cli chat`) instead.
|
||||
|
||||
## When to Use
|
||||
|
||||
Use `onyx-cli ask` when:
|
||||
|
||||
- The user asks about company-specific information (policies, docs, processes)
|
||||
- You need to search internal knowledge bases or connected data sources
|
||||
- The user references Onyx, asks you to "search Onyx", or wants to query their documents
|
||||
- You need context from company wikis, Confluence, Google Drive, Slack, or other connected sources
|
||||
|
||||
Do NOT use when:
|
||||
|
||||
- The question is about general programming knowledge (use your own knowledge)
|
||||
- The user is asking about code in the current repository (use grep/read tools)
|
||||
- The user hasn't mentioned Onyx and the question doesn't require internal company data
|
||||
|
||||
## Examples
|
||||
|
||||
```bash
|
||||
# Simple question
|
||||
onyx-cli ask "What are the steps to deploy to production?"
|
||||
|
||||
# Get structured output for parsing
|
||||
onyx-cli ask --json "List all active API integrations"
|
||||
|
||||
# Use a specialized agent
|
||||
onyx-cli ask --agent-id 3 "What were the action items from last week's standup?"
|
||||
|
||||
# Pipe the answer into another command
|
||||
onyx-cli ask "What is the database schema for users?" | head -20
|
||||
```
|
||||
96
.github/workflows/deployment.yml
vendored
96
.github/workflows/deployment.yml
vendored
@@ -182,52 +182,8 @@ jobs:
|
||||
title: "🚨 Version Tag Check Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
|
||||
# Create GitHub release first, before desktop builds start.
|
||||
# This ensures all desktop matrix jobs upload to the same release instead of
|
||||
# racing to create duplicate releases.
|
||||
create-release:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-desktop == 'true'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 10
|
||||
permissions:
|
||||
contents: write
|
||||
outputs:
|
||||
release-id: ${{ steps.create-release.outputs.id }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Determine release tag
|
||||
id: release-tag
|
||||
env:
|
||||
IS_TEST_RUN: ${{ needs.determine-builds.outputs.is-test-run }}
|
||||
SHORT_SHA: ${{ needs.determine-builds.outputs.short-sha }}
|
||||
run: |
|
||||
if [ "${IS_TEST_RUN}" == "true" ]; then
|
||||
echo "tag=v0.0.0-dev+${SHORT_SHA}" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "tag=${GITHUB_REF_NAME}" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: Create GitHub Release
|
||||
id: create-release
|
||||
uses: softprops/action-gh-release@da05d552573ad5aba039eaac05058a918a7bf631 # ratchet:softprops/action-gh-release@v2
|
||||
with:
|
||||
tag_name: ${{ steps.release-tag.outputs.tag }}
|
||||
name: ${{ steps.release-tag.outputs.tag }}
|
||||
body: "See the assets to download this version and install."
|
||||
draft: true
|
||||
prerelease: false
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
build-desktop:
|
||||
needs:
|
||||
- determine-builds
|
||||
- create-release
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-desktop == 'true'
|
||||
permissions:
|
||||
id-token: write
|
||||
@@ -252,12 +208,12 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6.0.2
|
||||
with:
|
||||
# NOTE: persist-credentials is needed for tauri-action to upload assets to GitHub releases.
|
||||
# NOTE: persist-credentials is needed for tauri-action to create GitHub releases.
|
||||
persist-credentials: true # zizmor: ignore[artipacked]
|
||||
|
||||
- name: Configure AWS credentials
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -397,9 +353,11 @@ jobs:
|
||||
APPLE_SIGNING_IDENTITY: ${{ env.CERT_ID }}
|
||||
APPLE_TEAM_ID: ${{ env.APPLE_TEAM_ID }}
|
||||
with:
|
||||
# Use the release created by the create-release job to avoid race conditions
|
||||
# when multiple matrix jobs try to create/update the same release simultaneously
|
||||
releaseId: ${{ needs.create-release.outputs.release-id }}
|
||||
tagName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseBody: "See the assets to download this version and install."
|
||||
releaseDraft: true
|
||||
prerelease: false
|
||||
assetNamePattern: "[name]_[arch][ext]"
|
||||
args: ${{ matrix.args }}
|
||||
|
||||
@@ -426,7 +384,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -500,7 +458,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -569,7 +527,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -639,7 +597,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -721,7 +679,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -798,7 +756,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -865,7 +823,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -938,7 +896,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1006,7 +964,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1076,7 +1034,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1149,7 +1107,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1218,7 +1176,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1288,7 +1246,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1368,7 +1326,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1442,7 +1400,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1507,7 +1465,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1562,7 +1520,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1622,7 +1580,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1679,7 +1637,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
@@ -15,8 +15,6 @@ permissions:
|
||||
jobs:
|
||||
provider-chat-test:
|
||||
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
|
||||
secrets:
|
||||
AWS_OIDC_ROLE_ARN: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
|
||||
@@ -6,13 +6,11 @@ on:
|
||||
- main
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
cherry-pick-to-latest-release:
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
outputs:
|
||||
should_cherrypick: ${{ steps.gate.outputs.should_cherrypick }}
|
||||
pr_number: ${{ steps.gate.outputs.pr_number }}
|
||||
|
||||
2
.github/workflows/pr-desktop-build.yml
vendored
2
.github/workflows/pr-desktop-build.yml
vendored
@@ -57,7 +57,7 @@ jobs:
|
||||
cache-dependency-path: ./desktop/package-lock.json
|
||||
|
||||
- name: Setup Rust
|
||||
uses: dtolnay/rust-toolchain@efa25f7f19611383d5b0ccf2d1c8914531636bf9
|
||||
uses: dtolnay/rust-toolchain@4be9e76fd7c4901c61fb841f559994984270fce7
|
||||
with:
|
||||
toolchain: stable
|
||||
targets: ${{ matrix.target }}
|
||||
|
||||
56
.github/workflows/pr-golang-tests.yml
vendored
56
.github/workflows/pr-golang-tests.yml
vendored
@@ -1,56 +0,0 @@
|
||||
name: Golang Tests
|
||||
concurrency:
|
||||
group: Golang-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- "release/**"
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
permissions: {}
|
||||
|
||||
env:
|
||||
GO_VERSION: "1.26"
|
||||
|
||||
jobs:
|
||||
detect-modules:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
outputs:
|
||||
modules: ${{ steps.set-modules.outputs.modules }}
|
||||
steps:
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8
|
||||
with:
|
||||
persist-credentials: false
|
||||
- id: set-modules
|
||||
run: echo "modules=$(find . -name 'go.mod' -exec dirname {} \; | jq -Rc '[.,inputs]')" >> "$GITHUB_OUTPUT"
|
||||
|
||||
golang:
|
||||
needs: detect-modules
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
strategy:
|
||||
matrix:
|
||||
modules: ${{ fromJSON(needs.detect-modules.outputs.modules) }}
|
||||
steps:
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # zizmor: ignore[cache-poisoning]
|
||||
with:
|
||||
go-version: ${{ env.GO_VERSION }}
|
||||
cache-dependency-path: "**/go.sum"
|
||||
|
||||
- run: go mod tidy
|
||||
working-directory: ${{ matrix.modules }}
|
||||
- run: git diff --exit-code go.mod go.sum
|
||||
working-directory: ${{ matrix.modules }}
|
||||
|
||||
- run: go test ./...
|
||||
working-directory: ${{ matrix.modules }}
|
||||
2
.github/workflows/pr-helm-chart-testing.yml
vendored
2
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -71,7 +71,7 @@ jobs:
|
||||
|
||||
- name: Create kind cluster
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@ef37e7f390d99f746eb8b610417061a60e82a6cc # ratchet:helm/kind-action@v1.14.0
|
||||
uses: helm/kind-action@92086f6be054225fa813e0a4b13787fc9088faab # ratchet:helm/kind-action@v1.13.0
|
||||
|
||||
- name: Pre-install cluster status check
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
|
||||
52
.github/workflows/pr-integration-tests.yml
vendored
52
.github/workflows/pr-integration-tests.yml
vendored
@@ -335,6 +335,7 @@ jobs:
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
|
||||
USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
|
||||
EOF
|
||||
fi
|
||||
|
||||
@@ -470,13 +471,13 @@ jobs:
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
|
||||
onyx-lite-tests:
|
||||
no-vectordb-tests:
|
||||
needs: [build-backend-image, build-integration-image]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=4cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-onyx-lite-tests",
|
||||
"run-id=${{ github.run_id }}-no-vectordb-tests",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
@@ -494,12 +495,13 @@ jobs:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Create .env file for Onyx Lite Docker Compose
|
||||
- name: Create .env file for no-vectordb Docker Compose
|
||||
env:
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
@@ -507,23 +509,28 @@ jobs:
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
DISABLE_VECTOR_DB=true
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
USE_LIGHTWEIGHT_BACKGROUND_WORKER=true
|
||||
EOF
|
||||
|
||||
# Start only the services needed for Onyx Lite (Postgres + API server)
|
||||
- name: Start Docker containers (onyx-lite)
|
||||
# Start only the services needed for no-vectordb mode (no Vespa, no model servers)
|
||||
- name: Start Docker containers (no-vectordb)
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml up \
|
||||
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml up \
|
||||
relational_db \
|
||||
cache \
|
||||
minio \
|
||||
api_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker_onyx_lite
|
||||
id: start_docker_no_vectordb
|
||||
|
||||
- name: Wait for services to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script (onyx-lite)..."
|
||||
echo "Starting wait-for-service script (no-vectordb)..."
|
||||
start_time=$(date +%s)
|
||||
timeout=300
|
||||
while true; do
|
||||
@@ -545,14 +552,14 @@ jobs:
|
||||
sleep 5
|
||||
done
|
||||
|
||||
- name: Run Onyx Lite Integration Tests
|
||||
- name: Run No-VectorDB Integration Tests
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
timeout_minutes: 20
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
echo "Running onyx-lite integration tests..."
|
||||
echo "Running no-vectordb integration tests..."
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
@@ -563,38 +570,39 @@ jobs:
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
|
||||
/app/tests/integration/tests/no_vectordb
|
||||
|
||||
- name: Dump API server logs (onyx-lite)
|
||||
- name: Dump API server logs (no-vectordb)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml \
|
||||
logs --no-color api_server > $GITHUB_WORKSPACE/api_server_onyx_lite.log || true
|
||||
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml \
|
||||
logs --no-color api_server > $GITHUB_WORKSPACE/api_server_no_vectordb.log || true
|
||||
|
||||
- name: Dump all-container logs (onyx-lite)
|
||||
- name: Dump all-container logs (no-vectordb)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml \
|
||||
logs --no-color > $GITHUB_WORKSPACE/docker-compose-onyx-lite.log || true
|
||||
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml \
|
||||
logs --no-color > $GITHUB_WORKSPACE/docker-compose-no-vectordb.log || true
|
||||
|
||||
- name: Upload logs (onyx-lite)
|
||||
- name: Upload logs (no-vectordb)
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-onyx-lite
|
||||
path: ${{ github.workspace }}/docker-compose-onyx-lite.log
|
||||
name: docker-all-logs-no-vectordb
|
||||
path: ${{ github.workspace }}/docker-compose-no-vectordb.log
|
||||
|
||||
- name: Stop Docker containers (onyx-lite)
|
||||
- name: Stop Docker containers (no-vectordb)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml down -v
|
||||
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml down -v
|
||||
|
||||
multitenant-tests:
|
||||
needs:
|
||||
@@ -736,7 +744,7 @@ jobs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
needs: [integration-tests, onyx-lite-tests, multitenant-tests]
|
||||
needs: [integration-tests, no-vectordb-tests, multitenant-tests]
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- name: Check job status
|
||||
|
||||
2
.github/workflows/pr-jest-tests.yml
vendored
2
.github/workflows/pr-jest-tests.yml
vendored
@@ -31,7 +31,7 @@ jobs:
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm" # zizmor: ignore[cache-poisoning] test-only workflow; no deploy artifacts
|
||||
cache: "npm"
|
||||
cache-dependency-path: ./web/package-lock.json
|
||||
|
||||
- name: Install node dependencies
|
||||
|
||||
110
.github/workflows/pr-playwright-tests.yml
vendored
110
.github/workflows/pr-playwright-tests.yml
vendored
@@ -268,11 +268,10 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm" # zizmor: ignore[cache-poisoning]
|
||||
cache: "npm"
|
||||
cache-dependency-path: ./web/package-lock.json
|
||||
|
||||
- name: Install node dependencies
|
||||
@@ -280,7 +279,6 @@ jobs:
|
||||
run: npm ci
|
||||
|
||||
- name: Cache playwright cache
|
||||
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: ~/.cache/ms-playwright
|
||||
@@ -461,7 +459,7 @@ jobs:
|
||||
# --- Visual Regression Diff ---
|
||||
- name: Configure AWS credentials
|
||||
if: always()
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -592,108 +590,6 @@ jobs:
|
||||
name: docker-logs-${{ matrix.project }}-${{ github.run_id }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
playwright-tests-lite:
|
||||
needs: [build-web-image, build-backend-image]
|
||||
name: Playwright Tests (lite)
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- "run-id=${{ github.run_id }}-playwright-tests-lite"
|
||||
- "extras=ecr-cache"
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm" # zizmor: ignore[cache-poisoning]
|
||||
cache-dependency-path: ./web/package-lock.json
|
||||
|
||||
- name: Install node dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
- name: Cache playwright cache
|
||||
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: ~/.cache/ms-playwright
|
||||
key: ${{ runner.os }}-playwright-npm-${{ hashFiles('web/package-lock.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-playwright-npm-
|
||||
|
||||
- name: Install playwright browsers
|
||||
working-directory: ./web
|
||||
run: npx playwright install --with-deps
|
||||
|
||||
- name: Create .env file for Docker Compose
|
||||
env:
|
||||
OPENAI_API_KEY_VALUE: ${{ env.OPENAI_API_KEY }}
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
GEN_AI_API_KEY=${OPENAI_API_KEY_VALUE}
|
||||
MOCK_LLM_RESPONSE=true
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:playwright-test-backend-${RUN_ID}
|
||||
ONYX_WEB_SERVER_IMAGE=${ECR_CACHE}:playwright-test-web-${RUN_ID}
|
||||
EOF
|
||||
|
||||
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Start Docker containers (lite)
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml up -d
|
||||
id: start_docker
|
||||
|
||||
- name: Run Playwright tests (lite)
|
||||
working-directory: ./web
|
||||
run: npx playwright test --project lite
|
||||
|
||||
- uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
if: always()
|
||||
with:
|
||||
name: playwright-test-results-lite-${{ github.run_id }}
|
||||
path: ./web/output/playwright/
|
||||
retention-days: 30
|
||||
|
||||
- name: Save Docker logs
|
||||
if: success() || failure()
|
||||
env:
|
||||
WORKSPACE: ${{ github.workspace }}
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose logs > docker-compose.log
|
||||
mv docker-compose.log ${WORKSPACE}/docker-compose.log
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-logs-lite-${{ github.run_id }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
# Post a single combined visual regression comment after all matrix jobs finish
|
||||
visual-regression-comment:
|
||||
needs: [playwright-tests]
|
||||
@@ -790,7 +686,7 @@ jobs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
needs: [playwright-tests, playwright-tests-lite]
|
||||
needs: [playwright-tests]
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- name: Check job status
|
||||
|
||||
4
.github/workflows/pr-quality-checks.yml
vendored
4
.github/workflows/pr-quality-checks.yml
vendored
@@ -38,9 +38,9 @@ jobs:
|
||||
- name: Install node dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
- uses: j178/prek-action@0bb87d7f00b0c99306c8bcb8b8beba1eb581c037 # ratchet:j178/prek-action@v1
|
||||
- uses: j178/prek-action@9d6a3097e0c1865ecce00cfb89fe80f2ee91b547 # ratchet:j178/prek-action@v1
|
||||
with:
|
||||
prek-version: '0.3.4'
|
||||
prek-version: '0.2.21'
|
||||
extra-args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || github.event_name == 'merge_group' && format('--from-ref {0} --to-ref {1}', github.event.merge_group.base_sha, github.event.merge_group.head_sha) || github.ref_name == 'main' && '--all-files' || '' }}
|
||||
- name: Check Actions
|
||||
uses: giner/check-actions@28d366c7cbbe235f9624a88aa31a628167eee28c # ratchet:giner/check-actions@v1.0.1
|
||||
|
||||
39
.github/workflows/release-cli.yml
vendored
39
.github/workflows/release-cli.yml
vendored
@@ -1,39 +0,0 @@
|
||||
name: Release CLI
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "cli/v*.*.*"
|
||||
|
||||
jobs:
|
||||
pypi:
|
||||
runs-on: ubuntu-latest
|
||||
environment:
|
||||
name: release-cli
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 10
|
||||
strategy:
|
||||
matrix:
|
||||
os-arch:
|
||||
- { goos: "linux", goarch: "amd64" }
|
||||
- { goos: "linux", goarch: "arm64" }
|
||||
- { goos: "windows", goarch: "amd64" }
|
||||
- { goos: "windows", goarch: "arm64" }
|
||||
- { goos: "darwin", goarch: "amd64" }
|
||||
- { goos: "darwin", goarch: "arm64" }
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
- run: |
|
||||
GOOS="${{ matrix.os-arch.goos }}" \
|
||||
GOARCH="${{ matrix.os-arch.goarch }}" \
|
||||
uv build --wheel
|
||||
working-directory: cli
|
||||
- run: uv publish
|
||||
working-directory: cli
|
||||
2
.github/workflows/release-devtools.yml
vendored
2
.github/workflows/release-devtools.yml
vendored
@@ -22,10 +22,12 @@ jobs:
|
||||
- { goos: "windows", goarch: "arm64" }
|
||||
- { goos: "darwin", goarch: "amd64" }
|
||||
- { goos: "darwin", goarch: "arm64" }
|
||||
- { goos: "", goarch: "" }
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
- uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
|
||||
@@ -48,10 +48,6 @@ on:
|
||||
required: false
|
||||
default: true
|
||||
type: boolean
|
||||
secrets:
|
||||
AWS_OIDC_ROLE_ARN:
|
||||
description: "AWS role ARN for OIDC auth"
|
||||
required: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -77,7 +73,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -120,7 +116,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -162,7 +158,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -268,7 +264,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
6
.github/workflows/sandbox-deployment.yml
vendored
6
.github/workflows/sandbox-deployment.yml
vendored
@@ -110,7 +110,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -180,7 +180,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -244,7 +244,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
@@ -119,11 +119,10 @@ repos:
|
||||
]
|
||||
|
||||
- repo: https://github.com/golangci/golangci-lint
|
||||
rev: 5d1e709b7be35cb2025444e19de266b056b7b7ee # frozen: v2.10.1
|
||||
rev: 9f61b0f53f80672872fced07b6874397c3ed197b # frozen: v2.7.2
|
||||
hooks:
|
||||
- id: golangci-lint
|
||||
language_version: "1.26.0"
|
||||
entry: bash -c "find . -name go.mod -not -path './.venv/*' -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
|
||||
entry: bash -c "find tools/ -name go.mod -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
|
||||
58
.vscode/launch.json
vendored
58
.vscode/launch.json
vendored
@@ -40,7 +40,19 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Celery",
|
||||
"name": "Celery (lightweight mode)",
|
||||
"configurations": [
|
||||
"Celery primary",
|
||||
"Celery background",
|
||||
"Celery beat"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
},
|
||||
"stopAll": true
|
||||
},
|
||||
{
|
||||
"name": "Celery (standard mode)",
|
||||
"configurations": [
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
@@ -241,6 +253,35 @@
|
||||
},
|
||||
"consoleTitle": "Celery light Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery background",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.background",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=20",
|
||||
"--prefetch-multiplier=4",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=background@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery background Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery heavy",
|
||||
"type": "debugpy",
|
||||
@@ -485,6 +526,21 @@
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clear and Restart OpenSearch Container",
|
||||
// Generic debugger type, required arg but has no bearing on bash.
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": [
|
||||
"${workspaceFolder}/backend/scripts/restart_opensearch_container.sh"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Eval CLI",
|
||||
"type": "debugpy",
|
||||
|
||||
74
AGENTS.md
74
AGENTS.md
@@ -86,6 +86,37 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
|
||||
- Monitoring tasks (every 5 minutes)
|
||||
- Cleanup tasks (hourly)
|
||||
|
||||
#### Worker Deployment Modes
|
||||
|
||||
Onyx supports two deployment modes for background workers, controlled by the `USE_LIGHTWEIGHT_BACKGROUND_WORKER` environment variable:
|
||||
|
||||
**Lightweight Mode** (default, `USE_LIGHTWEIGHT_BACKGROUND_WORKER=true`):
|
||||
|
||||
- Runs a single consolidated `background` worker that handles all background tasks:
|
||||
- Light worker tasks (Vespa operations, permissions sync, deletion)
|
||||
- Document processing (indexing pipeline)
|
||||
- Document fetching (connector data retrieval)
|
||||
- Pruning operations (from `heavy` worker)
|
||||
- Knowledge graph processing (from `kg_processing` worker)
|
||||
- Monitoring tasks (from `monitoring` worker)
|
||||
- User file processing (from `user_file_processing` worker)
|
||||
- Lower resource footprint (fewer worker processes)
|
||||
- Suitable for smaller deployments or development environments
|
||||
- Default concurrency: 20 threads (increased to handle combined workload)
|
||||
|
||||
**Standard Mode** (`USE_LIGHTWEIGHT_BACKGROUND_WORKER=false`):
|
||||
|
||||
- Runs separate specialized workers as documented above (light, docprocessing, docfetching, heavy, kg_processing, monitoring, user_file_processing)
|
||||
- Better isolation and scalability
|
||||
- Can scale individual workers independently based on workload
|
||||
- Suitable for production deployments with higher load
|
||||
|
||||
The deployment mode affects:
|
||||
|
||||
- **Backend**: Worker processes spawned by supervisord or dev scripts
|
||||
- **Helm**: Which Kubernetes deployments are created
|
||||
- **Dev Environment**: Which workers `dev_run_background_jobs.py` spawns
|
||||
|
||||
#### Key Features
|
||||
|
||||
- **Thread-based Workers**: All workers use thread pools (not processes) for stability
|
||||
@@ -104,10 +135,6 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
|
||||
|
||||
- Always use `@shared_task` rather than `@celery_app`
|
||||
- Put tasks under `background/celery/tasks/` or `ee/background/celery/tasks`
|
||||
- Never enqueue a task without an expiration. Always supply `expires=` when
|
||||
sending tasks, either from the beat schedule or directly from another task. It
|
||||
should never be acceptable to submit code which enqueues tasks without an
|
||||
expiration, as doing so can lead to unbounded task queue growth.
|
||||
|
||||
**Defining APIs**:
|
||||
When creating new FastAPI APIs, do NOT use the `response_model` field. Instead, just type the
|
||||
@@ -590,45 +617,6 @@ Keep it high level. You can reference certain files or functions though.
|
||||
|
||||
Before writing your plan, make sure to do research. Explore the relevant sections in the codebase.
|
||||
|
||||
## Error Handling
|
||||
|
||||
**Always raise `OnyxError` from `onyx.error_handling.exceptions` instead of `HTTPException`.
|
||||
Never hardcode status codes or use `starlette.status` / `fastapi.status` constants directly.**
|
||||
|
||||
A global FastAPI exception handler converts `OnyxError` into a JSON response with the standard
|
||||
`{"error_code": "...", "message": "..."}` shape. This eliminates boilerplate and keeps error
|
||||
handling consistent across the entire backend.
|
||||
|
||||
```python
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
# ✅ Good
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Session not found")
|
||||
|
||||
# ✅ Good — no extra message needed
|
||||
raise OnyxError(OnyxErrorCode.UNAUTHENTICATED)
|
||||
|
||||
# ✅ Good — upstream service with dynamic status code
|
||||
raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)
|
||||
|
||||
# ❌ Bad — using HTTPException directly
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
# ❌ Bad — starlette constant
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
|
||||
```
|
||||
|
||||
Available error codes are defined in `backend/onyx/error_handling/error_codes.py`. If a new error
|
||||
category is needed, add it there first — do not invent ad-hoc codes.
|
||||
|
||||
**Upstream service errors:** When forwarding errors from an upstream service where the HTTP
|
||||
status code is dynamic (comes from the upstream response), use `status_code_override`:
|
||||
|
||||
```python
|
||||
raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=e.response.status_code)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
In addition to the other content in this file, best practices for contributing
|
||||
|
||||
15
backend/ee/onyx/background/celery/apps/background.py
Normal file
15
backend/ee/onyx/background/celery/apps/background.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from onyx.background.celery.apps import app_base
|
||||
from onyx.background.celery.apps.background import celery_app
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cleanup",
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning",
|
||||
"ee.onyx.background.celery.tasks.query_history",
|
||||
]
|
||||
)
|
||||
)
|
||||
@@ -11,10 +11,11 @@ from ee.onyx.server.license.models import LicenseMetadata
|
||||
from ee.onyx.server.license.models import LicensePayload
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
|
||||
from onyx.db.models import License
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
@@ -141,7 +142,7 @@ def get_used_seats(tenant_id: str | None = None) -> int:
|
||||
|
||||
def get_cached_license_metadata(tenant_id: str | None = None) -> LicenseMetadata | None:
|
||||
"""
|
||||
Get license metadata from cache.
|
||||
Get license metadata from Redis cache.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
@@ -149,34 +150,38 @@ def get_cached_license_metadata(tenant_id: str | None = None) -> LicenseMetadata
|
||||
Returns:
|
||||
LicenseMetadata if cached, None otherwise
|
||||
"""
|
||||
cache = get_cache_backend(tenant_id=tenant_id)
|
||||
cached = cache.get(LICENSE_METADATA_KEY)
|
||||
if not cached:
|
||||
return None
|
||||
tenant = tenant_id or get_current_tenant_id()
|
||||
redis_client = get_redis_replica_client(tenant_id=tenant)
|
||||
|
||||
try:
|
||||
cached_str = (
|
||||
cached.decode("utf-8") if isinstance(cached, bytes) else str(cached)
|
||||
)
|
||||
return LicenseMetadata.model_validate_json(cached_str)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse cached license metadata: {e}")
|
||||
return None
|
||||
cached = redis_client.get(LICENSE_METADATA_KEY)
|
||||
if cached:
|
||||
try:
|
||||
cached_str: str
|
||||
if isinstance(cached, bytes):
|
||||
cached_str = cached.decode("utf-8")
|
||||
else:
|
||||
cached_str = str(cached)
|
||||
return LicenseMetadata.model_validate_json(cached_str)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse cached license metadata: {e}")
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def invalidate_license_cache(tenant_id: str | None = None) -> None:
|
||||
"""
|
||||
Invalidate the license metadata cache (not the license itself).
|
||||
|
||||
Deletes the cached LicenseMetadata. The actual license in the database
|
||||
is not affected. Delete is idempotent — if the key doesn't exist, this
|
||||
is a no-op.
|
||||
This deletes the cached LicenseMetadata from Redis. The actual license
|
||||
in the database is not affected. Redis delete is idempotent - if the
|
||||
key doesn't exist, this is a no-op.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
"""
|
||||
cache = get_cache_backend(tenant_id=tenant_id)
|
||||
cache.delete(LICENSE_METADATA_KEY)
|
||||
tenant = tenant_id or get_current_tenant_id()
|
||||
redis_client = get_redis_client(tenant_id=tenant)
|
||||
redis_client.delete(LICENSE_METADATA_KEY)
|
||||
logger.info("License cache invalidated")
|
||||
|
||||
|
||||
@@ -187,7 +192,7 @@ def update_license_cache(
|
||||
tenant_id: str | None = None,
|
||||
) -> LicenseMetadata:
|
||||
"""
|
||||
Update the cache with license metadata.
|
||||
Update the Redis cache with license metadata.
|
||||
|
||||
We cache all license statuses (ACTIVE, GRACE_PERIOD, GATED_ACCESS) because:
|
||||
1. Frontend needs status to show appropriate UI/banners
|
||||
@@ -206,7 +211,7 @@ def update_license_cache(
|
||||
from ee.onyx.utils.license import get_license_status
|
||||
|
||||
tenant = tenant_id or get_current_tenant_id()
|
||||
cache = get_cache_backend(tenant_id=tenant_id)
|
||||
redis_client = get_redis_client(tenant_id=tenant)
|
||||
|
||||
used_seats = get_used_seats(tenant)
|
||||
status = get_license_status(payload, grace_period_end)
|
||||
@@ -225,7 +230,7 @@ def update_license_cache(
|
||||
stripe_subscription_id=payload.stripe_subscription_id,
|
||||
)
|
||||
|
||||
cache.set(
|
||||
redis_client.set(
|
||||
LICENSE_METADATA_KEY,
|
||||
metadata.model_dump_json(),
|
||||
ex=LICENSE_CACHE_TTL_SECONDS,
|
||||
|
||||
@@ -15,7 +15,6 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.server.user_group.models import SetCuratorRequest
|
||||
from ee.onyx.server.user_group.models import UserGroupCreate
|
||||
from ee.onyx.server.user_group.models import UserGroupUpdate
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
@@ -472,9 +471,7 @@ def _add_user_group__cc_pair_relationships__no_commit(
|
||||
|
||||
def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserGroup:
|
||||
db_user_group = UserGroup(
|
||||
name=user_group.name,
|
||||
time_last_modified_by_user=func.now(),
|
||||
is_up_to_date=DISABLE_VECTOR_DB,
|
||||
name=user_group.name, time_last_modified_by_user=func.now()
|
||||
)
|
||||
db_session.add(db_user_group)
|
||||
db_session.flush() # give the group an ID
|
||||
@@ -777,7 +774,8 @@ def update_user_group(
|
||||
cc_pair_ids=user_group_update.cc_pair_ids,
|
||||
)
|
||||
|
||||
if cc_pairs_updated and not DISABLE_VECTOR_DB:
|
||||
# only needs to sync with Vespa if the cc_pairs have been updated
|
||||
if cc_pairs_updated:
|
||||
db_user_group.is_up_to_date = False
|
||||
|
||||
removed_users = db_session.scalars(
|
||||
|
||||
@@ -4,6 +4,7 @@ from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.server.analytics.api import router as analytics_router
|
||||
from ee.onyx.server.auth_check import check_ee_router_auth
|
||||
from ee.onyx.server.billing.api import router as billing_router
|
||||
@@ -152,9 +153,12 @@ def get_application() -> FastAPI:
|
||||
# License management
|
||||
include_router_with_global_prefix_prepended(application, license_router)
|
||||
|
||||
# Unified billing API - always registered in EE.
|
||||
# Each endpoint is protected by the `current_admin_user` dependency (admin auth).
|
||||
include_router_with_global_prefix_prepended(application, billing_router)
|
||||
# Unified billing API - available when license system is enabled
|
||||
# Works for both self-hosted and cloud deployments
|
||||
# TODO(ENG-3533): Once frontend migrates to /admin/billing/*, this becomes the
|
||||
# primary billing API and /tenants/* billing endpoints can be removed
|
||||
if LICENSE_ENFORCEMENT_ENABLED:
|
||||
include_router_with_global_prefix_prepended(application, billing_router)
|
||||
|
||||
if MULTI_TENANT:
|
||||
# Tenant management
|
||||
|
||||
@@ -26,6 +26,7 @@ import asyncio
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -41,6 +42,7 @@ from ee.onyx.server.billing.models import SeatUpdateRequest
|
||||
from ee.onyx.server.billing.models import SeatUpdateResponse
|
||||
from ee.onyx.server.billing.models import StripePublishableKeyResponse
|
||||
from ee.onyx.server.billing.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.billing.service import BillingServiceError
|
||||
from ee.onyx.server.billing.service import (
|
||||
create_checkout_session as create_checkout_service,
|
||||
)
|
||||
@@ -56,8 +58,6 @@ from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.redis.redis_pool import get_shared_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -169,23 +169,26 @@ async def create_checkout_session(
|
||||
if seats is not None:
|
||||
used_seats = get_used_seats(tenant_id)
|
||||
if seats < used_seats:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Cannot subscribe with fewer seats than current usage. "
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot subscribe with fewer seats than current usage. "
|
||||
f"You have {used_seats} active users/integrations but requested {seats} seats.",
|
||||
)
|
||||
|
||||
# Build redirect URL for after checkout completion
|
||||
redirect_url = f"{WEB_DOMAIN}/admin/billing?checkout=success"
|
||||
|
||||
return await create_checkout_service(
|
||||
billing_period=billing_period,
|
||||
seats=seats,
|
||||
email=email,
|
||||
license_data=license_data,
|
||||
redirect_url=redirect_url,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
try:
|
||||
return await create_checkout_service(
|
||||
billing_period=billing_period,
|
||||
seats=seats,
|
||||
email=email,
|
||||
license_data=license_data,
|
||||
redirect_url=redirect_url,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
@@ -203,15 +206,18 @@ async def create_customer_portal_session(
|
||||
|
||||
# Self-hosted requires license
|
||||
if not MULTI_TENANT and not license_data:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "No license found")
|
||||
raise HTTPException(status_code=400, detail="No license found")
|
||||
|
||||
return_url = request.return_url if request else f"{WEB_DOMAIN}/admin/billing"
|
||||
|
||||
return await create_portal_service(
|
||||
license_data=license_data,
|
||||
return_url=return_url,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
try:
|
||||
return await create_portal_service(
|
||||
license_data=license_data,
|
||||
return_url=return_url,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
|
||||
|
||||
@router.get("/billing-information")
|
||||
@@ -234,9 +240,9 @@ async def get_billing_information(
|
||||
|
||||
# Check circuit breaker (self-hosted only)
|
||||
if _is_billing_circuit_open():
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.SERVICE_UNAVAILABLE,
|
||||
"Stripe connection temporarily disabled. Click 'Connect to Stripe' to retry.",
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Stripe connection temporarily disabled. Click 'Connect to Stripe' to retry.",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -244,15 +250,11 @@ async def get_billing_information(
|
||||
license_data=license_data,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except OnyxError as e:
|
||||
except BillingServiceError as e:
|
||||
# Open circuit breaker on connection failures (self-hosted only)
|
||||
if e.status_code in (
|
||||
OnyxErrorCode.BAD_GATEWAY.status_code,
|
||||
OnyxErrorCode.SERVICE_UNAVAILABLE.status_code,
|
||||
OnyxErrorCode.GATEWAY_TIMEOUT.status_code,
|
||||
):
|
||||
if e.status_code in (502, 503, 504):
|
||||
_open_billing_circuit()
|
||||
raise
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
|
||||
|
||||
@router.post("/seats/update")
|
||||
@@ -272,25 +274,31 @@ async def update_seats(
|
||||
|
||||
# Self-hosted requires license
|
||||
if not MULTI_TENANT and not license_data:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "No license found")
|
||||
raise HTTPException(status_code=400, detail="No license found")
|
||||
|
||||
# Validate that new seat count is not less than current used seats
|
||||
used_seats = get_used_seats(tenant_id)
|
||||
if request.new_seat_count < used_seats:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Cannot reduce seats below current usage. "
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot reduce seats below current usage. "
|
||||
f"You have {used_seats} active users/integrations but requested {request.new_seat_count} seats.",
|
||||
)
|
||||
|
||||
# Note: Don't store license here - the control plane may still be processing
|
||||
# the subscription update. The frontend should call /license/claim after a
|
||||
# short delay to get the freshly generated license.
|
||||
return await update_seat_service(
|
||||
new_seat_count=request.new_seat_count,
|
||||
license_data=license_data,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
try:
|
||||
result = await update_seat_service(
|
||||
new_seat_count=request.new_seat_count,
|
||||
license_data=license_data,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# Note: Don't store license here - the control plane may still be processing
|
||||
# the subscription update. The frontend should call /license/claim after a
|
||||
# short delay to get the freshly generated license.
|
||||
|
||||
return result
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
|
||||
|
||||
@router.get("/stripe-publishable-key")
|
||||
@@ -321,18 +329,18 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
if STRIPE_PUBLISHABLE_KEY_OVERRIDE:
|
||||
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
|
||||
if not key.startswith("pk_"):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Invalid Stripe publishable key format",
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
)
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
|
||||
# Fall back to S3 bucket
|
||||
if not STRIPE_PUBLISHABLE_KEY_URL:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Stripe publishable key is not configured",
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Stripe publishable key is not configured",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -343,17 +351,17 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
|
||||
# Validate key format
|
||||
if not key.startswith("pk_"):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Invalid Stripe publishable key format",
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
)
|
||||
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
except httpx.HTTPError:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Failed to fetch Stripe publishable key",
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to fetch Stripe publishable key",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -22,8 +22,6 @@ from ee.onyx.server.billing.models import SeatUpdateResponse
|
||||
from ee.onyx.server.billing.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -33,6 +31,15 @@ logger = setup_logger()
|
||||
_REQUEST_TIMEOUT = 30.0
|
||||
|
||||
|
||||
class BillingServiceError(Exception):
|
||||
"""Exception raised for billing service errors."""
|
||||
|
||||
def __init__(self, message: str, status_code: int = 500):
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
def _get_proxy_headers(license_data: str | None) -> dict[str, str]:
|
||||
"""Build headers for proxy requests (self-hosted).
|
||||
|
||||
@@ -94,7 +101,7 @@ async def _make_billing_request(
|
||||
Response JSON as dict
|
||||
|
||||
Raises:
|
||||
OnyxError: If request fails
|
||||
BillingServiceError: If request fails
|
||||
"""
|
||||
|
||||
base_url = _get_base_url()
|
||||
@@ -121,17 +128,11 @@ async def _make_billing_request(
|
||||
except Exception:
|
||||
pass
|
||||
logger.error(f"{error_message}: {e.response.status_code} - {detail}")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
detail,
|
||||
status_code_override=e.response.status_code,
|
||||
)
|
||||
raise BillingServiceError(detail, e.response.status_code)
|
||||
|
||||
except httpx.RequestError:
|
||||
logger.exception("Failed to connect to billing service")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY, "Failed to connect to billing service"
|
||||
)
|
||||
raise BillingServiceError("Failed to connect to billing service", 502)
|
||||
|
||||
|
||||
async def create_checkout_session(
|
||||
|
||||
@@ -223,15 +223,6 @@ def get_active_scim_token(
|
||||
token = dal.get_active_token()
|
||||
if not token:
|
||||
raise HTTPException(status_code=404, detail="No active SCIM token")
|
||||
|
||||
# Derive the IdP domain from the first synced user as a heuristic.
|
||||
idp_domain: str | None = None
|
||||
mappings, _total = dal.list_user_mappings(start_index=1, count=1)
|
||||
if mappings:
|
||||
user = dal.get_user(mappings[0].user_id)
|
||||
if user and "@" in user.email:
|
||||
idp_domain = user.email.rsplit("@", 1)[1]
|
||||
|
||||
return ScimTokenResponse(
|
||||
id=token.id,
|
||||
name=token.name,
|
||||
@@ -239,7 +230,6 @@ def get_active_scim_token(
|
||||
is_active=token.is_active,
|
||||
created_at=token.created_at,
|
||||
last_used_at=token.last_used_at,
|
||||
idp_domain=idp_domain,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import requests
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import HTTPException
|
||||
from fastapi import UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -34,8 +35,6 @@ from ee.onyx.server.license.models import SeatUsageResponse
|
||||
from ee.onyx.utils.license import verify_license_signature
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -128,9 +127,9 @@ async def claim_license(
|
||||
2. Without session_id: Re-claim using existing license for auth
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"License claiming is only available for self-hosted deployments",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License claiming is only available for self-hosted deployments",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -147,16 +146,15 @@ async def claim_license(
|
||||
# Re-claim using existing license for auth
|
||||
metadata = get_license_metadata(db_session)
|
||||
if not metadata or not metadata.tenant_id:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No license found. Provide session_id after checkout.",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No license found. Provide session_id after checkout.",
|
||||
)
|
||||
|
||||
license_row = get_license(db_session)
|
||||
if not license_row or not license_row.license_data:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No license found in database",
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No license found in database"
|
||||
)
|
||||
|
||||
url = f"{CLOUD_DATA_PLANE_URL}/proxy/license/{metadata.tenant_id}"
|
||||
@@ -175,7 +173,7 @@ async def claim_license(
|
||||
license_data = data.get("license")
|
||||
|
||||
if not license_data:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "No license in response")
|
||||
raise HTTPException(status_code=404, detail="No license in response")
|
||||
|
||||
# Verify signature before persisting
|
||||
payload = verify_license_signature(license_data)
|
||||
@@ -201,14 +199,12 @@ async def claim_license(
|
||||
detail = error_data.get("detail", detail)
|
||||
except Exception:
|
||||
pass
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=status_code
|
||||
)
|
||||
raise HTTPException(status_code=status_code, detail=detail)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except requests.RequestException:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY, "Failed to connect to license server"
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Failed to connect to license server"
|
||||
)
|
||||
|
||||
|
||||
@@ -225,9 +221,9 @@ async def upload_license(
|
||||
The license file must be cryptographically signed by Onyx.
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"License upload is only available for self-hosted deployments",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License upload is only available for self-hosted deployments",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -238,14 +234,14 @@ async def upload_license(
|
||||
# Remove any stray whitespace/newlines from user input
|
||||
license_data = license_data.strip()
|
||||
except UnicodeDecodeError:
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, "Invalid license file format")
|
||||
raise HTTPException(status_code=400, detail="Invalid license file format")
|
||||
|
||||
# Verify cryptographic signature - this is the only validation needed
|
||||
# The license's tenant_id identifies the customer in control plane, not locally
|
||||
try:
|
||||
payload = verify_license_signature(license_data)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
# Persist to DB and update cache
|
||||
upsert_license(db_session, license_data)
|
||||
@@ -301,9 +297,9 @@ async def delete_license(
|
||||
Admin only - removes license from database and invalidates cache.
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"License deletion is only available for self-hosted deployments",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License deletion is only available for self-hosted deployments",
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -46,6 +46,7 @@ from fastapi import FastAPI
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from redis.exceptions import RedisError
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
@@ -55,7 +56,6 @@ from ee.onyx.configs.license_enforcement_config import (
|
||||
)
|
||||
from ee.onyx.db.license import get_cached_license_metadata
|
||||
from ee.onyx.db.license import refresh_license_cache
|
||||
from onyx.cache.interface import CACHE_TRANSIENT_ERRORS
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
@@ -164,9 +164,9 @@ def add_license_enforcement_middleware(
|
||||
"[license_enforcement] No license, allowing community features"
|
||||
)
|
||||
is_gated = False
|
||||
except CACHE_TRANSIENT_ERRORS as e:
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check license metadata: {e}")
|
||||
# Fail open - don't block users due to cache connectivity issues
|
||||
# Fail open - don't block users due to Redis connectivity issues
|
||||
is_gated = False
|
||||
|
||||
if is_gated:
|
||||
|
||||
@@ -365,7 +365,6 @@ class ScimTokenResponse(BaseModel):
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
last_used_at: datetime | None = None
|
||||
idp_domain: str | None = None
|
||||
|
||||
|
||||
class ScimTokenCreatedResponse(ScimTokenResponse):
|
||||
|
||||
@@ -6,7 +6,6 @@ from sqlalchemy.exc import SQLAlchemyError
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.db.license import get_cached_license_metadata
|
||||
from ee.onyx.db.license import refresh_license_cache
|
||||
from onyx.cache.interface import CACHE_TRANSIENT_ERRORS
|
||||
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
@@ -126,7 +125,7 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
# syncing) means indexed data may need protection.
|
||||
settings.application_status = _BLOCKING_STATUS
|
||||
settings.ee_features_enabled = False
|
||||
except CACHE_TRANSIENT_ERRORS as e:
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check license metadata for settings: {e}")
|
||||
# Fail closed - disable EE features if we can't verify license
|
||||
settings.ee_features_enabled = False
|
||||
|
||||
@@ -21,6 +21,7 @@ import asyncio
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.server.tenants.access import control_plane_dep
|
||||
@@ -42,8 +43,6 @@ from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
@@ -117,14 +116,9 @@ async def create_customer_portal_session(
|
||||
try:
|
||||
portal_url = fetch_customer_portal_session(tenant_id, return_url)
|
||||
return {"stripe_customer_portal_url": portal_url}
|
||||
except OnyxError:
|
||||
raise
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create customer portal session")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Failed to create customer portal session",
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/create-checkout-session")
|
||||
@@ -140,14 +134,9 @@ async def create_checkout_session(
|
||||
try:
|
||||
checkout_url = fetch_stripe_checkout_session(tenant_id, billing_period, seats)
|
||||
return {"stripe_checkout_url": checkout_url}
|
||||
except OnyxError:
|
||||
raise
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create checkout session")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Failed to create checkout session",
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/create-subscription-session")
|
||||
@@ -158,20 +147,15 @@ async def create_subscription_session(
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if not tenant_id:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Tenant ID not found")
|
||||
raise HTTPException(status_code=400, detail="Tenant ID not found")
|
||||
|
||||
billing_period = request.billing_period if request else "monthly"
|
||||
session_id = fetch_stripe_checkout_session(tenant_id, billing_period)
|
||||
return SubscriptionSessionResponse(sessionId=session_id)
|
||||
|
||||
except OnyxError:
|
||||
raise
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create subscription session")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Failed to create subscription session",
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/stripe-publishable-key")
|
||||
@@ -202,18 +186,18 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
if STRIPE_PUBLISHABLE_KEY_OVERRIDE:
|
||||
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
|
||||
if not key.startswith("pk_"):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Invalid Stripe publishable key format",
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
)
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
|
||||
# Fall back to S3 bucket
|
||||
if not STRIPE_PUBLISHABLE_KEY_URL:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Stripe publishable key is not configured",
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Stripe publishable key is not configured",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -224,15 +208,15 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
|
||||
# Validate key format
|
||||
if not key.startswith("pk_"):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Invalid Stripe publishable key format",
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
)
|
||||
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
except httpx.HTTPError:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Failed to fetch Stripe publishable key",
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to fetch Stripe publishable key",
|
||||
)
|
||||
|
||||
@@ -5,8 +5,6 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.user_group import add_users_to_user_group
|
||||
from ee.onyx.db.user_group import delete_user_group as db_delete_user_group
|
||||
from ee.onyx.db.user_group import fetch_user_group
|
||||
from ee.onyx.db.user_group import fetch_user_groups
|
||||
from ee.onyx.db.user_group import fetch_user_groups_for_user
|
||||
from ee.onyx.db.user_group import insert_user_group
|
||||
@@ -22,7 +20,6 @@ from ee.onyx.server.user_group.models import UserGroupUpdate
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
@@ -156,8 +153,3 @@ def delete_user_group(
|
||||
prepare_user_group_for_deletion(db_session, user_group_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
user_group = fetch_user_group(db_session, user_group_id)
|
||||
if user_group:
|
||||
db_delete_user_group(db_session, user_group)
|
||||
|
||||
@@ -120,6 +120,7 @@ from onyx.db.models import User
|
||||
from onyx.db.pat import fetch_user_for_pat
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -200,14 +201,13 @@ def user_needs_to_be_verified() -> bool:
|
||||
|
||||
|
||||
def anonymous_user_enabled(*, tenant_id: str | None = None) -> bool:
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
|
||||
cache = get_cache_backend(tenant_id=tenant_id)
|
||||
value = cache.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
value = redis_client.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
|
||||
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
assert isinstance(value, bytes)
|
||||
return int(value.decode("utf-8")) == 1
|
||||
|
||||
|
||||
|
||||
142
backend/onyx/background/celery/apps/background.py
Normal file
142
backend/onyx/background/celery/apps/background.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.apps.worker import Worker
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_process_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.background")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
EXTRA_CONCURRENCY = 8 # small extra fudge factor for connection limits
|
||||
|
||||
logger.info("worker_init signal received for consolidated background worker.")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME)
|
||||
pool_size = cast(int, sender.concurrency) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=pool_size, max_overflow=EXTRA_CONCURRENCY)
|
||||
|
||||
# Initialize Vespa httpx pool (needed for light worker tasks)
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
sender.concurrency + EXTRA_CONCURRENCY, # type: ignore
|
||||
ssl_cert=VESPA_CLOUD_CERT_PATH,
|
||||
ssl_key=VESPA_CLOUD_KEY_PATH,
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(sender.concurrency + EXTRA_CONCURRENCY) # type: ignore
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_process_init.connect
|
||||
def init_worker(**kwargs: Any) -> None: # noqa: ARG001
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
base_bootsteps = app_base.get_bootsteps()
|
||||
for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
# Original background worker tasks
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
# Light worker tasks
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"onyx.background.celery.tasks.opensearch_migration",
|
||||
# Docprocessing worker tasks
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
# Docfetching worker tasks
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
# Sandbox cleanup tasks (isolated in build feature)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
]
|
||||
)
|
||||
)
|
||||
@@ -39,13 +39,9 @@ CT = TypeVar("CT", bound=ConnectorCheckpoint)
|
||||
|
||||
|
||||
class SlimConnectorExtractionResult(BaseModel):
|
||||
"""Result of extracting document IDs and hierarchy nodes from a connector.
|
||||
"""Result of extracting document IDs and hierarchy nodes from a connector."""
|
||||
|
||||
raw_id_to_parent maps document ID → parent_hierarchy_raw_node_id (or None).
|
||||
Use raw_id_to_parent.keys() wherever the old set of IDs was needed.
|
||||
"""
|
||||
|
||||
raw_id_to_parent: dict[str, str | None]
|
||||
doc_ids: set[str]
|
||||
hierarchy_nodes: list[HierarchyNode]
|
||||
|
||||
|
||||
@@ -97,37 +93,30 @@ def _get_failure_id(failure: ConnectorFailure) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
class BatchResult(BaseModel):
|
||||
raw_id_to_parent: dict[str, str | None]
|
||||
hierarchy_nodes: list[HierarchyNode]
|
||||
|
||||
|
||||
def _extract_from_batch(
|
||||
doc_list: Sequence[Document | SlimDocument | HierarchyNode | ConnectorFailure],
|
||||
) -> BatchResult:
|
||||
"""Separate a batch into document IDs (with parent mapping) and hierarchy nodes.
|
||||
) -> tuple[set[str], list[HierarchyNode]]:
|
||||
"""Separate a batch into document IDs and hierarchy nodes.
|
||||
|
||||
ConnectorFailure items have their failed document/entity IDs added to the
|
||||
ID dict so that failed-to-retrieve documents are not accidentally pruned.
|
||||
ID set so that failed-to-retrieve documents are not accidentally pruned.
|
||||
"""
|
||||
ids: dict[str, str | None] = {}
|
||||
ids: set[str] = set()
|
||||
hierarchy_nodes: list[HierarchyNode] = []
|
||||
for item in doc_list:
|
||||
if isinstance(item, HierarchyNode):
|
||||
hierarchy_nodes.append(item)
|
||||
if item.raw_node_id not in ids:
|
||||
ids[item.raw_node_id] = None
|
||||
ids.add(item.raw_node_id)
|
||||
elif isinstance(item, ConnectorFailure):
|
||||
failed_id = _get_failure_id(item)
|
||||
if failed_id:
|
||||
ids[failed_id] = None
|
||||
ids.add(failed_id)
|
||||
logger.warning(
|
||||
f"Failed to retrieve document {failed_id}: " f"{item.failure_message}"
|
||||
)
|
||||
else:
|
||||
parent_raw = getattr(item, "parent_hierarchy_raw_node_id", None)
|
||||
ids[item.id] = parent_raw
|
||||
return BatchResult(raw_id_to_parent=ids, hierarchy_nodes=hierarchy_nodes)
|
||||
ids.add(item.id)
|
||||
return ids, hierarchy_nodes
|
||||
|
||||
|
||||
def extract_ids_from_runnable_connector(
|
||||
@@ -143,7 +132,7 @@ def extract_ids_from_runnable_connector(
|
||||
|
||||
Optionally, a callback can be passed to handle the length of each document batch.
|
||||
"""
|
||||
all_raw_id_to_parent: dict[str, str | None] = {}
|
||||
all_connector_doc_ids: set[str] = set()
|
||||
all_hierarchy_nodes: list[HierarchyNode] = []
|
||||
|
||||
# Sequence (covariant) lets all the specific list[...] iterator types unify here
|
||||
@@ -188,20 +177,15 @@ def extract_ids_from_runnable_connector(
|
||||
"extract_ids_from_runnable_connector: Stop signal detected"
|
||||
)
|
||||
|
||||
batch_result = _extract_from_batch(doc_list)
|
||||
batch_ids = batch_result.raw_id_to_parent
|
||||
batch_nodes = batch_result.hierarchy_nodes
|
||||
doc_batch_processing_func(batch_ids)
|
||||
for k, v in batch_ids.items():
|
||||
if v is not None or k not in all_raw_id_to_parent:
|
||||
all_raw_id_to_parent[k] = v
|
||||
batch_ids, batch_nodes = _extract_from_batch(doc_list)
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(batch_ids))
|
||||
all_hierarchy_nodes.extend(batch_nodes)
|
||||
|
||||
if callback:
|
||||
callback.progress("extract_ids_from_runnable_connector", len(batch_ids))
|
||||
|
||||
return SlimConnectorExtractionResult(
|
||||
raw_id_to_parent=all_raw_id_to_parent,
|
||||
doc_ids=all_connector_doc_ids,
|
||||
hierarchy_nodes=all_hierarchy_nodes,
|
||||
)
|
||||
|
||||
|
||||
23
backend/onyx/background/celery/configs/background.py
Normal file
23
backend/onyx/background/celery/configs/background.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import onyx.background.celery.configs.base as shared_config
|
||||
from onyx.configs.app_configs import CELERY_WORKER_BACKGROUND_CONCURRENCY
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
worker_concurrency = CELERY_WORKER_BACKGROUND_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
# Increased from 1 to 4 to handle fast light worker tasks more efficiently
|
||||
# This allows the worker to prefetch multiple tasks per thread
|
||||
worker_prefetch_multiplier = 4
|
||||
@@ -30,7 +30,6 @@ from onyx.background.celery.tasks.opensearch_migration.transformer import (
|
||||
transform_vespa_chunks_to_opensearch_chunks,
|
||||
)
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import VESPA_MIGRATION_REQUEST_TIMEOUT_S
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
@@ -48,7 +47,6 @@ from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchDocumentIndex,
|
||||
)
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@@ -148,12 +146,7 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
task_logger.error(err_str)
|
||||
return False
|
||||
|
||||
with (
|
||||
get_session_with_current_tenant() as db_session,
|
||||
get_vespa_http_client(
|
||||
timeout=VESPA_MIGRATION_REQUEST_TIMEOUT_S
|
||||
) as vespa_client,
|
||||
):
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT)
|
||||
@@ -168,7 +161,6 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
index_name=search_settings.index_name,
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=False,
|
||||
httpx_client=vespa_client,
|
||||
)
|
||||
|
||||
sanitized_doc_start_time = time.monotonic()
|
||||
|
||||
@@ -29,7 +29,6 @@ from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -48,8 +47,6 @@ from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.hierarchy import link_hierarchy_nodes_to_documents
|
||||
from onyx.db.hierarchy import update_document_parent_hierarchy_nodes
|
||||
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
@@ -60,8 +57,6 @@ from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload
|
||||
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
|
||||
from onyx.redis.redis_hierarchy import ensure_source_node_exists
|
||||
from onyx.redis.redis_hierarchy import get_node_id_from_raw_id
|
||||
from onyx.redis.redis_hierarchy import get_source_node_id_from_cache
|
||||
from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
@@ -118,38 +113,6 @@ class PruneCallback(IndexingCallbackBase):
|
||||
super().progress(tag, amount)
|
||||
|
||||
|
||||
def _resolve_and_update_document_parents(
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
source: DocumentSource,
|
||||
raw_id_to_parent: dict[str, str | None],
|
||||
) -> None:
|
||||
"""Resolve parent_hierarchy_raw_node_id → parent_hierarchy_node_id for
|
||||
each document and bulk-update the DB. Mirrors the resolution logic in
|
||||
run_docfetching.py."""
|
||||
source_node_id = get_source_node_id_from_cache(redis_client, db_session, source)
|
||||
|
||||
resolved: dict[str, int | None] = {}
|
||||
for doc_id, raw_parent_id in raw_id_to_parent.items():
|
||||
if raw_parent_id is None:
|
||||
continue
|
||||
node_id, found = get_node_id_from_raw_id(redis_client, source, raw_parent_id)
|
||||
resolved[doc_id] = node_id if found else source_node_id
|
||||
|
||||
if not resolved:
|
||||
return
|
||||
|
||||
update_document_parent_hierarchy_nodes(
|
||||
db_session=db_session,
|
||||
doc_parent_map=resolved,
|
||||
commit=True,
|
||||
)
|
||||
task_logger.info(
|
||||
f"Pruning: resolved and updated parent hierarchy for "
|
||||
f"{len(resolved)} documents (source={source.value})"
|
||||
)
|
||||
|
||||
|
||||
"""Jobs / utils for kicking off pruning tasks."""
|
||||
|
||||
|
||||
@@ -572,22 +535,22 @@ def connector_pruning_generator_task(
|
||||
extraction_result = extract_ids_from_runnable_connector(
|
||||
runnable_connector, callback
|
||||
)
|
||||
all_connector_doc_ids = extraction_result.raw_id_to_parent
|
||||
all_connector_doc_ids = extraction_result.doc_ids
|
||||
|
||||
# Process hierarchy nodes (same as docfetching):
|
||||
# upsert to Postgres and cache in Redis
|
||||
source = cc_pair.connector.source
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
if extraction_result.hierarchy_nodes:
|
||||
is_connector_public = cc_pair.access_type == AccessType.PUBLIC
|
||||
|
||||
ensure_source_node_exists(redis_client, db_session, source)
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
ensure_source_node_exists(
|
||||
redis_client, db_session, cc_pair.connector.source
|
||||
)
|
||||
|
||||
upserted_nodes = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=extraction_result.hierarchy_nodes,
|
||||
source=source,
|
||||
source=cc_pair.connector.source,
|
||||
commit=True,
|
||||
is_connector_public=is_connector_public,
|
||||
)
|
||||
@@ -598,7 +561,7 @@ def connector_pruning_generator_task(
|
||||
]
|
||||
cache_hierarchy_nodes_batch(
|
||||
redis_client=redis_client,
|
||||
source=source,
|
||||
source=cc_pair.connector.source,
|
||||
entries=cache_entries,
|
||||
)
|
||||
|
||||
@@ -607,26 +570,6 @@ def connector_pruning_generator_task(
|
||||
f"hierarchy nodes for cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
ensure_source_node_exists(redis_client, db_session, source)
|
||||
# Resolve parent_hierarchy_raw_node_id → parent_hierarchy_node_id
|
||||
# and bulk-update documents, mirroring the docfetching resolution
|
||||
_resolve_and_update_document_parents(
|
||||
db_session=db_session,
|
||||
redis_client=redis_client,
|
||||
source=source,
|
||||
raw_id_to_parent=all_connector_doc_ids,
|
||||
)
|
||||
|
||||
# Link hierarchy nodes to documents for sources where pages can be
|
||||
# both hierarchy nodes AND documents (e.g. Notion, Confluence)
|
||||
all_doc_id_list = list(all_connector_doc_ids.keys())
|
||||
link_hierarchy_nodes_to_documents(
|
||||
db_session=db_session,
|
||||
document_ids=all_doc_id_list,
|
||||
source=source,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# a list of docs in our local index
|
||||
all_indexed_document_ids = {
|
||||
doc.id
|
||||
@@ -638,9 +581,7 @@ def connector_pruning_generator_task(
|
||||
}
|
||||
|
||||
# generate list of docs to remove (no longer in the source)
|
||||
doc_ids_to_remove = list(
|
||||
all_indexed_document_ids - all_connector_doc_ids.keys()
|
||||
)
|
||||
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
|
||||
|
||||
task_logger.info(
|
||||
"Pruning set collected: "
|
||||
|
||||
10
backend/onyx/background/celery/versioned_apps/background.py
Normal file
10
backend/onyx/background/celery/versioned_apps/background.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from celery import Celery
|
||||
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
app: Celery = fetch_versioned_implementation(
|
||||
"onyx.background.celery.apps.background",
|
||||
"celery_app",
|
||||
)
|
||||
@@ -58,6 +58,8 @@ from onyx.file_store.document_batch_storage import DocumentBatchStorage
|
||||
from onyx.file_store.document_batch_storage import get_document_batch_storage
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
|
||||
from onyx.indexing.postgres_sanitization import sanitize_document_for_postgres
|
||||
from onyx.indexing.postgres_sanitization import sanitize_hierarchy_nodes_for_postgres
|
||||
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
|
||||
from onyx.redis.redis_hierarchy import ensure_source_node_exists
|
||||
from onyx.redis.redis_hierarchy import get_node_id_from_raw_id
|
||||
@@ -69,8 +71,6 @@ from onyx.server.features.build.indexing.persistent_document_writer import (
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.middleware import make_randomized_onyx_request_id
|
||||
from onyx.utils.postgres_sanitization import sanitize_document_for_postgres
|
||||
from onyx.utils.postgres_sanitization import sanitize_hierarchy_nodes_for_postgres
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import INDEX_ATTEMPT_INFO_CONTEXTVAR
|
||||
|
||||
11
backend/onyx/cache/interface.py
vendored
11
backend/onyx/cache/interface.py
vendored
@@ -1,20 +1,9 @@
|
||||
import abc
|
||||
from enum import Enum
|
||||
|
||||
from redis.exceptions import RedisError
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
TTL_KEY_NOT_FOUND = -2
|
||||
TTL_NO_EXPIRY = -1
|
||||
|
||||
CACHE_TRANSIENT_ERRORS: tuple[type[Exception], ...] = (RedisError, SQLAlchemyError)
|
||||
"""Exception types that represent transient cache connectivity / operational
|
||||
failures. Callers that want to fail-open (or fail-closed) on cache errors
|
||||
should catch this tuple instead of bare ``Exception``.
|
||||
|
||||
When adding a new ``CacheBackend`` implementation, add its transient error
|
||||
base class(es) here so all call-sites pick it up automatically."""
|
||||
|
||||
|
||||
class CacheBackendType(str, Enum):
|
||||
REDIS = "redis"
|
||||
|
||||
@@ -36,6 +36,7 @@ from onyx.db.memory import add_memory
|
||||
from onyx.db.memory import update_memory_at_index
|
||||
from onyx.db.memory import UserMemoryContext
|
||||
from onyx.db.models import Persona
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
@@ -51,7 +52,6 @@ from onyx.tools.built_in_tools import STOPPING_TOOLS_NAMES
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ChatFile
|
||||
from onyx.tools.models import MemoryToolResponseSnapshot
|
||||
from onyx.tools.models import PythonToolRichResponse
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
@@ -83,6 +83,28 @@ def _looks_like_xml_tool_call_payload(text: str | None) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def _should_keep_bedrock_tool_definitions(
|
||||
llm: object, simple_chat_history: list[ChatMessageSimple]
|
||||
) -> bool:
|
||||
"""Bedrock requires tool config when history includes toolUse/toolResult blocks."""
|
||||
model_provider = getattr(getattr(llm, "config", None), "model_provider", None)
|
||||
if model_provider not in {
|
||||
LlmProviderNames.BEDROCK,
|
||||
LlmProviderNames.BEDROCK_CONVERSE,
|
||||
}:
|
||||
return False
|
||||
|
||||
return any(
|
||||
(
|
||||
msg.message_type == MessageType.ASSISTANT
|
||||
and msg.tool_calls
|
||||
and len(msg.tool_calls) > 0
|
||||
)
|
||||
or msg.message_type == MessageType.TOOL_CALL_RESPONSE
|
||||
for msg in simple_chat_history
|
||||
)
|
||||
|
||||
|
||||
def _try_fallback_tool_extraction(
|
||||
llm_step_result: LlmStepResult,
|
||||
tool_choice: ToolChoiceOptions,
|
||||
@@ -663,7 +685,12 @@ def run_llm_loop(
|
||||
elif out_of_cycles or ran_image_gen:
|
||||
# Last cycle, no tools allowed, just answer!
|
||||
tool_choice = ToolChoiceOptions.NONE
|
||||
final_tools = []
|
||||
# Bedrock requires tool config in requests that include toolUse/toolResult history.
|
||||
final_tools = (
|
||||
tools
|
||||
if _should_keep_bedrock_tool_definitions(llm, simple_chat_history)
|
||||
else []
|
||||
)
|
||||
else:
|
||||
tool_choice = ToolChoiceOptions.AUTO
|
||||
final_tools = tools
|
||||
@@ -939,13 +966,6 @@ def run_llm_loop(
|
||||
):
|
||||
generated_images = tool_response.rich_response.generated_images
|
||||
|
||||
# Extract generated_files if this is a code interpreter response
|
||||
generated_files = None
|
||||
if isinstance(tool_response.rich_response, PythonToolRichResponse):
|
||||
generated_files = (
|
||||
tool_response.rich_response.generated_files or None
|
||||
)
|
||||
|
||||
# Persist memory if this is a memory tool response
|
||||
memory_snapshot: MemoryToolResponseSnapshot | None = None
|
||||
if isinstance(tool_response.rich_response, MemoryToolResponse):
|
||||
@@ -997,7 +1017,6 @@ def run_llm_loop(
|
||||
tool_call_response=saved_response,
|
||||
search_docs=displayed_docs or search_docs,
|
||||
generated_images=generated_images,
|
||||
generated_files=generated_files,
|
||||
)
|
||||
# Add to state container for partial save support
|
||||
state_container.add_tool_call(tool_call_info)
|
||||
|
||||
@@ -55,7 +55,6 @@ from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tracing.framework.create import generation_span
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_string
|
||||
from onyx.utils.text_processing import find_all_json_objects
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -167,6 +166,15 @@ def _find_function_calls_open_marker(text_lower: str) -> int:
|
||||
search_from = idx + 1
|
||||
|
||||
|
||||
def _sanitize_llm_output(value: str) -> str:
|
||||
"""Remove characters that PostgreSQL's text/JSONB types cannot store.
|
||||
|
||||
- NULL bytes (\x00): Not allowed in PostgreSQL text types
|
||||
- UTF-16 surrogates (\ud800-\udfff): Invalid in UTF-8 encoding
|
||||
"""
|
||||
return "".join(c for c in value if c != "\x00" and not ("\ud800" <= c <= "\udfff"))
|
||||
|
||||
|
||||
def _try_parse_json_string(value: Any) -> Any:
|
||||
"""Attempt to parse a JSON string value into its Python equivalent.
|
||||
|
||||
@@ -214,7 +222,9 @@ def _parse_tool_args_to_dict(raw_args: Any) -> dict[str, Any]:
|
||||
if isinstance(raw_args, dict):
|
||||
# Parse any string values that look like JSON arrays/objects
|
||||
return {
|
||||
k: _try_parse_json_string(sanitize_string(v) if isinstance(v, str) else v)
|
||||
k: _try_parse_json_string(
|
||||
_sanitize_llm_output(v) if isinstance(v, str) else v
|
||||
)
|
||||
for k, v in raw_args.items()
|
||||
}
|
||||
|
||||
@@ -222,7 +232,7 @@ def _parse_tool_args_to_dict(raw_args: Any) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
# Sanitize before parsing to remove NULL bytes and surrogates
|
||||
raw_args = sanitize_string(raw_args)
|
||||
raw_args = _sanitize_llm_output(raw_args)
|
||||
|
||||
try:
|
||||
parsed1: Any = json.loads(raw_args)
|
||||
@@ -535,12 +545,12 @@ def _extract_xml_attribute(attrs: str, attr_name: str) -> str | None:
|
||||
)
|
||||
if not attr_match:
|
||||
return None
|
||||
return sanitize_string(unescape(attr_match.group(2).strip()))
|
||||
return _sanitize_llm_output(unescape(attr_match.group(2).strip()))
|
||||
|
||||
|
||||
def _parse_xml_parameter_value(raw_value: str, string_attr: str | None) -> Any:
|
||||
"""Parse a parameter value from XML-style tool call payloads."""
|
||||
value = sanitize_string(unescape(raw_value).strip())
|
||||
value = _sanitize_llm_output(unescape(raw_value).strip())
|
||||
|
||||
if string_attr and string_attr.lower() == "true":
|
||||
return value
|
||||
@@ -559,7 +569,6 @@ def _resolve_tool_arguments(obj: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""
|
||||
arguments = obj.get("arguments", obj.get("parameters", {}))
|
||||
if isinstance(arguments, str):
|
||||
arguments = sanitize_string(arguments)
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
import mimetypes
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -13,42 +12,14 @@ from onyx.db.chat import create_db_search_doc
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ToolCall
|
||||
from onyx.db.tools import create_tool_call_no_commit
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_string
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _extract_referenced_file_descriptors(
|
||||
tool_calls: list[ToolCallInfo],
|
||||
message_text: str,
|
||||
) -> list[FileDescriptor]:
|
||||
"""Extract FileDescriptors for code interpreter files referenced in the message text."""
|
||||
descriptors: list[FileDescriptor] = []
|
||||
for tool_call_info in tool_calls:
|
||||
if not tool_call_info.generated_files:
|
||||
continue
|
||||
for gen_file in tool_call_info.generated_files:
|
||||
file_id = (
|
||||
gen_file.file_link.rsplit("/", 1)[-1] if gen_file.file_link else ""
|
||||
)
|
||||
if file_id and file_id in message_text:
|
||||
mime_type, _ = mimetypes.guess_type(gen_file.filename)
|
||||
descriptors.append(
|
||||
FileDescriptor(
|
||||
id=file_id,
|
||||
type=mime_type_to_chat_file_type(mime_type),
|
||||
name=gen_file.filename,
|
||||
)
|
||||
)
|
||||
return descriptors
|
||||
|
||||
|
||||
def _create_and_link_tool_calls(
|
||||
tool_calls: list[ToolCallInfo],
|
||||
assistant_message: ChatMessage,
|
||||
@@ -202,13 +173,8 @@ def save_chat_turn(
|
||||
pre_answer_processing_time: Duration of processing before answer starts (in seconds)
|
||||
"""
|
||||
# 1. Update ChatMessage with message content, reasoning tokens, and token count
|
||||
sanitized_message_text = (
|
||||
sanitize_string(message_text) if message_text else message_text
|
||||
)
|
||||
assistant_message.message = sanitized_message_text
|
||||
assistant_message.reasoning_tokens = (
|
||||
sanitize_string(reasoning_tokens) if reasoning_tokens else reasoning_tokens
|
||||
)
|
||||
assistant_message.message = message_text
|
||||
assistant_message.reasoning_tokens = reasoning_tokens
|
||||
assistant_message.is_clarification = is_clarification
|
||||
|
||||
# Use pre-answer processing time (captured when MESSAGE_START was emitted)
|
||||
@@ -218,10 +184,8 @@ def save_chat_turn(
|
||||
# Calculate token count using default tokenizer, when storing, this should not use the LLM
|
||||
# specific one so we use a system default tokenizer here.
|
||||
default_tokenizer = get_tokenizer(None, None)
|
||||
if sanitized_message_text:
|
||||
assistant_message.token_count = len(
|
||||
default_tokenizer.encode(sanitized_message_text)
|
||||
)
|
||||
if message_text:
|
||||
assistant_message.token_count = len(default_tokenizer.encode(message_text))
|
||||
else:
|
||||
assistant_message.token_count = 0
|
||||
|
||||
@@ -333,16 +297,5 @@ def save_chat_turn(
|
||||
citation_number_to_search_doc_id if citation_number_to_search_doc_id else None
|
||||
)
|
||||
|
||||
# 8. Attach code interpreter generated files that the assistant actually
|
||||
# referenced in its response, so they are available via load_all_chat_files
|
||||
# on subsequent turns. Files not mentioned are intermediate artifacts.
|
||||
if sanitized_message_text:
|
||||
referenced = _extract_referenced_file_descriptors(
|
||||
tool_calls, sanitized_message_text
|
||||
)
|
||||
if referenced:
|
||||
existing_files = assistant_message.files or []
|
||||
assistant_message.files = existing_files + referenced
|
||||
|
||||
# Finally save the messages, tool calls, and docs
|
||||
db_session.commit()
|
||||
|
||||
@@ -495,7 +495,14 @@ CELERY_WORKER_PRIMARY_POOL_OVERFLOW = int(
|
||||
os.environ.get("CELERY_WORKER_PRIMARY_POOL_OVERFLOW") or 4
|
||||
)
|
||||
|
||||
# Individual worker concurrency settings
|
||||
# Consolidated background worker (light, docprocessing, docfetching, heavy, monitoring, user_file_processing)
|
||||
# separate workers' defaults: light=24, docprocessing=6, docfetching=1, heavy=4, kg=2, monitoring=1, user_file=2
|
||||
# Total would be 40, but we use a more conservative default of 20 for the consolidated worker
|
||||
CELERY_WORKER_BACKGROUND_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_BACKGROUND_CONCURRENCY") or 20
|
||||
)
|
||||
|
||||
# Individual worker concurrency settings (used when USE_LIGHTWEIGHT_BACKGROUND_WORKER is False or on Kuberenetes deployments)
|
||||
CELERY_WORKER_HEAVY_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_HEAVY_CONCURRENCY") or 4
|
||||
)
|
||||
@@ -812,9 +819,7 @@ RERANK_COUNT = int(os.environ.get("RERANK_COUNT") or 1000)
|
||||
# Tool Configs
|
||||
#####
|
||||
# Code Interpreter Service Configuration
|
||||
CODE_INTERPRETER_BASE_URL = os.environ.get(
|
||||
"CODE_INTERPRETER_BASE_URL", "http://localhost:8000"
|
||||
)
|
||||
CODE_INTERPRETER_BASE_URL = os.environ.get("CODE_INTERPRETER_BASE_URL")
|
||||
|
||||
CODE_INTERPRETER_DEFAULT_TIMEOUT_MS = int(
|
||||
os.environ.get("CODE_INTERPRETER_DEFAULT_TIMEOUT_MS") or 60_000
|
||||
@@ -895,9 +900,6 @@ CUSTOM_ANSWER_VALIDITY_CONDITIONS = json.loads(
|
||||
)
|
||||
|
||||
VESPA_REQUEST_TIMEOUT = int(os.environ.get("VESPA_REQUEST_TIMEOUT") or "15")
|
||||
VESPA_MIGRATION_REQUEST_TIMEOUT_S = int(
|
||||
os.environ.get("VESPA_MIGRATION_REQUEST_TIMEOUT_S") or "120"
|
||||
)
|
||||
|
||||
SYSTEM_RECURSION_LIMIT = int(os.environ.get("SYSTEM_RECURSION_LIMIT") or "1000")
|
||||
|
||||
|
||||
@@ -84,6 +84,7 @@ POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
|
||||
POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME = "celery_worker_docprocessing"
|
||||
POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME = "celery_worker_docfetching"
|
||||
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
|
||||
POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME = "celery_worker_background"
|
||||
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
|
||||
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
|
||||
POSTGRES_CELERY_WORKER_USER_FILE_PROCESSING_APP_NAME = (
|
||||
|
||||
@@ -943,9 +943,6 @@ class ConfluenceConnector(
|
||||
if include_permissions
|
||||
else None
|
||||
),
|
||||
parent_hierarchy_raw_node_id=self._get_parent_hierarchy_raw_id(
|
||||
page
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -995,7 +992,6 @@ class ConfluenceConnector(
|
||||
if include_permissions
|
||||
else None
|
||||
),
|
||||
parent_hierarchy_raw_node_id=page_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -781,5 +781,4 @@ def build_slim_document(
|
||||
return SlimDocument(
|
||||
id=onyx_document_id_from_drive_file(file),
|
||||
external_access=external_access,
|
||||
parent_hierarchy_raw_node_id=(file.get("parents") or [None])[0],
|
||||
)
|
||||
|
||||
@@ -902,11 +902,6 @@ class JiraConnector(
|
||||
external_access=self._get_project_permissions(
|
||||
project_key, add_prefix=False
|
||||
),
|
||||
parent_hierarchy_raw_node_id=(
|
||||
self._get_parent_hierarchy_raw_node_id(issue, project_key)
|
||||
if project_key
|
||||
else None
|
||||
),
|
||||
)
|
||||
)
|
||||
current_offset += 1
|
||||
|
||||
@@ -385,7 +385,6 @@ class IndexingDocument(Document):
|
||||
class SlimDocument(BaseModel):
|
||||
id: str
|
||||
external_access: ExternalAccess | None = None
|
||||
parent_hierarchy_raw_node_id: str | None = None
|
||||
|
||||
|
||||
class HierarchyNode(BaseModel):
|
||||
|
||||
@@ -772,7 +772,6 @@ def _convert_driveitem_to_slim_document(
|
||||
drive_name: str,
|
||||
ctx: ClientContext,
|
||||
graph_client: GraphClient,
|
||||
parent_hierarchy_raw_node_id: str | None = None,
|
||||
) -> SlimDocument:
|
||||
if driveitem.id is None:
|
||||
raise ValueError("DriveItem ID is required")
|
||||
@@ -788,15 +787,11 @@ def _convert_driveitem_to_slim_document(
|
||||
return SlimDocument(
|
||||
id=driveitem.id,
|
||||
external_access=external_access,
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
|
||||
)
|
||||
|
||||
|
||||
def _convert_sitepage_to_slim_document(
|
||||
site_page: dict[str, Any],
|
||||
ctx: ClientContext | None,
|
||||
graph_client: GraphClient,
|
||||
parent_hierarchy_raw_node_id: str | None = None,
|
||||
site_page: dict[str, Any], ctx: ClientContext | None, graph_client: GraphClient
|
||||
) -> SlimDocument:
|
||||
"""Convert a SharePoint site page to a SlimDocument object."""
|
||||
if site_page.get("id") is None:
|
||||
@@ -813,7 +808,6 @@ def _convert_sitepage_to_slim_document(
|
||||
return SlimDocument(
|
||||
id=id,
|
||||
external_access=external_access,
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -1600,22 +1594,12 @@ class SharepointConnector(
|
||||
)
|
||||
)
|
||||
|
||||
parent_hierarchy_url: str | None = None
|
||||
if drive_web_url:
|
||||
parent_hierarchy_url = self._get_parent_hierarchy_url(
|
||||
site_url, drive_web_url, drive_name, driveitem
|
||||
)
|
||||
|
||||
try:
|
||||
logger.debug(f"Processing: {driveitem.web_url}")
|
||||
ctx = self._create_rest_client_context(site_descriptor.url)
|
||||
doc_batch.append(
|
||||
_convert_driveitem_to_slim_document(
|
||||
driveitem,
|
||||
drive_name,
|
||||
ctx,
|
||||
self.graph_client,
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_url,
|
||||
driveitem, drive_name, ctx, self.graph_client
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -1635,10 +1619,7 @@ class SharepointConnector(
|
||||
ctx = self._create_rest_client_context(site_descriptor.url)
|
||||
doc_batch.append(
|
||||
_convert_sitepage_to_slim_document(
|
||||
site_page,
|
||||
ctx,
|
||||
self.graph_client,
|
||||
parent_hierarchy_raw_node_id=site_descriptor.url,
|
||||
site_page, ctx, self.graph_client
|
||||
)
|
||||
)
|
||||
if len(doc_batch) >= SLIM_BATCH_SIZE:
|
||||
|
||||
@@ -565,7 +565,6 @@ def _get_all_doc_ids(
|
||||
channel_id=channel_id, thread_ts=message["ts"]
|
||||
),
|
||||
external_access=external_access,
|
||||
parent_hierarchy_raw_node_id=channel_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -38,7 +38,6 @@ from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.override_models import PromptOverride
|
||||
from onyx.server.query_and_chat.models import ChatMessageDetail
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_string
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -676,43 +675,58 @@ def set_as_latest_chat_message(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _sanitize_for_postgres(value: str) -> str:
|
||||
"""Remove NUL (0x00) characters from strings as PostgreSQL doesn't allow them."""
|
||||
sanitized = value.replace("\x00", "")
|
||||
if value and not sanitized:
|
||||
logger.warning("Sanitization removed all characters from string")
|
||||
return sanitized
|
||||
|
||||
|
||||
def _sanitize_list_for_postgres(values: list[str]) -> list[str]:
|
||||
"""Remove NUL (0x00) characters from all strings in a list."""
|
||||
return [_sanitize_for_postgres(v) for v in values]
|
||||
|
||||
|
||||
def create_db_search_doc(
|
||||
server_search_doc: ServerSearchDoc,
|
||||
db_session: Session,
|
||||
commit: bool = True,
|
||||
) -> DBSearchDoc:
|
||||
# Sanitize string fields to remove NUL characters (PostgreSQL doesn't allow them)
|
||||
db_search_doc = DBSearchDoc(
|
||||
document_id=sanitize_string(server_search_doc.document_id),
|
||||
document_id=_sanitize_for_postgres(server_search_doc.document_id),
|
||||
chunk_ind=server_search_doc.chunk_ind,
|
||||
semantic_id=sanitize_string(server_search_doc.semantic_identifier),
|
||||
semantic_id=_sanitize_for_postgres(server_search_doc.semantic_identifier),
|
||||
link=(
|
||||
sanitize_string(server_search_doc.link)
|
||||
_sanitize_for_postgres(server_search_doc.link)
|
||||
if server_search_doc.link is not None
|
||||
else None
|
||||
),
|
||||
blurb=sanitize_string(server_search_doc.blurb),
|
||||
blurb=_sanitize_for_postgres(server_search_doc.blurb),
|
||||
source_type=server_search_doc.source_type,
|
||||
boost=server_search_doc.boost,
|
||||
hidden=server_search_doc.hidden,
|
||||
doc_metadata=server_search_doc.metadata,
|
||||
is_relevant=server_search_doc.is_relevant,
|
||||
relevance_explanation=(
|
||||
sanitize_string(server_search_doc.relevance_explanation)
|
||||
_sanitize_for_postgres(server_search_doc.relevance_explanation)
|
||||
if server_search_doc.relevance_explanation is not None
|
||||
else None
|
||||
),
|
||||
# For docs further down that aren't reranked, we can't use the retrieval score
|
||||
score=server_search_doc.score or 0.0,
|
||||
match_highlights=[
|
||||
sanitize_string(h) for h in server_search_doc.match_highlights
|
||||
],
|
||||
match_highlights=_sanitize_list_for_postgres(
|
||||
server_search_doc.match_highlights
|
||||
),
|
||||
updated_at=server_search_doc.updated_at,
|
||||
primary_owners=(
|
||||
[sanitize_string(o) for o in server_search_doc.primary_owners]
|
||||
_sanitize_list_for_postgres(server_search_doc.primary_owners)
|
||||
if server_search_doc.primary_owners is not None
|
||||
else None
|
||||
),
|
||||
secondary_owners=(
|
||||
[sanitize_string(o) for o in server_search_doc.secondary_owners]
|
||||
_sanitize_list_for_postgres(server_search_doc.secondary_owners)
|
||||
if server_search_doc.secondary_owners is not None
|
||||
else None
|
||||
),
|
||||
|
||||
@@ -13,7 +13,6 @@ from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.enums import AccessType
|
||||
@@ -247,7 +246,6 @@ def insert_document_set(
|
||||
description=document_set_creation_request.description,
|
||||
user_id=user_id,
|
||||
is_public=document_set_creation_request.is_public,
|
||||
is_up_to_date=DISABLE_VECTOR_DB,
|
||||
time_last_modified_by_user=func.now(),
|
||||
)
|
||||
db_session.add(new_document_set_row)
|
||||
@@ -338,8 +336,7 @@ def update_document_set(
|
||||
)
|
||||
|
||||
document_set_row.description = document_set_update_request.description
|
||||
if not DISABLE_VECTOR_DB:
|
||||
document_set_row.is_up_to_date = False
|
||||
document_set_row.is_up_to_date = False
|
||||
document_set_row.is_public = document_set_update_request.is_public
|
||||
document_set_row.time_last_modified_by_user = func.now()
|
||||
versioned_private_doc_set_fn = fetch_versioned_implementation(
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""CRUD operations for HierarchyNode."""
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -527,53 +525,6 @@ def get_document_parent_hierarchy_node_ids(
|
||||
return {doc_id: parent_id for doc_id, parent_id in results}
|
||||
|
||||
|
||||
def update_document_parent_hierarchy_nodes(
|
||||
db_session: Session,
|
||||
doc_parent_map: dict[str, int | None],
|
||||
commit: bool = True,
|
||||
) -> int:
|
||||
"""Bulk-update Document.parent_hierarchy_node_id for multiple documents.
|
||||
|
||||
Only updates rows whose current value differs from the desired value to
|
||||
avoid unnecessary writes.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
doc_parent_map: Mapping of document_id → desired parent_hierarchy_node_id
|
||||
commit: Whether to commit the transaction
|
||||
|
||||
Returns:
|
||||
Number of documents actually updated
|
||||
"""
|
||||
if not doc_parent_map:
|
||||
return 0
|
||||
|
||||
doc_ids = list(doc_parent_map.keys())
|
||||
existing = get_document_parent_hierarchy_node_ids(db_session, doc_ids)
|
||||
|
||||
by_parent: dict[int | None, list[str]] = defaultdict(list)
|
||||
for doc_id, desired_parent_id in doc_parent_map.items():
|
||||
current = existing.get(doc_id)
|
||||
if current == desired_parent_id or doc_id not in existing:
|
||||
continue
|
||||
by_parent[desired_parent_id].append(doc_id)
|
||||
|
||||
updated = 0
|
||||
for desired_parent_id, ids in by_parent.items():
|
||||
db_session.query(Document).filter(Document.id.in_(ids)).update(
|
||||
{Document.parent_hierarchy_node_id: desired_parent_id},
|
||||
synchronize_session=False,
|
||||
)
|
||||
updated += len(ids)
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
elif updated:
|
||||
db_session.flush()
|
||||
|
||||
return updated
|
||||
|
||||
|
||||
def update_hierarchy_node_permissions(
|
||||
db_session: Session,
|
||||
raw_node_id: str,
|
||||
|
||||
@@ -25,11 +25,8 @@ from onyx.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def update_group_llm_provider_relationships__no_commit(
|
||||
llm_provider_id: int,
|
||||
@@ -535,7 +532,6 @@ def fetch_default_model(
|
||||
) -> ModelConfiguration | None:
|
||||
model_config = db_session.scalar(
|
||||
select(ModelConfiguration)
|
||||
.options(selectinload(ModelConfiguration.llm_provider))
|
||||
.join(LLMModelFlow)
|
||||
.where(
|
||||
ModelConfiguration.is_visible == True, # noqa: E712
|
||||
@@ -815,43 +811,6 @@ def sync_auto_mode_models(
|
||||
changes += 1
|
||||
|
||||
db_session.commit()
|
||||
|
||||
# Update the default if this provider currently holds the global CHAT default
|
||||
recommended_default = llm_recommendations.get_default_model(provider.provider)
|
||||
if recommended_default:
|
||||
current_default_name = db_session.scalar(
|
||||
select(ModelConfiguration.name)
|
||||
.join(
|
||||
LLMModelFlow,
|
||||
LLMModelFlow.model_configuration_id == ModelConfiguration.id,
|
||||
)
|
||||
.where(
|
||||
ModelConfiguration.llm_provider_id == provider.id,
|
||||
LLMModelFlow.llm_model_flow_type == LLMModelFlowType.CHAT,
|
||||
LLMModelFlow.is_default == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
|
||||
if (
|
||||
current_default_name is not None
|
||||
and current_default_name != recommended_default.name
|
||||
):
|
||||
try:
|
||||
_update_default_model(
|
||||
db_session=db_session,
|
||||
provider_id=provider.id,
|
||||
model=recommended_default.name,
|
||||
flow_type=LLMModelFlowType.CHAT,
|
||||
)
|
||||
changes += 1
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"Recommended default model '%s' not found "
|
||||
"for provider_id=%s; skipping default update.",
|
||||
recommended_default.name,
|
||||
provider.id,
|
||||
)
|
||||
|
||||
return changes
|
||||
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ def create_user_files(
|
||||
) -> CategorizedFilesResult:
|
||||
|
||||
# Categorize the files
|
||||
categorized_files = categorize_uploaded_files(files, db_session)
|
||||
categorized_files = categorize_uploaded_files(files)
|
||||
# NOTE: At the moment, zip metadata is not used for user files.
|
||||
# Should revisit to decide whether this should be a feature.
|
||||
upload_response = upload_files(categorized_files.acceptable, FileOrigin.USER_FILE)
|
||||
|
||||
@@ -129,7 +129,7 @@ def get_current_search_settings(db_session: Session) -> SearchSettings:
|
||||
latest_settings = result.scalars().first()
|
||||
|
||||
if not latest_settings:
|
||||
raise RuntimeError("No search settings specified; DB is not in a valid state.")
|
||||
raise RuntimeError("No search settings specified, DB is not in a valid state")
|
||||
return latest_settings
|
||||
|
||||
|
||||
|
||||
@@ -13,15 +13,12 @@ from onyx.db.constants import UNSET
|
||||
from onyx.db.constants import UnsetType
|
||||
from onyx.db.enums import MCPServerStatus
|
||||
from onyx.db.models import MCPServer
|
||||
from onyx.db.models import OAuthConfig
|
||||
from onyx.db.models import Tool
|
||||
from onyx.db.models import ToolCall
|
||||
from onyx.server.features.tool.models import Header
|
||||
from onyx.tools.built_in_tools import BUILT_IN_TOOL_TYPES
|
||||
from onyx.utils.headers import HeaderItemDict
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_json_like
|
||||
from onyx.utils.postgres_sanitization import sanitize_string
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
@@ -162,26 +159,10 @@ def update_tool(
|
||||
]
|
||||
if passthrough_auth is not None:
|
||||
tool.passthrough_auth = passthrough_auth
|
||||
old_oauth_config_id = tool.oauth_config_id
|
||||
if not isinstance(oauth_config_id, UnsetType):
|
||||
tool.oauth_config_id = oauth_config_id
|
||||
db_session.flush()
|
||||
|
||||
# Clean up orphaned OAuthConfig if the oauth_config_id was changed
|
||||
if (
|
||||
old_oauth_config_id is not None
|
||||
and not isinstance(oauth_config_id, UnsetType)
|
||||
and old_oauth_config_id != oauth_config_id
|
||||
):
|
||||
other_tools = db_session.scalars(
|
||||
select(Tool).where(Tool.oauth_config_id == old_oauth_config_id)
|
||||
).all()
|
||||
if not other_tools:
|
||||
oauth_config = db_session.get(OAuthConfig, old_oauth_config_id)
|
||||
if oauth_config:
|
||||
db_session.delete(oauth_config)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return tool
|
||||
|
||||
|
||||
@@ -190,21 +171,8 @@ def delete_tool__no_commit(tool_id: int, db_session: Session) -> None:
|
||||
if tool is None:
|
||||
raise ValueError(f"Tool with ID {tool_id} does not exist")
|
||||
|
||||
oauth_config_id = tool.oauth_config_id
|
||||
|
||||
db_session.delete(tool)
|
||||
db_session.flush()
|
||||
|
||||
# Clean up orphaned OAuthConfig if no other tools reference it
|
||||
if oauth_config_id is not None:
|
||||
other_tools = db_session.scalars(
|
||||
select(Tool).where(Tool.oauth_config_id == oauth_config_id)
|
||||
).all()
|
||||
if not other_tools:
|
||||
oauth_config = db_session.get(OAuthConfig, oauth_config_id)
|
||||
if oauth_config:
|
||||
db_session.delete(oauth_config)
|
||||
db_session.flush()
|
||||
db_session.flush() # Don't commit yet, let caller decide when to commit
|
||||
|
||||
|
||||
def get_builtin_tool(
|
||||
@@ -288,13 +256,11 @@ def create_tool_call_no_commit(
|
||||
tab_index=tab_index,
|
||||
tool_id=tool_id,
|
||||
tool_call_id=tool_call_id,
|
||||
reasoning_tokens=(
|
||||
sanitize_string(reasoning_tokens) if reasoning_tokens else reasoning_tokens
|
||||
),
|
||||
tool_call_arguments=sanitize_json_like(tool_call_arguments),
|
||||
tool_call_response=sanitize_json_like(tool_call_response),
|
||||
reasoning_tokens=reasoning_tokens,
|
||||
tool_call_arguments=tool_call_arguments,
|
||||
tool_call_response=tool_call_response,
|
||||
tool_call_tokens=tool_call_tokens,
|
||||
generated_images=sanitize_json_like(generated_images),
|
||||
generated_images=generated_images,
|
||||
)
|
||||
|
||||
db_session.add(tool_call)
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
# Vector DB Filter Semantics
|
||||
|
||||
How `IndexFilters` fields combine into the final query filter. Applies to both Vespa and OpenSearch.
|
||||
|
||||
## Filter categories
|
||||
|
||||
| Category | Fields | Join logic |
|
||||
|---|---|---|
|
||||
| **Visibility** | `hidden` | Always applied (unless `include_hidden`) |
|
||||
| **Tenant** | `tenant_id` | AND (multi-tenant only) |
|
||||
| **ACL** | `access_control_list` | OR within, AND with rest |
|
||||
| **Narrowing** | `source_type`, `tags`, `time_cutoff` | Each OR within, AND with rest |
|
||||
| **Knowledge scope** | `document_set`, `user_file_ids`, `attached_document_ids`, `hierarchy_node_ids` | OR within group, AND with rest |
|
||||
| **Additive scope** | `project_id`, `persona_id` | OR'd into knowledge scope **only when** a knowledge scope filter already exists |
|
||||
|
||||
## How filters combine
|
||||
|
||||
All categories are AND'd together. Within the knowledge scope category, individual filters are OR'd.
|
||||
|
||||
```
|
||||
NOT hidden
|
||||
AND tenant = T -- if multi-tenant
|
||||
AND (acl contains A1 OR acl contains A2)
|
||||
AND (source_type = S1 OR ...) -- if set
|
||||
AND (tag = T1 OR ...) -- if set
|
||||
AND <knowledge scope> -- see below
|
||||
AND time >= cutoff -- if set
|
||||
```
|
||||
|
||||
## Knowledge scope rules
|
||||
|
||||
The knowledge scope filter controls **what knowledge an assistant can access**.
|
||||
|
||||
### No explicit knowledge attached
|
||||
|
||||
When `document_set`, `user_file_ids`, `attached_document_ids`, and `hierarchy_node_ids` are all empty/None:
|
||||
|
||||
- **No knowledge scope filter is applied.** The assistant can see everything (subject to ACL).
|
||||
- `project_id` and `persona_id` are ignored — they never restrict on their own.
|
||||
|
||||
### One explicit knowledge type
|
||||
|
||||
```
|
||||
-- Only document sets
|
||||
AND (document_sets contains "Engineering" OR document_sets contains "Legal")
|
||||
|
||||
-- Only user files
|
||||
AND (document_id = "uuid-1" OR document_id = "uuid-2")
|
||||
```
|
||||
|
||||
### Multiple explicit knowledge types (OR'd)
|
||||
|
||||
```
|
||||
-- Document sets + user files
|
||||
AND (
|
||||
document_sets contains "Engineering"
|
||||
OR document_id = "uuid-1"
|
||||
)
|
||||
```
|
||||
|
||||
### Explicit knowledge + overflowing user files
|
||||
|
||||
When an explicit knowledge restriction is in effect **and** `project_id` or `persona_id` is set (user files overflowed the LLM context window), the additive scopes widen the filter:
|
||||
|
||||
```
|
||||
-- Document sets + persona user files overflowed
|
||||
AND (
|
||||
document_sets contains "Engineering"
|
||||
OR personas contains 42
|
||||
)
|
||||
|
||||
-- User files + project files overflowed
|
||||
AND (
|
||||
document_id = "uuid-1"
|
||||
OR user_project contains 7
|
||||
)
|
||||
```
|
||||
|
||||
### Only project_id or persona_id (no explicit knowledge)
|
||||
|
||||
No knowledge scope filter. The assistant searches everything.
|
||||
|
||||
```
|
||||
-- Just ACL, no restriction
|
||||
NOT hidden
|
||||
AND (acl contains ...)
|
||||
```
|
||||
|
||||
## Field reference
|
||||
|
||||
| Filter field | Vespa field | Vespa type | Purpose |
|
||||
|---|---|---|---|
|
||||
| `document_set` | `document_sets` | `weightedset<string>` | Connector doc sets attached to assistant |
|
||||
| `user_file_ids` | `document_id` | `string` | User files uploaded to assistant |
|
||||
| `attached_document_ids` | `document_id` | `string` | Documents explicitly attached (OpenSearch only) |
|
||||
| `hierarchy_node_ids` | `ancestor_hierarchy_node_ids` | `array<int>` | Folder/space nodes (OpenSearch only) |
|
||||
| `project_id` | `user_project` | `array<int>` | Project tag for overflowing user files |
|
||||
| `persona_id` | `personas` | `array<int>` | Persona tag for overflowing user files |
|
||||
| `access_control_list` | `access_control_list` | `weightedset<string>` | ACL entries for the requesting user |
|
||||
| `source_type` | `source_type` | `string` | Connector source type (e.g. `web`, `jira`) |
|
||||
| `tags` | `metadata_list` | `array<string>` | Document metadata tags |
|
||||
| `time_cutoff` | `doc_updated_at` | `long` | Minimum document update timestamp |
|
||||
| `tenant_id` | `tenant_id` | `string` | Tenant isolation (multi-tenant) |
|
||||
@@ -32,6 +32,9 @@ def get_multipass_config(search_settings: SearchSettings) -> MultipassConfig:
|
||||
Determines whether to enable multipass and large chunks by examining
|
||||
the current search settings and the embedder configuration.
|
||||
"""
|
||||
if not search_settings:
|
||||
return MultipassConfig(multipass_indexing=False, enable_large_chunks=False)
|
||||
|
||||
multipass = should_use_multipass(search_settings)
|
||||
enable_large_chunks = SearchSettings.can_use_large_chunks(
|
||||
multipass, search_settings.model_name, search_settings.provider_type
|
||||
|
||||
@@ -26,10 +26,11 @@ def get_default_document_index(
|
||||
To be used for retrieval only. Indexing should be done through both indices
|
||||
until Vespa is deprecated.
|
||||
|
||||
Pre-existing docstring for this function, although secondary indices are not
|
||||
currently supported:
|
||||
Primary index is the index that is used for querying/updating etc. Secondary
|
||||
index is for when both the currently used index and the upcoming index both
|
||||
need to be updated. Updates are applied to both indices.
|
||||
WARNING: In that case, get_all_document_indices should be used.
|
||||
need to be updated, updates are applied to both indices.
|
||||
"""
|
||||
if DISABLE_VECTOR_DB:
|
||||
return DisabledDocumentIndex(
|
||||
@@ -50,26 +51,11 @@ def get_default_document_index(
|
||||
opensearch_retrieval_enabled = get_opensearch_retrieval_state(db_session)
|
||||
if opensearch_retrieval_enabled:
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
secondary_indexing_setting = (
|
||||
IndexingSetting.from_db_model(secondary_search_settings)
|
||||
if secondary_search_settings
|
||||
else None
|
||||
)
|
||||
return OpenSearchOldDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
embedding_dim=indexing_setting.final_embedding_dim,
|
||||
embedding_precision=indexing_setting.embedding_precision,
|
||||
secondary_index_name=secondary_index_name,
|
||||
secondary_embedding_dim=(
|
||||
secondary_indexing_setting.final_embedding_dim
|
||||
if secondary_indexing_setting
|
||||
else None
|
||||
),
|
||||
secondary_embedding_precision=(
|
||||
secondary_indexing_setting.embedding_precision
|
||||
if secondary_indexing_setting
|
||||
else None
|
||||
),
|
||||
large_chunks_enabled=search_settings.large_chunks_enabled,
|
||||
secondary_large_chunks_enabled=secondary_large_chunks_enabled,
|
||||
multitenant=MULTI_TENANT,
|
||||
@@ -100,7 +86,8 @@ def get_all_document_indices(
|
||||
Used for indexing only. Until Vespa is deprecated we will index into both
|
||||
document indices. Retrieval is done through only one index however.
|
||||
|
||||
Large chunks are not currently supported so we hardcode appropriate values.
|
||||
Large chunks and secondary indices are not currently supported so we
|
||||
hardcode appropriate values.
|
||||
|
||||
NOTE: Make sure the Vespa index object is returned first. In the rare event
|
||||
that there is some conflict between indexing and the migration task, it is
|
||||
@@ -136,36 +123,13 @@ def get_all_document_indices(
|
||||
opensearch_document_index: OpenSearchOldDocumentIndex | None = None
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
secondary_indexing_setting = (
|
||||
IndexingSetting.from_db_model(secondary_search_settings)
|
||||
if secondary_search_settings
|
||||
else None
|
||||
)
|
||||
opensearch_document_index = OpenSearchOldDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
embedding_dim=indexing_setting.final_embedding_dim,
|
||||
embedding_precision=indexing_setting.embedding_precision,
|
||||
secondary_index_name=(
|
||||
secondary_search_settings.index_name
|
||||
if secondary_search_settings
|
||||
else None
|
||||
),
|
||||
secondary_embedding_dim=(
|
||||
secondary_indexing_setting.final_embedding_dim
|
||||
if secondary_indexing_setting
|
||||
else None
|
||||
),
|
||||
secondary_embedding_precision=(
|
||||
secondary_indexing_setting.embedding_precision
|
||||
if secondary_indexing_setting
|
||||
else None
|
||||
),
|
||||
large_chunks_enabled=search_settings.large_chunks_enabled,
|
||||
secondary_large_chunks_enabled=(
|
||||
secondary_search_settings.large_chunks_enabled
|
||||
if secondary_search_settings
|
||||
else None
|
||||
),
|
||||
secondary_index_name=None,
|
||||
large_chunks_enabled=False,
|
||||
secondary_large_chunks_enabled=None,
|
||||
multitenant=MULTI_TENANT,
|
||||
httpx_client=httpx_client,
|
||||
)
|
||||
|
||||
@@ -61,25 +61,6 @@ class SearchHit(BaseModel, Generic[SchemaDocumentModel]):
|
||||
explanation: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class IndexInfo(BaseModel):
|
||||
"""
|
||||
Represents information about an OpenSearch index.
|
||||
"""
|
||||
|
||||
model_config = {"frozen": True}
|
||||
|
||||
name: str
|
||||
health: str
|
||||
status: str
|
||||
num_primary_shards: str
|
||||
num_replica_shards: str
|
||||
docs_count: str
|
||||
docs_deleted: str
|
||||
created_at: str
|
||||
total_size: str
|
||||
primary_shards_size: str
|
||||
|
||||
|
||||
def get_new_body_without_vectors(body: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Recursively replaces vectors in the body with their length.
|
||||
|
||||
@@ -178,8 +159,8 @@ class OpenSearchClient(AbstractContextManager):
|
||||
Raises:
|
||||
Exception: There was an error creating the search pipeline.
|
||||
"""
|
||||
response = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
|
||||
if not response.get("acknowledged", False):
|
||||
result = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
|
||||
if not result.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to create search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
@@ -192,8 +173,8 @@ class OpenSearchClient(AbstractContextManager):
|
||||
Raises:
|
||||
Exception: There was an error deleting the search pipeline.
|
||||
"""
|
||||
response = self._client.search_pipeline.delete(id=pipeline_id)
|
||||
if not response.get("acknowledged", False):
|
||||
result = self._client.search_pipeline.delete(id=pipeline_id)
|
||||
if not result.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to delete search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
@@ -217,34 +198,6 @@ class OpenSearchClient(AbstractContextManager):
|
||||
logger.error(f"Failed to put cluster settings: {response}.")
|
||||
return False
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def list_indices_with_info(self) -> list[IndexInfo]:
|
||||
"""
|
||||
Lists the indices in the OpenSearch cluster with information about each
|
||||
index.
|
||||
|
||||
Returns:
|
||||
A list of IndexInfo objects for each index.
|
||||
"""
|
||||
response = self._client.cat.indices(format="json")
|
||||
indices: list[IndexInfo] = []
|
||||
for raw_index_info in response:
|
||||
indices.append(
|
||||
IndexInfo(
|
||||
name=raw_index_info.get("index", ""),
|
||||
health=raw_index_info.get("health", ""),
|
||||
status=raw_index_info.get("status", ""),
|
||||
num_primary_shards=raw_index_info.get("pri", ""),
|
||||
num_replica_shards=raw_index_info.get("rep", ""),
|
||||
docs_count=raw_index_info.get("docs.count", ""),
|
||||
docs_deleted=raw_index_info.get("docs.deleted", ""),
|
||||
created_at=raw_index_info.get("creation.date.string", ""),
|
||||
total_size=raw_index_info.get("store.size", ""),
|
||||
primary_shards_size=raw_index_info.get("pri.store.size", ""),
|
||||
)
|
||||
)
|
||||
return indices
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def ping(self) -> bool:
|
||||
"""Pings the OpenSearch cluster.
|
||||
|
||||
@@ -271,9 +271,6 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
embedding_dim: int,
|
||||
embedding_precision: EmbeddingPrecision,
|
||||
secondary_index_name: str | None,
|
||||
secondary_embedding_dim: int | None,
|
||||
secondary_embedding_precision: EmbeddingPrecision | None,
|
||||
# NOTE: We do not support large chunks right now.
|
||||
large_chunks_enabled: bool, # noqa: ARG002
|
||||
secondary_large_chunks_enabled: bool | None, # noqa: ARG002
|
||||
multitenant: bool = False,
|
||||
@@ -289,25 +286,12 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
f"Expected {MULTI_TENANT}, got {multitenant}."
|
||||
)
|
||||
tenant_id = get_current_tenant_id()
|
||||
tenant_state = TenantState(tenant_id=tenant_id, multitenant=multitenant)
|
||||
self._real_index = OpenSearchDocumentIndex(
|
||||
tenant_state=tenant_state,
|
||||
tenant_state=TenantState(tenant_id=tenant_id, multitenant=multitenant),
|
||||
index_name=index_name,
|
||||
embedding_dim=embedding_dim,
|
||||
embedding_precision=embedding_precision,
|
||||
)
|
||||
self._secondary_real_index: OpenSearchDocumentIndex | None = None
|
||||
if self.secondary_index_name:
|
||||
if secondary_embedding_dim is None or secondary_embedding_precision is None:
|
||||
raise ValueError(
|
||||
"Bug: Secondary index embedding dimension and precision are not set."
|
||||
)
|
||||
self._secondary_real_index = OpenSearchDocumentIndex(
|
||||
tenant_state=tenant_state,
|
||||
index_name=self.secondary_index_name,
|
||||
embedding_dim=secondary_embedding_dim,
|
||||
embedding_precision=secondary_embedding_precision,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def register_multitenant_indices(
|
||||
@@ -323,38 +307,19 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
self,
|
||||
primary_embedding_dim: int,
|
||||
primary_embedding_precision: EmbeddingPrecision,
|
||||
secondary_index_embedding_dim: int | None,
|
||||
secondary_index_embedding_precision: EmbeddingPrecision | None,
|
||||
secondary_index_embedding_dim: int | None, # noqa: ARG002
|
||||
secondary_index_embedding_precision: EmbeddingPrecision | None, # noqa: ARG002
|
||||
) -> None:
|
||||
self._real_index.verify_and_create_index_if_necessary(
|
||||
# Only handle primary index for now, ignore secondary.
|
||||
return self._real_index.verify_and_create_index_if_necessary(
|
||||
primary_embedding_dim, primary_embedding_precision
|
||||
)
|
||||
if self.secondary_index_name:
|
||||
if (
|
||||
secondary_index_embedding_dim is None
|
||||
or secondary_index_embedding_precision is None
|
||||
):
|
||||
raise ValueError(
|
||||
"Bug: Secondary index embedding dimension and precision are not set."
|
||||
)
|
||||
assert (
|
||||
self._secondary_real_index is not None
|
||||
), "Bug: Secondary index is not initialized."
|
||||
self._secondary_real_index.verify_and_create_index_if_necessary(
|
||||
secondary_index_embedding_dim, secondary_index_embedding_precision
|
||||
)
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[OldDocumentInsertionRecord]:
|
||||
"""
|
||||
NOTE: Do NOT consider the secondary index here. A separate indexing
|
||||
pipeline will be responsible for indexing to the secondary index. This
|
||||
design is not ideal and we should reconsider this when revamping index
|
||||
swapping.
|
||||
"""
|
||||
# Convert IndexBatchParams to IndexingMetadata.
|
||||
chunk_counts: dict[str, IndexingMetadata.ChunkCounts] = {}
|
||||
for doc_id in index_batch_params.doc_id_to_new_chunk_cnt:
|
||||
@@ -386,20 +351,7 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
tenant_id: str, # noqa: ARG002
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
"""
|
||||
NOTE: Remember to handle the secondary index here. There is no separate
|
||||
pipeline for deleting chunks in the secondary index. This design is not
|
||||
ideal and we should reconsider this when revamping index swapping.
|
||||
"""
|
||||
total_chunks_deleted = self._real_index.delete(doc_id, chunk_count)
|
||||
if self.secondary_index_name:
|
||||
assert (
|
||||
self._secondary_real_index is not None
|
||||
), "Bug: Secondary index is not initialized."
|
||||
total_chunks_deleted += self._secondary_real_index.delete(
|
||||
doc_id, chunk_count
|
||||
)
|
||||
return total_chunks_deleted
|
||||
return self._real_index.delete(doc_id, chunk_count)
|
||||
|
||||
def update_single(
|
||||
self,
|
||||
@@ -410,11 +362,6 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
fields: VespaDocumentFields | None,
|
||||
user_fields: VespaDocumentUserFields | None,
|
||||
) -> None:
|
||||
"""
|
||||
NOTE: Remember to handle the secondary index here. There is no separate
|
||||
pipeline for updating chunks in the secondary index. This design is not
|
||||
ideal and we should reconsider this when revamping index swapping.
|
||||
"""
|
||||
if fields is None and user_fields is None:
|
||||
logger.warning(
|
||||
f"Tried to update document {doc_id} with no updated fields or user fields."
|
||||
@@ -445,11 +392,6 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
|
||||
try:
|
||||
self._real_index.update([update_request])
|
||||
if self.secondary_index_name:
|
||||
assert (
|
||||
self._secondary_real_index is not None
|
||||
), "Bug: Secondary index is not initialized."
|
||||
self._secondary_real_index.update([update_request])
|
||||
except NotFoundError:
|
||||
logger.exception(
|
||||
f"Tried to update document {doc_id} but at least one of its chunks was not found in OpenSearch. "
|
||||
@@ -739,8 +681,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
The number of chunks successfully deleted.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Deleting document {document_id} from index "
|
||||
f"{self._index_name}."
|
||||
f"[OpenSearchDocumentIndex] Deleting document {document_id} from index {self._index_name}."
|
||||
)
|
||||
query_body = DocumentQuery.delete_from_document_id_query(
|
||||
document_id=document_id,
|
||||
@@ -776,8 +717,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
specified documents.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Updating {len(update_requests)} chunks for index "
|
||||
f"{self._index_name}."
|
||||
f"[OpenSearchDocumentIndex] Updating {len(update_requests)} chunks for index {self._index_name}."
|
||||
)
|
||||
for update_request in update_requests:
|
||||
properties_to_update: dict[str, Any] = dict()
|
||||
@@ -833,11 +773,9 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
# here.
|
||||
# TODO(andrei): Fix the aforementioned race condition.
|
||||
raise ChunkCountNotFoundError(
|
||||
f"Tried to update document {doc_id} but its chunk count is not known. "
|
||||
"Older versions of the application used to permit this but is not a "
|
||||
"supported state for a document when using OpenSearch. The document was "
|
||||
"likely just added to the indexing pipeline and the chunk count will be "
|
||||
"updated shortly."
|
||||
f"Tried to update document {doc_id} but its chunk count is not known. Older versions of the "
|
||||
"application used to permit this but is not a supported state for a document when using OpenSearch. "
|
||||
"The document was likely just added to the indexing pipeline and the chunk count will be updated shortly."
|
||||
)
|
||||
if doc_chunk_count == 0:
|
||||
raise ValueError(
|
||||
@@ -869,8 +807,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
chunk IDs vs querying for matching document chunks.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Retrieving {len(chunk_requests)} chunks for index "
|
||||
f"{self._index_name}."
|
||||
f"[OpenSearchDocumentIndex] Retrieving {len(chunk_requests)} chunks for index {self._index_name}."
|
||||
)
|
||||
results: list[InferenceChunk] = []
|
||||
for chunk_request in chunk_requests:
|
||||
@@ -917,8 +854,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Hybrid retrieving {num_to_retrieve} chunks for index "
|
||||
f"{self._index_name}."
|
||||
f"[OpenSearchDocumentIndex] Hybrid retrieving {num_to_retrieve} chunks for index {self._index_name}."
|
||||
)
|
||||
# TODO(andrei): This could be better, the caller should just make this
|
||||
# decision when passing in the query param. See the above comment in the
|
||||
@@ -938,10 +874,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
index_filters=filters,
|
||||
include_hidden=False,
|
||||
)
|
||||
# NOTE: Using z-score normalization here because it's better for hybrid
|
||||
# search from a theoretical standpoint. Empirically on a small dataset
|
||||
# of up to 10K docs, it's not very different. Likely more impactful at
|
||||
# scale.
|
||||
# NOTE: Using z-score normalization here because it's better for hybrid search from a theoretical standpoint.
|
||||
# Empirically on a small dataset of up to 10K docs, it's not very different. Likely more impactful at scale.
|
||||
# https://opensearch.org/blog/introducing-the-z-score-normalization-technique-for-hybrid-search/
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
|
||||
body=query_body,
|
||||
@@ -968,8 +902,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
dirty: bool | None = None, # noqa: ARG002
|
||||
) -> list[InferenceChunk]:
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Randomly retrieving {num_to_retrieve} chunks for index "
|
||||
f"{self._index_name}."
|
||||
f"[OpenSearchDocumentIndex] Randomly retrieving {num_to_retrieve} chunks for index {self._index_name}."
|
||||
)
|
||||
query_body = DocumentQuery.get_random_search_query(
|
||||
tenant_state=self._tenant_state,
|
||||
@@ -999,8 +932,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
complete.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Indexing {len(chunks)} raw chunks for index "
|
||||
f"{self._index_name}."
|
||||
f"[OpenSearchDocumentIndex] Indexing {len(chunks)} raw chunks for index {self._index_name}."
|
||||
)
|
||||
# Do not raise if the document already exists, just update. This is
|
||||
# because the document may already have been indexed during the
|
||||
|
||||
@@ -243,8 +243,7 @@ class DocumentChunk(BaseModel):
|
||||
return value
|
||||
if not isinstance(value, int):
|
||||
raise ValueError(
|
||||
f"Bug: Expected an int for the last_updated property from OpenSearch, got "
|
||||
f"{type(value)} instead."
|
||||
f"Bug: Expected an int for the last_updated property from OpenSearch, got {type(value)} instead."
|
||||
)
|
||||
return datetime.fromtimestamp(value, tz=timezone.utc)
|
||||
|
||||
@@ -285,22 +284,19 @@ class DocumentChunk(BaseModel):
|
||||
elif isinstance(value, TenantState):
|
||||
if MULTI_TENANT != value.multitenant:
|
||||
raise ValueError(
|
||||
f"Bug: An existing TenantState object was supplied to the DocumentChunk model "
|
||||
f"but its multi-tenant mode ({value.multitenant}) does not match the program's "
|
||||
"current global tenancy state."
|
||||
f"Bug: An existing TenantState object was supplied to the DocumentChunk model but its multi-tenant mode "
|
||||
f"({value.multitenant}) does not match the program's current global tenancy state."
|
||||
)
|
||||
return value
|
||||
elif not isinstance(value, str):
|
||||
raise ValueError(
|
||||
f"Bug: Expected a str for the tenant_id property from OpenSearch, got "
|
||||
f"{type(value)} instead."
|
||||
f"Bug: Expected a str for the tenant_id property from OpenSearch, got {type(value)} instead."
|
||||
)
|
||||
else:
|
||||
if not MULTI_TENANT:
|
||||
raise ValueError(
|
||||
"Bug: Got a non-null str for the tenant_id property from OpenSearch but "
|
||||
"multi-tenant mode is not enabled. This is unexpected because in single-tenant "
|
||||
"mode we don't expect to see a tenant_id."
|
||||
"Bug: Got a non-null str for the tenant_id property from OpenSearch but multi-tenant mode is not enabled. "
|
||||
"This is unexpected because in single-tenant mode we don't expect to see a tenant_id."
|
||||
)
|
||||
return TenantState(tenant_id=value, multitenant=MULTI_TENANT)
|
||||
|
||||
@@ -356,10 +352,8 @@ class DocumentSchema:
|
||||
"properties": {
|
||||
TITLE_FIELD_NAME: {
|
||||
"type": "text",
|
||||
# Language analyzer (e.g. english) stems at index and search
|
||||
# time for variant matching. Configure via
|
||||
# OPENSEARCH_TEXT_ANALYZER. Existing indices need reindexing
|
||||
# after a change.
|
||||
# Language analyzer (e.g. english) stems at index and search time for variant matching.
|
||||
# Configure via OPENSEARCH_TEXT_ANALYZER. Existing indices need reindexing after a change.
|
||||
"analyzer": OPENSEARCH_TEXT_ANALYZER,
|
||||
"fields": {
|
||||
# Subfield accessed as title.keyword. Not indexed for
|
||||
|
||||
@@ -698,6 +698,41 @@ class DocumentQuery:
|
||||
"""
|
||||
return {"terms": {ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME: node_ids}}
|
||||
|
||||
def _get_assistant_knowledge_filter(
|
||||
attached_doc_ids: list[str] | None,
|
||||
node_ids: list[int] | None,
|
||||
file_ids: list[UUID] | None,
|
||||
document_sets: list[str] | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Combined filter for assistant knowledge.
|
||||
|
||||
When an assistant has attached knowledge, search should be scoped to:
|
||||
- Documents explicitly attached (by document ID), OR
|
||||
- Documents under attached hierarchy nodes (by ancestor node IDs), OR
|
||||
- User-uploaded files attached to the assistant, OR
|
||||
- Documents in the assistant's document sets (if any)
|
||||
"""
|
||||
knowledge_filter: dict[str, Any] = {
|
||||
"bool": {"should": [], "minimum_should_match": 1}
|
||||
}
|
||||
if attached_doc_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_attached_document_id_filter(attached_doc_ids)
|
||||
)
|
||||
if node_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_hierarchy_node_filter(node_ids)
|
||||
)
|
||||
if file_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_file_id_filter(file_ids)
|
||||
)
|
||||
if document_sets:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_document_set_filter(document_sets)
|
||||
)
|
||||
return knowledge_filter
|
||||
|
||||
filter_clauses: list[dict[str, Any]] = []
|
||||
|
||||
if not include_hidden:
|
||||
@@ -723,53 +758,41 @@ class DocumentQuery:
|
||||
# document's metadata list.
|
||||
filter_clauses.append(_get_tag_filter(tags))
|
||||
|
||||
# Knowledge scope: explicit knowledge attachments restrict what
|
||||
# an assistant can see. When none are set the assistant
|
||||
# searches everything.
|
||||
#
|
||||
# project_id / persona_id are additive: they make overflowing
|
||||
# user files findable but must NOT trigger the restriction on
|
||||
# their own (an agent with no explicit knowledge should search
|
||||
# everything).
|
||||
has_knowledge_scope = (
|
||||
# Check if this is an assistant knowledge search (has any assistant-scoped knowledge)
|
||||
has_assistant_knowledge = (
|
||||
attached_document_ids
|
||||
or hierarchy_node_ids
|
||||
or user_file_ids
|
||||
or document_sets
|
||||
)
|
||||
|
||||
if has_knowledge_scope:
|
||||
knowledge_filter: dict[str, Any] = {
|
||||
"bool": {"should": [], "minimum_should_match": 1}
|
||||
}
|
||||
if attached_document_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_attached_document_id_filter(attached_document_ids)
|
||||
if has_assistant_knowledge:
|
||||
# If assistant has attached knowledge, scope search to that knowledge.
|
||||
# Document sets are included in the OR filter so directly attached
|
||||
# docs are always findable even if not in the document sets.
|
||||
filter_clauses.append(
|
||||
_get_assistant_knowledge_filter(
|
||||
attached_document_ids,
|
||||
hierarchy_node_ids,
|
||||
user_file_ids,
|
||||
document_sets,
|
||||
)
|
||||
if hierarchy_node_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_hierarchy_node_filter(hierarchy_node_ids)
|
||||
)
|
||||
if user_file_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_file_id_filter(user_file_ids)
|
||||
)
|
||||
if document_sets:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_document_set_filter(document_sets)
|
||||
)
|
||||
# Additive: widen scope to also cover overflowing user
|
||||
# files, but only when an explicit restriction is already
|
||||
# in effect.
|
||||
if project_id is not None:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_project_filter(project_id)
|
||||
)
|
||||
if persona_id is not None:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_persona_filter(persona_id)
|
||||
)
|
||||
filter_clauses.append(knowledge_filter)
|
||||
)
|
||||
elif user_file_ids:
|
||||
# Fallback for non-assistant user file searches (e.g., project searches)
|
||||
# If at least one user file ID is provided, the caller will only
|
||||
# retrieve documents where the document ID is in this input list of
|
||||
# file IDs.
|
||||
filter_clauses.append(_get_user_file_id_filter(user_file_ids))
|
||||
|
||||
if project_id is not None:
|
||||
# If a project ID is provided, the caller will only retrieve
|
||||
# documents where the project ID provided here is present in the
|
||||
# document's user projects list.
|
||||
filter_clauses.append(_get_user_project_filter(project_id))
|
||||
|
||||
if persona_id is not None:
|
||||
filter_clauses.append(_get_persona_filter(persona_id))
|
||||
|
||||
if time_cutoff is not None:
|
||||
# If a time cutoff is provided, the caller will only retrieve
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
import string
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
@@ -19,7 +18,6 @@ from onyx.background.celery.tasks.opensearch_migration.transformer import (
|
||||
)
|
||||
from onyx.configs.app_configs import LOG_VESPA_TIMING_INFORMATION
|
||||
from onyx.configs.app_configs import VESPA_LANGUAGE_OVERRIDE
|
||||
from onyx.configs.app_configs import VESPA_MIGRATION_REQUEST_TIMEOUT_S
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunkUncleaned
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
@@ -340,18 +338,12 @@ def get_all_chunks_paginated(
|
||||
params["continuation"] = continuation_token
|
||||
|
||||
response: httpx.Response | None = None
|
||||
start_time = time.monotonic()
|
||||
try:
|
||||
with get_vespa_http_client(
|
||||
timeout=VESPA_MIGRATION_REQUEST_TIMEOUT_S
|
||||
) as http_client:
|
||||
with get_vespa_http_client() as http_client:
|
||||
response = http_client.get(url, params=params)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPError as e:
|
||||
error_base = (
|
||||
f"Failed to get chunks from Vespa slice {slice_id} with continuation token "
|
||||
f"{continuation_token} in {time.monotonic() - start_time:.3f} seconds."
|
||||
)
|
||||
error_base = f"Failed to get chunks from Vespa slice {slice_id} with continuation token {continuation_token}."
|
||||
logger.exception(
|
||||
f"Request URL: {e.request.url}\n"
|
||||
f"Request Headers: {e.request.headers}\n"
|
||||
|
||||
@@ -465,12 +465,6 @@ class VespaIndex(DocumentIndex):
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[OldDocumentInsertionRecord]:
|
||||
"""
|
||||
NOTE: Do NOT consider the secondary index here. A separate indexing
|
||||
pipeline will be responsible for indexing to the secondary index. This
|
||||
design is not ideal and we should reconsider this when revamping index
|
||||
swapping.
|
||||
"""
|
||||
if len(index_batch_params.doc_id_to_previous_chunk_cnt) != len(
|
||||
index_batch_params.doc_id_to_new_chunk_cnt
|
||||
):
|
||||
@@ -665,10 +659,6 @@ class VespaIndex(DocumentIndex):
|
||||
"""Note: if the document id does not exist, the update will be a no-op and the
|
||||
function will complete with no errors or exceptions.
|
||||
Handle other exceptions if you wish to implement retry behavior
|
||||
|
||||
NOTE: Remember to handle the secondary index here. There is no separate
|
||||
pipeline for updating chunks in the secondary index. This design is not
|
||||
ideal and we should reconsider this when revamping index swapping.
|
||||
"""
|
||||
if fields is None and user_fields is None:
|
||||
logger.warning(
|
||||
@@ -689,6 +679,13 @@ class VespaIndex(DocumentIndex):
|
||||
f"Bug: Tenant ID mismatch. Expected {tenant_state.tenant_id}, got {tenant_id}."
|
||||
)
|
||||
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=self.index_name,
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=self.large_chunks_enabled,
|
||||
httpx_client=self.httpx_client,
|
||||
)
|
||||
|
||||
project_ids: set[int] | None = None
|
||||
if user_fields is not None and user_fields.user_projects is not None:
|
||||
project_ids = set(user_fields.user_projects)
|
||||
@@ -708,20 +705,7 @@ class VespaIndex(DocumentIndex):
|
||||
persona_ids=persona_ids,
|
||||
)
|
||||
|
||||
indices = [self.index_name]
|
||||
if self.secondary_index_name:
|
||||
indices.append(self.secondary_index_name)
|
||||
|
||||
for index_name in indices:
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=index_name,
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=self.index_to_large_chunks_enabled.get(
|
||||
index_name, False
|
||||
),
|
||||
httpx_client=self.httpx_client,
|
||||
)
|
||||
vespa_document_index.update([update_request])
|
||||
vespa_document_index.update([update_request])
|
||||
|
||||
def delete_single(
|
||||
self,
|
||||
@@ -730,11 +714,6 @@ class VespaIndex(DocumentIndex):
|
||||
tenant_id: str,
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
"""
|
||||
NOTE: Remember to handle the secondary index here. There is no separate
|
||||
pipeline for deleting chunks in the secondary index. This design is not
|
||||
ideal and we should reconsider this when revamping index swapping.
|
||||
"""
|
||||
tenant_state = TenantState(
|
||||
tenant_id=get_current_tenant_id(),
|
||||
multitenant=MULTI_TENANT,
|
||||
@@ -747,25 +726,13 @@ class VespaIndex(DocumentIndex):
|
||||
raise ValueError(
|
||||
f"Bug: Tenant ID mismatch. Expected {tenant_state.tenant_id}, got {tenant_id}."
|
||||
)
|
||||
indices = [self.index_name]
|
||||
if self.secondary_index_name:
|
||||
indices.append(self.secondary_index_name)
|
||||
|
||||
total_chunks_deleted = 0
|
||||
for index_name in indices:
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=index_name,
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=self.index_to_large_chunks_enabled.get(
|
||||
index_name, False
|
||||
),
|
||||
httpx_client=self.httpx_client,
|
||||
)
|
||||
total_chunks_deleted += vespa_document_index.delete(
|
||||
document_id=doc_id, chunk_count=chunk_count
|
||||
)
|
||||
|
||||
return total_chunks_deleted
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=self.index_name,
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=self.large_chunks_enabled,
|
||||
httpx_client=self.httpx_client,
|
||||
)
|
||||
return vespa_document_index.delete(document_id=doc_id, chunk_count=chunk_count)
|
||||
|
||||
def id_based_retrieval(
|
||||
self,
|
||||
|
||||
@@ -52,9 +52,7 @@ def replace_invalid_doc_id_characters(text: str) -> str:
|
||||
return text.replace("'", "_")
|
||||
|
||||
|
||||
def get_vespa_http_client(
|
||||
no_timeout: bool = False, http2: bool = True, timeout: int | None = None
|
||||
) -> httpx.Client:
|
||||
def get_vespa_http_client(no_timeout: bool = False, http2: bool = True) -> httpx.Client:
|
||||
"""
|
||||
Configures and returns an HTTP client for communicating with Vespa,
|
||||
including authentication if needed.
|
||||
@@ -66,7 +64,7 @@ def get_vespa_http_client(
|
||||
else None
|
||||
),
|
||||
verify=False if not MANAGED_VESPA else True,
|
||||
timeout=None if no_timeout else (timeout or VESPA_REQUEST_TIMEOUT),
|
||||
timeout=None if no_timeout else VESPA_REQUEST_TIMEOUT,
|
||||
http2=http2,
|
||||
)
|
||||
|
||||
|
||||
@@ -23,8 +23,11 @@ from shared_configs.configs import MULTI_TENANT
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def build_tenant_id_filter(tenant_id: str) -> str:
|
||||
return f'({TENANT_ID} contains "{tenant_id}")'
|
||||
def build_tenant_id_filter(tenant_id: str, include_trailing_and: bool = False) -> str:
|
||||
filter_str = f'({TENANT_ID} contains "{tenant_id}")'
|
||||
if include_trailing_and:
|
||||
filter_str += " and "
|
||||
return filter_str
|
||||
|
||||
|
||||
def build_vespa_filters(
|
||||
@@ -34,22 +37,30 @@ def build_vespa_filters(
|
||||
remove_trailing_and: bool = False, # Set to True when using as a complete Vespa query
|
||||
) -> str:
|
||||
def _build_or_filters(key: str, vals: list[str] | None) -> str:
|
||||
"""For string-based 'contains' filters, e.g. WSET fields or array<string> fields.
|
||||
Returns a bare clause like '(key contains "v1" or key contains "v2")' or ""."""
|
||||
"""For string-based 'contains' filters, e.g. WSET fields or array<string> fields."""
|
||||
if not key or not vals:
|
||||
return ""
|
||||
eq_elems = [f'{key} contains "{val}"' for val in vals if val]
|
||||
if not eq_elems:
|
||||
return ""
|
||||
return f"({' or '.join(eq_elems)})"
|
||||
or_clause = " or ".join(eq_elems)
|
||||
return f"({or_clause}) and "
|
||||
|
||||
def _build_int_or_filters(key: str, vals: list[int] | None) -> str:
|
||||
"""For an integer field filter.
|
||||
Returns a bare clause or ""."""
|
||||
"""
|
||||
For an integer field filter.
|
||||
If vals is not None, we want *only* docs whose key matches one of vals.
|
||||
"""
|
||||
# If `vals` is None => skip the filter entirely
|
||||
if vals is None or not vals:
|
||||
return ""
|
||||
|
||||
# Otherwise build the OR filter
|
||||
eq_elems = [f"{key} = {val}" for val in vals]
|
||||
return f"({' or '.join(eq_elems)})"
|
||||
or_clause = " or ".join(eq_elems)
|
||||
result = f"({or_clause}) and "
|
||||
|
||||
return result
|
||||
|
||||
def _build_kg_filter(
|
||||
kg_entities: list[str] | None,
|
||||
@@ -62,12 +73,16 @@ def build_vespa_filters(
|
||||
combined_filter_parts = []
|
||||
|
||||
def _build_kge(entity: str) -> str:
|
||||
# TYPE-SUBTYPE::ID -> "TYPE-SUBTYPE::ID"
|
||||
# TYPE-SUBTYPE::* -> ({prefix: true}"TYPE-SUBTYPE")
|
||||
# TYPE::* -> ({prefix: true}"TYPE")
|
||||
GENERAL = "::*"
|
||||
if entity.endswith(GENERAL):
|
||||
return f'({{prefix: true}}"{entity.split(GENERAL, 1)[0]}")'
|
||||
else:
|
||||
return f'"{entity}"'
|
||||
|
||||
# OR the entities (give new design)
|
||||
if kg_entities:
|
||||
filter_parts = []
|
||||
for kg_entity in kg_entities:
|
||||
@@ -89,7 +104,8 @@ def build_vespa_filters(
|
||||
|
||||
# TODO: remove kg terms entirely from prompts and codebase
|
||||
|
||||
return f"({' and '.join(combined_filter_parts)})"
|
||||
# AND the combined filter parts
|
||||
return f"({' and '.join(combined_filter_parts)}) and "
|
||||
|
||||
def _build_kg_source_filters(
|
||||
kg_sources: list[str] | None,
|
||||
@@ -98,14 +114,16 @@ def build_vespa_filters(
|
||||
return ""
|
||||
|
||||
source_phrases = [f'{DOCUMENT_ID} contains "{source}"' for source in kg_sources]
|
||||
return f"({' or '.join(source_phrases)})"
|
||||
|
||||
return f"({' or '.join(source_phrases)}) and "
|
||||
|
||||
def _build_kg_chunk_id_zero_only_filter(
|
||||
kg_chunk_id_zero_only: bool,
|
||||
) -> str:
|
||||
if not kg_chunk_id_zero_only:
|
||||
return ""
|
||||
return "(chunk_id = 0)"
|
||||
|
||||
return "(chunk_id = 0 ) and "
|
||||
|
||||
def _build_time_filter(
|
||||
cutoff: datetime | None,
|
||||
@@ -117,8 +135,8 @@ def build_vespa_filters(
|
||||
cutoff_secs = int(cutoff.timestamp())
|
||||
|
||||
if include_untimed:
|
||||
return f"!({DOC_UPDATED_AT} < {cutoff_secs})"
|
||||
return f"({DOC_UPDATED_AT} >= {cutoff_secs})"
|
||||
return f"!({DOC_UPDATED_AT} < {cutoff_secs}) and "
|
||||
return f"({DOC_UPDATED_AT} >= {cutoff_secs}) and "
|
||||
|
||||
def _build_user_project_filter(
|
||||
project_id: int | None,
|
||||
@@ -129,7 +147,8 @@ def build_vespa_filters(
|
||||
pid = int(project_id)
|
||||
except Exception:
|
||||
return ""
|
||||
return f'({USER_PROJECT} contains "{pid}")'
|
||||
# Vespa YQL 'contains' expects a string literal; quote the integer
|
||||
return f'({USER_PROJECT} contains "{pid}") and '
|
||||
|
||||
def _build_persona_filter(
|
||||
persona_id: int | None,
|
||||
@@ -141,94 +160,73 @@ def build_vespa_filters(
|
||||
except Exception:
|
||||
logger.warning(f"Invalid persona ID: {persona_id}")
|
||||
return ""
|
||||
return f'({PERSONAS} contains "{pid}")'
|
||||
return f'({PERSONAS} contains "{pid}") and '
|
||||
|
||||
def _append(parts: list[str], clause: str) -> None:
|
||||
if clause:
|
||||
parts.append(clause)
|
||||
|
||||
# Collect all top-level filter clauses, then join with " and " at the end.
|
||||
filter_parts: list[str] = []
|
||||
|
||||
if not include_hidden:
|
||||
filter_parts.append(f"!({HIDDEN}=true)")
|
||||
# Start building the filter string
|
||||
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
|
||||
|
||||
# TODO: add error condition if MULTI_TENANT and no tenant_id filter is set
|
||||
# If running in multi-tenant mode
|
||||
if filters.tenant_id and MULTI_TENANT:
|
||||
filter_parts.append(build_tenant_id_filter(filters.tenant_id))
|
||||
filter_str += build_tenant_id_filter(
|
||||
filters.tenant_id, include_trailing_and=True
|
||||
)
|
||||
|
||||
# ACL filters
|
||||
if filters.access_control_list is not None:
|
||||
_append(
|
||||
filter_parts,
|
||||
_build_or_filters(ACCESS_CONTROL_LIST, filters.access_control_list),
|
||||
filter_str += _build_or_filters(
|
||||
ACCESS_CONTROL_LIST, filters.access_control_list
|
||||
)
|
||||
|
||||
# Source type filters
|
||||
source_strs = (
|
||||
[s.value for s in filters.source_type] if filters.source_type else None
|
||||
)
|
||||
_append(filter_parts, _build_or_filters(SOURCE_TYPE, source_strs))
|
||||
filter_str += _build_or_filters(SOURCE_TYPE, source_strs)
|
||||
|
||||
# Tag filters
|
||||
tag_attributes = None
|
||||
if filters.tags:
|
||||
# build e.g. "tag_key|tag_value"
|
||||
tag_attributes = [
|
||||
f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}" for tag in filters.tags
|
||||
]
|
||||
_append(filter_parts, _build_or_filters(METADATA_LIST, tag_attributes))
|
||||
filter_str += _build_or_filters(METADATA_LIST, tag_attributes)
|
||||
|
||||
# Knowledge scope: explicit knowledge attachments (document_sets,
|
||||
# user_file_ids) restrict what an assistant can see. When none are
|
||||
# set, the assistant can see everything.
|
||||
#
|
||||
# project_id / persona_id are additive: they make overflowing user
|
||||
# files findable in Vespa but must NOT trigger the restriction on
|
||||
# their own (an agent with no explicit knowledge should search
|
||||
# everything).
|
||||
knowledge_scope_parts: list[str] = []
|
||||
|
||||
_append(
|
||||
knowledge_scope_parts, _build_or_filters(DOCUMENT_SETS, filters.document_set)
|
||||
)
|
||||
# Document sets
|
||||
filter_str += _build_or_filters(DOCUMENT_SETS, filters.document_set)
|
||||
|
||||
# Convert UUIDs to strings for user_file_ids
|
||||
user_file_ids_str = (
|
||||
[str(uuid) for uuid in filters.user_file_ids] if filters.user_file_ids else None
|
||||
)
|
||||
_append(knowledge_scope_parts, _build_or_filters(DOCUMENT_ID, user_file_ids_str))
|
||||
filter_str += _build_or_filters(DOCUMENT_ID, user_file_ids_str)
|
||||
|
||||
# Only include project/persona scopes when an explicit knowledge
|
||||
# restriction is already in effect — they widen the scope to also
|
||||
# cover overflowing user files but never restrict on their own.
|
||||
if knowledge_scope_parts:
|
||||
_append(knowledge_scope_parts, _build_user_project_filter(filters.project_id))
|
||||
_append(knowledge_scope_parts, _build_persona_filter(filters.persona_id))
|
||||
# User project filter (array<int> attribute membership)
|
||||
filter_str += _build_user_project_filter(filters.project_id)
|
||||
|
||||
if len(knowledge_scope_parts) > 1:
|
||||
filter_parts.append("(" + " or ".join(knowledge_scope_parts) + ")")
|
||||
elif len(knowledge_scope_parts) == 1:
|
||||
filter_parts.append(knowledge_scope_parts[0])
|
||||
# Persona filter (array<int> attribute membership)
|
||||
filter_str += _build_persona_filter(filters.persona_id)
|
||||
|
||||
# Time filter
|
||||
_append(filter_parts, _build_time_filter(filters.time_cutoff))
|
||||
filter_str += _build_time_filter(filters.time_cutoff)
|
||||
|
||||
# # Knowledge Graph Filters
|
||||
# _append(filter_parts, _build_kg_filter(
|
||||
# filter_str += _build_kg_filter(
|
||||
# kg_entities=filters.kg_entities,
|
||||
# kg_relationships=filters.kg_relationships,
|
||||
# kg_terms=filters.kg_terms,
|
||||
# ))
|
||||
# )
|
||||
|
||||
# _append(filter_parts, _build_kg_source_filters(filters.kg_sources))
|
||||
# filter_str += _build_kg_source_filters(filters.kg_sources)
|
||||
|
||||
# _append(filter_parts, _build_kg_chunk_id_zero_only_filter(
|
||||
# filter_str += _build_kg_chunk_id_zero_only_filter(
|
||||
# filters.kg_chunk_id_zero_only or False
|
||||
# ))
|
||||
# )
|
||||
|
||||
filter_str = " and ".join(filter_parts)
|
||||
|
||||
if filter_str and not remove_trailing_and:
|
||||
filter_str += " and "
|
||||
# Trim trailing " and "
|
||||
if remove_trailing_and and filter_str.endswith(" and "):
|
||||
filter_str = filter_str[:-5]
|
||||
|
||||
return filter_str
|
||||
|
||||
|
||||
@@ -1,101 +0,0 @@
|
||||
"""
|
||||
Standardized error codes for the Onyx backend.
|
||||
|
||||
Usage:
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
raise OnyxError(OnyxErrorCode.UNAUTHENTICATED, "Token expired")
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class OnyxErrorCode(Enum):
|
||||
"""
|
||||
Each member is a tuple of (error_code_string, http_status_code).
|
||||
|
||||
The error_code_string is a stable, machine-readable identifier that
|
||||
API consumers can match on. The http_status_code is the default HTTP
|
||||
status to return.
|
||||
"""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Authentication (401)
|
||||
# ------------------------------------------------------------------
|
||||
UNAUTHENTICATED = ("UNAUTHENTICATED", 401)
|
||||
INVALID_TOKEN = ("INVALID_TOKEN", 401)
|
||||
TOKEN_EXPIRED = ("TOKEN_EXPIRED", 401)
|
||||
CSRF_FAILURE = ("CSRF_FAILURE", 403)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Authorization (403)
|
||||
# ------------------------------------------------------------------
|
||||
UNAUTHORIZED = ("UNAUTHORIZED", 403)
|
||||
INSUFFICIENT_PERMISSIONS = ("INSUFFICIENT_PERMISSIONS", 403)
|
||||
ADMIN_ONLY = ("ADMIN_ONLY", 403)
|
||||
EE_REQUIRED = ("EE_REQUIRED", 403)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Validation / Bad Request (400)
|
||||
# ------------------------------------------------------------------
|
||||
VALIDATION_ERROR = ("VALIDATION_ERROR", 400)
|
||||
INVALID_INPUT = ("INVALID_INPUT", 400)
|
||||
MISSING_REQUIRED_FIELD = ("MISSING_REQUIRED_FIELD", 400)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Not Found (404)
|
||||
# ------------------------------------------------------------------
|
||||
NOT_FOUND = ("NOT_FOUND", 404)
|
||||
CONNECTOR_NOT_FOUND = ("CONNECTOR_NOT_FOUND", 404)
|
||||
CREDENTIAL_NOT_FOUND = ("CREDENTIAL_NOT_FOUND", 404)
|
||||
PERSONA_NOT_FOUND = ("PERSONA_NOT_FOUND", 404)
|
||||
DOCUMENT_NOT_FOUND = ("DOCUMENT_NOT_FOUND", 404)
|
||||
SESSION_NOT_FOUND = ("SESSION_NOT_FOUND", 404)
|
||||
USER_NOT_FOUND = ("USER_NOT_FOUND", 404)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Conflict (409)
|
||||
# ------------------------------------------------------------------
|
||||
CONFLICT = ("CONFLICT", 409)
|
||||
DUPLICATE_RESOURCE = ("DUPLICATE_RESOURCE", 409)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Rate Limiting / Quotas (429 / 402)
|
||||
# ------------------------------------------------------------------
|
||||
RATE_LIMITED = ("RATE_LIMITED", 429)
|
||||
SEAT_LIMIT_EXCEEDED = ("SEAT_LIMIT_EXCEEDED", 402)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Connector / Credential Errors (400-range)
|
||||
# ------------------------------------------------------------------
|
||||
CONNECTOR_VALIDATION_FAILED = ("CONNECTOR_VALIDATION_FAILED", 400)
|
||||
CREDENTIAL_INVALID = ("CREDENTIAL_INVALID", 400)
|
||||
CREDENTIAL_EXPIRED = ("CREDENTIAL_EXPIRED", 401)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Server Errors (5xx)
|
||||
# ------------------------------------------------------------------
|
||||
INTERNAL_ERROR = ("INTERNAL_ERROR", 500)
|
||||
NOT_IMPLEMENTED = ("NOT_IMPLEMENTED", 501)
|
||||
SERVICE_UNAVAILABLE = ("SERVICE_UNAVAILABLE", 503)
|
||||
BAD_GATEWAY = ("BAD_GATEWAY", 502)
|
||||
LLM_PROVIDER_ERROR = ("LLM_PROVIDER_ERROR", 502)
|
||||
GATEWAY_TIMEOUT = ("GATEWAY_TIMEOUT", 504)
|
||||
|
||||
def __init__(self, code: str, status_code: int) -> None:
|
||||
self.code = code
|
||||
self.status_code = status_code
|
||||
|
||||
def detail(self, message: str | None = None) -> dict[str, str]:
|
||||
"""Build a structured error detail dict.
|
||||
|
||||
Returns a dict like:
|
||||
{"error_code": "UNAUTHENTICATED", "message": "Token expired"}
|
||||
|
||||
If no message is supplied, the error code itself is used as the message.
|
||||
"""
|
||||
return {
|
||||
"error_code": self.code,
|
||||
"message": message or self.code,
|
||||
}
|
||||
@@ -1,83 +0,0 @@
|
||||
"""OnyxError — the single exception type for all Onyx business errors.
|
||||
|
||||
Raise ``OnyxError`` instead of ``HTTPException`` in business code. A global
|
||||
FastAPI exception handler (registered via ``register_onyx_exception_handlers``)
|
||||
converts it into a JSON response with the standard
|
||||
``{"error_code": "...", "message": "..."}`` shape.
|
||||
|
||||
Usage::
|
||||
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Session not found")
|
||||
|
||||
For upstream errors with a dynamic HTTP status (e.g. billing service),
|
||||
use ``status_code_override``::
|
||||
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
detail,
|
||||
status_code_override=upstream_status,
|
||||
)
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class OnyxError(Exception):
|
||||
"""Structured error that maps to a specific ``OnyxErrorCode``.
|
||||
|
||||
Attributes:
|
||||
error_code: The ``OnyxErrorCode`` enum member.
|
||||
message: Human-readable message (defaults to the error code string).
|
||||
status_code: HTTP status — either overridden or from the error code.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_code: OnyxErrorCode,
|
||||
message: str | None = None,
|
||||
*,
|
||||
status_code_override: int | None = None,
|
||||
) -> None:
|
||||
resolved_message = message or error_code.code
|
||||
super().__init__(resolved_message)
|
||||
self.error_code = error_code
|
||||
self.message = resolved_message
|
||||
self._status_code_override = status_code_override
|
||||
|
||||
@property
|
||||
def status_code(self) -> int:
|
||||
return self._status_code_override or self.error_code.status_code
|
||||
|
||||
|
||||
def register_onyx_exception_handlers(app: FastAPI) -> None:
|
||||
"""Register a global handler that converts ``OnyxError`` to JSON responses.
|
||||
|
||||
Must be called *after* the app is created but *before* it starts serving.
|
||||
The handler logs at WARNING for 4xx and ERROR for 5xx.
|
||||
"""
|
||||
|
||||
@app.exception_handler(OnyxError)
|
||||
async def _handle_onyx_error(
|
||||
request: Request, # noqa: ARG001
|
||||
exc: OnyxError,
|
||||
) -> JSONResponse:
|
||||
status_code = exc.status_code
|
||||
if status_code >= 500:
|
||||
logger.error(f"OnyxError {exc.error_code.code}: {exc.message}")
|
||||
elif status_code >= 400:
|
||||
logger.warning(f"OnyxError {exc.error_code.code}: {exc.message}")
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content=exc.error_code.detail(exc.message),
|
||||
)
|
||||
@@ -49,6 +49,7 @@ from onyx.indexing.embedder import IndexingEmbedder
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import IndexingBatchAdapter
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
from onyx.indexing.postgres_sanitization import sanitize_documents_for_postgres
|
||||
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
|
||||
from onyx.llm.factory import get_default_llm_with_vision
|
||||
from onyx.llm.factory import get_llm_for_contextual_rag
|
||||
@@ -64,7 +65,6 @@ from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT1
|
||||
from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT2
|
||||
from onyx.prompts.contextual_retrieval import DOCUMENT_SUMMARY_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_documents_for_postgres
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
@@ -1,49 +1,30 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_SURROGATE_RE = re.compile(r"[\ud800-\udfff]")
|
||||
|
||||
|
||||
def sanitize_string(value: str) -> str:
|
||||
"""Strip characters that PostgreSQL text/JSONB columns cannot store.
|
||||
|
||||
Removes:
|
||||
- NUL bytes (\\x00)
|
||||
- UTF-16 surrogates (\\ud800-\\udfff), which are invalid in UTF-8
|
||||
"""
|
||||
sanitized = value.replace("\x00", "")
|
||||
sanitized = _SURROGATE_RE.sub("", sanitized)
|
||||
if value and not sanitized:
|
||||
logger.warning(
|
||||
"sanitize_string: all characters were removed from a non-empty string"
|
||||
)
|
||||
return sanitized
|
||||
def _sanitize_string(value: str) -> str:
|
||||
return value.replace("\x00", "")
|
||||
|
||||
|
||||
def sanitize_json_like(value: Any) -> Any:
|
||||
"""Recursively sanitize all strings in a JSON-like structure (dict/list/tuple)."""
|
||||
def _sanitize_json_like(value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
return sanitize_string(value)
|
||||
return _sanitize_string(value)
|
||||
|
||||
if isinstance(value, list):
|
||||
return [sanitize_json_like(item) for item in value]
|
||||
return [_sanitize_json_like(item) for item in value]
|
||||
|
||||
if isinstance(value, tuple):
|
||||
return tuple(sanitize_json_like(item) for item in value)
|
||||
return tuple(_sanitize_json_like(item) for item in value)
|
||||
|
||||
if isinstance(value, dict):
|
||||
sanitized: dict[Any, Any] = {}
|
||||
for key, nested_value in value.items():
|
||||
cleaned_key = sanitize_string(key) if isinstance(key, str) else key
|
||||
sanitized[cleaned_key] = sanitize_json_like(nested_value)
|
||||
cleaned_key = _sanitize_string(key) if isinstance(key, str) else key
|
||||
sanitized[cleaned_key] = _sanitize_json_like(nested_value)
|
||||
return sanitized
|
||||
|
||||
return value
|
||||
@@ -53,27 +34,27 @@ def _sanitize_expert_info(expert: BasicExpertInfo) -> BasicExpertInfo:
|
||||
return expert.model_copy(
|
||||
update={
|
||||
"display_name": (
|
||||
sanitize_string(expert.display_name)
|
||||
_sanitize_string(expert.display_name)
|
||||
if expert.display_name is not None
|
||||
else None
|
||||
),
|
||||
"first_name": (
|
||||
sanitize_string(expert.first_name)
|
||||
_sanitize_string(expert.first_name)
|
||||
if expert.first_name is not None
|
||||
else None
|
||||
),
|
||||
"middle_initial": (
|
||||
sanitize_string(expert.middle_initial)
|
||||
_sanitize_string(expert.middle_initial)
|
||||
if expert.middle_initial is not None
|
||||
else None
|
||||
),
|
||||
"last_name": (
|
||||
sanitize_string(expert.last_name)
|
||||
_sanitize_string(expert.last_name)
|
||||
if expert.last_name is not None
|
||||
else None
|
||||
),
|
||||
"email": (
|
||||
sanitize_string(expert.email) if expert.email is not None else None
|
||||
_sanitize_string(expert.email) if expert.email is not None else None
|
||||
),
|
||||
}
|
||||
)
|
||||
@@ -82,10 +63,10 @@ def _sanitize_expert_info(expert: BasicExpertInfo) -> BasicExpertInfo:
|
||||
def _sanitize_external_access(external_access: ExternalAccess) -> ExternalAccess:
|
||||
return ExternalAccess(
|
||||
external_user_emails={
|
||||
sanitize_string(email) for email in external_access.external_user_emails
|
||||
_sanitize_string(email) for email in external_access.external_user_emails
|
||||
},
|
||||
external_user_group_ids={
|
||||
sanitize_string(group_id)
|
||||
_sanitize_string(group_id)
|
||||
for group_id in external_access.external_user_group_ids
|
||||
},
|
||||
is_public=external_access.is_public,
|
||||
@@ -95,26 +76,26 @@ def _sanitize_external_access(external_access: ExternalAccess) -> ExternalAccess
|
||||
def sanitize_document_for_postgres(document: Document) -> Document:
|
||||
cleaned_doc = document.model_copy(deep=True)
|
||||
|
||||
cleaned_doc.id = sanitize_string(cleaned_doc.id)
|
||||
cleaned_doc.semantic_identifier = sanitize_string(cleaned_doc.semantic_identifier)
|
||||
cleaned_doc.id = _sanitize_string(cleaned_doc.id)
|
||||
cleaned_doc.semantic_identifier = _sanitize_string(cleaned_doc.semantic_identifier)
|
||||
if cleaned_doc.title is not None:
|
||||
cleaned_doc.title = sanitize_string(cleaned_doc.title)
|
||||
cleaned_doc.title = _sanitize_string(cleaned_doc.title)
|
||||
if cleaned_doc.parent_hierarchy_raw_node_id is not None:
|
||||
cleaned_doc.parent_hierarchy_raw_node_id = sanitize_string(
|
||||
cleaned_doc.parent_hierarchy_raw_node_id = _sanitize_string(
|
||||
cleaned_doc.parent_hierarchy_raw_node_id
|
||||
)
|
||||
|
||||
cleaned_doc.metadata = {
|
||||
sanitize_string(key): (
|
||||
[sanitize_string(item) for item in value]
|
||||
_sanitize_string(key): (
|
||||
[_sanitize_string(item) for item in value]
|
||||
if isinstance(value, list)
|
||||
else sanitize_string(value)
|
||||
else _sanitize_string(value)
|
||||
)
|
||||
for key, value in cleaned_doc.metadata.items()
|
||||
}
|
||||
|
||||
if cleaned_doc.doc_metadata is not None:
|
||||
cleaned_doc.doc_metadata = sanitize_json_like(cleaned_doc.doc_metadata)
|
||||
cleaned_doc.doc_metadata = _sanitize_json_like(cleaned_doc.doc_metadata)
|
||||
|
||||
if cleaned_doc.primary_owners is not None:
|
||||
cleaned_doc.primary_owners = [
|
||||
@@ -132,11 +113,11 @@ def sanitize_document_for_postgres(document: Document) -> Document:
|
||||
|
||||
for section in cleaned_doc.sections:
|
||||
if section.link is not None:
|
||||
section.link = sanitize_string(section.link)
|
||||
section.link = _sanitize_string(section.link)
|
||||
if section.text is not None:
|
||||
section.text = sanitize_string(section.text)
|
||||
section.text = _sanitize_string(section.text)
|
||||
if section.image_file_id is not None:
|
||||
section.image_file_id = sanitize_string(section.image_file_id)
|
||||
section.image_file_id = _sanitize_string(section.image_file_id)
|
||||
|
||||
return cleaned_doc
|
||||
|
||||
@@ -148,12 +129,12 @@ def sanitize_documents_for_postgres(documents: list[Document]) -> list[Document]
|
||||
def sanitize_hierarchy_node_for_postgres(node: HierarchyNode) -> HierarchyNode:
|
||||
cleaned_node = node.model_copy(deep=True)
|
||||
|
||||
cleaned_node.raw_node_id = sanitize_string(cleaned_node.raw_node_id)
|
||||
cleaned_node.display_name = sanitize_string(cleaned_node.display_name)
|
||||
cleaned_node.raw_node_id = _sanitize_string(cleaned_node.raw_node_id)
|
||||
cleaned_node.display_name = _sanitize_string(cleaned_node.display_name)
|
||||
if cleaned_node.raw_parent_id is not None:
|
||||
cleaned_node.raw_parent_id = sanitize_string(cleaned_node.raw_parent_id)
|
||||
cleaned_node.raw_parent_id = _sanitize_string(cleaned_node.raw_parent_id)
|
||||
if cleaned_node.link is not None:
|
||||
cleaned_node.link = sanitize_string(cleaned_node.link)
|
||||
cleaned_node.link = _sanitize_string(cleaned_node.link)
|
||||
|
||||
if cleaned_node.external_access is not None:
|
||||
cleaned_node.external_access = _sanitize_external_access(
|
||||
@@ -22,7 +22,6 @@ class LlmProviderNames(str, Enum):
|
||||
OPENROUTER = "openrouter"
|
||||
AZURE = "azure"
|
||||
OLLAMA_CHAT = "ollama_chat"
|
||||
LM_STUDIO = "lm_studio"
|
||||
MISTRAL = "mistral"
|
||||
LITELLM_PROXY = "litellm_proxy"
|
||||
|
||||
@@ -42,7 +41,6 @@ WELL_KNOWN_PROVIDER_NAMES = [
|
||||
LlmProviderNames.OPENROUTER,
|
||||
LlmProviderNames.AZURE,
|
||||
LlmProviderNames.OLLAMA_CHAT,
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
]
|
||||
|
||||
|
||||
@@ -58,7 +56,6 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
|
||||
LlmProviderNames.AZURE: "Azure",
|
||||
"ollama": "Ollama",
|
||||
LlmProviderNames.OLLAMA_CHAT: "Ollama",
|
||||
LlmProviderNames.LM_STUDIO: "LM Studio",
|
||||
"groq": "Groq",
|
||||
"anyscale": "Anyscale",
|
||||
"deepseek": "DeepSeek",
|
||||
@@ -106,7 +103,6 @@ AGGREGATOR_PROVIDERS: set[str] = {
|
||||
LlmProviderNames.BEDROCK_CONVERSE,
|
||||
LlmProviderNames.OPENROUTER,
|
||||
LlmProviderNames.OLLAMA_CHAT,
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
LlmProviderNames.VERTEX_AI,
|
||||
LlmProviderNames.AZURE,
|
||||
}
|
||||
|
||||
@@ -20,9 +20,7 @@ from onyx.llm.multi_llm import LitellmLLM
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.utils import get_max_input_tokens_from_llm_provider
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.llm.well_known_providers.constants import (
|
||||
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING,
|
||||
)
|
||||
from onyx.llm.well_known_providers.constants import OLLAMA_API_KEY_CONFIG_KEY
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.utils.headers import build_llm_extra_headers
|
||||
@@ -34,18 +32,14 @@ logger = setup_logger()
|
||||
def _build_provider_extra_headers(
|
||||
provider: str, custom_config: dict[str, str] | None
|
||||
) -> dict[str, str]:
|
||||
if provider in PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING and custom_config:
|
||||
raw = custom_config.get(PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING[provider])
|
||||
api_key = raw.strip() if raw else None
|
||||
if provider == LlmProviderNames.OLLAMA_CHAT and custom_config:
|
||||
raw_api_key = custom_config.get(OLLAMA_API_KEY_CONFIG_KEY)
|
||||
api_key = raw_api_key.strip() if raw_api_key else None
|
||||
if not api_key:
|
||||
return {}
|
||||
return {
|
||||
"Authorization": (
|
||||
api_key
|
||||
if api_key.lower().startswith("bearer ")
|
||||
else f"Bearer {api_key}"
|
||||
)
|
||||
}
|
||||
if not api_key.lower().startswith("bearer "):
|
||||
api_key = f"Bearer {api_key}"
|
||||
return {"Authorization": api_key}
|
||||
|
||||
# Passing these will put Onyx on the OpenRouter leaderboard
|
||||
elif provider == LlmProviderNames.OPENROUTER:
|
||||
|
||||
@@ -1512,10 +1512,6 @@
|
||||
"display_name": "Claude Opus 4.5",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-opus-4-6": {
|
||||
"display_name": "Claude Opus 4.6",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-opus-4-5-20251101": {
|
||||
"display_name": "Claude Opus 4.5",
|
||||
"model_vendor": "anthropic",
|
||||
@@ -1530,10 +1526,6 @@
|
||||
"display_name": "Claude Sonnet 4.5",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-sonnet-4-6": {
|
||||
"display_name": "Claude Sonnet 4.6",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-sonnet-4-5-20250929": {
|
||||
"display_name": "Claude Sonnet 4.5",
|
||||
"model_vendor": "anthropic",
|
||||
@@ -2524,10 +2516,6 @@
|
||||
"model_vendor": "openai",
|
||||
"model_version": "2025-10-06"
|
||||
},
|
||||
"gpt-5.4": {
|
||||
"display_name": "GPT-5.4",
|
||||
"model_vendor": "openai"
|
||||
},
|
||||
"gpt-5.2-pro-2025-12-11": {
|
||||
"display_name": "GPT-5.2 Pro",
|
||||
"model_vendor": "openai",
|
||||
|
||||
@@ -42,7 +42,6 @@ from onyx.llm.well_known_providers.constants import AWS_SECRET_ACCESS_KEY_KWARG
|
||||
from onyx.llm.well_known_providers.constants import (
|
||||
AWS_SECRET_ACCESS_KEY_KWARG_ENV_VAR_FORMAT,
|
||||
)
|
||||
from onyx.llm.well_known_providers.constants import LM_STUDIO_API_KEY_CONFIG_KEY
|
||||
from onyx.llm.well_known_providers.constants import OLLAMA_API_KEY_CONFIG_KEY
|
||||
from onyx.llm.well_known_providers.constants import VERTEX_CREDENTIALS_FILE_KWARG
|
||||
from onyx.llm.well_known_providers.constants import (
|
||||
@@ -93,98 +92,6 @@ def _prompt_to_dicts(prompt: LanguageModelInput) -> list[dict[str, Any]]:
|
||||
return [prompt.model_dump(exclude_none=True)]
|
||||
|
||||
|
||||
def _normalize_content(raw: Any) -> str:
|
||||
"""Normalize a message content field to a plain string.
|
||||
|
||||
Content can be a string, None, or a list of content-block dicts
|
||||
(e.g. [{"type": "text", "text": "..."}]).
|
||||
"""
|
||||
if raw is None:
|
||||
return ""
|
||||
if isinstance(raw, str):
|
||||
return raw
|
||||
if isinstance(raw, list):
|
||||
return "\n".join(
|
||||
block.get("text", "") if isinstance(block, dict) else str(block)
|
||||
for block in raw
|
||||
)
|
||||
return str(raw)
|
||||
|
||||
|
||||
def _strip_tool_content_from_messages(
|
||||
messages: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert tool-related messages to plain text.
|
||||
|
||||
Bedrock's Converse API requires toolConfig when messages contain
|
||||
toolUse/toolResult content blocks. When no tools are provided for the
|
||||
current request, we must convert any tool-related history into plain text
|
||||
to avoid the "toolConfig field must be defined" error.
|
||||
|
||||
This is the same approach used by _OllamaHistoryMessageFormatter.
|
||||
"""
|
||||
result: list[dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
role = msg.get("role")
|
||||
tool_calls = msg.get("tool_calls")
|
||||
|
||||
if role == "assistant" and tool_calls:
|
||||
# Convert structured tool calls to text representation
|
||||
tool_call_lines = []
|
||||
for tc in tool_calls:
|
||||
func = tc.get("function", {})
|
||||
name = func.get("name", "unknown")
|
||||
args = func.get("arguments", "{}")
|
||||
tc_id = tc.get("id", "")
|
||||
tool_call_lines.append(
|
||||
f"[Tool Call] name={name} id={tc_id} args={args}"
|
||||
)
|
||||
|
||||
existing_content = _normalize_content(msg.get("content"))
|
||||
parts = (
|
||||
[existing_content] + tool_call_lines
|
||||
if existing_content
|
||||
else tool_call_lines
|
||||
)
|
||||
new_msg = {
|
||||
"role": "assistant",
|
||||
"content": "\n".join(parts),
|
||||
}
|
||||
result.append(new_msg)
|
||||
|
||||
elif role == "tool":
|
||||
# Convert tool response to user message with text content
|
||||
tool_call_id = msg.get("tool_call_id", "")
|
||||
content = _normalize_content(msg.get("content"))
|
||||
tool_result_text = f"[Tool Result] id={tool_call_id}\n{content}"
|
||||
# Merge into previous user message if it is also a converted
|
||||
# tool result to avoid consecutive user messages (Bedrock requires
|
||||
# strict user/assistant alternation).
|
||||
if (
|
||||
result
|
||||
and result[-1]["role"] == "user"
|
||||
and "[Tool Result]" in result[-1].get("content", "")
|
||||
):
|
||||
result[-1]["content"] += "\n\n" + tool_result_text
|
||||
else:
|
||||
result.append({"role": "user", "content": tool_result_text})
|
||||
|
||||
else:
|
||||
result.append(msg)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _messages_contain_tool_content(messages: list[dict[str, Any]]) -> bool:
|
||||
"""Check if any messages contain tool-related content blocks."""
|
||||
for msg in messages:
|
||||
if msg.get("role") == "tool":
|
||||
return True
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_vertex_model_rejecting_output_config(model_name: str) -> bool:
|
||||
normalized_model_name = model_name.lower()
|
||||
return any(
|
||||
@@ -250,9 +157,6 @@ class LitellmLLM(LLM):
|
||||
elif model_provider == LlmProviderNames.OLLAMA_CHAT:
|
||||
if k == OLLAMA_API_KEY_CONFIG_KEY:
|
||||
model_kwargs["api_key"] = v
|
||||
elif model_provider == LlmProviderNames.LM_STUDIO:
|
||||
if k == LM_STUDIO_API_KEY_CONFIG_KEY:
|
||||
model_kwargs["api_key"] = v
|
||||
elif model_provider == LlmProviderNames.BEDROCK:
|
||||
if k == AWS_REGION_NAME_KWARG:
|
||||
model_kwargs[k] = v
|
||||
@@ -269,19 +173,6 @@ class LitellmLLM(LLM):
|
||||
elif k == AWS_SECRET_ACCESS_KEY_KWARG_ENV_VAR_FORMAT:
|
||||
model_kwargs[AWS_SECRET_ACCESS_KEY_KWARG] = v
|
||||
|
||||
# LM Studio: LiteLLM defaults to "fake-api-key" when no key is provided,
|
||||
# which LM Studio rejects. Ensure we always pass an explicit key (or empty
|
||||
# string) to prevent LiteLLM from injecting its fake default.
|
||||
if model_provider == LlmProviderNames.LM_STUDIO:
|
||||
model_kwargs.setdefault("api_key", "")
|
||||
|
||||
# Users provide the server root (e.g. http://localhost:1234) but LiteLLM
|
||||
# needs /v1 for OpenAI-compatible calls.
|
||||
if self._api_base is not None:
|
||||
base = self._api_base.rstrip("/")
|
||||
self._api_base = base if base.endswith("/v1") else f"{base}/v1"
|
||||
model_kwargs["api_base"] = self._api_base
|
||||
|
||||
# Default vertex_location to "global" if not provided for Vertex AI
|
||||
# Latest gemini models are only available through the global region
|
||||
if (
|
||||
@@ -513,30 +404,13 @@ class LitellmLLM(LLM):
|
||||
else nullcontext()
|
||||
)
|
||||
with env_ctx:
|
||||
messages = _prompt_to_dicts(prompt)
|
||||
|
||||
# Bedrock's Converse API requires toolConfig when messages
|
||||
# contain toolUse/toolResult content blocks. When no tools are
|
||||
# provided for this request but the history contains tool
|
||||
# content from previous turns, strip it to plain text.
|
||||
is_bedrock = self._model_provider in {
|
||||
LlmProviderNames.BEDROCK,
|
||||
LlmProviderNames.BEDROCK_CONVERSE,
|
||||
}
|
||||
if (
|
||||
is_bedrock
|
||||
and not tools
|
||||
and _messages_contain_tool_content(messages)
|
||||
):
|
||||
messages = _strip_tool_content_from_messages(messages)
|
||||
|
||||
response = litellm.completion(
|
||||
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
|
||||
model=model,
|
||||
base_url=self._api_base or None,
|
||||
api_version=self._api_version or None,
|
||||
custom_llm_provider=self._custom_llm_provider or None,
|
||||
messages=messages,
|
||||
messages=_prompt_to_dicts(prompt),
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=stream,
|
||||
|
||||
@@ -322,7 +322,7 @@ def test_llm(llm: LLM) -> str | None:
|
||||
error_msg = None
|
||||
for _ in range(2):
|
||||
try:
|
||||
llm.invoke(UserMessage(content="Do not respond"), max_tokens=50)
|
||||
llm.invoke(UserMessage(content="Do not respond"))
|
||||
return None
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
|
||||
@@ -1,27 +1,52 @@
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
|
||||
OPENAI_PROVIDER_NAME = "openai"
|
||||
# Curated list of OpenAI models to show by default in the UI
|
||||
OPENAI_VISIBLE_MODEL_NAMES = {
|
||||
"gpt-5",
|
||||
"gpt-5-mini",
|
||||
"o1",
|
||||
"o3-mini",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
}
|
||||
|
||||
BEDROCK_PROVIDER_NAME = "bedrock"
|
||||
BEDROCK_DEFAULT_MODEL = "anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
|
||||
|
||||
def _fallback_bedrock_regions() -> list[str]:
|
||||
# Fall back to a conservative set of well-known Bedrock regions if boto3 data isn't available.
|
||||
return [
|
||||
"us-east-1",
|
||||
"us-east-2",
|
||||
"us-gov-east-1",
|
||||
"us-gov-west-1",
|
||||
"us-west-2",
|
||||
"ap-northeast-1",
|
||||
"ap-south-1",
|
||||
"ap-southeast-1",
|
||||
"ap-southeast-2",
|
||||
"ap-east-1",
|
||||
"ca-central-1",
|
||||
"eu-central-1",
|
||||
"eu-west-2",
|
||||
]
|
||||
|
||||
|
||||
OLLAMA_PROVIDER_NAME = "ollama_chat"
|
||||
OLLAMA_API_KEY_CONFIG_KEY = "OLLAMA_API_KEY"
|
||||
|
||||
LM_STUDIO_PROVIDER_NAME = "lm_studio"
|
||||
LM_STUDIO_API_KEY_CONFIG_KEY = "LM_STUDIO_API_KEY"
|
||||
|
||||
# Providers that use optional Bearer auth from custom_config
|
||||
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING: dict[str, str] = {
|
||||
LlmProviderNames.OLLAMA_CHAT: OLLAMA_API_KEY_CONFIG_KEY,
|
||||
LlmProviderNames.LM_STUDIO: LM_STUDIO_API_KEY_CONFIG_KEY,
|
||||
}
|
||||
|
||||
# OpenRouter
|
||||
OPENROUTER_PROVIDER_NAME = "openrouter"
|
||||
|
||||
ANTHROPIC_PROVIDER_NAME = "anthropic"
|
||||
|
||||
# Curated list of Anthropic models to show by default in the UI
|
||||
ANTHROPIC_VISIBLE_MODEL_NAMES = {
|
||||
"claude-opus-4-5",
|
||||
"claude-sonnet-4-5",
|
||||
"claude-haiku-4-5",
|
||||
}
|
||||
|
||||
AZURE_PROVIDER_NAME = "azure"
|
||||
|
||||
|
||||
@@ -29,6 +54,13 @@ VERTEXAI_PROVIDER_NAME = "vertex_ai"
|
||||
VERTEX_CREDENTIALS_FILE_KWARG = "vertex_credentials"
|
||||
VERTEX_CREDENTIALS_FILE_KWARG_ENV_VAR_FORMAT = "CREDENTIALS_FILE"
|
||||
VERTEX_LOCATION_KWARG = "vertex_location"
|
||||
VERTEXAI_DEFAULT_MODEL = "gemini-2.5-flash"
|
||||
# Curated list of Vertex AI models to show by default in the UI
|
||||
VERTEXAI_VISIBLE_MODEL_NAMES = {
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-flash-lite",
|
||||
"gemini-2.5-pro",
|
||||
}
|
||||
|
||||
AWS_REGION_NAME_KWARG = "aws_region_name"
|
||||
AWS_REGION_NAME_KWARG_ENV_VAR_FORMAT = "AWS_REGION_NAME"
|
||||
|
||||
@@ -15,7 +15,6 @@ from onyx.llm.well_known_providers.auto_update_service import (
|
||||
from onyx.llm.well_known_providers.constants import ANTHROPIC_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import AZURE_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import BEDROCK_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import LM_STUDIO_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OLLAMA_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENAI_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENROUTER_PROVIDER_NAME
|
||||
@@ -45,7 +44,6 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
|
||||
ANTHROPIC_PROVIDER_NAME: get_anthropic_model_names(),
|
||||
VERTEXAI_PROVIDER_NAME: get_vertexai_model_names(),
|
||||
OLLAMA_PROVIDER_NAME: [], # Dynamic - fetched from Ollama API
|
||||
LM_STUDIO_PROVIDER_NAME: [], # Dynamic - fetched from LM Studio API
|
||||
OPENROUTER_PROVIDER_NAME: [], # Dynamic - fetched from OpenRouter API
|
||||
}
|
||||
|
||||
@@ -325,7 +323,6 @@ def get_provider_display_name(provider_name: str) -> str:
|
||||
_ONYX_PROVIDER_DISPLAY_NAMES: dict[str, str] = {
|
||||
OPENAI_PROVIDER_NAME: "ChatGPT (OpenAI)",
|
||||
OLLAMA_PROVIDER_NAME: "Ollama",
|
||||
LM_STUDIO_PROVIDER_NAME: "LM Studio",
|
||||
ANTHROPIC_PROVIDER_NAME: "Claude (Anthropic)",
|
||||
AZURE_PROVIDER_NAME: "Azure OpenAI",
|
||||
BEDROCK_PROVIDER_NAME: "Amazon Bedrock",
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"version": "1.1",
|
||||
"updated_at": "2026-03-05T00:00:00Z",
|
||||
"updated_at": "2026-02-05T00:00:00Z",
|
||||
"providers": {
|
||||
"openai": {
|
||||
"default_model": { "name": "gpt-5.4" },
|
||||
"default_model": { "name": "gpt-5.2" },
|
||||
"additional_visible_models": [
|
||||
{ "name": "gpt-5.4" },
|
||||
{ "name": "gpt-5.2" }
|
||||
{ "name": "gpt-5-mini" },
|
||||
{ "name": "gpt-4.1" }
|
||||
]
|
||||
},
|
||||
"anthropic": {
|
||||
@@ -16,10 +16,6 @@
|
||||
"name": "claude-opus-4-6",
|
||||
"display_name": "Claude Opus 4.6"
|
||||
},
|
||||
{
|
||||
"name": "claude-sonnet-4-6",
|
||||
"display_name": "Claude Sonnet 4.6"
|
||||
},
|
||||
{
|
||||
"name": "claude-opus-4-5",
|
||||
"display_name": "Claude Opus 4.5"
|
||||
|
||||
@@ -59,7 +59,6 @@ from onyx.db.engine.async_sql_engine import get_sqlalchemy_async_engine
|
||||
from onyx.db.engine.connection_warmup import warm_up_connections
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.error_handling.exceptions import register_onyx_exception_handlers
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.server.api_key.api import router as api_key_router
|
||||
from onyx.server.auth_check import check_router_auth
|
||||
@@ -445,8 +444,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR, log_http_error
|
||||
)
|
||||
|
||||
register_onyx_exception_handlers(application)
|
||||
|
||||
include_router_with_global_prefix_prepended(application, password_router)
|
||||
include_router_with_global_prefix_prepended(application, chat_router)
|
||||
include_router_with_global_prefix_prepended(application, query_router)
|
||||
|
||||
@@ -130,7 +130,7 @@ def format_slack_message(message: str | None) -> str:
|
||||
message = _transform_outside_code_blocks(message, _sanitize_html)
|
||||
message = _convert_slack_links_to_markdown(message)
|
||||
normalized_message = _normalize_link_destinations(message)
|
||||
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough", "table"])
|
||||
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough"])
|
||||
result = md(normalized_message)
|
||||
# With HTMLRenderer, result is always str (not AST list)
|
||||
assert isinstance(result, str)
|
||||
@@ -146,11 +146,6 @@ class SlackRenderer(HTMLRenderer):
|
||||
|
||||
SPECIALS: dict[str, str] = {"&": "&", "<": "<", ">": ">"}
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._table_headers: list[str] = []
|
||||
self._current_row_cells: list[str] = []
|
||||
|
||||
def escape_special(self, text: str) -> str:
|
||||
for special, replacement in self.SPECIALS.items():
|
||||
text = text.replace(special, replacement)
|
||||
@@ -223,48 +218,5 @@ class SlackRenderer(HTMLRenderer):
|
||||
# as literal " text since Slack doesn't recognize that entity.
|
||||
return self.escape_special(text)
|
||||
|
||||
# -- Table rendering (converts markdown tables to vertical cards) --
|
||||
|
||||
def table_cell(
|
||||
self, text: str, align: str | None = None, head: bool = False # noqa: ARG002
|
||||
) -> str:
|
||||
if head:
|
||||
self._table_headers.append(text.strip())
|
||||
else:
|
||||
self._current_row_cells.append(text.strip())
|
||||
return ""
|
||||
|
||||
def table_head(self, text: str) -> str: # noqa: ARG002
|
||||
self._current_row_cells = []
|
||||
return ""
|
||||
|
||||
def table_row(self, text: str) -> str: # noqa: ARG002
|
||||
cells = self._current_row_cells
|
||||
self._current_row_cells = []
|
||||
# First column becomes the bold title, remaining columns are bulleted fields
|
||||
lines: list[str] = []
|
||||
if cells:
|
||||
title = cells[0]
|
||||
if title:
|
||||
# Avoid double-wrapping if cell already contains bold markup
|
||||
if title.startswith("*") and title.endswith("*") and len(title) > 1:
|
||||
lines.append(title)
|
||||
else:
|
||||
lines.append(f"*{title}*")
|
||||
for i, cell in enumerate(cells[1:], start=1):
|
||||
if i < len(self._table_headers):
|
||||
lines.append(f" • {self._table_headers[i]}: {cell}")
|
||||
else:
|
||||
lines.append(f" • {cell}")
|
||||
return "\n".join(lines) + "\n\n"
|
||||
|
||||
def table_body(self, text: str) -> str:
|
||||
return text
|
||||
|
||||
def table(self, text: str) -> str:
|
||||
self._table_headers = []
|
||||
self._current_row_cells = []
|
||||
return text + "\n"
|
||||
|
||||
def paragraph(self, text: str) -> str:
|
||||
return f"{text}\n\n"
|
||||
|
||||
@@ -92,7 +92,6 @@ from onyx.db.connector_credential_pair import get_connector_credential_pairs_for
|
||||
from onyx.db.connector_credential_pair import (
|
||||
get_connector_credential_pairs_for_user_parallel,
|
||||
)
|
||||
from onyx.db.connector_credential_pair import verify_user_has_access_to_cc_pair
|
||||
from onyx.db.credentials import cleanup_gmail_credentials
|
||||
from onyx.db.credentials import cleanup_google_drive_credentials
|
||||
from onyx.db.credentials import create_credential
|
||||
@@ -573,43 +572,6 @@ def _normalize_file_names_for_backwards_compatibility(
|
||||
return file_names + file_locations[len(file_names) :]
|
||||
|
||||
|
||||
def _fetch_and_check_file_connector_cc_pair_permissions(
|
||||
connector_id: int,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
require_editable: bool,
|
||||
) -> ConnectorCredentialPair:
|
||||
cc_pair = fetch_connector_credential_pair_for_connector(db_session, connector_id)
|
||||
if cc_pair is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No Connector-Credential Pair found for this connector",
|
||||
)
|
||||
|
||||
has_requested_access = verify_user_has_access_to_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
get_editable=require_editable,
|
||||
)
|
||||
if has_requested_access:
|
||||
return cc_pair
|
||||
|
||||
# Special case: global curators should be able to manage files
|
||||
# for public file connectors even when they are not the creator.
|
||||
if (
|
||||
require_editable
|
||||
and user.role == UserRole.GLOBAL_CURATOR
|
||||
and cc_pair.access_type == AccessType.PUBLIC
|
||||
):
|
||||
return cc_pair
|
||||
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Access denied. User cannot manage files for this connector.",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/admin/connector/file/upload", tags=PUBLIC_API_TAGS)
|
||||
def upload_files_api(
|
||||
files: list[UploadFile],
|
||||
@@ -621,7 +583,7 @@ def upload_files_api(
|
||||
@router.get("/admin/connector/{connector_id}/files", tags=PUBLIC_API_TAGS)
|
||||
def list_connector_files(
|
||||
connector_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
user: User = Depends(current_curator_or_admin_user), # noqa: ARG001
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ConnectorFilesResponse:
|
||||
"""List all files in a file connector."""
|
||||
@@ -634,13 +596,6 @@ def list_connector_files(
|
||||
status_code=400, detail="This endpoint only works with file connectors"
|
||||
)
|
||||
|
||||
_ = _fetch_and_check_file_connector_cc_pair_permissions(
|
||||
connector_id=connector_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
require_editable=False,
|
||||
)
|
||||
|
||||
file_locations = connector.connector_specific_config.get("file_locations", [])
|
||||
file_names = connector.connector_specific_config.get("file_names", [])
|
||||
|
||||
@@ -690,7 +645,7 @@ def update_connector_files(
|
||||
connector_id: int,
|
||||
files: list[UploadFile] | None = File(None),
|
||||
file_ids_to_remove: str = Form("[]"),
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
user: User = Depends(current_curator_or_admin_user), # noqa: ARG001
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> FileUploadResponse:
|
||||
"""
|
||||
@@ -708,13 +663,12 @@ def update_connector_files(
|
||||
)
|
||||
|
||||
# Get the connector-credential pair for indexing/pruning triggers
|
||||
# and validate user permissions for file management.
|
||||
cc_pair = _fetch_and_check_file_connector_cc_pair_permissions(
|
||||
connector_id=connector_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
require_editable=True,
|
||||
)
|
||||
cc_pair = fetch_connector_credential_pair_for_connector(db_session, connector_id)
|
||||
if cc_pair is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No Connector-Credential Pair found for this connector",
|
||||
)
|
||||
|
||||
# Parse file IDs to remove
|
||||
try:
|
||||
|
||||
@@ -961,9 +961,9 @@
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@hono/node-server": {
|
||||
"version": "1.19.10",
|
||||
"resolved": "https://registry.npmjs.org/@hono/node-server/-/node-server-1.19.10.tgz",
|
||||
"integrity": "sha512-hZ7nOssGqRgyV3FVVQdfi+U4q02uB23bpnYpdvNXkYTRRyWx84b7yf1ans+dnJ/7h41sGL3CeQTfO+ZGxuO+Iw==",
|
||||
"version": "1.19.9",
|
||||
"resolved": "https://registry.npmjs.org/@hono/node-server/-/node-server-1.19.9.tgz",
|
||||
"integrity": "sha512-vHL6w3ecZsky+8P5MD+eFfaGTyCeOHUIFYMGpQGbrBTSmNNoxv0if69rEZ5giu36weC5saFuznL411gRX7bJDw==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=18.14.1"
|
||||
@@ -1573,6 +1573,27 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@isaacs/balanced-match": {
|
||||
"version": "4.0.1",
|
||||
"resolved": "https://registry.npmjs.org/@isaacs/balanced-match/-/balanced-match-4.0.1.tgz",
|
||||
"integrity": "sha512-yzMTt9lEb8Gv7zRioUilSglI0c0smZ9k5D65677DLWLtWJaXIS3CqcGyUFByYKlnUj6TkjLVs54fBl6+TiGQDQ==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": "20 || >=22"
|
||||
}
|
||||
},
|
||||
"node_modules/@isaacs/brace-expansion": {
|
||||
"version": "5.0.1",
|
||||
"resolved": "https://registry.npmjs.org/@isaacs/brace-expansion/-/brace-expansion-5.0.1.tgz",
|
||||
"integrity": "sha512-WMz71T1JS624nWj2n2fnYAuPovhv7EUhk69R6i9dsVyzxt5eM3bjwvgk9L+APE1TRscGysAVMANkB0jh0LQZrQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@isaacs/balanced-match": "^4.0.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": "20 || >=22"
|
||||
}
|
||||
},
|
||||
"node_modules/@jridgewell/gen-mapping": {
|
||||
"version": "0.3.13",
|
||||
"resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.13.tgz",
|
||||
@@ -1659,9 +1680,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@modelcontextprotocol/sdk/node_modules/ajv": {
|
||||
"version": "8.18.0",
|
||||
"resolved": "https://registry.npmjs.org/ajv/-/ajv-8.18.0.tgz",
|
||||
"integrity": "sha512-PlXPeEWMXMZ7sPYOHqmDyCJzcfNrUr3fGNKtezX14ykXOEIvyK81d+qydx89KY5O71FKMPaQ2vBfBFI5NHR63A==",
|
||||
"version": "8.17.1",
|
||||
"resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz",
|
||||
"integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"fast-deep-equal": "^3.1.3",
|
||||
@@ -3834,27 +3855,6 @@
|
||||
"path-browserify": "^1.0.1"
|
||||
}
|
||||
},
|
||||
"node_modules/@ts-morph/common/node_modules/balanced-match": {
|
||||
"version": "4.0.4",
|
||||
"resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-4.0.4.tgz",
|
||||
"integrity": "sha512-BLrgEcRTwX2o6gGxGOCNyMvGSp35YofuYzw9h1IMTRmKqttAZZVU67bdb9Pr2vUHA8+j3i2tJfjO6C6+4myGTA==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": "18 || 20 || >=22"
|
||||
}
|
||||
},
|
||||
"node_modules/@ts-morph/common/node_modules/brace-expansion": {
|
||||
"version": "5.0.3",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-5.0.3.tgz",
|
||||
"integrity": "sha512-fy6KJm2RawA5RcHkLa1z/ScpBeA762UF9KmZQxwIbDtRJrgLzM10depAiEQ+CXYcoiqW1/m96OAAoke2nE9EeA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"balanced-match": "^4.0.2"
|
||||
},
|
||||
"engines": {
|
||||
"node": "18 || 20 || >=22"
|
||||
}
|
||||
},
|
||||
"node_modules/@ts-morph/common/node_modules/fast-glob": {
|
||||
"version": "3.3.3",
|
||||
"resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.3.3.tgz",
|
||||
@@ -3884,15 +3884,15 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@ts-morph/common/node_modules/minimatch": {
|
||||
"version": "10.2.4",
|
||||
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-10.2.4.tgz",
|
||||
"integrity": "sha512-oRjTw/97aTBN0RHbYCdtF1MQfvusSIBQM0IZEgzl6426+8jSC0nF1a/GmnVLpfB9yyr6g6FTqWqiZVbxrtaCIg==",
|
||||
"version": "10.1.1",
|
||||
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-10.1.1.tgz",
|
||||
"integrity": "sha512-enIvLvRAFZYXJzkCYG5RKmPfrFArdLv+R+lbQ53BmIMLIry74bjKzX6iHAm8WYamJkhSSEabrWN5D97XnKObjQ==",
|
||||
"license": "BlueOak-1.0.0",
|
||||
"dependencies": {
|
||||
"brace-expansion": "^5.0.2"
|
||||
"@isaacs/brace-expansion": "^5.0.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": "18 || 20 || >=22"
|
||||
"node": "20 || >=22"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/isaacs"
|
||||
@@ -4234,13 +4234,13 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@typescript-eslint/typescript-estree/node_modules/minimatch": {
|
||||
"version": "9.0.9",
|
||||
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.9.tgz",
|
||||
"integrity": "sha512-OBwBN9AL4dqmETlpS2zasx+vTeWclWzkblfZk7KTA5j3jeOONz/tRCnZomUyvNg83wL5Zv9Ss6HMJXAgL8R2Yg==",
|
||||
"version": "9.0.5",
|
||||
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz",
|
||||
"integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"dependencies": {
|
||||
"brace-expansion": "^2.0.2"
|
||||
"brace-expansion": "^2.0.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=16 || 14 >=14.17"
|
||||
@@ -4619,9 +4619,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/ajv": {
|
||||
"version": "6.14.0",
|
||||
"resolved": "https://registry.npmjs.org/ajv/-/ajv-6.14.0.tgz",
|
||||
"integrity": "sha512-IWrosm/yrn43eiKqkfkHis7QioDleaXQHdDVPKg0FSwwd/DuvyX79TZnFOnYpB7dcsFAMmtFztZuXPDvSePkFw==",
|
||||
"version": "6.12.6",
|
||||
"resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz",
|
||||
"integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
@@ -4653,9 +4653,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/ajv-formats/node_modules/ajv": {
|
||||
"version": "8.18.0",
|
||||
"resolved": "https://registry.npmjs.org/ajv/-/ajv-8.18.0.tgz",
|
||||
"integrity": "sha512-PlXPeEWMXMZ7sPYOHqmDyCJzcfNrUr3fGNKtezX14ykXOEIvyK81d+qydx89KY5O71FKMPaQ2vBfBFI5NHR63A==",
|
||||
"version": "8.17.1",
|
||||
"resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz",
|
||||
"integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"fast-deep-equal": "^3.1.3",
|
||||
@@ -6758,12 +6758,12 @@
|
||||
}
|
||||
},
|
||||
"node_modules/express-rate-limit": {
|
||||
"version": "8.3.0",
|
||||
"resolved": "https://registry.npmjs.org/express-rate-limit/-/express-rate-limit-8.3.0.tgz",
|
||||
"integrity": "sha512-KJzBawY6fB9FiZGdE/0aftepZ91YlaGIrV8vgblRM3J8X+dHx/aiowJWwkx6LIGyuqGiANsjSwwrbb8mifOJ4Q==",
|
||||
"version": "8.2.1",
|
||||
"resolved": "https://registry.npmjs.org/express-rate-limit/-/express-rate-limit-8.2.1.tgz",
|
||||
"integrity": "sha512-PCZEIEIxqwhzw4KF0n7QF4QqruVTcF73O5kFKUnGOyjbCCgizBBiFaYpd/fnBLUMPw/BWw9OsiN7GgrNYr7j6g==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"ip-address": "10.1.0"
|
||||
"ip-address": "10.0.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 16"
|
||||
@@ -7424,9 +7424,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/hono": {
|
||||
"version": "4.12.5",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.5.tgz",
|
||||
"integrity": "sha512-3qq+FUBtlTHhtYxbxheZgY8NIFnkkC/MR8u5TTsr7YZ3wixryQ3cCwn3iZbg8p8B88iDBBAYSfZDS75t8MN7Vg==",
|
||||
"version": "4.11.7",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.11.7.tgz",
|
||||
"integrity": "sha512-l7qMiNee7t82bH3SeyUCt9UF15EVmaBvsppY2zQtrbIhl/yzBTny+YUxsVjSjQ6gaqaeVtZmGocom8TzBlA4Yw==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=16.9.0"
|
||||
@@ -7556,9 +7556,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/ip-address": {
|
||||
"version": "10.1.0",
|
||||
"resolved": "https://registry.npmjs.org/ip-address/-/ip-address-10.1.0.tgz",
|
||||
"integrity": "sha512-XXADHxXmvT9+CRxhXg56LJovE+bmWnEWB78LB83VZTprKTmaC5QfruXocxzTZ2Kl0DNwKuBdlIhjL8LeY8Sf8Q==",
|
||||
"version": "10.0.1",
|
||||
"resolved": "https://registry.npmjs.org/ip-address/-/ip-address-10.0.1.tgz",
|
||||
"integrity": "sha512-NWv9YLW4PoW2B7xtzaS3NCot75m6nK7Icdv0o3lfMceJVRfSoQwqD4wEH5rLwoKJwUiZ/rfpiVBhnaF0FK4HoA==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 12"
|
||||
@@ -8831,9 +8831,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/minimatch": {
|
||||
"version": "3.1.5",
|
||||
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.5.tgz",
|
||||
"integrity": "sha512-VgjWUsnnT6n+NUk6eZq77zeFdpW2LWDzP6zFGrCbHXiYNul5Dzqk2HHQ5uFH2DNW5Xbp8+jVzaeNt94ssEEl4w==",
|
||||
"version": "3.1.2",
|
||||
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz",
|
||||
"integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"dependencies": {
|
||||
@@ -9699,9 +9699,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/qs": {
|
||||
"version": "6.14.2",
|
||||
"resolved": "https://registry.npmjs.org/qs/-/qs-6.14.2.tgz",
|
||||
"integrity": "sha512-V/yCWTTF7VJ9hIh18Ugr2zhJMP01MY7c5kh4J870L7imm6/DIzBsNLTXzMwUA3yZ5b/KBqLx8Kp3uRvd7xSe3Q==",
|
||||
"version": "6.14.1",
|
||||
"resolved": "https://registry.npmjs.org/qs/-/qs-6.14.1.tgz",
|
||||
"integrity": "sha512-4EK3+xJl8Ts67nLYNwqw/dsFVnCf+qR7RgXSK9jEEm9unao3njwMDdmsdvoKBKHzxd7tCYz5e5M+SnMjdtXGQQ==",
|
||||
"license": "BSD-3-Clause",
|
||||
"dependencies": {
|
||||
"side-channel": "^1.1.0"
|
||||
|
||||
@@ -11,7 +11,6 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.document_set import check_document_sets_are_public
|
||||
from onyx.db.document_set import delete_document_set as db_delete_document_set
|
||||
from onyx.db.document_set import fetch_all_document_sets_for_user
|
||||
from onyx.db.document_set import get_document_set_by_id
|
||||
from onyx.db.document_set import insert_document_set
|
||||
@@ -143,10 +142,7 @@ def delete_document_set(
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
db_session.refresh(document_set)
|
||||
db_delete_document_set(document_set, db_session)
|
||||
else:
|
||||
if not DISABLE_VECTOR_DB:
|
||||
client_app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
|
||||
@@ -7,14 +7,13 @@ from PIL import UnidentifiedImageError
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.file_processing.password_validation import is_file_password_protected
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -117,9 +116,7 @@ def estimate_image_tokens_for_upload(
|
||||
pass
|
||||
|
||||
|
||||
def categorize_uploaded_files(
|
||||
files: list[UploadFile], db_session: Session
|
||||
) -> CategorizedFiles:
|
||||
def categorize_uploaded_files(files: list[UploadFile]) -> CategorizedFiles:
|
||||
"""
|
||||
Categorize uploaded files based on text extractability and tokenized length.
|
||||
|
||||
@@ -131,11 +128,11 @@ def categorize_uploaded_files(
|
||||
"""
|
||||
|
||||
results = CategorizedFiles()
|
||||
default_model = fetch_default_llm_model(db_session)
|
||||
llm = get_default_llm()
|
||||
|
||||
model_name = default_model.name if default_model else None
|
||||
provider_type = default_model.llm_provider.provider if default_model else None
|
||||
tokenizer = get_tokenizer(model_name=model_name, provider_type=provider_type)
|
||||
tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name, provider_type=llm.config.model_provider
|
||||
)
|
||||
|
||||
# Check if threshold checks should be skipped
|
||||
skip_threshold = False
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
@@ -10,8 +11,6 @@ from onyx.db.llm import upsert_cloud_embedding_provider
|
||||
from onyx.db.models import User
|
||||
from onyx.db.search_settings import get_all_search_settings
|
||||
from onyx.db.search_settings import get_current_db_embedding_provider
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.indexing.models import EmbeddingModelDetail
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
@@ -60,7 +59,7 @@ def test_embedding_configuration(
|
||||
except Exception as e:
|
||||
error_msg = "An error occurred while testing your embedding model. Please check your configuration."
|
||||
logger.error(f"{error_msg} Error message: {e}", exc_info=True)
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, error_msg)
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
|
||||
|
||||
@admin_router.get("", response_model=list[EmbeddingModelDetail])
|
||||
@@ -94,9 +93,8 @@ def delete_embedding_provider(
|
||||
embedding_provider is not None
|
||||
and provider_type == embedding_provider.provider_type
|
||||
):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"You can't delete a currently active model",
|
||||
raise HTTPException(
|
||||
status_code=400, detail="You can't delete a currently active model"
|
||||
)
|
||||
|
||||
remove_embedding_provider(db_session, provider_type=provider_type)
|
||||
|
||||
@@ -11,6 +11,7 @@ from botocore.exceptions import ClientError
|
||||
from botocore.exceptions import NoCredentialsError
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -37,8 +38,6 @@ from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.llm import validate_persona_ids_exist
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import user_can_access_persona
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.llm.factory import get_llm
|
||||
from onyx.llm.factory import get_max_input_tokens_from_llm_provider
|
||||
@@ -48,7 +47,6 @@ from onyx.llm.utils import test_llm
|
||||
from onyx.llm.well_known_providers.auto_update_service import (
|
||||
fetch_llm_recommendations_from_github,
|
||||
)
|
||||
from onyx.llm.well_known_providers.constants import LM_STUDIO_API_KEY_CONFIG_KEY
|
||||
from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
fetch_available_well_known_llms,
|
||||
)
|
||||
@@ -63,8 +61,6 @@ from onyx.server.manage.llm.models import LLMProviderDescriptor
|
||||
from onyx.server.manage.llm.models import LLMProviderResponse
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import LMStudioFinalModelResponse
|
||||
from onyx.server.manage.llm.models import LMStudioModelsRequest
|
||||
from onyx.server.manage.llm.models import OllamaFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OllamaModelDetails
|
||||
from onyx.server.manage.llm.models import OllamaModelsRequest
|
||||
@@ -76,7 +72,6 @@ from onyx.server.manage.llm.models import VisionProviderResponse
|
||||
from onyx.server.manage.llm.utils import generate_bedrock_display_name
|
||||
from onyx.server.manage.llm.utils import generate_ollama_display_name
|
||||
from onyx.server.manage.llm.utils import infer_vision_support
|
||||
from onyx.server.manage.llm.utils import is_reasoning_model
|
||||
from onyx.server.manage.llm.utils import is_valid_bedrock_model
|
||||
from onyx.server.manage.llm.utils import ModelMetadata
|
||||
from onyx.server.manage.llm.utils import strip_openrouter_vendor_prefix
|
||||
@@ -191,7 +186,7 @@ def _validate_llm_provider_change(
|
||||
Only enforced in MULTI_TENANT mode.
|
||||
|
||||
Raises:
|
||||
OnyxError: If api_base or custom_config changed without changing API key
|
||||
HTTPException: If api_base or custom_config changed without changing API key
|
||||
"""
|
||||
if not MULTI_TENANT or api_key_changed:
|
||||
return
|
||||
@@ -205,9 +200,9 @@ def _validate_llm_provider_change(
|
||||
)
|
||||
|
||||
if api_base_changed or custom_config_changed:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"API base and/or custom config cannot be changed without changing the API key",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="API base and/or custom config cannot be changed without changing the API key",
|
||||
)
|
||||
|
||||
|
||||
@@ -227,7 +222,7 @@ def fetch_llm_provider_options(
|
||||
for well_known_llm in well_known_llms:
|
||||
if well_known_llm.name == provider_name:
|
||||
return well_known_llm
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Provider {provider_name} not found")
|
||||
raise HTTPException(status_code=404, detail=f"Provider {provider_name} not found")
|
||||
|
||||
|
||||
@admin_router.post("/test")
|
||||
@@ -286,7 +281,7 @@ def test_llm_configuration(
|
||||
error_msg = test_llm(llm)
|
||||
|
||||
if error_msg:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, error_msg)
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
|
||||
|
||||
@admin_router.post("/test/default")
|
||||
@@ -297,11 +292,11 @@ def test_default_provider(
|
||||
llm = get_default_llm()
|
||||
except ValueError:
|
||||
logger.exception("Failed to fetch default LLM Provider")
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "No LLM Provider setup")
|
||||
raise HTTPException(status_code=400, detail="No LLM Provider setup")
|
||||
|
||||
error = test_llm(llm)
|
||||
if error:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(error))
|
||||
raise HTTPException(status_code=400, detail=str(error))
|
||||
|
||||
|
||||
@admin_router.get("/provider")
|
||||
@@ -367,31 +362,35 @@ def put_llm_provider(
|
||||
# Check name constraints
|
||||
# TODO: Once port from name to id is complete, unique name will no longer be required
|
||||
if existing_provider and llm_provider_upsert_request.name != existing_provider.name:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Renaming providers is not currently supported",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Renaming providers is not currently supported",
|
||||
)
|
||||
|
||||
found_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
)
|
||||
if found_provider is not None and found_provider is not existing_provider:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.DUPLICATE_RESOURCE,
|
||||
f"Provider with name={llm_provider_upsert_request.name} already exists",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Provider with name={llm_provider_upsert_request.name} already exists",
|
||||
)
|
||||
|
||||
if existing_provider and is_creation:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.DUPLICATE_RESOURCE,
|
||||
f"LLM Provider with name {llm_provider_upsert_request.name} and "
|
||||
f"id={llm_provider_upsert_request.id} already exists",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"LLM Provider with name {llm_provider_upsert_request.name} and "
|
||||
f"id={llm_provider_upsert_request.id} already exists"
|
||||
),
|
||||
)
|
||||
elif not existing_provider and not is_creation:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"LLM Provider with name {llm_provider_upsert_request.name} and "
|
||||
f"id={llm_provider_upsert_request.id} does not exist",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"LLM Provider with name {llm_provider_upsert_request.name} and "
|
||||
f"id={llm_provider_upsert_request.id} does not exist"
|
||||
),
|
||||
)
|
||||
|
||||
# SSRF Protection: Validate api_base and custom_config match stored values
|
||||
@@ -416,9 +415,9 @@ def put_llm_provider(
|
||||
db_session, persona_ids
|
||||
)
|
||||
if missing_personas:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Invalid persona IDs: {', '.join(map(str, missing_personas))}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid persona IDs: {', '.join(map(str, missing_personas))}",
|
||||
)
|
||||
# Remove duplicates while preserving order
|
||||
seen: set[int] = set()
|
||||
@@ -445,17 +444,6 @@ def put_llm_provider(
|
||||
not existing_provider or not existing_provider.is_auto_mode
|
||||
)
|
||||
|
||||
# Before the upsert, check if this provider currently owns the global
|
||||
# CHAT default. The upsert may cascade-delete model_configurations
|
||||
# (and their flow mappings), so we need to remember this beforehand.
|
||||
was_default_provider = False
|
||||
if existing_provider and transitioning_to_auto_mode:
|
||||
current_default = fetch_default_llm_model(db_session)
|
||||
was_default_provider = (
|
||||
current_default is not None
|
||||
and current_default.llm_provider_id == existing_provider.id
|
||||
)
|
||||
|
||||
try:
|
||||
result = upsert_llm_provider(
|
||||
llm_provider_upsert_request=llm_provider_upsert_request,
|
||||
@@ -478,20 +466,6 @@ def put_llm_provider(
|
||||
updated_provider,
|
||||
config,
|
||||
)
|
||||
|
||||
# If this provider was the default before the transition,
|
||||
# restore the default using the recommended model.
|
||||
if was_default_provider:
|
||||
recommended = config.get_default_model(
|
||||
llm_provider_upsert_request.provider
|
||||
)
|
||||
if recommended:
|
||||
update_default_provider(
|
||||
provider_id=updated_provider.id,
|
||||
model_name=recommended.name,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Refresh result with synced models
|
||||
result = LLMProviderView.from_model(updated_provider)
|
||||
|
||||
@@ -499,7 +473,7 @@ def put_llm_provider(
|
||||
return result
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to upsert LLM Provider")
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.delete("/provider/{provider_id}")
|
||||
@@ -509,19 +483,19 @@ def delete_llm_provider(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
if not force:
|
||||
model = fetch_default_llm_model(db_session)
|
||||
|
||||
if model and model.llm_provider_id == provider_id:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Cannot delete the default LLM provider",
|
||||
)
|
||||
|
||||
try:
|
||||
if not force:
|
||||
model = fetch_default_llm_model(db_session)
|
||||
|
||||
if model and model.llm_provider_id == provider_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot delete the default LLM provider",
|
||||
)
|
||||
|
||||
remove_llm_provider(db_session, provider_id)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.post("/default")
|
||||
@@ -561,9 +535,9 @@ def get_auto_config(
|
||||
"""
|
||||
config = fetch_llm_recommendations_from_github()
|
||||
if not config:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
"Failed to fetch configuration from GitHub",
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to fetch configuration from GitHub",
|
||||
)
|
||||
return config.model_dump()
|
||||
|
||||
@@ -720,13 +694,13 @@ def list_llm_providers_for_persona(
|
||||
|
||||
persona = fetch_persona_with_groups(db_session, persona_id)
|
||||
if not persona:
|
||||
raise OnyxError(OnyxErrorCode.PERSONA_NOT_FOUND, "Persona not found")
|
||||
raise HTTPException(status_code=404, detail="Persona not found")
|
||||
|
||||
# Verify user has access to this persona
|
||||
if not user_can_access_persona(db_session, persona_id, user, get_editable=False):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
"You don't have access to this assistant",
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="You don't have access to this assistant",
|
||||
)
|
||||
|
||||
is_admin = user.role == UserRole.ADMIN
|
||||
@@ -880,9 +854,9 @@ def get_bedrock_available_models(
|
||||
try:
|
||||
bedrock = session.client("bedrock")
|
||||
except Exception as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CREDENTIAL_INVALID,
|
||||
f"Failed to create Bedrock client: {e}. Check AWS credentials and region.",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to create Bedrock client: {e}. Check AWS credentials and region.",
|
||||
)
|
||||
|
||||
# Build model info dict from foundation models (modelId -> metadata)
|
||||
@@ -1001,14 +975,14 @@ def get_bedrock_available_models(
|
||||
return results
|
||||
|
||||
except (ClientError, NoCredentialsError, BotoCoreError) as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CREDENTIAL_INVALID,
|
||||
f"Failed to connect to AWS Bedrock: {e}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to connect to AWS Bedrock: {e}",
|
||||
)
|
||||
except Exception as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
f"Unexpected error fetching Bedrock models: {e}",
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Unexpected error fetching Bedrock models: {e}",
|
||||
)
|
||||
|
||||
|
||||
@@ -1020,9 +994,9 @@ def _get_ollama_available_model_names(api_base: str) -> set[str]:
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
except Exception as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch Ollama models: {e}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to fetch Ollama models: {e}",
|
||||
)
|
||||
|
||||
models = response_json.get("models", [])
|
||||
@@ -1039,9 +1013,9 @@ def get_ollama_available_models(
|
||||
|
||||
cleaned_api_base = request.api_base.strip().rstrip("/")
|
||||
if not cleaned_api_base:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"API base URL is required to fetch Ollama models.",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="API base URL is required to fetch Ollama models.",
|
||||
)
|
||||
|
||||
# NOTE: most people run Ollama locally, so we don't disallow internal URLs
|
||||
@@ -1050,9 +1024,9 @@ def get_ollama_available_models(
|
||||
# with the same response format
|
||||
model_names = _get_ollama_available_model_names(cleaned_api_base)
|
||||
if not model_names:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No models found from your Ollama server",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No models found from your Ollama server",
|
||||
)
|
||||
|
||||
all_models_with_context_size_and_vision: list[OllamaFinalModelResponse] = []
|
||||
@@ -1154,9 +1128,9 @@ def _get_openrouter_models_response(api_base: str, api_key: str) -> dict:
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch OpenRouter models: {e}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to fetch OpenRouter models: {e}",
|
||||
)
|
||||
|
||||
|
||||
@@ -1177,9 +1151,9 @@ def get_openrouter_available_models(
|
||||
|
||||
data = response_json.get("data", [])
|
||||
if not isinstance(data, list) or len(data) == 0:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No models found from your OpenRouter endpoint",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No models found from your OpenRouter endpoint",
|
||||
)
|
||||
|
||||
results: list[OpenRouterFinalModelResponse] = []
|
||||
@@ -1214,9 +1188,8 @@ def get_openrouter_available_models(
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No compatible models found from OpenRouter",
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No compatible models found from OpenRouter"
|
||||
)
|
||||
|
||||
sorted_results = sorted(results, key=lambda m: m.name.lower())
|
||||
@@ -1246,117 +1219,3 @@ def get_openrouter_available_models(
|
||||
logger.warning(f"Failed to sync OpenRouter models to DB: {e}")
|
||||
|
||||
return sorted_results
|
||||
|
||||
|
||||
@admin_router.post("/lm-studio/available-models")
|
||||
def get_lm_studio_available_models(
|
||||
request: LMStudioModelsRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LMStudioFinalModelResponse]:
|
||||
"""Fetch available models from an LM Studio server.
|
||||
|
||||
Uses the LM Studio-native /api/v1/models endpoint which exposes
|
||||
rich metadata including capabilities (vision, reasoning),
|
||||
display names, and context lengths.
|
||||
"""
|
||||
cleaned_api_base = request.api_base.strip().rstrip("/")
|
||||
# Strip /v1 suffix that users may copy from OpenAI-compatible tool configs;
|
||||
# the native metadata endpoint lives at /api/v1/models, not /v1/api/v1/models.
|
||||
cleaned_api_base = cleaned_api_base.removesuffix("/v1")
|
||||
if not cleaned_api_base:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"API base URL is required to fetch LM Studio models.",
|
||||
)
|
||||
|
||||
# If provider_name is given and the api_key hasn't been changed by the user,
|
||||
# fall back to the stored API key from the database (the form value is masked).
|
||||
api_key = request.api_key
|
||||
if request.provider_name and not request.api_key_changed:
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=request.provider_name, db_session=db_session
|
||||
)
|
||||
if existing_provider and existing_provider.custom_config:
|
||||
api_key = existing_provider.custom_config.get(LM_STUDIO_API_KEY_CONFIG_KEY)
|
||||
|
||||
url = f"{cleaned_api_base}/api/v1/models"
|
||||
headers: dict[str, str] = {}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
try:
|
||||
response = httpx.get(url, headers=headers, timeout=10.0)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
except Exception as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch LM Studio models: {e}",
|
||||
)
|
||||
|
||||
models = response_json.get("models", [])
|
||||
if not isinstance(models, list) or len(models) == 0:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No models found from your LM Studio server.",
|
||||
)
|
||||
|
||||
results: list[LMStudioFinalModelResponse] = []
|
||||
for item in models:
|
||||
# Filter to LLM-type models only (skip embeddings, etc.)
|
||||
if item.get("type") != "llm":
|
||||
continue
|
||||
|
||||
model_key = item.get("key")
|
||||
if not model_key:
|
||||
continue
|
||||
|
||||
display_name = item.get("display_name") or model_key
|
||||
max_context_length = item.get("max_context_length")
|
||||
capabilities = item.get("capabilities") or {}
|
||||
|
||||
results.append(
|
||||
LMStudioFinalModelResponse(
|
||||
name=model_key,
|
||||
display_name=display_name,
|
||||
max_input_tokens=max_context_length,
|
||||
supports_image_input=capabilities.get("vision", False),
|
||||
supports_reasoning=capabilities.get("reasoning", False)
|
||||
or is_reasoning_model(model_key, display_name),
|
||||
)
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No compatible models found from LM Studio server.",
|
||||
)
|
||||
|
||||
sorted_results = sorted(results, key=lambda m: m.name.lower())
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
try:
|
||||
models_to_sync = [
|
||||
{
|
||||
"name": r.name,
|
||||
"display_name": r.display_name,
|
||||
"max_input_tokens": r.max_input_tokens,
|
||||
"supports_image_input": r.supports_image_input,
|
||||
}
|
||||
for r in sorted_results
|
||||
]
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=models_to_sync,
|
||||
)
|
||||
if new_count > 0:
|
||||
logger.info(
|
||||
f"Added {new_count} new LM Studio models to provider '{request.provider_name}'"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync LM Studio models to DB: {e}")
|
||||
|
||||
return sorted_results
|
||||
|
||||
@@ -371,22 +371,6 @@ class OpenRouterFinalModelResponse(BaseModel):
|
||||
supports_image_input: bool
|
||||
|
||||
|
||||
# LM Studio dynamic models fetch
|
||||
class LMStudioModelsRequest(BaseModel):
|
||||
api_base: str
|
||||
api_key: str | None = None
|
||||
api_key_changed: bool = False
|
||||
provider_name: str | None = None # Optional: to save models to existing provider
|
||||
|
||||
|
||||
class LMStudioFinalModelResponse(BaseModel):
|
||||
name: str # Model ID from LM Studio (e.g., "lmstudio-community/Meta-Llama-3-8B")
|
||||
display_name: str # Human-readable name
|
||||
max_input_tokens: int | None # From LM Studio API or None if unavailable
|
||||
supports_image_input: bool
|
||||
supports_reasoning: bool
|
||||
|
||||
|
||||
class DefaultModel(BaseModel):
|
||||
provider_id: int
|
||||
model_name: str
|
||||
|
||||
@@ -12,7 +12,6 @@ from typing import TypedDict
|
||||
|
||||
from onyx.llm.constants import BEDROCK_MODEL_NAME_MAPPINGS
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.constants import MODEL_PREFIX_TO_VENDOR
|
||||
from onyx.llm.constants import OLLAMA_MODEL_NAME_MAPPINGS
|
||||
from onyx.llm.constants import OLLAMA_MODEL_TO_VENDOR
|
||||
from onyx.llm.constants import PROVIDER_DISPLAY_NAMES
|
||||
@@ -24,7 +23,6 @@ DYNAMIC_LLM_PROVIDERS = frozenset(
|
||||
LlmProviderNames.OPENROUTER,
|
||||
LlmProviderNames.BEDROCK,
|
||||
LlmProviderNames.OLLAMA_CHAT,
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -350,19 +348,4 @@ def extract_vendor_from_model_name(model_name: str, provider: str) -> str | None
|
||||
# Fallback: capitalize the base name as vendor
|
||||
return base_name.split("-")[0].title()
|
||||
|
||||
elif provider == LlmProviderNames.LM_STUDIO:
|
||||
# LM Studio model IDs can be paths like "publisher/model-name"
|
||||
# or simple names. Use MODEL_PREFIX_TO_VENDOR for matching.
|
||||
|
||||
model_lower = model_name.lower()
|
||||
# Check for slash-separated vendor prefix first
|
||||
if "/" in model_lower:
|
||||
vendor_key = model_lower.split("/")[0]
|
||||
return PROVIDER_DISPLAY_NAMES.get(vendor_key, vendor_key.title())
|
||||
# Fallback to model prefix matching
|
||||
for prefix, vendor in MODEL_PREFIX_TO_VENDOR.items():
|
||||
if model_lower.startswith(prefix):
|
||||
return PROVIDER_DISPLAY_NAMES.get(vendor, vendor.title())
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
@@ -6,11 +6,8 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from onyx.context.search.models import SavedSearchSettings
|
||||
from onyx.context.search.models import SearchSettingsCreationRequest
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.connector_credential_pair import resync_cc_pair
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.index_attempt import expire_index_attempts
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
@@ -18,25 +15,20 @@ from onyx.db.llm import update_default_contextual_model
|
||||
from onyx.db.llm import update_no_default_contextual_rag_provider
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.db.models import User
|
||||
from onyx.db.search_settings import create_search_settings
|
||||
from onyx.db.search_settings import delete_search_settings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.search_settings import get_embedding_provider_from_provider_type
|
||||
from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.db.search_settings import update_current_search_settings
|
||||
from onyx.db.search_settings import update_search_settings_status
|
||||
from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.file_processing.unstructured import delete_unstructured_api_key
|
||||
from onyx.file_processing.unstructured import get_unstructured_api_key
|
||||
from onyx.file_processing.unstructured import update_unstructured_api_key
|
||||
from onyx.natural_language_processing.search_nlp_models import clean_model_name
|
||||
from onyx.server.manage.embedding.models import SearchSettingsDeleteRequest
|
||||
from onyx.server.manage.models import FullModelVersionResponse
|
||||
from onyx.server.models import IdReturn
|
||||
from onyx.server.utils_vector_db import require_vector_db
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import ALT_INDEX_SUFFIX
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
router = APIRouter(prefix="/search-settings")
|
||||
@@ -49,99 +41,110 @@ def set_new_search_settings(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session), # noqa: ARG001
|
||||
) -> IdReturn:
|
||||
"""Creates a new EmbeddingModel row and cancels the previous secondary indexing if any
|
||||
Gives an error if the same model name is used as the current or secondary index
|
||||
"""
|
||||
Creates a new SearchSettings row and cancels the previous secondary indexing
|
||||
if any exists.
|
||||
"""
|
||||
if search_settings_new.index_name:
|
||||
logger.warning("Index name was specified by request, this is not suggested")
|
||||
|
||||
# Disallow contextual RAG for cloud deployments.
|
||||
if MULTI_TENANT and search_settings_new.enable_contextual_rag:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Contextual RAG disabled in Onyx Cloud",
|
||||
)
|
||||
|
||||
# Validate cloud provider exists or create new LiteLLM provider.
|
||||
if search_settings_new.provider_type is not None:
|
||||
cloud_provider = get_embedding_provider_from_provider_type(
|
||||
db_session, provider_type=search_settings_new.provider_type
|
||||
)
|
||||
|
||||
if cloud_provider is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"No embedding provider exists for cloud embedding type {search_settings_new.provider_type}",
|
||||
)
|
||||
|
||||
validate_contextual_rag_model(
|
||||
provider_name=search_settings_new.contextual_rag_llm_provider,
|
||||
model_name=search_settings_new.contextual_rag_llm_name,
|
||||
db_session=db_session,
|
||||
# TODO(andrei): Re-enable.
|
||||
# NOTE Enable integration external dependency tests in test_search_settings.py
|
||||
# when this is reenabled. They are currently skipped
|
||||
logger.error("Setting new search settings is temporarily disabled.")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="Setting new search settings is temporarily disabled.",
|
||||
)
|
||||
# if search_settings_new.index_name:
|
||||
# logger.warning("Index name was specified by request, this is not suggested")
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
# # Disallow contextual RAG for cloud deployments
|
||||
# if MULTI_TENANT and search_settings_new.enable_contextual_rag:
|
||||
# raise HTTPException(
|
||||
# status_code=status.HTTP_400_BAD_REQUEST,
|
||||
# detail="Contextual RAG disabled in Onyx Cloud",
|
||||
# )
|
||||
|
||||
if search_settings_new.index_name is None:
|
||||
# We define index name here.
|
||||
index_name = f"danswer_chunk_{clean_model_name(search_settings_new.model_name)}"
|
||||
if (
|
||||
search_settings_new.model_name == search_settings.model_name
|
||||
and not search_settings.index_name.endswith(ALT_INDEX_SUFFIX)
|
||||
):
|
||||
index_name += ALT_INDEX_SUFFIX
|
||||
search_values = search_settings_new.model_dump()
|
||||
search_values["index_name"] = index_name
|
||||
new_search_settings_request = SavedSearchSettings(**search_values)
|
||||
else:
|
||||
new_search_settings_request = SavedSearchSettings(
|
||||
**search_settings_new.model_dump()
|
||||
)
|
||||
# # Validate cloud provider exists or create new LiteLLM provider
|
||||
# if search_settings_new.provider_type is not None:
|
||||
# cloud_provider = get_embedding_provider_from_provider_type(
|
||||
# db_session, provider_type=search_settings_new.provider_type
|
||||
# )
|
||||
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
# if cloud_provider is None:
|
||||
# raise HTTPException(
|
||||
# status_code=status.HTTP_400_BAD_REQUEST,
|
||||
# detail=f"No embedding provider exists for cloud embedding type {search_settings_new.provider_type}",
|
||||
# )
|
||||
|
||||
if secondary_search_settings:
|
||||
# Cancel any background indexing jobs.
|
||||
expire_index_attempts(
|
||||
search_settings_id=secondary_search_settings.id, db_session=db_session
|
||||
)
|
||||
# validate_contextual_rag_model(
|
||||
# provider_name=search_settings_new.contextual_rag_llm_provider,
|
||||
# model_name=search_settings_new.contextual_rag_llm_name,
|
||||
# db_session=db_session,
|
||||
# )
|
||||
|
||||
# Mark previous model as a past model directly.
|
||||
update_search_settings_status(
|
||||
search_settings=secondary_search_settings,
|
||||
new_status=IndexModelStatus.PAST,
|
||||
db_session=db_session,
|
||||
)
|
||||
# search_settings = get_current_search_settings(db_session)
|
||||
|
||||
new_search_settings = create_search_settings(
|
||||
search_settings=new_search_settings_request, db_session=db_session
|
||||
)
|
||||
# if search_settings_new.index_name is None:
|
||||
# # We define index name here
|
||||
# index_name = f"danswer_chunk_{clean_model_name(search_settings_new.model_name)}"
|
||||
# if (
|
||||
# search_settings_new.model_name == search_settings.model_name
|
||||
# and not search_settings.index_name.endswith(ALT_INDEX_SUFFIX)
|
||||
# ):
|
||||
# index_name += ALT_INDEX_SUFFIX
|
||||
# search_values = search_settings_new.model_dump()
|
||||
# search_values["index_name"] = index_name
|
||||
# new_search_settings_request = SavedSearchSettings(**search_values)
|
||||
# else:
|
||||
# new_search_settings_request = SavedSearchSettings(
|
||||
# **search_settings_new.model_dump()
|
||||
# )
|
||||
|
||||
# Ensure the document indices have the new index immediately.
|
||||
document_indices = get_all_document_indices(search_settings, new_search_settings)
|
||||
for document_index in document_indices:
|
||||
document_index.ensure_indices_exist(
|
||||
primary_embedding_dim=search_settings.final_embedding_dim,
|
||||
primary_embedding_precision=search_settings.embedding_precision,
|
||||
secondary_index_embedding_dim=new_search_settings.final_embedding_dim,
|
||||
secondary_index_embedding_precision=new_search_settings.embedding_precision,
|
||||
)
|
||||
# secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
|
||||
# Pause index attempts for the currently in-use index to preserve resources.
|
||||
if DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
expire_index_attempts(
|
||||
search_settings_id=search_settings.id, db_session=db_session
|
||||
)
|
||||
for cc_pair in get_connector_credential_pairs(db_session):
|
||||
resync_cc_pair(
|
||||
cc_pair=cc_pair,
|
||||
search_settings_id=new_search_settings.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
# if secondary_search_settings:
|
||||
# # Cancel any background indexing jobs
|
||||
# expire_index_attempts(
|
||||
# search_settings_id=secondary_search_settings.id, db_session=db_session
|
||||
# )
|
||||
|
||||
db_session.commit()
|
||||
return IdReturn(id=new_search_settings.id)
|
||||
# # Mark previous model as a past model directly
|
||||
# update_search_settings_status(
|
||||
# search_settings=secondary_search_settings,
|
||||
# new_status=IndexModelStatus.PAST,
|
||||
# db_session=db_session,
|
||||
# )
|
||||
|
||||
# new_search_settings = create_search_settings(
|
||||
# search_settings=new_search_settings_request, db_session=db_session
|
||||
# )
|
||||
|
||||
# # Ensure Vespa has the new index immediately
|
||||
# get_multipass_config(search_settings)
|
||||
# get_multipass_config(new_search_settings)
|
||||
# document_index = get_default_document_index(
|
||||
# search_settings, new_search_settings, db_session
|
||||
# )
|
||||
|
||||
# document_index.ensure_indices_exist(
|
||||
# primary_embedding_dim=search_settings.final_embedding_dim,
|
||||
# primary_embedding_precision=search_settings.embedding_precision,
|
||||
# secondary_index_embedding_dim=new_search_settings.final_embedding_dim,
|
||||
# secondary_index_embedding_precision=new_search_settings.embedding_precision,
|
||||
# )
|
||||
|
||||
# # Pause index attempts for the currently in use index to preserve resources
|
||||
# if DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
# expire_index_attempts(
|
||||
# search_settings_id=search_settings.id, db_session=db_session
|
||||
# )
|
||||
# for cc_pair in get_connector_credential_pairs(db_session):
|
||||
# resync_cc_pair(
|
||||
# cc_pair=cc_pair,
|
||||
# search_settings_id=new_search_settings.id,
|
||||
# db_session=db_session,
|
||||
# )
|
||||
|
||||
# db_session.commit()
|
||||
# return IdReturn(id=new_search_settings.id)
|
||||
|
||||
|
||||
@router.post("/cancel-new-embedding", dependencies=[Depends(require_vector_db)])
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from datetime import timedelta
|
||||
from uuid import UUID
|
||||
@@ -60,6 +61,7 @@ from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.usage import increment_usage
|
||||
from onyx.db.usage import UsageType
|
||||
from onyx.db.user_file import get_file_id_by_user_file_id
|
||||
from onyx.file_processing.extract_file_text import docx_to_txt_filename
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.factory import get_default_llm
|
||||
@@ -810,6 +812,18 @@ def fetch_chat_file(
|
||||
if not file_record:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
original_file_name = file_record.display_name
|
||||
if file_record.file_type.startswith(
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
):
|
||||
# Check if a converted text file exists for .docx files
|
||||
txt_file_name = docx_to_txt_filename(original_file_name)
|
||||
txt_file_id = os.path.join(os.path.dirname(file_id), txt_file_name)
|
||||
txt_file_record = file_store.read_file_record(txt_file_id)
|
||||
if txt_file_record:
|
||||
file_record = txt_file_record
|
||||
file_id = txt_file_id
|
||||
|
||||
media_type = file_record.file_type
|
||||
file_io = file_store.read_file(file_id, mode="b")
|
||||
|
||||
|
||||
@@ -60,11 +60,9 @@ class Settings(BaseModel):
|
||||
deep_research_enabled: bool | None = None
|
||||
search_ui_enabled: bool | None = None
|
||||
|
||||
# Whether EE features are unlocked for use.
|
||||
# Depends on license status: True when the user has a valid license
|
||||
# (ACTIVE, GRACE_PERIOD, PAYMENT_REMINDER), False when there's no license
|
||||
# or the license is expired (GATED_ACCESS).
|
||||
# This controls UI visibility of EE features (user groups, analytics, RBAC, etc.).
|
||||
# Enterprise features flag - set by license enforcement at runtime
|
||||
# When LICENSE_ENFORCEMENT_ENABLED=true, this reflects license status
|
||||
# When LICENSE_ENFORCEMENT_ENABLED=false, defaults to False
|
||||
ee_features_enabled: bool = False
|
||||
|
||||
temperature_override_enabled: bool | None = False
|
||||
|
||||
@@ -93,8 +93,6 @@ class ToolResponse(BaseModel):
|
||||
# | WebContentResponse
|
||||
# This comes from custom tools, tool result needs to be saved
|
||||
| CustomToolCallSummary
|
||||
# This comes from code interpreter, carries generated files
|
||||
| PythonToolRichResponse
|
||||
# If the rich response is a string, this is what's saved to the tool call in the DB
|
||||
| str
|
||||
| None # If nothing needs to be persisted outside of the string value passed to the LLM
|
||||
@@ -195,12 +193,6 @@ class ChatFile(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class PythonToolRichResponse(BaseModel):
|
||||
"""Rich response from the Python tool carrying generated files."""
|
||||
|
||||
generated_files: list[PythonExecutionFile] = []
|
||||
|
||||
|
||||
class PythonToolOverrideKwargs(BaseModel):
|
||||
"""Override kwargs for the Python/Code Interpreter tool."""
|
||||
|
||||
@@ -253,7 +245,6 @@ class ToolCallInfo(BaseModel):
|
||||
tool_call_response: str
|
||||
search_docs: list[SearchDoc] | None = None
|
||||
generated_images: list[GeneratedImage] | None = None
|
||||
generated_files: list[PythonExecutionFile] | None = None
|
||||
|
||||
|
||||
CHAT_SESSION_ID_PLACEHOLDER = "CHAT_SESSION_ID"
|
||||
|
||||
@@ -111,26 +111,19 @@ def _normalize_text_with_mapping(text: str) -> tuple[str, list[int]]:
|
||||
# Step 1: NFC normalization with position mapping
|
||||
nfc_text = unicodedata.normalize("NFC", text)
|
||||
|
||||
# Map NFD positions → original positions.
|
||||
# NFD only decomposes, so each original char produces 1+ NFD chars.
|
||||
nfd_to_orig: list[int] = []
|
||||
for orig_idx, orig_char in enumerate(original_text):
|
||||
nfd_of_char = unicodedata.normalize("NFD", orig_char)
|
||||
for _ in nfd_of_char:
|
||||
nfd_to_orig.append(orig_idx)
|
||||
|
||||
# Map NFC positions → NFD positions.
|
||||
# Each NFC char, when decomposed, tells us exactly how many NFD
|
||||
# chars it was composed from.
|
||||
# Build mapping from NFC positions to original start positions
|
||||
nfc_to_orig: list[int] = []
|
||||
nfd_idx = 0
|
||||
orig_idx = 0
|
||||
for nfc_char in nfc_text:
|
||||
if nfd_idx < len(nfd_to_orig):
|
||||
nfc_to_orig.append(nfd_to_orig[nfd_idx])
|
||||
nfc_to_orig.append(orig_idx)
|
||||
# Find how many original chars contributed to this NFC char
|
||||
for length in range(1, len(original_text) - orig_idx + 1):
|
||||
substr = original_text[orig_idx : orig_idx + length]
|
||||
if unicodedata.normalize("NFC", substr) == nfc_char:
|
||||
orig_idx += length
|
||||
break
|
||||
else:
|
||||
nfc_to_orig.append(len(original_text) - 1)
|
||||
nfd_of_nfc = unicodedata.normalize("NFD", nfc_char)
|
||||
nfd_idx += len(nfd_of_nfc)
|
||||
orig_idx += 1 # Fallback
|
||||
|
||||
# Work with NFC text from here
|
||||
text = nfc_text
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Literal
|
||||
from typing import TypedDict
|
||||
@@ -15,9 +12,6 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_HEALTH_CACHE_TTL_SECONDS = 30
|
||||
_health_cache: dict[str, tuple[float, bool]] = {}
|
||||
|
||||
|
||||
class FileInput(TypedDict):
|
||||
"""Input file to be staged in execution workspace"""
|
||||
@@ -86,19 +80,6 @@ class CodeInterpreterClient:
|
||||
raise ValueError("CODE_INTERPRETER_BASE_URL not configured")
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.session = requests.Session()
|
||||
self._closed = False
|
||||
|
||||
def __enter__(self) -> CodeInterpreterClient:
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: object) -> None:
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self.session.close()
|
||||
self._closed = True
|
||||
|
||||
def _build_payload(
|
||||
self,
|
||||
@@ -117,32 +98,16 @@ class CodeInterpreterClient:
|
||||
payload["files"] = files
|
||||
return payload
|
||||
|
||||
def health(self, use_cache: bool = False) -> bool:
|
||||
"""Check if the Code Interpreter service is healthy
|
||||
|
||||
Args:
|
||||
use_cache: When True, return a cached result if available and
|
||||
within the TTL window. The cache is always populated
|
||||
after a live request regardless of this flag.
|
||||
"""
|
||||
if use_cache:
|
||||
cached = _health_cache.get(self.base_url)
|
||||
if cached is not None:
|
||||
cached_at, cached_result = cached
|
||||
if time.monotonic() - cached_at < _HEALTH_CACHE_TTL_SECONDS:
|
||||
return cached_result
|
||||
|
||||
def health(self) -> bool:
|
||||
"""Check if the Code Interpreter service is healthy"""
|
||||
url = f"{self.base_url}/health"
|
||||
try:
|
||||
response = self.session.get(url, timeout=5)
|
||||
response.raise_for_status()
|
||||
result = response.json().get("status") == "ok"
|
||||
return response.json().get("status") == "ok"
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception caught when checking health, e={e}")
|
||||
result = False
|
||||
|
||||
_health_cache[self.base_url] = (time.monotonic(), result)
|
||||
return result
|
||||
return False
|
||||
|
||||
def execute(
|
||||
self,
|
||||
@@ -192,11 +157,8 @@ class CodeInterpreterClient:
|
||||
yield from self._batch_as_stream(code, stdin, timeout_ms, files)
|
||||
return
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
yield from self._parse_sse(response)
|
||||
finally:
|
||||
response.close()
|
||||
response.raise_for_status()
|
||||
yield from self._parse_sse(response)
|
||||
|
||||
def _parse_sse(
|
||||
self, response: requests.Response
|
||||
|
||||
@@ -23,7 +23,6 @@ from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import LlmPythonExecutionResult
|
||||
from onyx.tools.models import PythonExecutionFile
|
||||
from onyx.tools.models import PythonToolOverrideKwargs
|
||||
from onyx.tools.models import PythonToolRichResponse
|
||||
from onyx.tools.models import ToolCallException
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
@@ -108,11 +107,7 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
if not CODE_INTERPRETER_BASE_URL:
|
||||
return False
|
||||
server = fetch_code_interpreter_server(db_session)
|
||||
if not server.server_enabled:
|
||||
return False
|
||||
|
||||
with CodeInterpreterClient() as client:
|
||||
return client.health(use_cache=True)
|
||||
return server.server_enabled
|
||||
|
||||
def tool_definition(self) -> dict:
|
||||
return {
|
||||
@@ -176,203 +171,194 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
)
|
||||
)
|
||||
|
||||
# Create Code Interpreter client — context manager ensures
|
||||
# session.close() is called on every exit path.
|
||||
with CodeInterpreterClient() as client:
|
||||
# Stage chat files for execution
|
||||
files_to_stage: list[FileInput] = []
|
||||
for ind, chat_file in enumerate(chat_files):
|
||||
file_name = chat_file.filename or f"file_{ind}"
|
||||
try:
|
||||
# Upload to Code Interpreter
|
||||
ci_file_id = client.upload_file(chat_file.content, file_name)
|
||||
|
||||
# Stage for execution
|
||||
files_to_stage.append({"path": file_name, "file_id": ci_file_id})
|
||||
|
||||
logger.info(f"Staged file for Python execution: {file_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to stage file {file_name}: {e}")
|
||||
# Create Code Interpreter client
|
||||
client = CodeInterpreterClient()
|
||||
|
||||
# Stage chat files for execution
|
||||
files_to_stage: list[FileInput] = []
|
||||
for ind, chat_file in enumerate(chat_files):
|
||||
file_name = chat_file.filename or f"file_{ind}"
|
||||
try:
|
||||
logger.debug(f"Executing code: {code}")
|
||||
# Upload to Code Interpreter
|
||||
ci_file_id = client.upload_file(chat_file.content, file_name)
|
||||
|
||||
# Execute code with streaming (falls back to batch if unavailable)
|
||||
stdout_parts: list[str] = []
|
||||
stderr_parts: list[str] = []
|
||||
result_event: StreamResultEvent | None = None
|
||||
# Stage for execution
|
||||
files_to_stage.append({"path": file_name, "file_id": ci_file_id})
|
||||
|
||||
for event in client.execute_streaming(
|
||||
code=code,
|
||||
timeout_ms=CODE_INTERPRETER_DEFAULT_TIMEOUT_MS,
|
||||
files=files_to_stage or None,
|
||||
):
|
||||
if isinstance(event, StreamOutputEvent):
|
||||
if event.stream == "stdout":
|
||||
stdout_parts.append(event.data)
|
||||
else:
|
||||
stderr_parts.append(event.data)
|
||||
# Emit incremental delta to frontend
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
obj=PythonToolDelta(
|
||||
stdout=(
|
||||
event.data if event.stream == "stdout" else ""
|
||||
),
|
||||
stderr=(
|
||||
event.data if event.stream == "stderr" else ""
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
elif isinstance(event, StreamResultEvent):
|
||||
result_event = event
|
||||
elif isinstance(event, StreamErrorEvent):
|
||||
raise RuntimeError(f"Code interpreter error: {event.message}")
|
||||
logger.info(f"Staged file for Python execution: {file_name}")
|
||||
|
||||
if result_event is None:
|
||||
raise RuntimeError(
|
||||
"Code interpreter stream ended without a result event"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to stage file {file_name}: {e}")
|
||||
|
||||
full_stdout = "".join(stdout_parts)
|
||||
full_stderr = "".join(stderr_parts)
|
||||
try:
|
||||
logger.debug(f"Executing code: {code}")
|
||||
|
||||
# Truncate output for LLM consumption
|
||||
truncated_stdout = _truncate_output(
|
||||
full_stdout, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stdout"
|
||||
)
|
||||
truncated_stderr = _truncate_output(
|
||||
full_stderr, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stderr"
|
||||
)
|
||||
# Execute code with streaming (falls back to batch if unavailable)
|
||||
stdout_parts: list[str] = []
|
||||
stderr_parts: list[str] = []
|
||||
result_event: StreamResultEvent | None = None
|
||||
|
||||
# Handle generated files
|
||||
generated_files: list[PythonExecutionFile] = []
|
||||
generated_file_ids: list[str] = []
|
||||
file_ids_to_cleanup: list[str] = []
|
||||
file_store = get_default_file_store()
|
||||
|
||||
for workspace_file in result_event.files:
|
||||
if workspace_file.kind != "file" or not workspace_file.file_id:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Download file from Code Interpreter
|
||||
file_content = client.download_file(workspace_file.file_id)
|
||||
|
||||
# Determine MIME type from file extension
|
||||
filename = workspace_file.path.split("/")[-1]
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
# Default to binary if we can't determine the type
|
||||
mime_type = mime_type or "application/octet-stream"
|
||||
|
||||
# Save to Onyx file store
|
||||
onyx_file_id = file_store.save_file(
|
||||
content=BytesIO(file_content),
|
||||
display_name=filename,
|
||||
file_origin=FileOrigin.CHAT_UPLOAD,
|
||||
file_type=mime_type,
|
||||
)
|
||||
|
||||
generated_files.append(
|
||||
PythonExecutionFile(
|
||||
filename=filename,
|
||||
file_link=build_full_frontend_file_url(onyx_file_id),
|
||||
)
|
||||
)
|
||||
generated_file_ids.append(onyx_file_id)
|
||||
|
||||
# Mark for cleanup
|
||||
file_ids_to_cleanup.append(workspace_file.file_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to handle generated file "
|
||||
f"{workspace_file.path}: {e}"
|
||||
)
|
||||
|
||||
# Cleanup Code Interpreter files (generated files)
|
||||
for ci_file_id in file_ids_to_cleanup:
|
||||
try:
|
||||
client.delete_file(ci_file_id)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete Code Interpreter generated "
|
||||
f"file {ci_file_id}: {e}"
|
||||
)
|
||||
|
||||
# Cleanup staged input files
|
||||
for file_mapping in files_to_stage:
|
||||
try:
|
||||
client.delete_file(file_mapping["file_id"])
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete Code Interpreter staged "
|
||||
f"file {file_mapping['file_id']}: {e}"
|
||||
)
|
||||
|
||||
# Emit file_ids once files are processed
|
||||
if generated_file_ids:
|
||||
for event in client.execute_streaming(
|
||||
code=code,
|
||||
timeout_ms=CODE_INTERPRETER_DEFAULT_TIMEOUT_MS,
|
||||
files=files_to_stage or None,
|
||||
):
|
||||
if isinstance(event, StreamOutputEvent):
|
||||
if event.stream == "stdout":
|
||||
stdout_parts.append(event.data)
|
||||
else:
|
||||
stderr_parts.append(event.data)
|
||||
# Emit incremental delta to frontend
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
obj=PythonToolDelta(file_ids=generated_file_ids),
|
||||
obj=PythonToolDelta(
|
||||
stdout=event.data if event.stream == "stdout" else "",
|
||||
stderr=event.data if event.stream == "stderr" else "",
|
||||
),
|
||||
)
|
||||
)
|
||||
elif isinstance(event, StreamResultEvent):
|
||||
result_event = event
|
||||
elif isinstance(event, StreamErrorEvent):
|
||||
raise RuntimeError(f"Code interpreter error: {event.message}")
|
||||
|
||||
# Build result
|
||||
result = LlmPythonExecutionResult(
|
||||
stdout=truncated_stdout,
|
||||
stderr=truncated_stderr,
|
||||
exit_code=result_event.exit_code,
|
||||
timed_out=result_event.timed_out,
|
||||
generated_files=generated_files,
|
||||
error=(None if result_event.exit_code == 0 else truncated_stderr),
|
||||
if result_event is None:
|
||||
raise RuntimeError(
|
||||
"Code interpreter stream ended without a result event"
|
||||
)
|
||||
|
||||
# Serialize result for LLM
|
||||
adapter = TypeAdapter(LlmPythonExecutionResult)
|
||||
llm_response = adapter.dump_json(result).decode()
|
||||
full_stdout = "".join(stdout_parts)
|
||||
full_stderr = "".join(stderr_parts)
|
||||
|
||||
return ToolResponse(
|
||||
rich_response=PythonToolRichResponse(
|
||||
generated_files=generated_files,
|
||||
),
|
||||
llm_facing_response=llm_response,
|
||||
)
|
||||
# Truncate output for LLM consumption
|
||||
truncated_stdout = _truncate_output(
|
||||
full_stdout, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stdout"
|
||||
)
|
||||
truncated_stderr = _truncate_output(
|
||||
full_stderr, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stderr"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Python execution failed: {e}")
|
||||
error_msg = str(e)
|
||||
# Handle generated files
|
||||
generated_files: list[PythonExecutionFile] = []
|
||||
generated_file_ids: list[str] = []
|
||||
file_ids_to_cleanup: list[str] = []
|
||||
file_store = get_default_file_store()
|
||||
|
||||
# Emit error delta
|
||||
for workspace_file in result_event.files:
|
||||
if workspace_file.kind != "file" or not workspace_file.file_id:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Download file from Code Interpreter
|
||||
file_content = client.download_file(workspace_file.file_id)
|
||||
|
||||
# Determine MIME type from file extension
|
||||
filename = workspace_file.path.split("/")[-1]
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
# Default to binary if we can't determine the type
|
||||
mime_type = mime_type or "application/octet-stream"
|
||||
|
||||
# Save to Onyx file store
|
||||
onyx_file_id = file_store.save_file(
|
||||
content=BytesIO(file_content),
|
||||
display_name=filename,
|
||||
file_origin=FileOrigin.CHAT_UPLOAD,
|
||||
file_type=mime_type,
|
||||
)
|
||||
|
||||
generated_files.append(
|
||||
PythonExecutionFile(
|
||||
filename=filename,
|
||||
file_link=build_full_frontend_file_url(onyx_file_id),
|
||||
)
|
||||
)
|
||||
generated_file_ids.append(onyx_file_id)
|
||||
|
||||
# Mark for cleanup
|
||||
file_ids_to_cleanup.append(workspace_file.file_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to handle generated file {workspace_file.path}: {e}"
|
||||
)
|
||||
|
||||
# Cleanup Code Interpreter files (generated files)
|
||||
for ci_file_id in file_ids_to_cleanup:
|
||||
try:
|
||||
client.delete_file(ci_file_id)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete Code Interpreter generated file {ci_file_id}: {e}"
|
||||
)
|
||||
|
||||
# Cleanup staged input files
|
||||
for file_mapping in files_to_stage:
|
||||
try:
|
||||
client.delete_file(file_mapping["file_id"])
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete Code Interpreter staged file {file_mapping['file_id']}: {e}"
|
||||
)
|
||||
|
||||
# Emit file_ids once files are processed
|
||||
if generated_file_ids:
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
obj=PythonToolDelta(
|
||||
stdout="",
|
||||
stderr=error_msg,
|
||||
file_ids=[],
|
||||
),
|
||||
obj=PythonToolDelta(file_ids=generated_file_ids),
|
||||
)
|
||||
)
|
||||
|
||||
# Return error result
|
||||
result = LlmPythonExecutionResult(
|
||||
stdout="",
|
||||
stderr=error_msg,
|
||||
exit_code=-1,
|
||||
timed_out=False,
|
||||
generated_files=[],
|
||||
error=error_msg,
|
||||
)
|
||||
# Build result
|
||||
result = LlmPythonExecutionResult(
|
||||
stdout=truncated_stdout,
|
||||
stderr=truncated_stderr,
|
||||
exit_code=result_event.exit_code,
|
||||
timed_out=result_event.timed_out,
|
||||
generated_files=generated_files,
|
||||
error=None if result_event.exit_code == 0 else truncated_stderr,
|
||||
)
|
||||
|
||||
adapter = TypeAdapter(LlmPythonExecutionResult)
|
||||
llm_response = adapter.dump_json(result).decode()
|
||||
# Serialize result for LLM
|
||||
adapter = TypeAdapter(LlmPythonExecutionResult)
|
||||
llm_response = adapter.dump_json(result).decode()
|
||||
|
||||
return ToolResponse(
|
||||
rich_response=None,
|
||||
llm_facing_response=llm_response,
|
||||
return ToolResponse(
|
||||
rich_response=None, # No rich response needed for Python tool
|
||||
llm_facing_response=llm_response,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Python execution failed: {e}")
|
||||
error_msg = str(e)
|
||||
|
||||
# Emit error delta
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
obj=PythonToolDelta(
|
||||
stdout="",
|
||||
stderr=error_msg,
|
||||
file_ids=[],
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Return error result
|
||||
result = LlmPythonExecutionResult(
|
||||
stdout="",
|
||||
stderr=error_msg,
|
||||
exit_code=-1,
|
||||
timed_out=False,
|
||||
generated_files=[],
|
||||
error=error_msg,
|
||||
)
|
||||
|
||||
adapter = TypeAdapter(LlmPythonExecutionResult)
|
||||
llm_response = adapter.dump_json(result).decode()
|
||||
|
||||
return ToolResponse(
|
||||
rich_response=None,
|
||||
llm_facing_response=llm_response,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from typing_extensions import override
|
||||
@@ -56,30 +57,6 @@ def _sanitize_query(query: str) -> str:
|
||||
return " ".join(sanitized.split())
|
||||
|
||||
|
||||
def _normalize_queries_input(raw: Any) -> list[str]:
|
||||
"""Coerce LLM output to a list of sanitized query strings.
|
||||
|
||||
Accepts a bare string or a list (possibly with non-string elements).
|
||||
Sanitizes each query (strip control chars, normalize whitespace) and
|
||||
drops empty or whitespace-only entries.
|
||||
"""
|
||||
if isinstance(raw, str):
|
||||
raw = raw.strip()
|
||||
if not raw:
|
||||
return []
|
||||
raw = [raw]
|
||||
elif not isinstance(raw, list):
|
||||
return []
|
||||
result: list[str] = []
|
||||
for q in raw:
|
||||
if q is None:
|
||||
continue
|
||||
sanitized = _sanitize_query(str(q))
|
||||
if sanitized:
|
||||
result.append(sanitized)
|
||||
return result
|
||||
|
||||
|
||||
class WebSearchTool(Tool[WebSearchToolOverrideKwargs]):
|
||||
NAME = "web_search"
|
||||
DESCRIPTION = "Search the web for information."
|
||||
@@ -212,7 +189,13 @@ class WebSearchTool(Tool[WebSearchToolOverrideKwargs]):
|
||||
f'like: {{"queries": ["your search query here"]}}'
|
||||
),
|
||||
)
|
||||
queries = _normalize_queries_input(llm_kwargs[QUERIES_FIELD])
|
||||
raw_queries = cast(list[str], llm_kwargs[QUERIES_FIELD])
|
||||
|
||||
# Normalize queries:
|
||||
# - remove control characters (null bytes, etc.) that LLMs sometimes produce
|
||||
# - collapse whitespace and strip
|
||||
# - drop empty/whitespace-only queries
|
||||
queries = [sanitized for q in raw_queries if (sanitized := _sanitize_query(q))]
|
||||
if not queries:
|
||||
raise ToolCallException(
|
||||
message=(
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user