mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-16 13:12:41 +00:00
Compare commits
89 Commits
nikg/std-e
...
v3.0.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
79b615db46 | ||
|
|
98756bccd4 | ||
|
|
418f84ccdf | ||
|
|
d37756a884 | ||
|
|
9cdc92441b | ||
|
|
b8ed30644a | ||
|
|
d7d19e5a28 | ||
|
|
948650829d | ||
|
|
b6e689be0f | ||
|
|
85877408c8 | ||
|
|
c00df75c79 | ||
|
|
6352c9a09e | ||
|
|
3065f70d7d | ||
|
|
4befbc49dc | ||
|
|
ae9679e8c4 | ||
|
|
ea0ddee5c8 | ||
|
|
2826405dd2 | ||
|
|
8485bf4368 | ||
|
|
7bb52b0839 | ||
|
|
85a54c01f1 | ||
|
|
e4577bd564 | ||
|
|
f150a7b940 | ||
|
|
f1df36e306 | ||
|
|
1611604269 | ||
|
|
c2a71091dc | ||
|
|
cc008699e5 | ||
|
|
48802618db | ||
|
|
6917953b86 | ||
|
|
e7cf027f8a | ||
|
|
41fb1480bb | ||
|
|
bdc2bfdcee | ||
|
|
8816d52b27 | ||
|
|
6590f1d7ba | ||
|
|
c527f75557 | ||
|
|
472d1788a7 | ||
|
|
99e95f8205 | ||
|
|
e618bf8385 | ||
|
|
f4dcd130ba | ||
|
|
910718deaa | ||
|
|
1a7ca93b93 | ||
|
|
a615a920cb | ||
|
|
29d8b310b5 | ||
|
|
d1409ccafa | ||
|
|
e41bad9103 | ||
|
|
661dc831dc | ||
|
|
19016dd35a | ||
|
|
127b2dcc80 | ||
|
|
b015a37cea | ||
|
|
b45277a8b0 | ||
|
|
893e8da79a | ||
|
|
a51f0d7cb2 | ||
|
|
c826d0469e | ||
|
|
0f6ae6f69c | ||
|
|
d0836e2603 | ||
|
|
bda03bafca | ||
|
|
376adff94a | ||
|
|
d2d4b89286 | ||
|
|
dde7a18bb7 | ||
|
|
3f004cf02f | ||
|
|
ae893079c3 | ||
|
|
189c07a913 | ||
|
|
2b82743bf5 | ||
|
|
ba2a5a60e1 | ||
|
|
5888f9d69f | ||
|
|
23b3a0a6ae | ||
|
|
eced88fa7a | ||
|
|
f59aaa902d | ||
|
|
57349bdbd1 | ||
|
|
192639a801 | ||
|
|
c10ffbb464 | ||
|
|
091f41fd1f | ||
|
|
45d77be4eb | ||
|
|
413fa85134 | ||
|
|
108cde4f55 | ||
|
|
f88ce32bd4 | ||
|
|
911f3439ea | ||
|
|
b02590d2b2 | ||
|
|
2d75b4b1f8 | ||
|
|
7e3f7d01c2 | ||
|
|
9d6ce26ea3 | ||
|
|
41713d42a2 | ||
|
|
8afc283410 | ||
|
|
b5c873077e | ||
|
|
20a4dd32eb | ||
|
|
fde0d44bc1 | ||
|
|
8fd91b6e83 | ||
|
|
8247fdd45b | ||
|
|
8c5859ba4d | ||
|
|
62ef6f59bb |
186
.cursor/skills/onyx-cli/SKILL.md
Normal file
186
.cursor/skills/onyx-cli/SKILL.md
Normal file
@@ -0,0 +1,186 @@
|
||||
---
|
||||
name: onyx-cli
|
||||
description: Query the Onyx knowledge base using the onyx-cli command. Use when the user wants to search company documents, ask questions about internal knowledge, query connected data sources, or look up information stored in Onyx.
|
||||
---
|
||||
|
||||
# Onyx CLI — Agent Tool
|
||||
|
||||
Onyx is an enterprise search and Gen-AI platform that connects to company documents, apps, and people. The `onyx-cli` CLI provides non-interactive commands to query the Onyx knowledge base and list available agents.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### 1. Check if installed
|
||||
|
||||
```bash
|
||||
which onyx-cli
|
||||
```
|
||||
|
||||
### 2. Install (if needed)
|
||||
|
||||
**Primary — pip:**
|
||||
|
||||
```bash
|
||||
pip install onyx-cli
|
||||
```
|
||||
|
||||
**From source (Go):**
|
||||
|
||||
```bash
|
||||
cd cli && go build -o onyx-cli . && sudo mv onyx-cli /usr/local/bin/
|
||||
```
|
||||
|
||||
### 3. Check if configured
|
||||
|
||||
```bash
|
||||
onyx-cli validate-config
|
||||
```
|
||||
|
||||
This checks the config file exists, API key is present, and tests the server connection via `/api/me`. Exit code 0 on success, non-zero with a descriptive error on failure.
|
||||
|
||||
If unconfigured, you have two options:
|
||||
|
||||
**Option A — Interactive setup (requires user input):**
|
||||
|
||||
```bash
|
||||
onyx-cli configure
|
||||
```
|
||||
|
||||
This prompts for the Onyx server URL and API key, tests the connection, and saves config.
|
||||
|
||||
**Option B — Environment variables (non-interactive, preferred for agents):**
|
||||
|
||||
```bash
|
||||
export ONYX_SERVER_URL="https://your-onyx-server.com" # default: https://cloud.onyx.app
|
||||
export ONYX_API_KEY="your-api-key"
|
||||
```
|
||||
|
||||
Environment variables override the config file. If these are set, no config file is needed.
|
||||
|
||||
| Variable | Required | Description |
|
||||
|----------|----------|-------------|
|
||||
| `ONYX_SERVER_URL` | No | Onyx server base URL (default: `https://cloud.onyx.app`) |
|
||||
| `ONYX_API_KEY` | Yes | API key for authentication |
|
||||
| `ONYX_PERSONA_ID` | No | Default agent/persona ID |
|
||||
|
||||
If neither the config file nor environment variables are set, tell the user that `onyx-cli` needs to be configured and ask them to either:
|
||||
- Run `onyx-cli configure` interactively, or
|
||||
- Set `ONYX_SERVER_URL` and `ONYX_API_KEY` environment variables
|
||||
|
||||
## Commands
|
||||
|
||||
### Validate configuration
|
||||
|
||||
```bash
|
||||
onyx-cli validate-config
|
||||
```
|
||||
|
||||
Checks config file exists, API key is present, and tests the server connection. Use this before `ask` or `agents` to confirm the CLI is properly set up.
|
||||
|
||||
### List available agents
|
||||
|
||||
```bash
|
||||
onyx-cli agents
|
||||
```
|
||||
|
||||
Prints a table of agent IDs, names, and descriptions. Use `--json` for structured output:
|
||||
|
||||
```bash
|
||||
onyx-cli agents --json
|
||||
```
|
||||
|
||||
Use agent IDs with `ask --agent-id` to query a specific agent.
|
||||
|
||||
### Basic query (plain text output)
|
||||
|
||||
```bash
|
||||
onyx-cli ask "What is our company's PTO policy?"
|
||||
```
|
||||
|
||||
Streams the answer as plain text to stdout. Exit code 0 on success, non-zero on error.
|
||||
|
||||
### JSON output (structured events)
|
||||
|
||||
```bash
|
||||
onyx-cli ask --json "What authentication methods do we support?"
|
||||
```
|
||||
|
||||
Outputs JSON-encoded parsed stream events (one object per line). Key event objects include message deltas, stop, errors, search-start, and citation payloads.
|
||||
|
||||
Each line is a JSON object with this envelope:
|
||||
|
||||
```json
|
||||
{"type": "<event_type>", "event": { ... }}
|
||||
```
|
||||
|
||||
| Event Type | Description |
|
||||
|------------|-------------|
|
||||
| `message_delta` | Content token — concatenate all `content` fields for the full answer |
|
||||
| `stop` | Stream complete |
|
||||
| `error` | Error with `error` message field |
|
||||
| `search_tool_start` | Onyx started searching documents |
|
||||
| `citation_info` | Source citation — see shape below |
|
||||
|
||||
`citation_info` event shape:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "citation_info",
|
||||
"event": {
|
||||
"citation_number": 1,
|
||||
"document_id": "abc123def456",
|
||||
"placement": {"turn_index": 0, "tab_index": 0, "sub_turn_index": null}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`placement` is metadata about where in the conversation the citation appeared and can be ignored for most use cases.
|
||||
|
||||
### Specify an agent
|
||||
|
||||
```bash
|
||||
onyx-cli ask --agent-id 5 "Summarize our Q4 roadmap"
|
||||
```
|
||||
|
||||
Uses a specific Onyx agent/persona instead of the default.
|
||||
|
||||
### All flags
|
||||
|
||||
| Flag | Type | Description |
|
||||
|------|------|-------------|
|
||||
| `--agent-id` | int | Agent ID to use (overrides default) |
|
||||
| `--json` | bool | Output raw NDJSON events instead of plain text |
|
||||
|
||||
## Statelessness
|
||||
|
||||
Each `onyx-cli ask` call creates an independent chat session. There is no built-in way to chain context across multiple `ask` invocations — every call starts fresh. If you need multi-turn conversation with memory, use the interactive TUI (`onyx-cli` or `onyx-cli chat`) instead.
|
||||
|
||||
## When to Use
|
||||
|
||||
Use `onyx-cli ask` when:
|
||||
|
||||
- The user asks about company-specific information (policies, docs, processes)
|
||||
- You need to search internal knowledge bases or connected data sources
|
||||
- The user references Onyx, asks you to "search Onyx", or wants to query their documents
|
||||
- You need context from company wikis, Confluence, Google Drive, Slack, or other connected sources
|
||||
|
||||
Do NOT use when:
|
||||
|
||||
- The question is about general programming knowledge (use your own knowledge)
|
||||
- The user is asking about code in the current repository (use grep/read tools)
|
||||
- The user hasn't mentioned Onyx and the question doesn't require internal company data
|
||||
|
||||
## Examples
|
||||
|
||||
```bash
|
||||
# Simple question
|
||||
onyx-cli ask "What are the steps to deploy to production?"
|
||||
|
||||
# Get structured output for parsing
|
||||
onyx-cli ask --json "List all active API integrations"
|
||||
|
||||
# Use a specialized agent
|
||||
onyx-cli ask --agent-id 3 "What were the action items from last week's standup?"
|
||||
|
||||
# Pipe the answer into another command
|
||||
onyx-cli ask "What is the database schema for users?" | head -20
|
||||
```
|
||||
96
.github/workflows/deployment.yml
vendored
96
.github/workflows/deployment.yml
vendored
@@ -182,9 +182,53 @@ jobs:
|
||||
title: "🚨 Version Tag Check Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
|
||||
build-desktop:
|
||||
# Create GitHub release first, before desktop builds start.
|
||||
# This ensures all desktop matrix jobs upload to the same release instead of
|
||||
# racing to create duplicate releases.
|
||||
create-release:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-desktop == 'true'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 10
|
||||
permissions:
|
||||
contents: write
|
||||
outputs:
|
||||
release-id: ${{ steps.create-release.outputs.id }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Determine release tag
|
||||
id: release-tag
|
||||
env:
|
||||
IS_TEST_RUN: ${{ needs.determine-builds.outputs.is-test-run }}
|
||||
SHORT_SHA: ${{ needs.determine-builds.outputs.short-sha }}
|
||||
run: |
|
||||
if [ "${IS_TEST_RUN}" == "true" ]; then
|
||||
echo "tag=v0.0.0-dev+${SHORT_SHA}" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "tag=${GITHUB_REF_NAME}" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: Create GitHub Release
|
||||
id: create-release
|
||||
uses: softprops/action-gh-release@da05d552573ad5aba039eaac05058a918a7bf631 # ratchet:softprops/action-gh-release@v2
|
||||
with:
|
||||
tag_name: ${{ steps.release-tag.outputs.tag }}
|
||||
name: ${{ steps.release-tag.outputs.tag }}
|
||||
body: "See the assets to download this version and install."
|
||||
draft: true
|
||||
prerelease: false
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
build-desktop:
|
||||
needs:
|
||||
- determine-builds
|
||||
- create-release
|
||||
if: needs.determine-builds.outputs.build-desktop == 'true'
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: write
|
||||
@@ -208,12 +252,12 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6.0.2
|
||||
with:
|
||||
# NOTE: persist-credentials is needed for tauri-action to create GitHub releases.
|
||||
# NOTE: persist-credentials is needed for tauri-action to upload assets to GitHub releases.
|
||||
persist-credentials: true # zizmor: ignore[artipacked]
|
||||
|
||||
- name: Configure AWS credentials
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -353,11 +397,9 @@ jobs:
|
||||
APPLE_SIGNING_IDENTITY: ${{ env.CERT_ID }}
|
||||
APPLE_TEAM_ID: ${{ env.APPLE_TEAM_ID }}
|
||||
with:
|
||||
tagName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseBody: "See the assets to download this version and install."
|
||||
releaseDraft: true
|
||||
prerelease: false
|
||||
# Use the release created by the create-release job to avoid race conditions
|
||||
# when multiple matrix jobs try to create/update the same release simultaneously
|
||||
releaseId: ${{ needs.create-release.outputs.release-id }}
|
||||
assetNamePattern: "[name]_[arch][ext]"
|
||||
args: ${{ matrix.args }}
|
||||
|
||||
@@ -384,7 +426,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -458,7 +500,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -527,7 +569,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -597,7 +639,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -679,7 +721,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -756,7 +798,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -823,7 +865,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -896,7 +938,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -964,7 +1006,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1034,7 +1076,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1107,7 +1149,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1176,7 +1218,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1246,7 +1288,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1326,7 +1368,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1400,7 +1442,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1465,7 +1507,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1520,7 +1562,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1580,7 +1622,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1637,7 +1679,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
@@ -15,7 +15,8 @@ permissions:
|
||||
jobs:
|
||||
provider-chat-test:
|
||||
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
|
||||
secrets: inherit
|
||||
secrets:
|
||||
AWS_OIDC_ROLE_ARN: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
|
||||
@@ -6,11 +6,13 @@ on:
|
||||
- main
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
cherry-pick-to-latest-release:
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
outputs:
|
||||
should_cherrypick: ${{ steps.gate.outputs.should_cherrypick }}
|
||||
pr_number: ${{ steps.gate.outputs.pr_number }}
|
||||
|
||||
2
.github/workflows/pr-desktop-build.yml
vendored
2
.github/workflows/pr-desktop-build.yml
vendored
@@ -57,7 +57,7 @@ jobs:
|
||||
cache-dependency-path: ./desktop/package-lock.json
|
||||
|
||||
- name: Setup Rust
|
||||
uses: dtolnay/rust-toolchain@4be9e76fd7c4901c61fb841f559994984270fce7
|
||||
uses: dtolnay/rust-toolchain@efa25f7f19611383d5b0ccf2d1c8914531636bf9
|
||||
with:
|
||||
toolchain: stable
|
||||
targets: ${{ matrix.target }}
|
||||
|
||||
56
.github/workflows/pr-golang-tests.yml
vendored
Normal file
56
.github/workflows/pr-golang-tests.yml
vendored
Normal file
@@ -0,0 +1,56 @@
|
||||
name: Golang Tests
|
||||
concurrency:
|
||||
group: Golang-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- "release/**"
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
permissions: {}
|
||||
|
||||
env:
|
||||
GO_VERSION: "1.26"
|
||||
|
||||
jobs:
|
||||
detect-modules:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
outputs:
|
||||
modules: ${{ steps.set-modules.outputs.modules }}
|
||||
steps:
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8
|
||||
with:
|
||||
persist-credentials: false
|
||||
- id: set-modules
|
||||
run: echo "modules=$(find . -name 'go.mod' -exec dirname {} \; | jq -Rc '[.,inputs]')" >> "$GITHUB_OUTPUT"
|
||||
|
||||
golang:
|
||||
needs: detect-modules
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
strategy:
|
||||
matrix:
|
||||
modules: ${{ fromJSON(needs.detect-modules.outputs.modules) }}
|
||||
steps:
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # zizmor: ignore[cache-poisoning]
|
||||
with:
|
||||
go-version: ${{ env.GO_VERSION }}
|
||||
cache-dependency-path: "**/go.sum"
|
||||
|
||||
- run: go mod tidy
|
||||
working-directory: ${{ matrix.modules }}
|
||||
- run: git diff --exit-code go.mod go.sum
|
||||
working-directory: ${{ matrix.modules }}
|
||||
|
||||
- run: go test ./...
|
||||
working-directory: ${{ matrix.modules }}
|
||||
2
.github/workflows/pr-helm-chart-testing.yml
vendored
2
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -71,7 +71,7 @@ jobs:
|
||||
|
||||
- name: Create kind cluster
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@92086f6be054225fa813e0a4b13787fc9088faab # ratchet:helm/kind-action@v1.13.0
|
||||
uses: helm/kind-action@ef37e7f390d99f746eb8b610417061a60e82a6cc # ratchet:helm/kind-action@v1.14.0
|
||||
|
||||
- name: Pre-install cluster status check
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
|
||||
4
.github/workflows/pr-integration-tests.yml
vendored
4
.github/workflows/pr-integration-tests.yml
vendored
@@ -316,6 +316,7 @@ jobs:
|
||||
# Base config shared by both editions
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
OPENSEARCH_FOR_ONYX_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
@@ -418,6 +419,7 @@ jobs:
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e ENABLE_OPENSEARCH_INDEXING_FOR_ONYX=false \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
@@ -637,6 +639,7 @@ jobs:
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
|
||||
DEV_MODE=true \
|
||||
OPENSEARCH_FOR_ONYX_ENABLED=false \
|
||||
docker compose -f docker-compose.multitenant-dev.yml up \
|
||||
relational_db \
|
||||
index \
|
||||
@@ -691,6 +694,7 @@ jobs:
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e ENABLE_OPENSEARCH_INDEXING_FOR_ONYX=false \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
|
||||
2
.github/workflows/pr-jest-tests.yml
vendored
2
.github/workflows/pr-jest-tests.yml
vendored
@@ -31,7 +31,7 @@ jobs:
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
cache: "npm" # zizmor: ignore[cache-poisoning] test-only workflow; no deploy artifacts
|
||||
cache-dependency-path: ./web/package-lock.json
|
||||
|
||||
- name: Install node dependencies
|
||||
|
||||
5
.github/workflows/pr-playwright-tests.yml
vendored
5
.github/workflows/pr-playwright-tests.yml
vendored
@@ -12,6 +12,9 @@ on:
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
# TODO: Remove this if we enable merge-queues for release branches.
|
||||
branches:
|
||||
- "release/**"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -461,7 +464,7 @@ jobs:
|
||||
# --- Visual Regression Diff ---
|
||||
- name: Configure AWS credentials
|
||||
if: always()
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
4
.github/workflows/pr-quality-checks.yml
vendored
4
.github/workflows/pr-quality-checks.yml
vendored
@@ -38,9 +38,9 @@ jobs:
|
||||
- name: Install node dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
- uses: j178/prek-action@9d6a3097e0c1865ecce00cfb89fe80f2ee91b547 # ratchet:j178/prek-action@v1
|
||||
- uses: j178/prek-action@0bb87d7f00b0c99306c8bcb8b8beba1eb581c037 # ratchet:j178/prek-action@v1
|
||||
with:
|
||||
prek-version: '0.2.21'
|
||||
prek-version: '0.3.4'
|
||||
extra-args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || github.event_name == 'merge_group' && format('--from-ref {0} --to-ref {1}', github.event.merge_group.base_sha, github.event.merge_group.head_sha) || github.ref_name == 'main' && '--all-files' || '' }}
|
||||
- name: Check Actions
|
||||
uses: giner/check-actions@28d366c7cbbe235f9624a88aa31a628167eee28c # ratchet:giner/check-actions@v1.0.1
|
||||
|
||||
214
.github/workflows/release-cli.yml
vendored
Normal file
214
.github/workflows/release-cli.yml
vendored
Normal file
@@ -0,0 +1,214 @@
|
||||
name: Release CLI
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "cli/v*.*.*"
|
||||
|
||||
jobs:
|
||||
pypi:
|
||||
runs-on: ubuntu-latest
|
||||
environment:
|
||||
name: release-cli
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 10
|
||||
strategy:
|
||||
matrix:
|
||||
os-arch:
|
||||
- { goos: "linux", goarch: "amd64" }
|
||||
- { goos: "linux", goarch: "arm64" }
|
||||
- { goos: "windows", goarch: "amd64" }
|
||||
- { goos: "windows", goarch: "arm64" }
|
||||
- { goos: "darwin", goarch: "amd64" }
|
||||
- { goos: "darwin", goarch: "arm64" }
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
- run: |
|
||||
GOOS="${{ matrix.os-arch.goos }}" \
|
||||
GOARCH="${{ matrix.os-arch.goarch }}" \
|
||||
uv build --wheel
|
||||
working-directory: cli
|
||||
- run: uv publish
|
||||
working-directory: cli
|
||||
|
||||
docker-amd64:
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-cli-amd64
|
||||
- extras=ecr-cache
|
||||
environment: deploy
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 30
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-cli
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./cli
|
||||
file: ./cli/Dockerfile
|
||||
platforms: linux/amd64
|
||||
cache-from: type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: type=inline
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
docker-arm64:
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-cli-arm64
|
||||
- extras=ecr-cache
|
||||
environment: deploy
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 30
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-cli
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./cli
|
||||
file: ./cli/Dockerfile
|
||||
platforms: linux/arm64
|
||||
cache-from: type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: type=inline
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
merge-docker:
|
||||
needs:
|
||||
- docker-amd64
|
||||
- docker-arm64
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-cli-merge
|
||||
environment: deploy
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 10
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-cli
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
AMD64_DIGEST: ${{ needs.docker-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.docker-arm64.outputs.digest }}
|
||||
TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
SANITIZED_TAG="${TAG#cli/}"
|
||||
IMAGES=(
|
||||
"${REGISTRY_IMAGE}@${AMD64_DIGEST}"
|
||||
"${REGISTRY_IMAGE}@${ARM64_DIGEST}"
|
||||
)
|
||||
|
||||
if [[ "$TAG" =~ ^cli/v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||
docker buildx imagetools create \
|
||||
-t "${REGISTRY_IMAGE}:${SANITIZED_TAG}" \
|
||||
-t "${REGISTRY_IMAGE}:latest" \
|
||||
"${IMAGES[@]}"
|
||||
else
|
||||
docker buildx imagetools create \
|
||||
-t "${REGISTRY_IMAGE}:${SANITIZED_TAG}" \
|
||||
"${IMAGES[@]}"
|
||||
fi
|
||||
2
.github/workflows/release-devtools.yml
vendored
2
.github/workflows/release-devtools.yml
vendored
@@ -22,12 +22,10 @@ jobs:
|
||||
- { goos: "windows", goarch: "arm64" }
|
||||
- { goos: "darwin", goarch: "amd64" }
|
||||
- { goos: "darwin", goarch: "arm64" }
|
||||
- { goos: "", goarch: "" }
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
- uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
|
||||
@@ -48,6 +48,10 @@ on:
|
||||
required: false
|
||||
default: true
|
||||
type: boolean
|
||||
secrets:
|
||||
AWS_OIDC_ROLE_ARN:
|
||||
description: "AWS role ARN for OIDC auth"
|
||||
required: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -73,7 +77,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -116,7 +120,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -158,7 +162,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -264,7 +268,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
6
.github/workflows/sandbox-deployment.yml
vendored
6
.github/workflows/sandbox-deployment.yml
vendored
@@ -110,7 +110,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -180,7 +180,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -244,7 +244,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
@@ -119,10 +119,11 @@ repos:
|
||||
]
|
||||
|
||||
- repo: https://github.com/golangci/golangci-lint
|
||||
rev: 9f61b0f53f80672872fced07b6874397c3ed197b # frozen: v2.7.2
|
||||
rev: 5d1e709b7be35cb2025444e19de266b056b7b7ee # frozen: v2.10.1
|
||||
hooks:
|
||||
- id: golangci-lint
|
||||
entry: bash -c "find tools/ -name go.mod -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
|
||||
language_version: "1.26.0"
|
||||
entry: bash -c "find . -name go.mod -not -path './.venv/*' -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
|
||||
@@ -104,6 +104,10 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
|
||||
|
||||
- Always use `@shared_task` rather than `@celery_app`
|
||||
- Put tasks under `background/celery/tasks/` or `ee/background/celery/tasks`
|
||||
- Never enqueue a task without an expiration. Always supply `expires=` when
|
||||
sending tasks, either from the beat schedule or directly from another task. It
|
||||
should never be acceptable to submit code which enqueues tasks without an
|
||||
expiration, as doing so can lead to unbounded task queue growth.
|
||||
|
||||
**Defining APIs**:
|
||||
When creating new FastAPI APIs, do NOT use the `response_model` field. Instead, just type the
|
||||
@@ -540,6 +544,8 @@ To run them:
|
||||
npx playwright test <TEST_NAME>
|
||||
```
|
||||
|
||||
For shared fixtures, best practices, and detailed guidance, see `backend/tests/README.md`.
|
||||
|
||||
## Logs
|
||||
|
||||
When (1) writing integration tests or (2) doing live tests (e.g. curl / playwright) you can get access
|
||||
@@ -592,7 +598,7 @@ Before writing your plan, make sure to do research. Explore the relevant section
|
||||
Never hardcode status codes or use `starlette.status` / `fastapi.status` constants directly.**
|
||||
|
||||
A global FastAPI exception handler converts `OnyxError` into a JSON response with the standard
|
||||
`{"error_code": "...", "message": "..."}` shape. This eliminates boilerplate and keeps error
|
||||
`{"error_code": "...", "detail": "..."}` shape. This eliminates boilerplate and keeps error
|
||||
handling consistent across the entire backend.
|
||||
|
||||
```python
|
||||
|
||||
@@ -68,6 +68,7 @@ def get_external_access_for_raw_gdrive_file(
|
||||
company_domain: str,
|
||||
retriever_drive_service: GoogleDriveService | None,
|
||||
admin_drive_service: GoogleDriveService,
|
||||
fallback_user_email: str,
|
||||
add_prefix: bool = False,
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
@@ -79,6 +80,11 @@ def get_external_access_for_raw_gdrive_file(
|
||||
set add_prefix to True so group IDs are prefixed with the source type.
|
||||
When invoked from doc_sync (permission sync), use the default (False)
|
||||
since upsert_document_external_perms handles prefixing.
|
||||
fallback_user_email: When we cannot retrieve any permission info for a file
|
||||
(e.g. externally-owned files where the API returns no permissions
|
||||
and permissions.list returns 403), fall back to granting access
|
||||
to this user. This is typically the impersonated org user whose
|
||||
drive contained the file.
|
||||
"""
|
||||
doc_id = file.get("id")
|
||||
if not doc_id:
|
||||
@@ -117,6 +123,26 @@ def get_external_access_for_raw_gdrive_file(
|
||||
[permissions_list, backup_permissions_list]
|
||||
)
|
||||
|
||||
# For externally-owned files, the Drive API may return no permissions
|
||||
# and permissions.list may return 403. In this case, fall back to
|
||||
# granting access to the user who found the file in their drive.
|
||||
# Note, even if other users also have access to this file,
|
||||
# they will not be granted access in Onyx.
|
||||
# We check permissions_list (the final result after all fetch attempts)
|
||||
# rather than the raw fields, because permission_ids may be present
|
||||
# but the actual fetch can still return empty due to a 403.
|
||||
if not permissions_list:
|
||||
logger.info(
|
||||
f"No permission info available for file {doc_id} "
|
||||
f"(likely owned by a user outside of your organization). "
|
||||
f"Falling back to granting access to retriever user: {fallback_user_email}"
|
||||
)
|
||||
return ExternalAccess(
|
||||
external_user_emails={fallback_user_email},
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
folder_ids_to_inherit_permissions_from: set[str] = set()
|
||||
user_emails: set[str] = set()
|
||||
group_emails: set[str] = set()
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from jira import JIRA
|
||||
from jira.exceptions import JIRAError
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from onyx.connectors.jira.utils import build_jira_client
|
||||
@@ -9,107 +11,102 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_ATLASSIAN_ACCOUNT_TYPE = "atlassian"
|
||||
_GROUP_MEMBER_PAGE_SIZE = 50
|
||||
|
||||
def _get_jira_group_members_email(
|
||||
# The GET /group/member endpoint was introduced in Jira 6.0.
|
||||
# Jira versions older than 6.0 do not have group management REST APIs at all.
|
||||
_MIN_JIRA_VERSION_FOR_GROUP_MEMBER = "6.0"
|
||||
|
||||
|
||||
def _fetch_group_member_page(
|
||||
jira_client: JIRA,
|
||||
group_name: str,
|
||||
) -> list[str]:
|
||||
"""Get all member emails for a Jira group.
|
||||
start_at: int,
|
||||
) -> dict[str, Any]:
|
||||
"""Fetch a single page from the non-deprecated GET /group/member endpoint.
|
||||
|
||||
Filters out app accounts (bots, integrations) and only returns real user emails.
|
||||
The old GET /group endpoint (used by jira_client.group_members()) is deprecated
|
||||
and decommissioned in Jira Server 10.3+. This uses the replacement endpoint
|
||||
directly via the library's internal _get_json helper, following the same pattern
|
||||
as enhanced_search_ids / bulk_fetch_issues in connector.py.
|
||||
|
||||
There is an open PR to the library to switch to this endpoint since last year:
|
||||
https://github.com/pycontribs/jira/pull/2356
|
||||
so once it is merged and released, we can switch to using the library function.
|
||||
"""
|
||||
emails: list[str] = []
|
||||
|
||||
try:
|
||||
# group_members returns an OrderedDict of account_id -> member_info
|
||||
members = jira_client.group_members(group=group_name)
|
||||
|
||||
if not members:
|
||||
logger.warning(f"No members found for group {group_name}")
|
||||
return emails
|
||||
|
||||
for account_id, member_info in members.items():
|
||||
# member_info is a dict with keys like 'fullname', 'email', 'active'
|
||||
email = member_info.get("email")
|
||||
|
||||
# Skip "hidden" emails - these are typically app accounts
|
||||
if email and email != "hidden":
|
||||
emails.append(email)
|
||||
else:
|
||||
# For cloud, we might need to fetch user details separately
|
||||
try:
|
||||
user = jira_client.user(id=account_id)
|
||||
|
||||
# Skip app accounts (bots, integrations, etc.)
|
||||
if hasattr(user, "accountType") and user.accountType == "app":
|
||||
logger.info(
|
||||
f"Skipping app account {account_id} for group {group_name}"
|
||||
)
|
||||
continue
|
||||
|
||||
if hasattr(user, "emailAddress") and user.emailAddress:
|
||||
emails.append(user.emailAddress)
|
||||
else:
|
||||
logger.warning(f"User {account_id} has no email address")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not fetch email for user {account_id} in group {group_name}: {e}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching members for group {group_name}: {e}")
|
||||
|
||||
return emails
|
||||
return jira_client._get_json(
|
||||
"group/member",
|
||||
params={
|
||||
"groupname": group_name,
|
||||
"includeInactiveUsers": "false",
|
||||
"startAt": start_at,
|
||||
"maxResults": _GROUP_MEMBER_PAGE_SIZE,
|
||||
},
|
||||
)
|
||||
except JIRAError as e:
|
||||
if e.status_code == 404:
|
||||
raise RuntimeError(
|
||||
f"GET /group/member returned 404 for group '{group_name}'. "
|
||||
f"This endpoint requires Jira {_MIN_JIRA_VERSION_FOR_GROUP_MEMBER}+. "
|
||||
f"If you are running a self-hosted Jira instance, please upgrade "
|
||||
f"to at least Jira {_MIN_JIRA_VERSION_FOR_GROUP_MEMBER}."
|
||||
) from e
|
||||
raise
|
||||
|
||||
|
||||
def _build_group_member_email_map(
|
||||
def _get_group_member_emails(
|
||||
jira_client: JIRA,
|
||||
) -> dict[str, set[str]]:
|
||||
"""Build a map of group names to member emails."""
|
||||
group_member_emails: dict[str, set[str]] = {}
|
||||
group_name: str,
|
||||
) -> set[str]:
|
||||
"""Get all member emails for a single Jira group.
|
||||
|
||||
try:
|
||||
# Get all groups from Jira - returns a list of group name strings
|
||||
group_names = jira_client.groups()
|
||||
Uses the non-deprecated GET /group/member endpoint which returns full user
|
||||
objects including accountType, so we can filter out app/customer accounts
|
||||
without making separate user() calls.
|
||||
"""
|
||||
emails: set[str] = set()
|
||||
start_at = 0
|
||||
|
||||
if not group_names:
|
||||
logger.warning("No groups found in Jira")
|
||||
return group_member_emails
|
||||
while True:
|
||||
try:
|
||||
page = _fetch_group_member_page(jira_client, group_name, start_at)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching members for group {group_name}: {e}")
|
||||
raise
|
||||
|
||||
logger.info(f"Found {len(group_names)} groups in Jira")
|
||||
|
||||
for group_name in group_names:
|
||||
if not group_name:
|
||||
members: list[dict[str, Any]] = page.get("values", [])
|
||||
for member in members:
|
||||
account_type = member.get("accountType")
|
||||
# On Jira DC < 9.0, accountType is absent; include those users.
|
||||
# On Cloud / DC 9.0+, filter to real user accounts only.
|
||||
if account_type is not None and account_type != _ATLASSIAN_ACCOUNT_TYPE:
|
||||
continue
|
||||
|
||||
member_emails = _get_jira_group_members_email(
|
||||
jira_client=jira_client,
|
||||
group_name=group_name,
|
||||
)
|
||||
|
||||
if member_emails:
|
||||
group_member_emails[group_name] = set(member_emails)
|
||||
logger.debug(
|
||||
f"Found {len(member_emails)} members for group {group_name}"
|
||||
)
|
||||
email = member.get("emailAddress")
|
||||
if email:
|
||||
emails.add(email)
|
||||
else:
|
||||
logger.debug(f"No members found for group {group_name}")
|
||||
logger.warning(
|
||||
f"Atlassian user {member.get('accountId', 'unknown')} "
|
||||
f"in group {group_name} has no visible email address"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error building group member email map: {e}")
|
||||
if page.get("isLast", True) or not members:
|
||||
break
|
||||
start_at += len(members)
|
||||
|
||||
return group_member_emails
|
||||
return emails
|
||||
|
||||
|
||||
def jira_group_sync(
|
||||
tenant_id: str, # noqa: ARG001
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> Generator[ExternalUserGroup, None, None]:
|
||||
"""
|
||||
Sync Jira groups and their members.
|
||||
"""Sync Jira groups and their members, yielding one group at a time.
|
||||
|
||||
This function fetches all groups from Jira and yields ExternalUserGroup
|
||||
objects containing the group ID and member emails.
|
||||
Streams group-by-group rather than accumulating all groups in memory.
|
||||
"""
|
||||
jira_base_url = cc_pair.connector.connector_specific_config.get("jira_base_url", "")
|
||||
scoped_token = cc_pair.connector.connector_specific_config.get(
|
||||
@@ -130,12 +127,26 @@ def jira_group_sync(
|
||||
scoped_token=scoped_token,
|
||||
)
|
||||
|
||||
group_member_email_map = _build_group_member_email_map(jira_client=jira_client)
|
||||
if not group_member_email_map:
|
||||
raise ValueError(f"No groups with members found for cc_pair_id={cc_pair.id}")
|
||||
group_names = jira_client.groups()
|
||||
if not group_names:
|
||||
raise ValueError(f"No groups found for cc_pair_id={cc_pair.id}")
|
||||
|
||||
for group_id, group_member_emails in group_member_email_map.items():
|
||||
yield ExternalUserGroup(
|
||||
id=group_id,
|
||||
user_emails=list(group_member_emails),
|
||||
logger.info(f"Found {len(group_names)} groups in Jira")
|
||||
|
||||
for group_name in group_names:
|
||||
if not group_name:
|
||||
continue
|
||||
|
||||
member_emails = _get_group_member_emails(
|
||||
jira_client=jira_client,
|
||||
group_name=group_name,
|
||||
)
|
||||
if not member_emails:
|
||||
logger.debug(f"No members found for group {group_name}")
|
||||
continue
|
||||
|
||||
logger.debug(f"Found {len(member_emails)} members for group {group_name}")
|
||||
yield ExternalUserGroup(
|
||||
id=group_name,
|
||||
user_emails=list(member_emails),
|
||||
)
|
||||
|
||||
@@ -246,7 +246,11 @@ async def get_billing_information(
|
||||
)
|
||||
except OnyxError as e:
|
||||
# Open circuit breaker on connection failures (self-hosted only)
|
||||
if e.status_code in (502, 503, 504):
|
||||
if e.status_code in (
|
||||
OnyxErrorCode.BAD_GATEWAY.status_code,
|
||||
OnyxErrorCode.SERVICE_UNAVAILABLE.status_code,
|
||||
OnyxErrorCode.GATEWAY_TIMEOUT.status_code,
|
||||
):
|
||||
_open_billing_circuit()
|
||||
raise
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from onyx.db.models import Tool
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.server.settings.store import store_settings as store_base_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -125,10 +126,16 @@ def _seed_llms(
|
||||
existing = fetch_existing_llm_provider(name=request.name, db_session=db_session)
|
||||
if existing:
|
||||
request.id = existing.id
|
||||
seeded_providers = [
|
||||
upsert_llm_provider(llm_upsert_request, db_session)
|
||||
for llm_upsert_request in llm_upsert_requests
|
||||
]
|
||||
seeded_providers: list[LLMProviderView] = []
|
||||
for llm_upsert_request in llm_upsert_requests:
|
||||
try:
|
||||
seeded_providers.append(upsert_llm_provider(llm_upsert_request, db_session))
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
"Failed to upsert LLM provider '%s' during seeding: %s",
|
||||
llm_upsert_request.name,
|
||||
e,
|
||||
)
|
||||
|
||||
default_provider = next(
|
||||
(p for p in seeded_providers if p.model_configurations), None
|
||||
|
||||
@@ -58,8 +58,6 @@ from onyx.file_store.document_batch_storage import DocumentBatchStorage
|
||||
from onyx.file_store.document_batch_storage import get_document_batch_storage
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
|
||||
from onyx.indexing.postgres_sanitization import sanitize_document_for_postgres
|
||||
from onyx.indexing.postgres_sanitization import sanitize_hierarchy_nodes_for_postgres
|
||||
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
|
||||
from onyx.redis.redis_hierarchy import ensure_source_node_exists
|
||||
from onyx.redis.redis_hierarchy import get_node_id_from_raw_id
|
||||
@@ -71,6 +69,8 @@ from onyx.server.features.build.indexing.persistent_document_writer import (
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.middleware import make_randomized_onyx_request_id
|
||||
from onyx.utils.postgres_sanitization import sanitize_document_for_postgres
|
||||
from onyx.utils.postgres_sanitization import sanitize_hierarchy_nodes_for_postgres
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import INDEX_ATTEMPT_INFO_CONTEXTVAR
|
||||
|
||||
@@ -36,7 +36,6 @@ from onyx.db.memory import add_memory
|
||||
from onyx.db.memory import update_memory_at_index
|
||||
from onyx.db.memory import UserMemoryContext
|
||||
from onyx.db.models import Persona
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
@@ -51,6 +50,7 @@ from onyx.tools.built_in_tools import CITEABLE_TOOLS_NAMES
|
||||
from onyx.tools.built_in_tools import STOPPING_TOOLS_NAMES
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ChatFile
|
||||
from onyx.tools.models import CustomToolCallSummary
|
||||
from onyx.tools.models import MemoryToolResponseSnapshot
|
||||
from onyx.tools.models import PythonToolRichResponse
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
@@ -84,28 +84,6 @@ def _looks_like_xml_tool_call_payload(text: str | None) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def _should_keep_bedrock_tool_definitions(
|
||||
llm: object, simple_chat_history: list[ChatMessageSimple]
|
||||
) -> bool:
|
||||
"""Bedrock requires tool config when history includes toolUse/toolResult blocks."""
|
||||
model_provider = getattr(getattr(llm, "config", None), "model_provider", None)
|
||||
if model_provider not in {
|
||||
LlmProviderNames.BEDROCK,
|
||||
LlmProviderNames.BEDROCK_CONVERSE,
|
||||
}:
|
||||
return False
|
||||
|
||||
return any(
|
||||
(
|
||||
msg.message_type == MessageType.ASSISTANT
|
||||
and msg.tool_calls
|
||||
and len(msg.tool_calls) > 0
|
||||
)
|
||||
or msg.message_type == MessageType.TOOL_CALL_RESPONSE
|
||||
for msg in simple_chat_history
|
||||
)
|
||||
|
||||
|
||||
def _try_fallback_tool_extraction(
|
||||
llm_step_result: LlmStepResult,
|
||||
tool_choice: ToolChoiceOptions,
|
||||
@@ -686,12 +664,7 @@ def run_llm_loop(
|
||||
elif out_of_cycles or ran_image_gen:
|
||||
# Last cycle, no tools allowed, just answer!
|
||||
tool_choice = ToolChoiceOptions.NONE
|
||||
# Bedrock requires tool config in requests that include toolUse/toolResult history.
|
||||
final_tools = (
|
||||
tools
|
||||
if _should_keep_bedrock_tool_definitions(llm, simple_chat_history)
|
||||
else []
|
||||
)
|
||||
final_tools = []
|
||||
else:
|
||||
tool_choice = ToolChoiceOptions.AUTO
|
||||
final_tools = tools
|
||||
@@ -1008,6 +981,10 @@ def run_llm_loop(
|
||||
|
||||
if memory_snapshot:
|
||||
saved_response = json.dumps(memory_snapshot.model_dump())
|
||||
elif isinstance(tool_response.rich_response, CustomToolCallSummary):
|
||||
saved_response = json.dumps(
|
||||
tool_response.rich_response.model_dump()
|
||||
)
|
||||
elif isinstance(tool_response.rich_response, str):
|
||||
saved_response = tool_response.rich_response
|
||||
else:
|
||||
|
||||
@@ -55,6 +55,7 @@ from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tracing.framework.create import generation_span
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_string
|
||||
from onyx.utils.text_processing import find_all_json_objects
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -166,15 +167,6 @@ def _find_function_calls_open_marker(text_lower: str) -> int:
|
||||
search_from = idx + 1
|
||||
|
||||
|
||||
def _sanitize_llm_output(value: str) -> str:
|
||||
"""Remove characters that PostgreSQL's text/JSONB types cannot store.
|
||||
|
||||
- NULL bytes (\x00): Not allowed in PostgreSQL text types
|
||||
- UTF-16 surrogates (\ud800-\udfff): Invalid in UTF-8 encoding
|
||||
"""
|
||||
return "".join(c for c in value if c != "\x00" and not ("\ud800" <= c <= "\udfff"))
|
||||
|
||||
|
||||
def _try_parse_json_string(value: Any) -> Any:
|
||||
"""Attempt to parse a JSON string value into its Python equivalent.
|
||||
|
||||
@@ -222,9 +214,7 @@ def _parse_tool_args_to_dict(raw_args: Any) -> dict[str, Any]:
|
||||
if isinstance(raw_args, dict):
|
||||
# Parse any string values that look like JSON arrays/objects
|
||||
return {
|
||||
k: _try_parse_json_string(
|
||||
_sanitize_llm_output(v) if isinstance(v, str) else v
|
||||
)
|
||||
k: _try_parse_json_string(sanitize_string(v) if isinstance(v, str) else v)
|
||||
for k, v in raw_args.items()
|
||||
}
|
||||
|
||||
@@ -232,7 +222,7 @@ def _parse_tool_args_to_dict(raw_args: Any) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
# Sanitize before parsing to remove NULL bytes and surrogates
|
||||
raw_args = _sanitize_llm_output(raw_args)
|
||||
raw_args = sanitize_string(raw_args)
|
||||
|
||||
try:
|
||||
parsed1: Any = json.loads(raw_args)
|
||||
@@ -545,12 +535,12 @@ def _extract_xml_attribute(attrs: str, attr_name: str) -> str | None:
|
||||
)
|
||||
if not attr_match:
|
||||
return None
|
||||
return _sanitize_llm_output(unescape(attr_match.group(2).strip()))
|
||||
return sanitize_string(unescape(attr_match.group(2).strip()))
|
||||
|
||||
|
||||
def _parse_xml_parameter_value(raw_value: str, string_attr: str | None) -> Any:
|
||||
"""Parse a parameter value from XML-style tool call payloads."""
|
||||
value = _sanitize_llm_output(unescape(raw_value).strip())
|
||||
value = sanitize_string(unescape(raw_value).strip())
|
||||
|
||||
if string_attr and string_attr.lower() == "true":
|
||||
return value
|
||||
@@ -569,6 +559,7 @@ def _resolve_tool_arguments(obj: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""
|
||||
arguments = obj.get("arguments", obj.get("parameters", {}))
|
||||
if isinstance(arguments, str):
|
||||
arguments = sanitize_string(arguments)
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
|
||||
@@ -19,6 +19,7 @@ from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_string
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -201,8 +202,13 @@ def save_chat_turn(
|
||||
pre_answer_processing_time: Duration of processing before answer starts (in seconds)
|
||||
"""
|
||||
# 1. Update ChatMessage with message content, reasoning tokens, and token count
|
||||
assistant_message.message = message_text
|
||||
assistant_message.reasoning_tokens = reasoning_tokens
|
||||
sanitized_message_text = (
|
||||
sanitize_string(message_text) if message_text else message_text
|
||||
)
|
||||
assistant_message.message = sanitized_message_text
|
||||
assistant_message.reasoning_tokens = (
|
||||
sanitize_string(reasoning_tokens) if reasoning_tokens else reasoning_tokens
|
||||
)
|
||||
assistant_message.is_clarification = is_clarification
|
||||
|
||||
# Use pre-answer processing time (captured when MESSAGE_START was emitted)
|
||||
@@ -212,8 +218,10 @@ def save_chat_turn(
|
||||
# Calculate token count using default tokenizer, when storing, this should not use the LLM
|
||||
# specific one so we use a system default tokenizer here.
|
||||
default_tokenizer = get_tokenizer(None, None)
|
||||
if message_text:
|
||||
assistant_message.token_count = len(default_tokenizer.encode(message_text))
|
||||
if sanitized_message_text:
|
||||
assistant_message.token_count = len(
|
||||
default_tokenizer.encode(sanitized_message_text)
|
||||
)
|
||||
else:
|
||||
assistant_message.token_count = 0
|
||||
|
||||
@@ -328,8 +336,10 @@ def save_chat_turn(
|
||||
# 8. Attach code interpreter generated files that the assistant actually
|
||||
# referenced in its response, so they are available via load_all_chat_files
|
||||
# on subsequent turns. Files not mentioned are intermediate artifacts.
|
||||
if message_text:
|
||||
referenced = _extract_referenced_file_descriptors(tool_calls, message_text)
|
||||
if sanitized_message_text:
|
||||
referenced = _extract_referenced_file_descriptors(
|
||||
tool_calls, sanitized_message_text
|
||||
)
|
||||
if referenced:
|
||||
existing_files = assistant_message.files or []
|
||||
assistant_message.files = existing_files + referenced
|
||||
|
||||
@@ -288,8 +288,9 @@ OPENSEARCH_TEXT_ANALYZER = os.environ.get("OPENSEARCH_TEXT_ANALYZER") or "englis
|
||||
# environments we always want to be dual indexing into both OpenSearch and Vespa
|
||||
# to stress test the new codepaths. Only enable this if there is some instance
|
||||
# of OpenSearch running for the relevant Onyx instance.
|
||||
# NOTE: Now enabled on by default, unless the env indicates otherwise.
|
||||
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX = (
|
||||
os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "").lower() == "true"
|
||||
os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "true").lower() == "true"
|
||||
)
|
||||
# NOTE: This effectively does nothing anymore, admins can now toggle whether
|
||||
# retrieval is through OpenSearch. This value is only used as a final fallback
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncIterable
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime
|
||||
@@ -204,7 +205,7 @@ def _manage_async_retrieval(
|
||||
|
||||
end_time: datetime | None = end
|
||||
|
||||
async def _async_fetch() -> AsyncIterable[Document]:
|
||||
async def _async_fetch() -> AsyncGenerator[Document, None]:
|
||||
intents = Intents.default()
|
||||
intents.message_content = True
|
||||
async with Client(intents=intents) as discord_client:
|
||||
@@ -227,22 +228,23 @@ def _manage_async_retrieval(
|
||||
|
||||
def run_and_yield() -> Iterable[Document]:
|
||||
loop = asyncio.new_event_loop()
|
||||
async_gen = _async_fetch()
|
||||
try:
|
||||
# Get the async generator
|
||||
async_gen = _async_fetch()
|
||||
# Convert to AsyncIterator
|
||||
async_iter = async_gen.__aiter__()
|
||||
while True:
|
||||
try:
|
||||
# Create a coroutine by calling anext with the async iterator
|
||||
next_coro = anext(async_iter)
|
||||
# Run the coroutine to get the next document
|
||||
doc = loop.run_until_complete(next_coro)
|
||||
doc = loop.run_until_complete(anext(async_gen))
|
||||
yield doc
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
finally:
|
||||
loop.close()
|
||||
# Must close the async generator before the loop so the Discord
|
||||
# client's `async with` block can await its shutdown coroutine.
|
||||
# The nested try/finally ensures the loop always closes even if
|
||||
# aclose() raises (same pattern as cursor.close() before conn.close()).
|
||||
try:
|
||||
loop.run_until_complete(async_gen.aclose())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
return run_and_yield()
|
||||
|
||||
|
||||
@@ -1722,6 +1722,7 @@ class GoogleDriveConnector(
|
||||
primary_admin_email=self.primary_admin_email,
|
||||
google_domain=self.google_domain,
|
||||
),
|
||||
retriever_email=file.user_email,
|
||||
):
|
||||
slim_batch.append(doc)
|
||||
|
||||
|
||||
@@ -476,6 +476,7 @@ def _get_external_access_for_raw_gdrive_file(
|
||||
company_domain: str,
|
||||
retriever_drive_service: GoogleDriveService | None,
|
||||
admin_drive_service: GoogleDriveService,
|
||||
fallback_user_email: str,
|
||||
add_prefix: bool = False,
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
@@ -484,6 +485,8 @@ def _get_external_access_for_raw_gdrive_file(
|
||||
add_prefix: When True, prefix group IDs with source type (for indexing path).
|
||||
When False (default), leave unprefixed (for permission sync path
|
||||
where upsert_document_external_perms handles prefixing).
|
||||
fallback_user_email: When permission info can't be retrieved (e.g. externally-owned
|
||||
files), fall back to granting access to this user.
|
||||
"""
|
||||
external_access_fn = cast(
|
||||
Callable[
|
||||
@@ -492,6 +495,7 @@ def _get_external_access_for_raw_gdrive_file(
|
||||
str,
|
||||
GoogleDriveService | None,
|
||||
GoogleDriveService,
|
||||
str,
|
||||
bool,
|
||||
],
|
||||
ExternalAccess,
|
||||
@@ -507,6 +511,7 @@ def _get_external_access_for_raw_gdrive_file(
|
||||
company_domain,
|
||||
retriever_drive_service,
|
||||
admin_drive_service,
|
||||
fallback_user_email,
|
||||
add_prefix,
|
||||
)
|
||||
|
||||
@@ -672,6 +677,7 @@ def _convert_drive_item_to_document(
|
||||
creds, user_email=permission_sync_context.primary_admin_email
|
||||
),
|
||||
add_prefix=True, # Indexing path - prefix here
|
||||
fallback_user_email=retriever_email,
|
||||
)
|
||||
if permission_sync_context
|
||||
else None
|
||||
@@ -753,6 +759,7 @@ def build_slim_document(
|
||||
# if not specified, we will not sync permissions
|
||||
# will also be a no-op if EE is not enabled
|
||||
permission_sync_context: PermissionSyncContext | None,
|
||||
retriever_email: str,
|
||||
) -> SlimDocument | None:
|
||||
if file.get("mimeType") in [DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE]:
|
||||
return None
|
||||
@@ -774,6 +781,7 @@ def build_slim_document(
|
||||
creds,
|
||||
user_email=permission_sync_context.primary_admin_email,
|
||||
),
|
||||
fallback_user_email=retriever_email,
|
||||
)
|
||||
if permission_sync_context
|
||||
else None
|
||||
|
||||
@@ -38,6 +38,7 @@ from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.override_models import PromptOverride
|
||||
from onyx.server.query_and_chat.models import ChatMessageDetail
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_string
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -675,58 +676,43 @@ def set_as_latest_chat_message(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _sanitize_for_postgres(value: str) -> str:
|
||||
"""Remove NUL (0x00) characters from strings as PostgreSQL doesn't allow them."""
|
||||
sanitized = value.replace("\x00", "")
|
||||
if value and not sanitized:
|
||||
logger.warning("Sanitization removed all characters from string")
|
||||
return sanitized
|
||||
|
||||
|
||||
def _sanitize_list_for_postgres(values: list[str]) -> list[str]:
|
||||
"""Remove NUL (0x00) characters from all strings in a list."""
|
||||
return [_sanitize_for_postgres(v) for v in values]
|
||||
|
||||
|
||||
def create_db_search_doc(
|
||||
server_search_doc: ServerSearchDoc,
|
||||
db_session: Session,
|
||||
commit: bool = True,
|
||||
) -> DBSearchDoc:
|
||||
# Sanitize string fields to remove NUL characters (PostgreSQL doesn't allow them)
|
||||
db_search_doc = DBSearchDoc(
|
||||
document_id=_sanitize_for_postgres(server_search_doc.document_id),
|
||||
document_id=sanitize_string(server_search_doc.document_id),
|
||||
chunk_ind=server_search_doc.chunk_ind,
|
||||
semantic_id=_sanitize_for_postgres(server_search_doc.semantic_identifier),
|
||||
semantic_id=sanitize_string(server_search_doc.semantic_identifier),
|
||||
link=(
|
||||
_sanitize_for_postgres(server_search_doc.link)
|
||||
sanitize_string(server_search_doc.link)
|
||||
if server_search_doc.link is not None
|
||||
else None
|
||||
),
|
||||
blurb=_sanitize_for_postgres(server_search_doc.blurb),
|
||||
blurb=sanitize_string(server_search_doc.blurb),
|
||||
source_type=server_search_doc.source_type,
|
||||
boost=server_search_doc.boost,
|
||||
hidden=server_search_doc.hidden,
|
||||
doc_metadata=server_search_doc.metadata,
|
||||
is_relevant=server_search_doc.is_relevant,
|
||||
relevance_explanation=(
|
||||
_sanitize_for_postgres(server_search_doc.relevance_explanation)
|
||||
sanitize_string(server_search_doc.relevance_explanation)
|
||||
if server_search_doc.relevance_explanation is not None
|
||||
else None
|
||||
),
|
||||
# For docs further down that aren't reranked, we can't use the retrieval score
|
||||
score=server_search_doc.score or 0.0,
|
||||
match_highlights=_sanitize_list_for_postgres(
|
||||
server_search_doc.match_highlights
|
||||
),
|
||||
match_highlights=[
|
||||
sanitize_string(h) for h in server_search_doc.match_highlights
|
||||
],
|
||||
updated_at=server_search_doc.updated_at,
|
||||
primary_owners=(
|
||||
_sanitize_list_for_postgres(server_search_doc.primary_owners)
|
||||
[sanitize_string(o) for o in server_search_doc.primary_owners]
|
||||
if server_search_doc.primary_owners is not None
|
||||
else None
|
||||
),
|
||||
secondary_owners=(
|
||||
_sanitize_list_for_postgres(server_search_doc.secondary_owners)
|
||||
[sanitize_string(o) for o in server_search_doc.secondary_owners]
|
||||
if server_search_doc.secondary_owners is not None
|
||||
else None
|
||||
),
|
||||
|
||||
@@ -25,8 +25,12 @@ from onyx.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import SyncModelEntry
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def update_group_llm_provider_relationships__no_commit(
|
||||
llm_provider_id: int,
|
||||
@@ -267,10 +271,35 @@ def upsert_llm_provider(
|
||||
mc.name for mc in llm_provider_upsert_request.model_configurations
|
||||
}
|
||||
|
||||
# Build a lookup of requested visibility by model name
|
||||
requested_visibility = {
|
||||
mc.name: mc.is_visible
|
||||
for mc in llm_provider_upsert_request.model_configurations
|
||||
}
|
||||
|
||||
# Delete removed models
|
||||
removed_ids = [
|
||||
mc.id for name, mc in existing_by_name.items() if name not in models_to_exist
|
||||
]
|
||||
|
||||
default_model = fetch_default_llm_model(db_session)
|
||||
|
||||
# Prevent removing and hiding the default model
|
||||
if default_model:
|
||||
for name, mc in existing_by_name.items():
|
||||
if mc.id == default_model.id:
|
||||
if default_model.id in removed_ids:
|
||||
raise ValueError(
|
||||
f"Cannot remove the default model '{name}'. "
|
||||
"Please change the default model before removing."
|
||||
)
|
||||
if not requested_visibility.get(name, True):
|
||||
raise ValueError(
|
||||
f"Cannot hide the default model '{name}'. "
|
||||
"Please change the default model before hiding."
|
||||
)
|
||||
break
|
||||
|
||||
if removed_ids:
|
||||
db_session.query(ModelConfiguration).filter(
|
||||
ModelConfiguration.id.in_(removed_ids)
|
||||
@@ -341,9 +370,9 @@ def upsert_llm_provider(
|
||||
def sync_model_configurations(
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
models: list[dict],
|
||||
models: list[SyncModelEntry],
|
||||
) -> int:
|
||||
"""Sync model configurations for a dynamic provider (OpenRouter, Bedrock, Ollama).
|
||||
"""Sync model configurations for a dynamic provider (OpenRouter, Bedrock, Ollama, etc.).
|
||||
|
||||
This inserts NEW models from the source API without overwriting existing ones.
|
||||
User preferences (is_visible, max_input_tokens) are preserved for existing models.
|
||||
@@ -351,7 +380,7 @@ def sync_model_configurations(
|
||||
Args:
|
||||
db_session: Database session
|
||||
provider_name: Name of the LLM provider
|
||||
models: List of model dicts with keys: name, display_name, max_input_tokens, supports_image_input
|
||||
models: List of SyncModelEntry objects describing the fetched models
|
||||
|
||||
Returns:
|
||||
Number of new models added
|
||||
@@ -365,21 +394,20 @@ def sync_model_configurations(
|
||||
|
||||
new_count = 0
|
||||
for model in models:
|
||||
model_name = model["name"]
|
||||
if model_name not in existing_names:
|
||||
if model.name not in existing_names:
|
||||
# Insert new model with is_visible=False (user must explicitly enable)
|
||||
supported_flows = [LLMModelFlowType.CHAT]
|
||||
if model.get("supports_image_input", False):
|
||||
if model.supports_image_input:
|
||||
supported_flows.append(LLMModelFlowType.VISION)
|
||||
|
||||
insert_new_model_configuration__no_commit(
|
||||
db_session=db_session,
|
||||
llm_provider_id=provider.id,
|
||||
model_name=model_name,
|
||||
model_name=model.name,
|
||||
supported_flows=supported_flows,
|
||||
is_visible=False,
|
||||
max_input_tokens=model.get("max_input_tokens"),
|
||||
display_name=model.get("display_name"),
|
||||
max_input_tokens=model.max_input_tokens,
|
||||
display_name=model.display_name,
|
||||
)
|
||||
new_count += 1
|
||||
|
||||
@@ -535,7 +563,6 @@ def fetch_default_model(
|
||||
.options(selectinload(ModelConfiguration.llm_provider))
|
||||
.join(LLMModelFlow)
|
||||
.where(
|
||||
ModelConfiguration.is_visible == True, # noqa: E712
|
||||
LLMModelFlow.llm_model_flow_type == flow_type,
|
||||
LLMModelFlow.is_default == True, # noqa: E712
|
||||
)
|
||||
@@ -811,6 +838,29 @@ def sync_auto_mode_models(
|
||||
)
|
||||
changes += 1
|
||||
|
||||
# Update the default if this provider currently holds the global CHAT default.
|
||||
# We flush (but don't commit) so that _update_default_model can see the new
|
||||
# model rows, then commit everything atomically to avoid a window where the
|
||||
# old default is invisible but still pointed-to.
|
||||
db_session.flush()
|
||||
|
||||
recommended_default = llm_recommendations.get_default_model(provider.provider)
|
||||
if recommended_default:
|
||||
current_default = fetch_default_llm_model(db_session)
|
||||
|
||||
if (
|
||||
current_default
|
||||
and current_default.llm_provider_id == provider.id
|
||||
and current_default.name != recommended_default.name
|
||||
):
|
||||
_update_default_model__no_commit(
|
||||
db_session=db_session,
|
||||
provider_id=provider.id,
|
||||
model=recommended_default.name,
|
||||
flow_type=LLMModelFlowType.CHAT,
|
||||
)
|
||||
changes += 1
|
||||
|
||||
db_session.commit()
|
||||
return changes
|
||||
|
||||
@@ -942,7 +992,7 @@ def update_model_configuration__no_commit(
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def _update_default_model(
|
||||
def _update_default_model__no_commit(
|
||||
db_session: Session,
|
||||
provider_id: int,
|
||||
model: str,
|
||||
@@ -980,6 +1030,14 @@ def _update_default_model(
|
||||
new_default.is_default = True
|
||||
model_config.is_visible = True
|
||||
|
||||
|
||||
def _update_default_model(
|
||||
db_session: Session,
|
||||
provider_id: int,
|
||||
model: str,
|
||||
flow_type: LLMModelFlowType,
|
||||
) -> None:
|
||||
_update_default_model__no_commit(db_session, provider_id, model, flow_type)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
|
||||
@@ -13,12 +13,15 @@ from onyx.db.constants import UNSET
|
||||
from onyx.db.constants import UnsetType
|
||||
from onyx.db.enums import MCPServerStatus
|
||||
from onyx.db.models import MCPServer
|
||||
from onyx.db.models import OAuthConfig
|
||||
from onyx.db.models import Tool
|
||||
from onyx.db.models import ToolCall
|
||||
from onyx.server.features.tool.models import Header
|
||||
from onyx.tools.built_in_tools import BUILT_IN_TOOL_TYPES
|
||||
from onyx.utils.headers import HeaderItemDict
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_json_like
|
||||
from onyx.utils.postgres_sanitization import sanitize_string
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
@@ -159,10 +162,26 @@ def update_tool(
|
||||
]
|
||||
if passthrough_auth is not None:
|
||||
tool.passthrough_auth = passthrough_auth
|
||||
old_oauth_config_id = tool.oauth_config_id
|
||||
if not isinstance(oauth_config_id, UnsetType):
|
||||
tool.oauth_config_id = oauth_config_id
|
||||
db_session.commit()
|
||||
db_session.flush()
|
||||
|
||||
# Clean up orphaned OAuthConfig if the oauth_config_id was changed
|
||||
if (
|
||||
old_oauth_config_id is not None
|
||||
and not isinstance(oauth_config_id, UnsetType)
|
||||
and old_oauth_config_id != oauth_config_id
|
||||
):
|
||||
other_tools = db_session.scalars(
|
||||
select(Tool).where(Tool.oauth_config_id == old_oauth_config_id)
|
||||
).all()
|
||||
if not other_tools:
|
||||
oauth_config = db_session.get(OAuthConfig, old_oauth_config_id)
|
||||
if oauth_config:
|
||||
db_session.delete(oauth_config)
|
||||
|
||||
db_session.commit()
|
||||
return tool
|
||||
|
||||
|
||||
@@ -171,8 +190,21 @@ def delete_tool__no_commit(tool_id: int, db_session: Session) -> None:
|
||||
if tool is None:
|
||||
raise ValueError(f"Tool with ID {tool_id} does not exist")
|
||||
|
||||
oauth_config_id = tool.oauth_config_id
|
||||
|
||||
db_session.delete(tool)
|
||||
db_session.flush() # Don't commit yet, let caller decide when to commit
|
||||
db_session.flush()
|
||||
|
||||
# Clean up orphaned OAuthConfig if no other tools reference it
|
||||
if oauth_config_id is not None:
|
||||
other_tools = db_session.scalars(
|
||||
select(Tool).where(Tool.oauth_config_id == oauth_config_id)
|
||||
).all()
|
||||
if not other_tools:
|
||||
oauth_config = db_session.get(OAuthConfig, oauth_config_id)
|
||||
if oauth_config:
|
||||
db_session.delete(oauth_config)
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def get_builtin_tool(
|
||||
@@ -256,11 +288,13 @@ def create_tool_call_no_commit(
|
||||
tab_index=tab_index,
|
||||
tool_id=tool_id,
|
||||
tool_call_id=tool_call_id,
|
||||
reasoning_tokens=reasoning_tokens,
|
||||
tool_call_arguments=tool_call_arguments,
|
||||
tool_call_response=tool_call_response,
|
||||
reasoning_tokens=(
|
||||
sanitize_string(reasoning_tokens) if reasoning_tokens else reasoning_tokens
|
||||
),
|
||||
tool_call_arguments=sanitize_json_like(tool_call_arguments),
|
||||
tool_call_response=sanitize_json_like(tool_call_response),
|
||||
tool_call_tokens=tool_call_tokens,
|
||||
generated_images=generated_images,
|
||||
generated_images=sanitize_json_like(generated_images),
|
||||
)
|
||||
|
||||
db_session.add(tool_call)
|
||||
|
||||
103
backend/onyx/document_index/FILTER_SEMANTICS.md
Normal file
103
backend/onyx/document_index/FILTER_SEMANTICS.md
Normal file
@@ -0,0 +1,103 @@
|
||||
# Vector DB Filter Semantics
|
||||
|
||||
How `IndexFilters` fields combine into the final query filter. Applies to both Vespa and OpenSearch.
|
||||
|
||||
## Filter categories
|
||||
|
||||
| Category | Fields | Join logic |
|
||||
|---|---|---|
|
||||
| **Visibility** | `hidden` | Always applied (unless `include_hidden`) |
|
||||
| **Tenant** | `tenant_id` | AND (multi-tenant only) |
|
||||
| **ACL** | `access_control_list` | OR within, AND with rest |
|
||||
| **Narrowing** | `source_type`, `tags`, `time_cutoff` | Each OR within, AND with rest |
|
||||
| **Knowledge scope** | `document_set`, `user_file_ids`, `attached_document_ids`, `hierarchy_node_ids` | OR within group, AND with rest |
|
||||
| **Additive scope** | `project_id`, `persona_id` | OR'd into knowledge scope **only when** a knowledge scope filter already exists |
|
||||
|
||||
## How filters combine
|
||||
|
||||
All categories are AND'd together. Within the knowledge scope category, individual filters are OR'd.
|
||||
|
||||
```
|
||||
NOT hidden
|
||||
AND tenant = T -- if multi-tenant
|
||||
AND (acl contains A1 OR acl contains A2)
|
||||
AND (source_type = S1 OR ...) -- if set
|
||||
AND (tag = T1 OR ...) -- if set
|
||||
AND <knowledge scope> -- see below
|
||||
AND time >= cutoff -- if set
|
||||
```
|
||||
|
||||
## Knowledge scope rules
|
||||
|
||||
The knowledge scope filter controls **what knowledge an assistant can access**.
|
||||
|
||||
### No explicit knowledge attached
|
||||
|
||||
When `document_set`, `user_file_ids`, `attached_document_ids`, and `hierarchy_node_ids` are all empty/None:
|
||||
|
||||
- **No knowledge scope filter is applied.** The assistant can see everything (subject to ACL).
|
||||
- `project_id` and `persona_id` are ignored — they never restrict on their own.
|
||||
|
||||
### One explicit knowledge type
|
||||
|
||||
```
|
||||
-- Only document sets
|
||||
AND (document_sets contains "Engineering" OR document_sets contains "Legal")
|
||||
|
||||
-- Only user files
|
||||
AND (document_id = "uuid-1" OR document_id = "uuid-2")
|
||||
```
|
||||
|
||||
### Multiple explicit knowledge types (OR'd)
|
||||
|
||||
```
|
||||
-- Document sets + user files
|
||||
AND (
|
||||
document_sets contains "Engineering"
|
||||
OR document_id = "uuid-1"
|
||||
)
|
||||
```
|
||||
|
||||
### Explicit knowledge + overflowing user files
|
||||
|
||||
When an explicit knowledge restriction is in effect **and** `project_id` or `persona_id` is set (user files overflowed the LLM context window), the additive scopes widen the filter:
|
||||
|
||||
```
|
||||
-- Document sets + persona user files overflowed
|
||||
AND (
|
||||
document_sets contains "Engineering"
|
||||
OR personas contains 42
|
||||
)
|
||||
|
||||
-- User files + project files overflowed
|
||||
AND (
|
||||
document_id = "uuid-1"
|
||||
OR user_project contains 7
|
||||
)
|
||||
```
|
||||
|
||||
### Only project_id or persona_id (no explicit knowledge)
|
||||
|
||||
No knowledge scope filter. The assistant searches everything.
|
||||
|
||||
```
|
||||
-- Just ACL, no restriction
|
||||
NOT hidden
|
||||
AND (acl contains ...)
|
||||
```
|
||||
|
||||
## Field reference
|
||||
|
||||
| Filter field | Vespa field | Vespa type | Purpose |
|
||||
|---|---|---|---|
|
||||
| `document_set` | `document_sets` | `weightedset<string>` | Connector doc sets attached to assistant |
|
||||
| `user_file_ids` | `document_id` | `string` | User files uploaded to assistant |
|
||||
| `attached_document_ids` | `document_id` | `string` | Documents explicitly attached (OpenSearch only) |
|
||||
| `hierarchy_node_ids` | `ancestor_hierarchy_node_ids` | `array<int>` | Folder/space nodes (OpenSearch only) |
|
||||
| `project_id` | `user_project` | `array<int>` | Project tag for overflowing user files |
|
||||
| `persona_id` | `personas` | `array<int>` | Persona tag for overflowing user files |
|
||||
| `access_control_list` | `access_control_list` | `weightedset<string>` | ACL entries for the requesting user |
|
||||
| `source_type` | `source_type` | `string` | Connector source type (e.g. `web`, `jira`) |
|
||||
| `tags` | `metadata_list` | `array<string>` | Document metadata tags |
|
||||
| `time_cutoff` | `doc_updated_at` | `long` | Minimum document update timestamp |
|
||||
| `tenant_id` | `tenant_id` | `string` | Tenant isolation (multi-tenant) |
|
||||
@@ -61,6 +61,25 @@ class SearchHit(BaseModel, Generic[SchemaDocumentModel]):
|
||||
explanation: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class IndexInfo(BaseModel):
|
||||
"""
|
||||
Represents information about an OpenSearch index.
|
||||
"""
|
||||
|
||||
model_config = {"frozen": True}
|
||||
|
||||
name: str
|
||||
health: str
|
||||
status: str
|
||||
num_primary_shards: str
|
||||
num_replica_shards: str
|
||||
docs_count: str
|
||||
docs_deleted: str
|
||||
created_at: str
|
||||
total_size: str
|
||||
primary_shards_size: str
|
||||
|
||||
|
||||
def get_new_body_without_vectors(body: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Recursively replaces vectors in the body with their length.
|
||||
|
||||
@@ -159,8 +178,8 @@ class OpenSearchClient(AbstractContextManager):
|
||||
Raises:
|
||||
Exception: There was an error creating the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
|
||||
if not result.get("acknowledged", False):
|
||||
response = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
|
||||
if not response.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to create search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
@@ -173,8 +192,8 @@ class OpenSearchClient(AbstractContextManager):
|
||||
Raises:
|
||||
Exception: There was an error deleting the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.delete(id=pipeline_id)
|
||||
if not result.get("acknowledged", False):
|
||||
response = self._client.search_pipeline.delete(id=pipeline_id)
|
||||
if not response.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to delete search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
@@ -198,6 +217,34 @@ class OpenSearchClient(AbstractContextManager):
|
||||
logger.error(f"Failed to put cluster settings: {response}.")
|
||||
return False
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def list_indices_with_info(self) -> list[IndexInfo]:
|
||||
"""
|
||||
Lists the indices in the OpenSearch cluster with information about each
|
||||
index.
|
||||
|
||||
Returns:
|
||||
A list of IndexInfo objects for each index.
|
||||
"""
|
||||
response = self._client.cat.indices(format="json")
|
||||
indices: list[IndexInfo] = []
|
||||
for raw_index_info in response:
|
||||
indices.append(
|
||||
IndexInfo(
|
||||
name=raw_index_info.get("index", ""),
|
||||
health=raw_index_info.get("health", ""),
|
||||
status=raw_index_info.get("status", ""),
|
||||
num_primary_shards=raw_index_info.get("pri", ""),
|
||||
num_replica_shards=raw_index_info.get("rep", ""),
|
||||
docs_count=raw_index_info.get("docs.count", ""),
|
||||
docs_deleted=raw_index_info.get("docs.deleted", ""),
|
||||
created_at=raw_index_info.get("creation.date.string", ""),
|
||||
total_size=raw_index_info.get("store.size", ""),
|
||||
primary_shards_size=raw_index_info.get("pri.store.size", ""),
|
||||
)
|
||||
)
|
||||
return indices
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def ping(self) -> bool:
|
||||
"""Pings the OpenSearch cluster.
|
||||
|
||||
@@ -739,7 +739,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
The number of chunks successfully deleted.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Deleting document {document_id} from index {self._index_name}."
|
||||
f"[OpenSearchDocumentIndex] Deleting document {document_id} from index "
|
||||
f"{self._index_name}."
|
||||
)
|
||||
query_body = DocumentQuery.delete_from_document_id_query(
|
||||
document_id=document_id,
|
||||
@@ -775,7 +776,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
specified documents.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Updating {len(update_requests)} chunks for index {self._index_name}."
|
||||
f"[OpenSearchDocumentIndex] Updating {len(update_requests)} chunks for index "
|
||||
f"{self._index_name}."
|
||||
)
|
||||
for update_request in update_requests:
|
||||
properties_to_update: dict[str, Any] = dict()
|
||||
@@ -831,9 +833,11 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
# here.
|
||||
# TODO(andrei): Fix the aforementioned race condition.
|
||||
raise ChunkCountNotFoundError(
|
||||
f"Tried to update document {doc_id} but its chunk count is not known. Older versions of the "
|
||||
"application used to permit this but is not a supported state for a document when using OpenSearch. "
|
||||
"The document was likely just added to the indexing pipeline and the chunk count will be updated shortly."
|
||||
f"Tried to update document {doc_id} but its chunk count is not known. "
|
||||
"Older versions of the application used to permit this but is not a "
|
||||
"supported state for a document when using OpenSearch. The document was "
|
||||
"likely just added to the indexing pipeline and the chunk count will be "
|
||||
"updated shortly."
|
||||
)
|
||||
if doc_chunk_count == 0:
|
||||
raise ValueError(
|
||||
@@ -865,7 +869,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
chunk IDs vs querying for matching document chunks.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Retrieving {len(chunk_requests)} chunks for index {self._index_name}."
|
||||
f"[OpenSearchDocumentIndex] Retrieving {len(chunk_requests)} chunks for index "
|
||||
f"{self._index_name}."
|
||||
)
|
||||
results: list[InferenceChunk] = []
|
||||
for chunk_request in chunk_requests:
|
||||
@@ -912,7 +917,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Hybrid retrieving {num_to_retrieve} chunks for index {self._index_name}."
|
||||
f"[OpenSearchDocumentIndex] Hybrid retrieving {num_to_retrieve} chunks for index "
|
||||
f"{self._index_name}."
|
||||
)
|
||||
# TODO(andrei): This could be better, the caller should just make this
|
||||
# decision when passing in the query param. See the above comment in the
|
||||
@@ -932,8 +938,10 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
index_filters=filters,
|
||||
include_hidden=False,
|
||||
)
|
||||
# NOTE: Using z-score normalization here because it's better for hybrid search from a theoretical standpoint.
|
||||
# Empirically on a small dataset of up to 10K docs, it's not very different. Likely more impactful at scale.
|
||||
# NOTE: Using z-score normalization here because it's better for hybrid
|
||||
# search from a theoretical standpoint. Empirically on a small dataset
|
||||
# of up to 10K docs, it's not very different. Likely more impactful at
|
||||
# scale.
|
||||
# https://opensearch.org/blog/introducing-the-z-score-normalization-technique-for-hybrid-search/
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
|
||||
body=query_body,
|
||||
@@ -960,7 +968,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
dirty: bool | None = None, # noqa: ARG002
|
||||
) -> list[InferenceChunk]:
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Randomly retrieving {num_to_retrieve} chunks for index {self._index_name}."
|
||||
f"[OpenSearchDocumentIndex] Randomly retrieving {num_to_retrieve} chunks for index "
|
||||
f"{self._index_name}."
|
||||
)
|
||||
query_body = DocumentQuery.get_random_search_query(
|
||||
tenant_state=self._tenant_state,
|
||||
@@ -990,7 +999,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
complete.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Indexing {len(chunks)} raw chunks for index {self._index_name}."
|
||||
f"[OpenSearchDocumentIndex] Indexing {len(chunks)} raw chunks for index "
|
||||
f"{self._index_name}."
|
||||
)
|
||||
# Do not raise if the document already exists, just update. This is
|
||||
# because the document may already have been indexed during the
|
||||
|
||||
@@ -243,7 +243,8 @@ class DocumentChunk(BaseModel):
|
||||
return value
|
||||
if not isinstance(value, int):
|
||||
raise ValueError(
|
||||
f"Bug: Expected an int for the last_updated property from OpenSearch, got {type(value)} instead."
|
||||
f"Bug: Expected an int for the last_updated property from OpenSearch, got "
|
||||
f"{type(value)} instead."
|
||||
)
|
||||
return datetime.fromtimestamp(value, tz=timezone.utc)
|
||||
|
||||
@@ -284,19 +285,22 @@ class DocumentChunk(BaseModel):
|
||||
elif isinstance(value, TenantState):
|
||||
if MULTI_TENANT != value.multitenant:
|
||||
raise ValueError(
|
||||
f"Bug: An existing TenantState object was supplied to the DocumentChunk model but its multi-tenant mode "
|
||||
f"({value.multitenant}) does not match the program's current global tenancy state."
|
||||
f"Bug: An existing TenantState object was supplied to the DocumentChunk model "
|
||||
f"but its multi-tenant mode ({value.multitenant}) does not match the program's "
|
||||
"current global tenancy state."
|
||||
)
|
||||
return value
|
||||
elif not isinstance(value, str):
|
||||
raise ValueError(
|
||||
f"Bug: Expected a str for the tenant_id property from OpenSearch, got {type(value)} instead."
|
||||
f"Bug: Expected a str for the tenant_id property from OpenSearch, got "
|
||||
f"{type(value)} instead."
|
||||
)
|
||||
else:
|
||||
if not MULTI_TENANT:
|
||||
raise ValueError(
|
||||
"Bug: Got a non-null str for the tenant_id property from OpenSearch but multi-tenant mode is not enabled. "
|
||||
"This is unexpected because in single-tenant mode we don't expect to see a tenant_id."
|
||||
"Bug: Got a non-null str for the tenant_id property from OpenSearch but "
|
||||
"multi-tenant mode is not enabled. This is unexpected because in single-tenant "
|
||||
"mode we don't expect to see a tenant_id."
|
||||
)
|
||||
return TenantState(tenant_id=value, multitenant=MULTI_TENANT)
|
||||
|
||||
@@ -352,8 +356,10 @@ class DocumentSchema:
|
||||
"properties": {
|
||||
TITLE_FIELD_NAME: {
|
||||
"type": "text",
|
||||
# Language analyzer (e.g. english) stems at index and search time for variant matching.
|
||||
# Configure via OPENSEARCH_TEXT_ANALYZER. Existing indices need reindexing after a change.
|
||||
# Language analyzer (e.g. english) stems at index and search
|
||||
# time for variant matching. Configure via
|
||||
# OPENSEARCH_TEXT_ANALYZER. Existing indices need reindexing
|
||||
# after a change.
|
||||
"analyzer": OPENSEARCH_TEXT_ANALYZER,
|
||||
"fields": {
|
||||
# Subfield accessed as title.keyword. Not indexed for
|
||||
|
||||
@@ -698,41 +698,6 @@ class DocumentQuery:
|
||||
"""
|
||||
return {"terms": {ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME: node_ids}}
|
||||
|
||||
def _get_assistant_knowledge_filter(
|
||||
attached_doc_ids: list[str] | None,
|
||||
node_ids: list[int] | None,
|
||||
file_ids: list[UUID] | None,
|
||||
document_sets: list[str] | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Combined filter for assistant knowledge.
|
||||
|
||||
When an assistant has attached knowledge, search should be scoped to:
|
||||
- Documents explicitly attached (by document ID), OR
|
||||
- Documents under attached hierarchy nodes (by ancestor node IDs), OR
|
||||
- User-uploaded files attached to the assistant, OR
|
||||
- Documents in the assistant's document sets (if any)
|
||||
"""
|
||||
knowledge_filter: dict[str, Any] = {
|
||||
"bool": {"should": [], "minimum_should_match": 1}
|
||||
}
|
||||
if attached_doc_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_attached_document_id_filter(attached_doc_ids)
|
||||
)
|
||||
if node_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_hierarchy_node_filter(node_ids)
|
||||
)
|
||||
if file_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_file_id_filter(file_ids)
|
||||
)
|
||||
if document_sets:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_document_set_filter(document_sets)
|
||||
)
|
||||
return knowledge_filter
|
||||
|
||||
filter_clauses: list[dict[str, Any]] = []
|
||||
|
||||
if not include_hidden:
|
||||
@@ -758,41 +723,53 @@ class DocumentQuery:
|
||||
# document's metadata list.
|
||||
filter_clauses.append(_get_tag_filter(tags))
|
||||
|
||||
# Check if this is an assistant knowledge search (has any assistant-scoped knowledge)
|
||||
has_assistant_knowledge = (
|
||||
# Knowledge scope: explicit knowledge attachments restrict what
|
||||
# an assistant can see. When none are set the assistant
|
||||
# searches everything.
|
||||
#
|
||||
# project_id / persona_id are additive: they make overflowing
|
||||
# user files findable but must NOT trigger the restriction on
|
||||
# their own (an agent with no explicit knowledge should search
|
||||
# everything).
|
||||
has_knowledge_scope = (
|
||||
attached_document_ids
|
||||
or hierarchy_node_ids
|
||||
or user_file_ids
|
||||
or document_sets
|
||||
)
|
||||
|
||||
if has_assistant_knowledge:
|
||||
# If assistant has attached knowledge, scope search to that knowledge.
|
||||
# Document sets are included in the OR filter so directly attached
|
||||
# docs are always findable even if not in the document sets.
|
||||
filter_clauses.append(
|
||||
_get_assistant_knowledge_filter(
|
||||
attached_document_ids,
|
||||
hierarchy_node_ids,
|
||||
user_file_ids,
|
||||
document_sets,
|
||||
if has_knowledge_scope:
|
||||
knowledge_filter: dict[str, Any] = {
|
||||
"bool": {"should": [], "minimum_should_match": 1}
|
||||
}
|
||||
if attached_document_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_attached_document_id_filter(attached_document_ids)
|
||||
)
|
||||
)
|
||||
elif user_file_ids:
|
||||
# Fallback for non-assistant user file searches (e.g., project searches)
|
||||
# If at least one user file ID is provided, the caller will only
|
||||
# retrieve documents where the document ID is in this input list of
|
||||
# file IDs.
|
||||
filter_clauses.append(_get_user_file_id_filter(user_file_ids))
|
||||
|
||||
if project_id is not None:
|
||||
# If a project ID is provided, the caller will only retrieve
|
||||
# documents where the project ID provided here is present in the
|
||||
# document's user projects list.
|
||||
filter_clauses.append(_get_user_project_filter(project_id))
|
||||
|
||||
if persona_id is not None:
|
||||
filter_clauses.append(_get_persona_filter(persona_id))
|
||||
if hierarchy_node_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_hierarchy_node_filter(hierarchy_node_ids)
|
||||
)
|
||||
if user_file_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_file_id_filter(user_file_ids)
|
||||
)
|
||||
if document_sets:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_document_set_filter(document_sets)
|
||||
)
|
||||
# Additive: widen scope to also cover overflowing user
|
||||
# files, but only when an explicit restriction is already
|
||||
# in effect.
|
||||
if project_id is not None:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_project_filter(project_id)
|
||||
)
|
||||
if persona_id is not None:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_persona_filter(persona_id)
|
||||
)
|
||||
filter_clauses.append(knowledge_filter)
|
||||
|
||||
if time_cutoff is not None:
|
||||
# If a time cutoff is provided, the caller will only retrieve
|
||||
|
||||
@@ -23,11 +23,8 @@ from shared_configs.configs import MULTI_TENANT
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def build_tenant_id_filter(tenant_id: str, include_trailing_and: bool = False) -> str:
|
||||
filter_str = f'({TENANT_ID} contains "{tenant_id}")'
|
||||
if include_trailing_and:
|
||||
filter_str += " and "
|
||||
return filter_str
|
||||
def build_tenant_id_filter(tenant_id: str) -> str:
|
||||
return f'({TENANT_ID} contains "{tenant_id}")'
|
||||
|
||||
|
||||
def build_vespa_filters(
|
||||
@@ -37,30 +34,22 @@ def build_vespa_filters(
|
||||
remove_trailing_and: bool = False, # Set to True when using as a complete Vespa query
|
||||
) -> str:
|
||||
def _build_or_filters(key: str, vals: list[str] | None) -> str:
|
||||
"""For string-based 'contains' filters, e.g. WSET fields or array<string> fields."""
|
||||
"""For string-based 'contains' filters, e.g. WSET fields or array<string> fields.
|
||||
Returns a bare clause like '(key contains "v1" or key contains "v2")' or ""."""
|
||||
if not key or not vals:
|
||||
return ""
|
||||
eq_elems = [f'{key} contains "{val}"' for val in vals if val]
|
||||
if not eq_elems:
|
||||
return ""
|
||||
or_clause = " or ".join(eq_elems)
|
||||
return f"({or_clause}) and "
|
||||
return f"({' or '.join(eq_elems)})"
|
||||
|
||||
def _build_int_or_filters(key: str, vals: list[int] | None) -> str:
|
||||
"""
|
||||
For an integer field filter.
|
||||
If vals is not None, we want *only* docs whose key matches one of vals.
|
||||
"""
|
||||
# If `vals` is None => skip the filter entirely
|
||||
"""For an integer field filter.
|
||||
Returns a bare clause or ""."""
|
||||
if vals is None or not vals:
|
||||
return ""
|
||||
|
||||
# Otherwise build the OR filter
|
||||
eq_elems = [f"{key} = {val}" for val in vals]
|
||||
or_clause = " or ".join(eq_elems)
|
||||
result = f"({or_clause}) and "
|
||||
|
||||
return result
|
||||
return f"({' or '.join(eq_elems)})"
|
||||
|
||||
def _build_kg_filter(
|
||||
kg_entities: list[str] | None,
|
||||
@@ -73,16 +62,12 @@ def build_vespa_filters(
|
||||
combined_filter_parts = []
|
||||
|
||||
def _build_kge(entity: str) -> str:
|
||||
# TYPE-SUBTYPE::ID -> "TYPE-SUBTYPE::ID"
|
||||
# TYPE-SUBTYPE::* -> ({prefix: true}"TYPE-SUBTYPE")
|
||||
# TYPE::* -> ({prefix: true}"TYPE")
|
||||
GENERAL = "::*"
|
||||
if entity.endswith(GENERAL):
|
||||
return f'({{prefix: true}}"{entity.split(GENERAL, 1)[0]}")'
|
||||
else:
|
||||
return f'"{entity}"'
|
||||
|
||||
# OR the entities (give new design)
|
||||
if kg_entities:
|
||||
filter_parts = []
|
||||
for kg_entity in kg_entities:
|
||||
@@ -104,8 +89,7 @@ def build_vespa_filters(
|
||||
|
||||
# TODO: remove kg terms entirely from prompts and codebase
|
||||
|
||||
# AND the combined filter parts
|
||||
return f"({' and '.join(combined_filter_parts)}) and "
|
||||
return f"({' and '.join(combined_filter_parts)})"
|
||||
|
||||
def _build_kg_source_filters(
|
||||
kg_sources: list[str] | None,
|
||||
@@ -114,16 +98,14 @@ def build_vespa_filters(
|
||||
return ""
|
||||
|
||||
source_phrases = [f'{DOCUMENT_ID} contains "{source}"' for source in kg_sources]
|
||||
|
||||
return f"({' or '.join(source_phrases)}) and "
|
||||
return f"({' or '.join(source_phrases)})"
|
||||
|
||||
def _build_kg_chunk_id_zero_only_filter(
|
||||
kg_chunk_id_zero_only: bool,
|
||||
) -> str:
|
||||
if not kg_chunk_id_zero_only:
|
||||
return ""
|
||||
|
||||
return "(chunk_id = 0 ) and "
|
||||
return "(chunk_id = 0)"
|
||||
|
||||
def _build_time_filter(
|
||||
cutoff: datetime | None,
|
||||
@@ -135,8 +117,8 @@ def build_vespa_filters(
|
||||
cutoff_secs = int(cutoff.timestamp())
|
||||
|
||||
if include_untimed:
|
||||
return f"!({DOC_UPDATED_AT} < {cutoff_secs}) and "
|
||||
return f"({DOC_UPDATED_AT} >= {cutoff_secs}) and "
|
||||
return f"!({DOC_UPDATED_AT} < {cutoff_secs})"
|
||||
return f"({DOC_UPDATED_AT} >= {cutoff_secs})"
|
||||
|
||||
def _build_user_project_filter(
|
||||
project_id: int | None,
|
||||
@@ -147,8 +129,7 @@ def build_vespa_filters(
|
||||
pid = int(project_id)
|
||||
except Exception:
|
||||
return ""
|
||||
# Vespa YQL 'contains' expects a string literal; quote the integer
|
||||
return f'({USER_PROJECT} contains "{pid}") and '
|
||||
return f'({USER_PROJECT} contains "{pid}")'
|
||||
|
||||
def _build_persona_filter(
|
||||
persona_id: int | None,
|
||||
@@ -160,73 +141,94 @@ def build_vespa_filters(
|
||||
except Exception:
|
||||
logger.warning(f"Invalid persona ID: {persona_id}")
|
||||
return ""
|
||||
return f'({PERSONAS} contains "{pid}") and '
|
||||
return f'({PERSONAS} contains "{pid}")'
|
||||
|
||||
# Start building the filter string
|
||||
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
|
||||
def _append(parts: list[str], clause: str) -> None:
|
||||
if clause:
|
||||
parts.append(clause)
|
||||
|
||||
# Collect all top-level filter clauses, then join with " and " at the end.
|
||||
filter_parts: list[str] = []
|
||||
|
||||
if not include_hidden:
|
||||
filter_parts.append(f"!({HIDDEN}=true)")
|
||||
|
||||
# TODO: add error condition if MULTI_TENANT and no tenant_id filter is set
|
||||
# If running in multi-tenant mode
|
||||
if filters.tenant_id and MULTI_TENANT:
|
||||
filter_str += build_tenant_id_filter(
|
||||
filters.tenant_id, include_trailing_and=True
|
||||
)
|
||||
filter_parts.append(build_tenant_id_filter(filters.tenant_id))
|
||||
|
||||
# ACL filters
|
||||
if filters.access_control_list is not None:
|
||||
filter_str += _build_or_filters(
|
||||
ACCESS_CONTROL_LIST, filters.access_control_list
|
||||
_append(
|
||||
filter_parts,
|
||||
_build_or_filters(ACCESS_CONTROL_LIST, filters.access_control_list),
|
||||
)
|
||||
|
||||
# Source type filters
|
||||
source_strs = (
|
||||
[s.value for s in filters.source_type] if filters.source_type else None
|
||||
)
|
||||
filter_str += _build_or_filters(SOURCE_TYPE, source_strs)
|
||||
_append(filter_parts, _build_or_filters(SOURCE_TYPE, source_strs))
|
||||
|
||||
# Tag filters
|
||||
tag_attributes = None
|
||||
if filters.tags:
|
||||
# build e.g. "tag_key|tag_value"
|
||||
tag_attributes = [
|
||||
f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}" for tag in filters.tags
|
||||
]
|
||||
filter_str += _build_or_filters(METADATA_LIST, tag_attributes)
|
||||
_append(filter_parts, _build_or_filters(METADATA_LIST, tag_attributes))
|
||||
|
||||
# Document sets
|
||||
filter_str += _build_or_filters(DOCUMENT_SETS, filters.document_set)
|
||||
# Knowledge scope: explicit knowledge attachments (document_sets,
|
||||
# user_file_ids) restrict what an assistant can see. When none are
|
||||
# set, the assistant can see everything.
|
||||
#
|
||||
# project_id / persona_id are additive: they make overflowing user
|
||||
# files findable in Vespa but must NOT trigger the restriction on
|
||||
# their own (an agent with no explicit knowledge should search
|
||||
# everything).
|
||||
knowledge_scope_parts: list[str] = []
|
||||
|
||||
_append(
|
||||
knowledge_scope_parts, _build_or_filters(DOCUMENT_SETS, filters.document_set)
|
||||
)
|
||||
|
||||
# Convert UUIDs to strings for user_file_ids
|
||||
user_file_ids_str = (
|
||||
[str(uuid) for uuid in filters.user_file_ids] if filters.user_file_ids else None
|
||||
)
|
||||
filter_str += _build_or_filters(DOCUMENT_ID, user_file_ids_str)
|
||||
_append(knowledge_scope_parts, _build_or_filters(DOCUMENT_ID, user_file_ids_str))
|
||||
|
||||
# User project filter (array<int> attribute membership)
|
||||
filter_str += _build_user_project_filter(filters.project_id)
|
||||
# Only include project/persona scopes when an explicit knowledge
|
||||
# restriction is already in effect — they widen the scope to also
|
||||
# cover overflowing user files but never restrict on their own.
|
||||
if knowledge_scope_parts:
|
||||
_append(knowledge_scope_parts, _build_user_project_filter(filters.project_id))
|
||||
_append(knowledge_scope_parts, _build_persona_filter(filters.persona_id))
|
||||
|
||||
# Persona filter (array<int> attribute membership)
|
||||
filter_str += _build_persona_filter(filters.persona_id)
|
||||
if len(knowledge_scope_parts) > 1:
|
||||
filter_parts.append("(" + " or ".join(knowledge_scope_parts) + ")")
|
||||
elif len(knowledge_scope_parts) == 1:
|
||||
filter_parts.append(knowledge_scope_parts[0])
|
||||
|
||||
# Time filter
|
||||
filter_str += _build_time_filter(filters.time_cutoff)
|
||||
_append(filter_parts, _build_time_filter(filters.time_cutoff))
|
||||
|
||||
# # Knowledge Graph Filters
|
||||
# filter_str += _build_kg_filter(
|
||||
# _append(filter_parts, _build_kg_filter(
|
||||
# kg_entities=filters.kg_entities,
|
||||
# kg_relationships=filters.kg_relationships,
|
||||
# kg_terms=filters.kg_terms,
|
||||
# )
|
||||
# ))
|
||||
|
||||
# filter_str += _build_kg_source_filters(filters.kg_sources)
|
||||
# _append(filter_parts, _build_kg_source_filters(filters.kg_sources))
|
||||
|
||||
# filter_str += _build_kg_chunk_id_zero_only_filter(
|
||||
# _append(filter_parts, _build_kg_chunk_id_zero_only_filter(
|
||||
# filters.kg_chunk_id_zero_only or False
|
||||
# )
|
||||
# ))
|
||||
|
||||
# Trim trailing " and "
|
||||
if remove_trailing_and and filter_str.endswith(" and "):
|
||||
filter_str = filter_str[:-5]
|
||||
filter_str = " and ".join(filter_parts)
|
||||
|
||||
if filter_str and not remove_trailing_and:
|
||||
filter_str += " and "
|
||||
|
||||
return filter_str
|
||||
|
||||
|
||||
@@ -91,11 +91,11 @@ class OnyxErrorCode(Enum):
|
||||
"""Build a structured error detail dict.
|
||||
|
||||
Returns a dict like:
|
||||
{"error_code": "UNAUTHENTICATED", "message": "Token expired"}
|
||||
{"error_code": "UNAUTHENTICATED", "detail": "Token expired"}
|
||||
|
||||
If no message is supplied, the error code itself is used as the message.
|
||||
If no message is supplied, the error code itself is used as the detail.
|
||||
"""
|
||||
return {
|
||||
"error_code": self.code,
|
||||
"message": message or self.code,
|
||||
"detail": message or self.code,
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
Raise ``OnyxError`` instead of ``HTTPException`` in business code. A global
|
||||
FastAPI exception handler (registered via ``register_onyx_exception_handlers``)
|
||||
converts it into a JSON response with the standard
|
||||
``{"error_code": "...", "message": "..."}`` shape.
|
||||
``{"error_code": "...", "detail": "..."}`` shape.
|
||||
|
||||
Usage::
|
||||
|
||||
@@ -37,21 +37,22 @@ class OnyxError(Exception):
|
||||
|
||||
Attributes:
|
||||
error_code: The ``OnyxErrorCode`` enum member.
|
||||
message: Human-readable message (defaults to the error code string).
|
||||
detail: Human-readable detail (defaults to the error code string).
|
||||
status_code: HTTP status — either overridden or from the error code.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_code: OnyxErrorCode,
|
||||
message: str | None = None,
|
||||
detail: str | None = None,
|
||||
*,
|
||||
status_code_override: int | None = None,
|
||||
) -> None:
|
||||
resolved_detail = detail or error_code.code
|
||||
super().__init__(resolved_detail)
|
||||
self.error_code = error_code
|
||||
self.message = message or error_code.code
|
||||
self.detail = resolved_detail
|
||||
self._status_code_override = status_code_override
|
||||
super().__init__(self.message)
|
||||
|
||||
@property
|
||||
def status_code(self) -> int:
|
||||
@@ -72,11 +73,11 @@ def register_onyx_exception_handlers(app: FastAPI) -> None:
|
||||
) -> JSONResponse:
|
||||
status_code = exc.status_code
|
||||
if status_code >= 500:
|
||||
logger.error(f"OnyxError {exc.error_code.code}: {exc.message}")
|
||||
logger.error(f"OnyxError {exc.error_code.code}: {exc.detail}")
|
||||
elif status_code >= 400:
|
||||
logger.warning(f"OnyxError {exc.error_code.code}: {exc.message}")
|
||||
logger.warning(f"OnyxError {exc.error_code.code}: {exc.detail}")
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content=exc.error_code.detail(exc.message),
|
||||
content=exc.error_code.detail(exc.detail),
|
||||
)
|
||||
|
||||
@@ -49,7 +49,6 @@ from onyx.indexing.embedder import IndexingEmbedder
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import IndexingBatchAdapter
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
from onyx.indexing.postgres_sanitization import sanitize_documents_for_postgres
|
||||
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
|
||||
from onyx.llm.factory import get_default_llm_with_vision
|
||||
from onyx.llm.factory import get_llm_for_contextual_rag
|
||||
@@ -65,6 +64,7 @@ from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT1
|
||||
from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT2
|
||||
from onyx.prompts.contextual_retrieval import DOCUMENT_SUMMARY_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_documents_for_postgres
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ class LlmProviderNames(str, Enum):
|
||||
OPENROUTER = "openrouter"
|
||||
AZURE = "azure"
|
||||
OLLAMA_CHAT = "ollama_chat"
|
||||
LM_STUDIO = "lm_studio"
|
||||
MISTRAL = "mistral"
|
||||
LITELLM_PROXY = "litellm_proxy"
|
||||
|
||||
@@ -41,6 +42,8 @@ WELL_KNOWN_PROVIDER_NAMES = [
|
||||
LlmProviderNames.OPENROUTER,
|
||||
LlmProviderNames.AZURE,
|
||||
LlmProviderNames.OLLAMA_CHAT,
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
LlmProviderNames.LITELLM_PROXY,
|
||||
]
|
||||
|
||||
|
||||
@@ -56,6 +59,8 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
|
||||
LlmProviderNames.AZURE: "Azure",
|
||||
"ollama": "Ollama",
|
||||
LlmProviderNames.OLLAMA_CHAT: "Ollama",
|
||||
LlmProviderNames.LM_STUDIO: "LM Studio",
|
||||
LlmProviderNames.LITELLM_PROXY: "LiteLLM Proxy",
|
||||
"groq": "Groq",
|
||||
"anyscale": "Anyscale",
|
||||
"deepseek": "DeepSeek",
|
||||
@@ -103,8 +108,10 @@ AGGREGATOR_PROVIDERS: set[str] = {
|
||||
LlmProviderNames.BEDROCK_CONVERSE,
|
||||
LlmProviderNames.OPENROUTER,
|
||||
LlmProviderNames.OLLAMA_CHAT,
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
LlmProviderNames.VERTEX_AI,
|
||||
LlmProviderNames.AZURE,
|
||||
LlmProviderNames.LITELLM_PROXY,
|
||||
}
|
||||
|
||||
# Model family name mappings for display name generation
|
||||
|
||||
@@ -20,7 +20,9 @@ from onyx.llm.multi_llm import LitellmLLM
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.utils import get_max_input_tokens_from_llm_provider
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.llm.well_known_providers.constants import OLLAMA_API_KEY_CONFIG_KEY
|
||||
from onyx.llm.well_known_providers.constants import (
|
||||
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING,
|
||||
)
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.utils.headers import build_llm_extra_headers
|
||||
@@ -32,14 +34,18 @@ logger = setup_logger()
|
||||
def _build_provider_extra_headers(
|
||||
provider: str, custom_config: dict[str, str] | None
|
||||
) -> dict[str, str]:
|
||||
if provider == LlmProviderNames.OLLAMA_CHAT and custom_config:
|
||||
raw_api_key = custom_config.get(OLLAMA_API_KEY_CONFIG_KEY)
|
||||
api_key = raw_api_key.strip() if raw_api_key else None
|
||||
if provider in PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING and custom_config:
|
||||
raw = custom_config.get(PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING[provider])
|
||||
api_key = raw.strip() if raw else None
|
||||
if not api_key:
|
||||
return {}
|
||||
if not api_key.lower().startswith("bearer "):
|
||||
api_key = f"Bearer {api_key}"
|
||||
return {"Authorization": api_key}
|
||||
return {
|
||||
"Authorization": (
|
||||
api_key
|
||||
if api_key.lower().startswith("bearer ")
|
||||
else f"Bearer {api_key}"
|
||||
)
|
||||
}
|
||||
|
||||
# Passing these will put Onyx on the OpenRouter leaderboard
|
||||
elif provider == LlmProviderNames.OPENROUTER:
|
||||
|
||||
@@ -1512,6 +1512,10 @@
|
||||
"display_name": "Claude Opus 4.5",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-opus-4-6": {
|
||||
"display_name": "Claude Opus 4.6",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-opus-4-5-20251101": {
|
||||
"display_name": "Claude Opus 4.5",
|
||||
"model_vendor": "anthropic",
|
||||
@@ -1526,6 +1530,10 @@
|
||||
"display_name": "Claude Sonnet 4.5",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-sonnet-4-6": {
|
||||
"display_name": "Claude Sonnet 4.6",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-sonnet-4-5-20250929": {
|
||||
"display_name": "Claude Sonnet 4.5",
|
||||
"model_vendor": "anthropic",
|
||||
@@ -2516,6 +2524,10 @@
|
||||
"model_vendor": "openai",
|
||||
"model_version": "2025-10-06"
|
||||
},
|
||||
"gpt-5.4": {
|
||||
"display_name": "GPT-5.4",
|
||||
"model_vendor": "openai"
|
||||
},
|
||||
"gpt-5.2-pro-2025-12-11": {
|
||||
"display_name": "GPT-5.2 Pro",
|
||||
"model_vendor": "openai",
|
||||
|
||||
@@ -42,6 +42,7 @@ from onyx.llm.well_known_providers.constants import AWS_SECRET_ACCESS_KEY_KWARG
|
||||
from onyx.llm.well_known_providers.constants import (
|
||||
AWS_SECRET_ACCESS_KEY_KWARG_ENV_VAR_FORMAT,
|
||||
)
|
||||
from onyx.llm.well_known_providers.constants import LM_STUDIO_API_KEY_CONFIG_KEY
|
||||
from onyx.llm.well_known_providers.constants import OLLAMA_API_KEY_CONFIG_KEY
|
||||
from onyx.llm.well_known_providers.constants import VERTEX_CREDENTIALS_FILE_KWARG
|
||||
from onyx.llm.well_known_providers.constants import (
|
||||
@@ -92,6 +93,98 @@ def _prompt_to_dicts(prompt: LanguageModelInput) -> list[dict[str, Any]]:
|
||||
return [prompt.model_dump(exclude_none=True)]
|
||||
|
||||
|
||||
def _normalize_content(raw: Any) -> str:
|
||||
"""Normalize a message content field to a plain string.
|
||||
|
||||
Content can be a string, None, or a list of content-block dicts
|
||||
(e.g. [{"type": "text", "text": "..."}]).
|
||||
"""
|
||||
if raw is None:
|
||||
return ""
|
||||
if isinstance(raw, str):
|
||||
return raw
|
||||
if isinstance(raw, list):
|
||||
return "\n".join(
|
||||
block.get("text", "") if isinstance(block, dict) else str(block)
|
||||
for block in raw
|
||||
)
|
||||
return str(raw)
|
||||
|
||||
|
||||
def _strip_tool_content_from_messages(
|
||||
messages: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert tool-related messages to plain text.
|
||||
|
||||
Bedrock's Converse API requires toolConfig when messages contain
|
||||
toolUse/toolResult content blocks. When no tools are provided for the
|
||||
current request, we must convert any tool-related history into plain text
|
||||
to avoid the "toolConfig field must be defined" error.
|
||||
|
||||
This is the same approach used by _OllamaHistoryMessageFormatter.
|
||||
"""
|
||||
result: list[dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
role = msg.get("role")
|
||||
tool_calls = msg.get("tool_calls")
|
||||
|
||||
if role == "assistant" and tool_calls:
|
||||
# Convert structured tool calls to text representation
|
||||
tool_call_lines = []
|
||||
for tc in tool_calls:
|
||||
func = tc.get("function", {})
|
||||
name = func.get("name", "unknown")
|
||||
args = func.get("arguments", "{}")
|
||||
tc_id = tc.get("id", "")
|
||||
tool_call_lines.append(
|
||||
f"[Tool Call] name={name} id={tc_id} args={args}"
|
||||
)
|
||||
|
||||
existing_content = _normalize_content(msg.get("content"))
|
||||
parts = (
|
||||
[existing_content] + tool_call_lines
|
||||
if existing_content
|
||||
else tool_call_lines
|
||||
)
|
||||
new_msg = {
|
||||
"role": "assistant",
|
||||
"content": "\n".join(parts),
|
||||
}
|
||||
result.append(new_msg)
|
||||
|
||||
elif role == "tool":
|
||||
# Convert tool response to user message with text content
|
||||
tool_call_id = msg.get("tool_call_id", "")
|
||||
content = _normalize_content(msg.get("content"))
|
||||
tool_result_text = f"[Tool Result] id={tool_call_id}\n{content}"
|
||||
# Merge into previous user message if it is also a converted
|
||||
# tool result to avoid consecutive user messages (Bedrock requires
|
||||
# strict user/assistant alternation).
|
||||
if (
|
||||
result
|
||||
and result[-1]["role"] == "user"
|
||||
and "[Tool Result]" in result[-1].get("content", "")
|
||||
):
|
||||
result[-1]["content"] += "\n\n" + tool_result_text
|
||||
else:
|
||||
result.append({"role": "user", "content": tool_result_text})
|
||||
|
||||
else:
|
||||
result.append(msg)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _messages_contain_tool_content(messages: list[dict[str, Any]]) -> bool:
|
||||
"""Check if any messages contain tool-related content blocks."""
|
||||
for msg in messages:
|
||||
if msg.get("role") == "tool":
|
||||
return True
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_vertex_model_rejecting_output_config(model_name: str) -> bool:
|
||||
normalized_model_name = model_name.lower()
|
||||
return any(
|
||||
@@ -157,6 +250,9 @@ class LitellmLLM(LLM):
|
||||
elif model_provider == LlmProviderNames.OLLAMA_CHAT:
|
||||
if k == OLLAMA_API_KEY_CONFIG_KEY:
|
||||
model_kwargs["api_key"] = v
|
||||
elif model_provider == LlmProviderNames.LM_STUDIO:
|
||||
if k == LM_STUDIO_API_KEY_CONFIG_KEY:
|
||||
model_kwargs["api_key"] = v
|
||||
elif model_provider == LlmProviderNames.BEDROCK:
|
||||
if k == AWS_REGION_NAME_KWARG:
|
||||
model_kwargs[k] = v
|
||||
@@ -173,6 +269,19 @@ class LitellmLLM(LLM):
|
||||
elif k == AWS_SECRET_ACCESS_KEY_KWARG_ENV_VAR_FORMAT:
|
||||
model_kwargs[AWS_SECRET_ACCESS_KEY_KWARG] = v
|
||||
|
||||
# LM Studio: LiteLLM defaults to "fake-api-key" when no key is provided,
|
||||
# which LM Studio rejects. Ensure we always pass an explicit key (or empty
|
||||
# string) to prevent LiteLLM from injecting its fake default.
|
||||
if model_provider == LlmProviderNames.LM_STUDIO:
|
||||
model_kwargs.setdefault("api_key", "")
|
||||
|
||||
# Users provide the server root (e.g. http://localhost:1234) but LiteLLM
|
||||
# needs /v1 for OpenAI-compatible calls.
|
||||
if self._api_base is not None:
|
||||
base = self._api_base.rstrip("/")
|
||||
self._api_base = base if base.endswith("/v1") else f"{base}/v1"
|
||||
model_kwargs["api_base"] = self._api_base
|
||||
|
||||
# Default vertex_location to "global" if not provided for Vertex AI
|
||||
# Latest gemini models are only available through the global region
|
||||
if (
|
||||
@@ -404,13 +513,30 @@ class LitellmLLM(LLM):
|
||||
else nullcontext()
|
||||
)
|
||||
with env_ctx:
|
||||
messages = _prompt_to_dicts(prompt)
|
||||
|
||||
# Bedrock's Converse API requires toolConfig when messages
|
||||
# contain toolUse/toolResult content blocks. When no tools are
|
||||
# provided for this request but the history contains tool
|
||||
# content from previous turns, strip it to plain text.
|
||||
is_bedrock = self._model_provider in {
|
||||
LlmProviderNames.BEDROCK,
|
||||
LlmProviderNames.BEDROCK_CONVERSE,
|
||||
}
|
||||
if (
|
||||
is_bedrock
|
||||
and not tools
|
||||
and _messages_contain_tool_content(messages)
|
||||
):
|
||||
messages = _strip_tool_content_from_messages(messages)
|
||||
|
||||
response = litellm.completion(
|
||||
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
|
||||
model=model,
|
||||
base_url=self._api_base or None,
|
||||
api_version=self._api_version or None,
|
||||
custom_llm_provider=self._custom_llm_provider or None,
|
||||
messages=_prompt_to_dicts(prompt),
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=stream,
|
||||
|
||||
@@ -322,7 +322,7 @@ def test_llm(llm: LLM) -> str | None:
|
||||
error_msg = None
|
||||
for _ in range(2):
|
||||
try:
|
||||
llm.invoke(UserMessage(content="Do not respond"))
|
||||
llm.invoke(UserMessage(content="Do not respond"), max_tokens=50)
|
||||
return None
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
|
||||
@@ -1,52 +1,29 @@
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
|
||||
OPENAI_PROVIDER_NAME = "openai"
|
||||
# Curated list of OpenAI models to show by default in the UI
|
||||
OPENAI_VISIBLE_MODEL_NAMES = {
|
||||
"gpt-5",
|
||||
"gpt-5-mini",
|
||||
"o1",
|
||||
"o3-mini",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
}
|
||||
|
||||
BEDROCK_PROVIDER_NAME = "bedrock"
|
||||
BEDROCK_DEFAULT_MODEL = "anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
|
||||
|
||||
def _fallback_bedrock_regions() -> list[str]:
|
||||
# Fall back to a conservative set of well-known Bedrock regions if boto3 data isn't available.
|
||||
return [
|
||||
"us-east-1",
|
||||
"us-east-2",
|
||||
"us-gov-east-1",
|
||||
"us-gov-west-1",
|
||||
"us-west-2",
|
||||
"ap-northeast-1",
|
||||
"ap-south-1",
|
||||
"ap-southeast-1",
|
||||
"ap-southeast-2",
|
||||
"ap-east-1",
|
||||
"ca-central-1",
|
||||
"eu-central-1",
|
||||
"eu-west-2",
|
||||
]
|
||||
|
||||
|
||||
OLLAMA_PROVIDER_NAME = "ollama_chat"
|
||||
OLLAMA_API_KEY_CONFIG_KEY = "OLLAMA_API_KEY"
|
||||
|
||||
LM_STUDIO_PROVIDER_NAME = "lm_studio"
|
||||
LM_STUDIO_API_KEY_CONFIG_KEY = "LM_STUDIO_API_KEY"
|
||||
|
||||
LITELLM_PROXY_PROVIDER_NAME = "litellm_proxy"
|
||||
|
||||
# Providers that use optional Bearer auth from custom_config
|
||||
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING: dict[str, str] = {
|
||||
LlmProviderNames.OLLAMA_CHAT: OLLAMA_API_KEY_CONFIG_KEY,
|
||||
LlmProviderNames.LM_STUDIO: LM_STUDIO_API_KEY_CONFIG_KEY,
|
||||
}
|
||||
|
||||
# OpenRouter
|
||||
OPENROUTER_PROVIDER_NAME = "openrouter"
|
||||
|
||||
ANTHROPIC_PROVIDER_NAME = "anthropic"
|
||||
|
||||
# Curated list of Anthropic models to show by default in the UI
|
||||
ANTHROPIC_VISIBLE_MODEL_NAMES = {
|
||||
"claude-opus-4-5",
|
||||
"claude-sonnet-4-5",
|
||||
"claude-haiku-4-5",
|
||||
}
|
||||
|
||||
AZURE_PROVIDER_NAME = "azure"
|
||||
|
||||
|
||||
@@ -54,13 +31,6 @@ VERTEXAI_PROVIDER_NAME = "vertex_ai"
|
||||
VERTEX_CREDENTIALS_FILE_KWARG = "vertex_credentials"
|
||||
VERTEX_CREDENTIALS_FILE_KWARG_ENV_VAR_FORMAT = "CREDENTIALS_FILE"
|
||||
VERTEX_LOCATION_KWARG = "vertex_location"
|
||||
VERTEXAI_DEFAULT_MODEL = "gemini-2.5-flash"
|
||||
# Curated list of Vertex AI models to show by default in the UI
|
||||
VERTEXAI_VISIBLE_MODEL_NAMES = {
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-flash-lite",
|
||||
"gemini-2.5-pro",
|
||||
}
|
||||
|
||||
AWS_REGION_NAME_KWARG = "aws_region_name"
|
||||
AWS_REGION_NAME_KWARG_ENV_VAR_FORMAT = "AWS_REGION_NAME"
|
||||
|
||||
@@ -15,6 +15,8 @@ from onyx.llm.well_known_providers.auto_update_service import (
|
||||
from onyx.llm.well_known_providers.constants import ANTHROPIC_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import AZURE_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import BEDROCK_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import LITELLM_PROXY_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import LM_STUDIO_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OLLAMA_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENAI_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENROUTER_PROVIDER_NAME
|
||||
@@ -44,7 +46,9 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
|
||||
ANTHROPIC_PROVIDER_NAME: get_anthropic_model_names(),
|
||||
VERTEXAI_PROVIDER_NAME: get_vertexai_model_names(),
|
||||
OLLAMA_PROVIDER_NAME: [], # Dynamic - fetched from Ollama API
|
||||
LM_STUDIO_PROVIDER_NAME: [], # Dynamic - fetched from LM Studio API
|
||||
OPENROUTER_PROVIDER_NAME: [], # Dynamic - fetched from OpenRouter API
|
||||
LITELLM_PROXY_PROVIDER_NAME: [], # Dynamic - fetched from LiteLLM proxy API
|
||||
}
|
||||
|
||||
|
||||
@@ -323,11 +327,13 @@ def get_provider_display_name(provider_name: str) -> str:
|
||||
_ONYX_PROVIDER_DISPLAY_NAMES: dict[str, str] = {
|
||||
OPENAI_PROVIDER_NAME: "ChatGPT (OpenAI)",
|
||||
OLLAMA_PROVIDER_NAME: "Ollama",
|
||||
LM_STUDIO_PROVIDER_NAME: "LM Studio",
|
||||
ANTHROPIC_PROVIDER_NAME: "Claude (Anthropic)",
|
||||
AZURE_PROVIDER_NAME: "Azure OpenAI",
|
||||
BEDROCK_PROVIDER_NAME: "Amazon Bedrock",
|
||||
VERTEXAI_PROVIDER_NAME: "Google Vertex AI",
|
||||
OPENROUTER_PROVIDER_NAME: "OpenRouter",
|
||||
LITELLM_PROXY_PROVIDER_NAME: "LiteLLM Proxy",
|
||||
}
|
||||
|
||||
if provider_name in _ONYX_PROVIDER_DISPLAY_NAMES:
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"version": "1.1",
|
||||
"updated_at": "2026-02-05T00:00:00Z",
|
||||
"updated_at": "2026-03-05T00:00:00Z",
|
||||
"providers": {
|
||||
"openai": {
|
||||
"default_model": { "name": "gpt-5.2" },
|
||||
"default_model": { "name": "gpt-5.4" },
|
||||
"additional_visible_models": [
|
||||
{ "name": "gpt-5-mini" },
|
||||
{ "name": "gpt-4.1" }
|
||||
{ "name": "gpt-5.4" },
|
||||
{ "name": "gpt-5.2" }
|
||||
]
|
||||
},
|
||||
"anthropic": {
|
||||
@@ -16,6 +16,10 @@
|
||||
"name": "claude-opus-4-6",
|
||||
"display_name": "Claude Opus 4.6"
|
||||
},
|
||||
{
|
||||
"name": "claude-sonnet-4-6",
|
||||
"display_name": "Claude Sonnet 4.6"
|
||||
},
|
||||
{
|
||||
"name": "claude-opus-4-5",
|
||||
"display_name": "Claude Opus 4.5"
|
||||
|
||||
@@ -10,6 +10,7 @@ from onyx.mcp_server.utils import get_indexed_sources
|
||||
from onyx.mcp_server.utils import require_access_token
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import build_api_server_url_for_http_requests
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -26,6 +27,14 @@ async def search_indexed_documents(
|
||||
Use this tool for information that is not public knowledge and specific to the user,
|
||||
their team, their work, or their organization/company.
|
||||
|
||||
Note: In CE mode, this tool uses the chat endpoint internally which invokes an LLM
|
||||
on every call, consuming tokens and adding latency.
|
||||
Additionally, CE callers receive a truncated snippet (blurb) instead of a full document chunk,
|
||||
but this should still be sufficient for most use cases. CE mode functionality should be swapped
|
||||
when a dedicated CE search endpoint is implemented.
|
||||
|
||||
In EE mode, the dedicated search endpoint is used instead.
|
||||
|
||||
To find a list of available sources, use the `indexed_sources` resource.
|
||||
Returns chunks of text as search results with snippets, scores, and metadata.
|
||||
|
||||
@@ -111,48 +120,73 @@ async def search_indexed_documents(
|
||||
if time_cutoff_dt:
|
||||
filters["time_cutoff"] = time_cutoff_dt.isoformat()
|
||||
|
||||
# Build the search request using the new SendSearchQueryRequest format
|
||||
search_request = {
|
||||
"search_query": query,
|
||||
"filters": filters,
|
||||
"num_docs_fed_to_llm_selection": limit,
|
||||
"run_query_expansion": False,
|
||||
"include_content": True,
|
||||
"stream": False,
|
||||
}
|
||||
is_ee = global_version.is_ee_version()
|
||||
base_url = build_api_server_url_for_http_requests(respect_env_override_if_set=True)
|
||||
auth_headers = {"Authorization": f"Bearer {access_token.token}"}
|
||||
|
||||
search_request: dict[str, Any]
|
||||
if is_ee:
|
||||
# EE: use the dedicated search endpoint (no LLM invocation)
|
||||
search_request = {
|
||||
"search_query": query,
|
||||
"filters": filters,
|
||||
"num_docs_fed_to_llm_selection": limit,
|
||||
"run_query_expansion": False,
|
||||
"include_content": True,
|
||||
"stream": False,
|
||||
}
|
||||
endpoint = f"{base_url}/search/send-search-message"
|
||||
error_key = "error"
|
||||
docs_key = "search_docs"
|
||||
content_field = "content"
|
||||
else:
|
||||
# CE: fall back to the chat endpoint (invokes LLM, consumes tokens)
|
||||
search_request = {
|
||||
"message": query,
|
||||
"stream": False,
|
||||
"chat_session_info": {},
|
||||
}
|
||||
if filters:
|
||||
search_request["internal_search_filters"] = filters
|
||||
endpoint = f"{base_url}/chat/send-chat-message"
|
||||
error_key = "error_msg"
|
||||
docs_key = "top_documents"
|
||||
content_field = "blurb"
|
||||
|
||||
# Call the API server using the new send-search-message route
|
||||
try:
|
||||
response = await get_http_client().post(
|
||||
f"{build_api_server_url_for_http_requests(respect_env_override_if_set=True)}/search/send-search-message",
|
||||
endpoint,
|
||||
json=search_request,
|
||||
headers={"Authorization": f"Bearer {access_token.token}"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
# Check for error in response
|
||||
if result.get("error"):
|
||||
if result.get(error_key):
|
||||
return {
|
||||
"documents": [],
|
||||
"total_results": 0,
|
||||
"query": query,
|
||||
"error": result.get("error"),
|
||||
"error": result.get(error_key),
|
||||
}
|
||||
|
||||
# Return simplified format for MCP clients
|
||||
fields_to_return = [
|
||||
"semantic_identifier",
|
||||
"content",
|
||||
"source_type",
|
||||
"link",
|
||||
"score",
|
||||
]
|
||||
documents = [
|
||||
{key: doc.get(key) for key in fields_to_return}
|
||||
for doc in result.get("search_docs", [])
|
||||
{
|
||||
"semantic_identifier": doc.get("semantic_identifier"),
|
||||
"content": doc.get(content_field),
|
||||
"source_type": doc.get("source_type"),
|
||||
"link": doc.get("link"),
|
||||
"score": doc.get("score"),
|
||||
}
|
||||
for doc in result.get(docs_key, [])
|
||||
]
|
||||
|
||||
# NOTE: search depth is controlled by the backend persona defaults, not `limit`.
|
||||
# `limit` only caps the returned list; fewer results may be returned if the
|
||||
# backend retrieves fewer documents than requested.
|
||||
documents = documents[:limit]
|
||||
|
||||
logger.info(
|
||||
f"Onyx MCP Server: Internal search returned {len(documents)} results"
|
||||
)
|
||||
@@ -160,7 +194,6 @@ async def search_indexed_documents(
|
||||
"documents": documents,
|
||||
"total_results": len(documents),
|
||||
"query": query,
|
||||
"executed_queries": result.get("all_executed_queries", [query]),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Onyx MCP Server: Document search error: {e}", exc_info=True)
|
||||
|
||||
@@ -961,9 +961,9 @@
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@hono/node-server": {
|
||||
"version": "1.19.9",
|
||||
"resolved": "https://registry.npmjs.org/@hono/node-server/-/node-server-1.19.9.tgz",
|
||||
"integrity": "sha512-vHL6w3ecZsky+8P5MD+eFfaGTyCeOHUIFYMGpQGbrBTSmNNoxv0if69rEZ5giu36weC5saFuznL411gRX7bJDw==",
|
||||
"version": "1.19.10",
|
||||
"resolved": "https://registry.npmjs.org/@hono/node-server/-/node-server-1.19.10.tgz",
|
||||
"integrity": "sha512-hZ7nOssGqRgyV3FVVQdfi+U4q02uB23bpnYpdvNXkYTRRyWx84b7yf1ans+dnJ/7h41sGL3CeQTfO+ZGxuO+Iw==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=18.14.1"
|
||||
@@ -1573,27 +1573,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@isaacs/balanced-match": {
|
||||
"version": "4.0.1",
|
||||
"resolved": "https://registry.npmjs.org/@isaacs/balanced-match/-/balanced-match-4.0.1.tgz",
|
||||
"integrity": "sha512-yzMTt9lEb8Gv7zRioUilSglI0c0smZ9k5D65677DLWLtWJaXIS3CqcGyUFByYKlnUj6TkjLVs54fBl6+TiGQDQ==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": "20 || >=22"
|
||||
}
|
||||
},
|
||||
"node_modules/@isaacs/brace-expansion": {
|
||||
"version": "5.0.1",
|
||||
"resolved": "https://registry.npmjs.org/@isaacs/brace-expansion/-/brace-expansion-5.0.1.tgz",
|
||||
"integrity": "sha512-WMz71T1JS624nWj2n2fnYAuPovhv7EUhk69R6i9dsVyzxt5eM3bjwvgk9L+APE1TRscGysAVMANkB0jh0LQZrQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@isaacs/balanced-match": "^4.0.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": "20 || >=22"
|
||||
}
|
||||
},
|
||||
"node_modules/@jridgewell/gen-mapping": {
|
||||
"version": "0.3.13",
|
||||
"resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.13.tgz",
|
||||
@@ -1680,9 +1659,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@modelcontextprotocol/sdk/node_modules/ajv": {
|
||||
"version": "8.17.1",
|
||||
"resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz",
|
||||
"integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==",
|
||||
"version": "8.18.0",
|
||||
"resolved": "https://registry.npmjs.org/ajv/-/ajv-8.18.0.tgz",
|
||||
"integrity": "sha512-PlXPeEWMXMZ7sPYOHqmDyCJzcfNrUr3fGNKtezX14ykXOEIvyK81d+qydx89KY5O71FKMPaQ2vBfBFI5NHR63A==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"fast-deep-equal": "^3.1.3",
|
||||
@@ -3855,6 +3834,27 @@
|
||||
"path-browserify": "^1.0.1"
|
||||
}
|
||||
},
|
||||
"node_modules/@ts-morph/common/node_modules/balanced-match": {
|
||||
"version": "4.0.4",
|
||||
"resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-4.0.4.tgz",
|
||||
"integrity": "sha512-BLrgEcRTwX2o6gGxGOCNyMvGSp35YofuYzw9h1IMTRmKqttAZZVU67bdb9Pr2vUHA8+j3i2tJfjO6C6+4myGTA==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": "18 || 20 || >=22"
|
||||
}
|
||||
},
|
||||
"node_modules/@ts-morph/common/node_modules/brace-expansion": {
|
||||
"version": "5.0.3",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-5.0.3.tgz",
|
||||
"integrity": "sha512-fy6KJm2RawA5RcHkLa1z/ScpBeA762UF9KmZQxwIbDtRJrgLzM10depAiEQ+CXYcoiqW1/m96OAAoke2nE9EeA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"balanced-match": "^4.0.2"
|
||||
},
|
||||
"engines": {
|
||||
"node": "18 || 20 || >=22"
|
||||
}
|
||||
},
|
||||
"node_modules/@ts-morph/common/node_modules/fast-glob": {
|
||||
"version": "3.3.3",
|
||||
"resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.3.3.tgz",
|
||||
@@ -3884,15 +3884,15 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@ts-morph/common/node_modules/minimatch": {
|
||||
"version": "10.1.1",
|
||||
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-10.1.1.tgz",
|
||||
"integrity": "sha512-enIvLvRAFZYXJzkCYG5RKmPfrFArdLv+R+lbQ53BmIMLIry74bjKzX6iHAm8WYamJkhSSEabrWN5D97XnKObjQ==",
|
||||
"version": "10.2.4",
|
||||
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-10.2.4.tgz",
|
||||
"integrity": "sha512-oRjTw/97aTBN0RHbYCdtF1MQfvusSIBQM0IZEgzl6426+8jSC0nF1a/GmnVLpfB9yyr6g6FTqWqiZVbxrtaCIg==",
|
||||
"license": "BlueOak-1.0.0",
|
||||
"dependencies": {
|
||||
"@isaacs/brace-expansion": "^5.0.0"
|
||||
"brace-expansion": "^5.0.2"
|
||||
},
|
||||
"engines": {
|
||||
"node": "20 || >=22"
|
||||
"node": "18 || 20 || >=22"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/isaacs"
|
||||
@@ -4234,13 +4234,13 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@typescript-eslint/typescript-estree/node_modules/minimatch": {
|
||||
"version": "9.0.5",
|
||||
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz",
|
||||
"integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==",
|
||||
"version": "9.0.9",
|
||||
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.9.tgz",
|
||||
"integrity": "sha512-OBwBN9AL4dqmETlpS2zasx+vTeWclWzkblfZk7KTA5j3jeOONz/tRCnZomUyvNg83wL5Zv9Ss6HMJXAgL8R2Yg==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"dependencies": {
|
||||
"brace-expansion": "^2.0.1"
|
||||
"brace-expansion": "^2.0.2"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=16 || 14 >=14.17"
|
||||
@@ -4619,9 +4619,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/ajv": {
|
||||
"version": "6.12.6",
|
||||
"resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz",
|
||||
"integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==",
|
||||
"version": "6.14.0",
|
||||
"resolved": "https://registry.npmjs.org/ajv/-/ajv-6.14.0.tgz",
|
||||
"integrity": "sha512-IWrosm/yrn43eiKqkfkHis7QioDleaXQHdDVPKg0FSwwd/DuvyX79TZnFOnYpB7dcsFAMmtFztZuXPDvSePkFw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
@@ -4653,9 +4653,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/ajv-formats/node_modules/ajv": {
|
||||
"version": "8.17.1",
|
||||
"resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz",
|
||||
"integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==",
|
||||
"version": "8.18.0",
|
||||
"resolved": "https://registry.npmjs.org/ajv/-/ajv-8.18.0.tgz",
|
||||
"integrity": "sha512-PlXPeEWMXMZ7sPYOHqmDyCJzcfNrUr3fGNKtezX14ykXOEIvyK81d+qydx89KY5O71FKMPaQ2vBfBFI5NHR63A==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"fast-deep-equal": "^3.1.3",
|
||||
@@ -6758,12 +6758,12 @@
|
||||
}
|
||||
},
|
||||
"node_modules/express-rate-limit": {
|
||||
"version": "8.2.1",
|
||||
"resolved": "https://registry.npmjs.org/express-rate-limit/-/express-rate-limit-8.2.1.tgz",
|
||||
"integrity": "sha512-PCZEIEIxqwhzw4KF0n7QF4QqruVTcF73O5kFKUnGOyjbCCgizBBiFaYpd/fnBLUMPw/BWw9OsiN7GgrNYr7j6g==",
|
||||
"version": "8.3.0",
|
||||
"resolved": "https://registry.npmjs.org/express-rate-limit/-/express-rate-limit-8.3.0.tgz",
|
||||
"integrity": "sha512-KJzBawY6fB9FiZGdE/0aftepZ91YlaGIrV8vgblRM3J8X+dHx/aiowJWwkx6LIGyuqGiANsjSwwrbb8mifOJ4Q==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"ip-address": "10.0.1"
|
||||
"ip-address": "10.1.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 16"
|
||||
@@ -7556,9 +7556,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/ip-address": {
|
||||
"version": "10.0.1",
|
||||
"resolved": "https://registry.npmjs.org/ip-address/-/ip-address-10.0.1.tgz",
|
||||
"integrity": "sha512-NWv9YLW4PoW2B7xtzaS3NCot75m6nK7Icdv0o3lfMceJVRfSoQwqD4wEH5rLwoKJwUiZ/rfpiVBhnaF0FK4HoA==",
|
||||
"version": "10.1.0",
|
||||
"resolved": "https://registry.npmjs.org/ip-address/-/ip-address-10.1.0.tgz",
|
||||
"integrity": "sha512-XXADHxXmvT9+CRxhXg56LJovE+bmWnEWB78LB83VZTprKTmaC5QfruXocxzTZ2Kl0DNwKuBdlIhjL8LeY8Sf8Q==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 12"
|
||||
@@ -8831,9 +8831,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/minimatch": {
|
||||
"version": "3.1.2",
|
||||
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz",
|
||||
"integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==",
|
||||
"version": "3.1.5",
|
||||
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.5.tgz",
|
||||
"integrity": "sha512-VgjWUsnnT6n+NUk6eZq77zeFdpW2LWDzP6zFGrCbHXiYNul5Dzqk2HHQ5uFH2DNW5Xbp8+jVzaeNt94ssEEl4w==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"dependencies": {
|
||||
@@ -9699,9 +9699,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/qs": {
|
||||
"version": "6.14.1",
|
||||
"resolved": "https://registry.npmjs.org/qs/-/qs-6.14.1.tgz",
|
||||
"integrity": "sha512-4EK3+xJl8Ts67nLYNwqw/dsFVnCf+qR7RgXSK9jEEm9unao3njwMDdmsdvoKBKHzxd7tCYz5e5M+SnMjdtXGQQ==",
|
||||
"version": "6.14.2",
|
||||
"resolved": "https://registry.npmjs.org/qs/-/qs-6.14.2.tgz",
|
||||
"integrity": "sha512-V/yCWTTF7VJ9hIh18Ugr2zhJMP01MY7c5kh4J870L7imm6/DIzBsNLTXzMwUA3yZ5b/KBqLx8Kp3uRvd7xSe3Q==",
|
||||
"license": "BSD-3-Clause",
|
||||
"dependencies": {
|
||||
"side-channel": "^1.1.0"
|
||||
|
||||
@@ -48,6 +48,7 @@ from onyx.llm.utils import test_llm
|
||||
from onyx.llm.well_known_providers.auto_update_service import (
|
||||
fetch_llm_recommendations_from_github,
|
||||
)
|
||||
from onyx.llm.well_known_providers.constants import LM_STUDIO_API_KEY_CONFIG_KEY
|
||||
from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
fetch_available_well_known_llms,
|
||||
)
|
||||
@@ -57,22 +58,30 @@ from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
from onyx.server.manage.llm.models import BedrockFinalModelResponse
|
||||
from onyx.server.manage.llm.models import BedrockModelsRequest
|
||||
from onyx.server.manage.llm.models import DefaultModel
|
||||
from onyx.server.manage.llm.models import LitellmFinalModelResponse
|
||||
from onyx.server.manage.llm.models import LitellmModelDetails
|
||||
from onyx.server.manage.llm.models import LitellmModelsRequest
|
||||
from onyx.server.manage.llm.models import LLMCost
|
||||
from onyx.server.manage.llm.models import LLMProviderDescriptor
|
||||
from onyx.server.manage.llm.models import LLMProviderResponse
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import LMStudioFinalModelResponse
|
||||
from onyx.server.manage.llm.models import LMStudioModelsRequest
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.manage.llm.models import OllamaFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OllamaModelDetails
|
||||
from onyx.server.manage.llm.models import OllamaModelsRequest
|
||||
from onyx.server.manage.llm.models import OpenRouterFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OpenRouterModelDetails
|
||||
from onyx.server.manage.llm.models import OpenRouterModelsRequest
|
||||
from onyx.server.manage.llm.models import SyncModelEntry
|
||||
from onyx.server.manage.llm.models import TestLLMRequest
|
||||
from onyx.server.manage.llm.models import VisionProviderResponse
|
||||
from onyx.server.manage.llm.utils import generate_bedrock_display_name
|
||||
from onyx.server.manage.llm.utils import generate_ollama_display_name
|
||||
from onyx.server.manage.llm.utils import infer_vision_support
|
||||
from onyx.server.manage.llm.utils import is_reasoning_model
|
||||
from onyx.server.manage.llm.utils import is_valid_bedrock_model
|
||||
from onyx.server.manage.llm.utils import ModelMetadata
|
||||
from onyx.server.manage.llm.utils import strip_openrouter_vendor_prefix
|
||||
@@ -93,6 +102,34 @@ def _mask_string(value: str) -> str:
|
||||
return value[:4] + "****" + value[-4:]
|
||||
|
||||
|
||||
def _sync_fetched_models(
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
models: list[SyncModelEntry],
|
||||
source_label: str,
|
||||
) -> None:
|
||||
"""Sync fetched models to DB for the given provider.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
provider_name: Name of the LLM provider
|
||||
models: List of SyncModelEntry objects describing the fetched models
|
||||
source_label: Human-readable label for log messages (e.g. "Bedrock", "LiteLLM")
|
||||
"""
|
||||
try:
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
provider_name=provider_name,
|
||||
models=models,
|
||||
)
|
||||
if new_count > 0:
|
||||
logger.info(
|
||||
f"Added {new_count} new {source_label} models to provider '{provider_name}'"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync {source_label} models to DB: {e}")
|
||||
|
||||
|
||||
# Keys in custom_config that contain sensitive credentials
|
||||
_SENSITIVE_CONFIG_KEYS = {
|
||||
"vertex_credentials",
|
||||
@@ -441,6 +478,18 @@ def put_llm_provider(
|
||||
not existing_provider or not existing_provider.is_auto_mode
|
||||
)
|
||||
|
||||
# When transitioning to auto mode, preserve existing model configurations
|
||||
# so the upsert doesn't try to delete them (which would trip the default
|
||||
# model protection guard). sync_auto_mode_models will handle the model
|
||||
# lifecycle afterward — adding new models, hiding removed ones, and
|
||||
# updating the default. This is safe even if sync fails: the provider
|
||||
# keeps its old models and default rather than losing them.
|
||||
if transitioning_to_auto_mode and existing_provider:
|
||||
llm_provider_upsert_request.model_configurations = [
|
||||
ModelConfigurationUpsertRequest.from_model(mc)
|
||||
for mc in existing_provider.model_configurations
|
||||
]
|
||||
|
||||
try:
|
||||
result = upsert_llm_provider(
|
||||
llm_provider_upsert_request=llm_provider_upsert_request,
|
||||
@@ -453,7 +502,6 @@ def put_llm_provider(
|
||||
|
||||
config = fetch_llm_recommendations_from_github()
|
||||
if config and llm_provider_upsert_request.provider in config.providers:
|
||||
# Refetch the provider to get the updated model
|
||||
updated_provider = fetch_existing_llm_provider_by_id(
|
||||
id=result.id, db_session=db_session
|
||||
)
|
||||
@@ -947,27 +995,20 @@ def get_bedrock_available_models(
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
try:
|
||||
models_to_sync = [
|
||||
{
|
||||
"name": r.name,
|
||||
"display_name": r.display_name,
|
||||
"max_input_tokens": r.max_input_tokens,
|
||||
"supports_image_input": r.supports_image_input,
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=models_to_sync,
|
||||
)
|
||||
if new_count > 0:
|
||||
logger.info(
|
||||
f"Added {new_count} new Bedrock models to provider '{request.provider_name}'"
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.name,
|
||||
display_name=r.display_name,
|
||||
max_input_tokens=r.max_input_tokens,
|
||||
supports_image_input=r.supports_image_input,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync Bedrock models to DB: {e}")
|
||||
for r in results
|
||||
],
|
||||
source_label="Bedrock",
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@@ -1085,27 +1126,20 @@ def get_ollama_available_models(
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
try:
|
||||
models_to_sync = [
|
||||
{
|
||||
"name": r.name,
|
||||
"display_name": r.display_name,
|
||||
"max_input_tokens": r.max_input_tokens,
|
||||
"supports_image_input": r.supports_image_input,
|
||||
}
|
||||
for r in sorted_results
|
||||
]
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=models_to_sync,
|
||||
)
|
||||
if new_count > 0:
|
||||
logger.info(
|
||||
f"Added {new_count} new Ollama models to provider '{request.provider_name}'"
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.name,
|
||||
display_name=r.display_name,
|
||||
max_input_tokens=r.max_input_tokens,
|
||||
supports_image_input=r.supports_image_input,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync Ollama models to DB: {e}")
|
||||
for r in sorted_results
|
||||
],
|
||||
source_label="Ollama",
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
@@ -1194,26 +1228,226 @@ def get_openrouter_available_models(
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
try:
|
||||
models_to_sync = [
|
||||
{
|
||||
"name": r.name,
|
||||
"display_name": r.display_name,
|
||||
"max_input_tokens": r.max_input_tokens,
|
||||
"supports_image_input": r.supports_image_input,
|
||||
}
|
||||
for r in sorted_results
|
||||
]
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=models_to_sync,
|
||||
)
|
||||
if new_count > 0:
|
||||
logger.info(
|
||||
f"Added {new_count} new OpenRouter models to provider '{request.provider_name}'"
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.name,
|
||||
display_name=r.display_name,
|
||||
max_input_tokens=r.max_input_tokens,
|
||||
supports_image_input=r.supports_image_input,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync OpenRouter models to DB: {e}")
|
||||
for r in sorted_results
|
||||
],
|
||||
source_label="OpenRouter",
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
|
||||
@admin_router.post("/lm-studio/available-models")
|
||||
def get_lm_studio_available_models(
|
||||
request: LMStudioModelsRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LMStudioFinalModelResponse]:
|
||||
"""Fetch available models from an LM Studio server.
|
||||
|
||||
Uses the LM Studio-native /api/v1/models endpoint which exposes
|
||||
rich metadata including capabilities (vision, reasoning),
|
||||
display names, and context lengths.
|
||||
"""
|
||||
cleaned_api_base = request.api_base.strip().rstrip("/")
|
||||
# Strip /v1 suffix that users may copy from OpenAI-compatible tool configs;
|
||||
# the native metadata endpoint lives at /api/v1/models, not /v1/api/v1/models.
|
||||
cleaned_api_base = cleaned_api_base.removesuffix("/v1")
|
||||
if not cleaned_api_base:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"API base URL is required to fetch LM Studio models.",
|
||||
)
|
||||
|
||||
# If provider_name is given and the api_key hasn't been changed by the user,
|
||||
# fall back to the stored API key from the database (the form value is masked).
|
||||
api_key = request.api_key
|
||||
if request.provider_name and not request.api_key_changed:
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=request.provider_name, db_session=db_session
|
||||
)
|
||||
if existing_provider and existing_provider.custom_config:
|
||||
api_key = existing_provider.custom_config.get(LM_STUDIO_API_KEY_CONFIG_KEY)
|
||||
|
||||
url = f"{cleaned_api_base}/api/v1/models"
|
||||
headers: dict[str, str] = {}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
try:
|
||||
response = httpx.get(url, headers=headers, timeout=10.0)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
except Exception as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch LM Studio models: {e}",
|
||||
)
|
||||
|
||||
models = response_json.get("models", [])
|
||||
if not isinstance(models, list) or len(models) == 0:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No models found from your LM Studio server.",
|
||||
)
|
||||
|
||||
results: list[LMStudioFinalModelResponse] = []
|
||||
for item in models:
|
||||
# Filter to LLM-type models only (skip embeddings, etc.)
|
||||
if item.get("type") != "llm":
|
||||
continue
|
||||
|
||||
model_key = item.get("key")
|
||||
if not model_key:
|
||||
continue
|
||||
|
||||
display_name = item.get("display_name") or model_key
|
||||
max_context_length = item.get("max_context_length")
|
||||
capabilities = item.get("capabilities") or {}
|
||||
|
||||
results.append(
|
||||
LMStudioFinalModelResponse(
|
||||
name=model_key,
|
||||
display_name=display_name,
|
||||
max_input_tokens=max_context_length,
|
||||
supports_image_input=capabilities.get("vision", False),
|
||||
supports_reasoning=capabilities.get("reasoning", False)
|
||||
or is_reasoning_model(model_key, display_name),
|
||||
)
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No compatible models found from LM Studio server.",
|
||||
)
|
||||
|
||||
sorted_results = sorted(results, key=lambda m: m.name.lower())
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.name,
|
||||
display_name=r.display_name,
|
||||
max_input_tokens=r.max_input_tokens,
|
||||
supports_image_input=r.supports_image_input,
|
||||
)
|
||||
for r in sorted_results
|
||||
],
|
||||
source_label="LM Studio",
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
|
||||
@admin_router.post("/litellm/available-models")
|
||||
def get_litellm_available_models(
|
||||
request: LitellmModelsRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LitellmFinalModelResponse]:
|
||||
"""Fetch available models from Litellm proxy /v1/models endpoint."""
|
||||
response_json = _get_litellm_models_response(
|
||||
api_key=request.api_key, api_base=request.api_base
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
if not isinstance(models, list) or len(models) == 0:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No models found from your Litellm endpoint",
|
||||
)
|
||||
|
||||
results: list[LitellmFinalModelResponse] = []
|
||||
for model in models:
|
||||
try:
|
||||
model_details = LitellmModelDetails.model_validate(model)
|
||||
|
||||
results.append(
|
||||
LitellmFinalModelResponse(
|
||||
provider_name=model_details.owned_by,
|
||||
model_name=model_details.id,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to parse Litellm model entry",
|
||||
extra={"error": str(e), "item": str(model)[:1000]},
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No compatible models found from Litellm",
|
||||
)
|
||||
|
||||
sorted_results = sorted(results, key=lambda m: m.model_name.lower())
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.model_name,
|
||||
display_name=r.model_name,
|
||||
)
|
||||
for r in sorted_results
|
||||
],
|
||||
source_label="LiteLLM",
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
|
||||
def _get_litellm_models_response(api_key: str, api_base: str) -> dict:
|
||||
"""Perform GET to Litellm proxy /api/v1/models and return parsed JSON."""
|
||||
cleaned_api_base = api_base.strip().rstrip("/")
|
||||
url = f"{cleaned_api_base}/v1/models"
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"HTTP-Referer": "https://onyx.app",
|
||||
"X-Title": "Onyx",
|
||||
}
|
||||
|
||||
try:
|
||||
response = httpx.get(url, headers=headers, timeout=10.0)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Authentication failed: invalid or missing API key for LiteLLM proxy.",
|
||||
)
|
||||
elif e.response.status_code == 404:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"LiteLLM models endpoint not found at {url}. "
|
||||
"Please verify the API base URL.",
|
||||
)
|
||||
else:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch LiteLLM models: {e}",
|
||||
)
|
||||
except Exception as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch LiteLLM models: {e}",
|
||||
)
|
||||
|
||||
@@ -371,6 +371,22 @@ class OpenRouterFinalModelResponse(BaseModel):
|
||||
supports_image_input: bool
|
||||
|
||||
|
||||
# LM Studio dynamic models fetch
|
||||
class LMStudioModelsRequest(BaseModel):
|
||||
api_base: str
|
||||
api_key: str | None = None
|
||||
api_key_changed: bool = False
|
||||
provider_name: str | None = None # Optional: to save models to existing provider
|
||||
|
||||
|
||||
class LMStudioFinalModelResponse(BaseModel):
|
||||
name: str # Model ID from LM Studio (e.g., "lmstudio-community/Meta-Llama-3-8B")
|
||||
display_name: str # Human-readable name
|
||||
max_input_tokens: int | None # From LM Studio API or None if unavailable
|
||||
supports_image_input: bool
|
||||
supports_reasoning: bool
|
||||
|
||||
|
||||
class DefaultModel(BaseModel):
|
||||
provider_id: int
|
||||
model_name: str
|
||||
@@ -404,3 +420,32 @@ class LLMProviderResponse(BaseModel, Generic[T]):
|
||||
default_text=default_text,
|
||||
default_vision=default_vision,
|
||||
)
|
||||
|
||||
|
||||
class SyncModelEntry(BaseModel):
|
||||
"""Typed model for syncing fetched models to the DB."""
|
||||
|
||||
name: str
|
||||
display_name: str
|
||||
max_input_tokens: int | None = None
|
||||
supports_image_input: bool = False
|
||||
|
||||
|
||||
class LitellmModelsRequest(BaseModel):
|
||||
api_key: str
|
||||
api_base: str
|
||||
provider_name: str | None = None # Optional: to save models to existing provider
|
||||
|
||||
|
||||
class LitellmModelDetails(BaseModel):
|
||||
"""Response model for Litellm proxy /api/v1/models endpoint"""
|
||||
|
||||
id: str # Model ID (e.g. "gpt-4o")
|
||||
object: str # "model"
|
||||
created: int # Unix timestamp in seconds
|
||||
owned_by: str # Provider name (e.g. "openai")
|
||||
|
||||
|
||||
class LitellmFinalModelResponse(BaseModel):
|
||||
provider_name: str # Provider name (e.g. "openai")
|
||||
model_name: str # Model ID (e.g. "gpt-4o")
|
||||
|
||||
@@ -12,6 +12,7 @@ from typing import TypedDict
|
||||
|
||||
from onyx.llm.constants import BEDROCK_MODEL_NAME_MAPPINGS
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.constants import MODEL_PREFIX_TO_VENDOR
|
||||
from onyx.llm.constants import OLLAMA_MODEL_NAME_MAPPINGS
|
||||
from onyx.llm.constants import OLLAMA_MODEL_TO_VENDOR
|
||||
from onyx.llm.constants import PROVIDER_DISPLAY_NAMES
|
||||
@@ -23,6 +24,7 @@ DYNAMIC_LLM_PROVIDERS = frozenset(
|
||||
LlmProviderNames.OPENROUTER,
|
||||
LlmProviderNames.BEDROCK,
|
||||
LlmProviderNames.OLLAMA_CHAT,
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -348,4 +350,19 @@ def extract_vendor_from_model_name(model_name: str, provider: str) -> str | None
|
||||
# Fallback: capitalize the base name as vendor
|
||||
return base_name.split("-")[0].title()
|
||||
|
||||
elif provider == LlmProviderNames.LM_STUDIO:
|
||||
# LM Studio model IDs can be paths like "publisher/model-name"
|
||||
# or simple names. Use MODEL_PREFIX_TO_VENDOR for matching.
|
||||
|
||||
model_lower = model_name.lower()
|
||||
# Check for slash-separated vendor prefix first
|
||||
if "/" in model_lower:
|
||||
vendor_key = model_lower.split("/")[0]
|
||||
return PROVIDER_DISPLAY_NAMES.get(vendor_key, vendor_key.title())
|
||||
# Fallback to model prefix matching
|
||||
for prefix, vendor in MODEL_PREFIX_TO_VENDOR.items():
|
||||
if model_lower.startswith(prefix):
|
||||
return PROVIDER_DISPLAY_NAMES.get(vendor, vendor.title())
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.citation_utils import extract_citation_order_from_text
|
||||
@@ -20,7 +22,9 @@ from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolArgs
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolErrorInfo
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import FileReaderResult
|
||||
from onyx.server.query_and_chat.streaming_models import FileReaderStart
|
||||
@@ -180,24 +184,37 @@ def create_custom_tool_packets(
|
||||
tab_index: int = 0,
|
||||
data: dict | list | str | int | float | bool | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
error: CustomToolErrorInfo | None = None,
|
||||
tool_args: dict[str, Any] | None = None,
|
||||
tool_id: int | None = None,
|
||||
) -> list[Packet]:
|
||||
packets: list[Packet] = []
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index, tab_index=tab_index),
|
||||
obj=CustomToolStart(tool_name=tool_name),
|
||||
obj=CustomToolStart(tool_name=tool_name, tool_id=tool_id),
|
||||
)
|
||||
)
|
||||
|
||||
if tool_args:
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index, tab_index=tab_index),
|
||||
obj=CustomToolArgs(tool_name=tool_name, tool_args=tool_args),
|
||||
)
|
||||
)
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index, tab_index=tab_index),
|
||||
obj=CustomToolDelta(
|
||||
tool_name=tool_name,
|
||||
tool_id=tool_id,
|
||||
response_type=response_type,
|
||||
data=data,
|
||||
file_ids=file_ids,
|
||||
error=error,
|
||||
),
|
||||
),
|
||||
)
|
||||
@@ -657,13 +674,55 @@ def translate_assistant_message_to_packets(
|
||||
|
||||
else:
|
||||
# Custom tool or unknown tool
|
||||
# Try to parse as structured CustomToolCallSummary JSON
|
||||
custom_data: dict | list | str | int | float | bool | None = (
|
||||
tool_call.tool_call_response
|
||||
)
|
||||
custom_error: CustomToolErrorInfo | None = None
|
||||
custom_response_type = "text"
|
||||
|
||||
try:
|
||||
parsed = json.loads(tool_call.tool_call_response)
|
||||
if isinstance(parsed, dict) and "tool_name" in parsed:
|
||||
custom_data = parsed.get("tool_result")
|
||||
custom_response_type = parsed.get(
|
||||
"response_type", "text"
|
||||
)
|
||||
if parsed.get("error"):
|
||||
custom_error = CustomToolErrorInfo(
|
||||
**parsed["error"]
|
||||
)
|
||||
except (
|
||||
json.JSONDecodeError,
|
||||
KeyError,
|
||||
TypeError,
|
||||
ValidationError,
|
||||
):
|
||||
pass
|
||||
|
||||
custom_file_ids: list[str] | None = None
|
||||
if custom_response_type in ("image", "csv") and isinstance(
|
||||
custom_data, dict
|
||||
):
|
||||
custom_file_ids = custom_data.get("file_ids")
|
||||
custom_data = None
|
||||
|
||||
custom_args = {
|
||||
k: v
|
||||
for k, v in (tool_call.tool_call_arguments or {}).items()
|
||||
if k != "requestBody"
|
||||
}
|
||||
turn_tool_packets.extend(
|
||||
create_custom_tool_packets(
|
||||
tool_name=tool.display_name or tool.name,
|
||||
response_type="text",
|
||||
response_type=custom_response_type,
|
||||
turn_index=turn_num,
|
||||
tab_index=tool_call.tab_index,
|
||||
data=tool_call.tool_call_response,
|
||||
data=custom_data,
|
||||
file_ids=custom_file_ids,
|
||||
error=custom_error,
|
||||
tool_args=custom_args if custom_args else None,
|
||||
tool_id=tool_call.tool_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ class StreamingType(Enum):
|
||||
PYTHON_TOOL_START = "python_tool_start"
|
||||
PYTHON_TOOL_DELTA = "python_tool_delta"
|
||||
CUSTOM_TOOL_START = "custom_tool_start"
|
||||
CUSTOM_TOOL_ARGS = "custom_tool_args"
|
||||
CUSTOM_TOOL_DELTA = "custom_tool_delta"
|
||||
FILE_READER_START = "file_reader_start"
|
||||
FILE_READER_RESULT = "file_reader_result"
|
||||
@@ -245,6 +246,20 @@ class CustomToolStart(BaseObj):
|
||||
type: Literal["custom_tool_start"] = StreamingType.CUSTOM_TOOL_START.value
|
||||
|
||||
tool_name: str
|
||||
tool_id: int | None = None
|
||||
|
||||
|
||||
class CustomToolArgs(BaseObj):
|
||||
type: Literal["custom_tool_args"] = StreamingType.CUSTOM_TOOL_ARGS.value
|
||||
|
||||
tool_name: str
|
||||
tool_args: dict[str, Any]
|
||||
|
||||
|
||||
class CustomToolErrorInfo(BaseModel):
|
||||
is_auth_error: bool = False
|
||||
status_code: int
|
||||
message: str
|
||||
|
||||
|
||||
# The allowed streamed packets for a custom tool
|
||||
@@ -252,11 +267,13 @@ class CustomToolDelta(BaseObj):
|
||||
type: Literal["custom_tool_delta"] = StreamingType.CUSTOM_TOOL_DELTA.value
|
||||
|
||||
tool_name: str
|
||||
tool_id: int | None = None
|
||||
response_type: str
|
||||
# For non-file responses
|
||||
data: dict | list | str | int | float | bool | None = None
|
||||
# For file-based responses like image/csv
|
||||
file_ids: list[str] | None = None
|
||||
error: CustomToolErrorInfo | None = None
|
||||
|
||||
|
||||
################################################
|
||||
@@ -366,6 +383,7 @@ PacketObj = Union[
|
||||
PythonToolStart,
|
||||
PythonToolDelta,
|
||||
CustomToolStart,
|
||||
CustomToolArgs,
|
||||
CustomToolDelta,
|
||||
FileReaderStart,
|
||||
FileReaderResult,
|
||||
|
||||
@@ -8,8 +8,6 @@ from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import GeneratedImage
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationFinal
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
|
||||
@@ -165,39 +163,6 @@ def create_image_generation_packets(
|
||||
return packets
|
||||
|
||||
|
||||
def create_custom_tool_packets(
|
||||
tool_name: str,
|
||||
response_type: str,
|
||||
turn_index: int,
|
||||
data: dict | list | str | int | float | bool | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
) -> list[Packet]:
|
||||
packets: list[Packet] = []
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
obj=CustomToolStart(tool_name=tool_name),
|
||||
)
|
||||
)
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
obj=CustomToolDelta(
|
||||
tool_name=tool_name,
|
||||
response_type=response_type,
|
||||
data=data,
|
||||
file_ids=file_ids,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
|
||||
|
||||
return packets
|
||||
|
||||
|
||||
def create_fetch_packets(
|
||||
fetch_docs: list[SavedSearchDoc],
|
||||
urls: list[str],
|
||||
|
||||
@@ -275,9 +275,13 @@ def setup_postgres(db_session: Session) -> None:
|
||||
],
|
||||
api_key_changed=True,
|
||||
)
|
||||
new_llm_provider = upsert_llm_provider(
|
||||
llm_provider_upsert_request=model_req, db_session=db_session
|
||||
)
|
||||
try:
|
||||
new_llm_provider = upsert_llm_provider(
|
||||
llm_provider_upsert_request=model_req, db_session=db_session
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning("Failed to upsert LLM provider during setup: %s", e)
|
||||
return
|
||||
update_default_provider(
|
||||
provider_id=new_llm_provider.id, model_name=llm_model, db_session=db_session
|
||||
)
|
||||
|
||||
@@ -18,6 +18,7 @@ from onyx.context.search.models import SearchDoc
|
||||
from onyx.context.search.models import SearchDocsResponse
|
||||
from onyx.db.memory import UserMemoryContext
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolErrorInfo
|
||||
from onyx.server.query_and_chat.streaming_models import GeneratedImage
|
||||
from onyx.tools.tool_implementations.images.models import FinalImageGenerationResponse
|
||||
from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
|
||||
@@ -61,6 +62,7 @@ class CustomToolCallSummary(BaseModel):
|
||||
tool_name: str
|
||||
response_type: str # e.g., 'json', 'image', 'csv', 'graph'
|
||||
tool_result: Any # The response data
|
||||
error: CustomToolErrorInfo | None = None
|
||||
|
||||
|
||||
class ToolCallKickoff(BaseModel):
|
||||
|
||||
@@ -15,7 +15,9 @@ from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolArgs
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolErrorInfo
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.tools.interface import Tool
|
||||
@@ -139,7 +141,7 @@ class CustomTool(Tool[None]):
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
obj=CustomToolStart(tool_name=self._name),
|
||||
obj=CustomToolStart(tool_name=self._name, tool_id=self._id),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -149,10 +151,8 @@ class CustomTool(Tool[None]):
|
||||
override_kwargs: None = None, # noqa: ARG002
|
||||
**llm_kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
request_body = llm_kwargs.get(REQUEST_BODY)
|
||||
|
||||
# Build path params
|
||||
path_params = {}
|
||||
|
||||
for path_param_schema in self._method_spec.get_path_param_schemas():
|
||||
param_name = path_param_schema["name"]
|
||||
if param_name not in llm_kwargs:
|
||||
@@ -165,6 +165,7 @@ class CustomTool(Tool[None]):
|
||||
)
|
||||
path_params[param_name] = llm_kwargs[param_name]
|
||||
|
||||
# Build query params
|
||||
query_params = {}
|
||||
for query_param_schema in self._method_spec.get_query_param_schemas():
|
||||
if query_param_schema["name"] in llm_kwargs:
|
||||
@@ -172,6 +173,20 @@ class CustomTool(Tool[None]):
|
||||
query_param_schema["name"]
|
||||
]
|
||||
|
||||
# Emit args packet (path + query params only, no request body)
|
||||
tool_args = {**path_params, **query_params}
|
||||
if tool_args:
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
obj=CustomToolArgs(
|
||||
tool_name=self._name,
|
||||
tool_args=tool_args,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
request_body = llm_kwargs.get(REQUEST_BODY)
|
||||
url = self._method_spec.build_url(self._base_url, path_params, query_params)
|
||||
method = self._method_spec.method
|
||||
|
||||
@@ -180,6 +195,18 @@ class CustomTool(Tool[None]):
|
||||
)
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
|
||||
# Detect HTTP errors — only 401/403 are flagged as auth errors
|
||||
error_info: CustomToolErrorInfo | None = None
|
||||
if response.status_code in (401, 403):
|
||||
error_info = CustomToolErrorInfo(
|
||||
is_auth_error=True,
|
||||
status_code=response.status_code,
|
||||
message=f"{self._name} action failed because of authentication error",
|
||||
)
|
||||
logger.warning(
|
||||
f"Auth error from custom tool '{self._name}': HTTP {response.status_code}"
|
||||
)
|
||||
|
||||
tool_result: Any
|
||||
response_type: str
|
||||
file_ids: List[str] | None = None
|
||||
@@ -222,9 +249,11 @@ class CustomTool(Tool[None]):
|
||||
placement=placement,
|
||||
obj=CustomToolDelta(
|
||||
tool_name=self._name,
|
||||
tool_id=self._id,
|
||||
response_type=response_type,
|
||||
data=data,
|
||||
file_ids=file_ids,
|
||||
error=error_info,
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -236,6 +265,7 @@ class CustomTool(Tool[None]):
|
||||
tool_name=self._name,
|
||||
response_type=response_type,
|
||||
tool_result=tool_result,
|
||||
error=error_info,
|
||||
),
|
||||
llm_facing_response=llm_facing_response,
|
||||
)
|
||||
|
||||
@@ -111,19 +111,26 @@ def _normalize_text_with_mapping(text: str) -> tuple[str, list[int]]:
|
||||
# Step 1: NFC normalization with position mapping
|
||||
nfc_text = unicodedata.normalize("NFC", text)
|
||||
|
||||
# Build mapping from NFC positions to original start positions
|
||||
# Map NFD positions → original positions.
|
||||
# NFD only decomposes, so each original char produces 1+ NFD chars.
|
||||
nfd_to_orig: list[int] = []
|
||||
for orig_idx, orig_char in enumerate(original_text):
|
||||
nfd_of_char = unicodedata.normalize("NFD", orig_char)
|
||||
for _ in nfd_of_char:
|
||||
nfd_to_orig.append(orig_idx)
|
||||
|
||||
# Map NFC positions → NFD positions.
|
||||
# Each NFC char, when decomposed, tells us exactly how many NFD
|
||||
# chars it was composed from.
|
||||
nfc_to_orig: list[int] = []
|
||||
orig_idx = 0
|
||||
nfd_idx = 0
|
||||
for nfc_char in nfc_text:
|
||||
nfc_to_orig.append(orig_idx)
|
||||
# Find how many original chars contributed to this NFC char
|
||||
for length in range(1, len(original_text) - orig_idx + 1):
|
||||
substr = original_text[orig_idx : orig_idx + length]
|
||||
if unicodedata.normalize("NFC", substr) == nfc_char:
|
||||
orig_idx += length
|
||||
break
|
||||
if nfd_idx < len(nfd_to_orig):
|
||||
nfc_to_orig.append(nfd_to_orig[nfd_idx])
|
||||
else:
|
||||
orig_idx += 1 # Fallback
|
||||
nfc_to_orig.append(len(original_text) - 1)
|
||||
nfd_of_nfc = unicodedata.normalize("NFD", nfc_char)
|
||||
nfd_idx += len(nfd_of_nfc)
|
||||
|
||||
# Work with NFC text from here
|
||||
text = nfc_text
|
||||
|
||||
17
backend/onyx/utils/jsonriver/__init__.py
Normal file
17
backend/onyx/utils/jsonriver/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
jsonriver - A streaming JSON parser for Python
|
||||
|
||||
Parse JSON incrementally as it streams in, e.g. from a network request or a language model.
|
||||
Gives you a sequence of increasingly complete values.
|
||||
|
||||
Copyright (c) 2023 Google LLC (original TypeScript implementation)
|
||||
Copyright (c) 2024 jsonriver-python contributors (Python port)
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
"""
|
||||
|
||||
from .parse import _Parser as Parser
|
||||
from .parse import JsonObject
|
||||
from .parse import JsonValue
|
||||
|
||||
__all__ = ["Parser", "JsonValue", "JsonObject"]
|
||||
__version__ = "0.0.1"
|
||||
427
backend/onyx/utils/jsonriver/parse.py
Normal file
427
backend/onyx/utils/jsonriver/parse.py
Normal file
@@ -0,0 +1,427 @@
|
||||
"""
|
||||
JSON parser for streaming incremental parsing
|
||||
|
||||
Copyright (c) 2023 Google LLC (original TypeScript implementation)
|
||||
Copyright (c) 2024 jsonriver-python contributors (Python port)
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from enum import IntEnum
|
||||
from typing import cast
|
||||
from typing import Union
|
||||
|
||||
from .tokenize import _Input
|
||||
from .tokenize import json_token_type_to_string
|
||||
from .tokenize import JsonTokenType
|
||||
from .tokenize import Tokenizer
|
||||
|
||||
|
||||
# Type definitions for JSON values
|
||||
JsonValue = Union[None, bool, float, str, list["JsonValue"], dict[str, "JsonValue"]]
|
||||
JsonObject = dict[str, JsonValue]
|
||||
|
||||
|
||||
class _StateEnum(IntEnum):
|
||||
"""Parser state machine states"""
|
||||
|
||||
Initial = 0
|
||||
InString = 1
|
||||
InArray = 2
|
||||
InObjectExpectingKey = 3
|
||||
InObjectExpectingValue = 4
|
||||
|
||||
|
||||
class _State:
|
||||
"""Base class for parser states"""
|
||||
|
||||
type: _StateEnum
|
||||
value: JsonValue | tuple[str, JsonObject] | None
|
||||
|
||||
|
||||
class _InitialState(_State):
|
||||
"""Initial state before any parsing"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.type = _StateEnum.Initial
|
||||
self.value = None
|
||||
|
||||
|
||||
class _InStringState(_State):
|
||||
"""State while parsing a string"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.type = _StateEnum.InString
|
||||
self.value = ""
|
||||
|
||||
|
||||
class _InArrayState(_State):
|
||||
"""State while parsing an array"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.type = _StateEnum.InArray
|
||||
self.value: list[JsonValue] = []
|
||||
|
||||
|
||||
class _InObjectExpectingKeyState(_State):
|
||||
"""State while parsing an object, expecting a key"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.type = _StateEnum.InObjectExpectingKey
|
||||
self.value: JsonObject = {}
|
||||
|
||||
|
||||
class _InObjectExpectingValueState(_State):
|
||||
"""State while parsing an object, expecting a value"""
|
||||
|
||||
def __init__(self, key: str, obj: JsonObject) -> None:
|
||||
self.type = _StateEnum.InObjectExpectingValue
|
||||
self.value = (key, obj)
|
||||
|
||||
|
||||
# Sentinel value to distinguish "not set" from "set to None/null"
|
||||
class _Unset:
|
||||
pass
|
||||
|
||||
|
||||
_UNSET = _Unset()
|
||||
|
||||
|
||||
class _Parser:
|
||||
"""
|
||||
Incremental JSON parser
|
||||
|
||||
Feed chunks of JSON text via feed() and get back progressively
|
||||
more complete JSON values.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._state_stack: list[_State] = [_InitialState()]
|
||||
self._toplevel_value: JsonValue | _Unset = _UNSET
|
||||
self._input = _Input()
|
||||
self.tokenizer = Tokenizer(self._input, self)
|
||||
self._finished = False
|
||||
self._progressed = False
|
||||
self._prev_snapshot: JsonValue | _Unset = _UNSET
|
||||
|
||||
def feed(self, chunk: str) -> list[JsonValue]:
|
||||
"""
|
||||
Feed a chunk of JSON text and return deltas from the previous state.
|
||||
|
||||
Each element in the returned list represents what changed since the
|
||||
last yielded value. For dicts, only changed/new keys are included,
|
||||
with string values containing only the newly appended characters.
|
||||
"""
|
||||
if self._finished:
|
||||
return []
|
||||
|
||||
self._input.feed(chunk)
|
||||
return self._collect_deltas()
|
||||
|
||||
@staticmethod
|
||||
def _compute_delta(prev: JsonValue | None, current: JsonValue) -> JsonValue | None:
|
||||
if prev is None:
|
||||
return current
|
||||
|
||||
if isinstance(current, dict) and isinstance(prev, dict):
|
||||
result: JsonObject = {}
|
||||
for key in current:
|
||||
cur_val = current[key]
|
||||
prev_val = prev.get(key)
|
||||
if key not in prev:
|
||||
result[key] = cur_val
|
||||
elif isinstance(cur_val, str) and isinstance(prev_val, str):
|
||||
if cur_val != prev_val:
|
||||
result[key] = cur_val[len(prev_val) :]
|
||||
elif isinstance(cur_val, list) and isinstance(prev_val, list):
|
||||
if cur_val != prev_val:
|
||||
new_items = cur_val[len(prev_val) :]
|
||||
# check if the last existing element was updated
|
||||
if (
|
||||
prev_val
|
||||
and len(cur_val) >= len(prev_val)
|
||||
and cur_val[len(prev_val) - 1] != prev_val[-1]
|
||||
):
|
||||
result[key] = [cur_val[len(prev_val) - 1]] + new_items
|
||||
elif new_items:
|
||||
result[key] = new_items
|
||||
elif cur_val != prev_val:
|
||||
result[key] = cur_val
|
||||
return result if result else None
|
||||
|
||||
if isinstance(current, str) and isinstance(prev, str):
|
||||
delta = current[len(prev) :]
|
||||
return delta if delta else None
|
||||
|
||||
if isinstance(current, list) and isinstance(prev, list):
|
||||
if current != prev:
|
||||
new_items = current[len(prev) :]
|
||||
if (
|
||||
prev
|
||||
and len(current) >= len(prev)
|
||||
and current[len(prev) - 1] != prev[-1]
|
||||
):
|
||||
return [current[len(prev) - 1]] + new_items
|
||||
return new_items if new_items else None
|
||||
return None
|
||||
|
||||
if current != prev:
|
||||
return current
|
||||
return None
|
||||
|
||||
def finish(self) -> list[JsonValue]:
|
||||
"""Signal that no more chunks will be fed. Validates trailing content.
|
||||
|
||||
Returns any final deltas produced by flushing pending tokens (e.g.
|
||||
numbers, which have no terminator and wait for more input).
|
||||
"""
|
||||
self._input.mark_complete()
|
||||
# Pump once more so the tokenizer can emit tokens that were waiting
|
||||
# for more input (e.g. numbers need buffer_complete to finalize).
|
||||
results = self._collect_deltas()
|
||||
self._input.expect_end_of_content()
|
||||
return results
|
||||
|
||||
def _collect_deltas(self) -> list[JsonValue]:
|
||||
"""Run one pump cycle and return any deltas produced."""
|
||||
results: list[JsonValue] = []
|
||||
while True:
|
||||
self._progressed = False
|
||||
self.tokenizer.pump()
|
||||
|
||||
if self._progressed:
|
||||
if self._toplevel_value is _UNSET:
|
||||
raise RuntimeError(
|
||||
"Internal error: toplevel_value should not be unset "
|
||||
"after progressing"
|
||||
)
|
||||
current = copy.deepcopy(cast(JsonValue, self._toplevel_value))
|
||||
if isinstance(self._prev_snapshot, _Unset):
|
||||
results.append(current)
|
||||
else:
|
||||
delta = self._compute_delta(self._prev_snapshot, current)
|
||||
if delta is not None:
|
||||
results.append(delta)
|
||||
self._prev_snapshot = current
|
||||
else:
|
||||
if not self._state_stack:
|
||||
self._finished = True
|
||||
break
|
||||
return results
|
||||
|
||||
# TokenHandler protocol implementation
|
||||
|
||||
def handle_null(self) -> None:
|
||||
"""Handle null token"""
|
||||
self._handle_value_token(JsonTokenType.Null, None)
|
||||
|
||||
def handle_boolean(self, value: bool) -> None:
|
||||
"""Handle boolean token"""
|
||||
self._handle_value_token(JsonTokenType.Boolean, value)
|
||||
|
||||
def handle_number(self, value: float) -> None:
|
||||
"""Handle number token"""
|
||||
self._handle_value_token(JsonTokenType.Number, value)
|
||||
|
||||
def handle_string_start(self) -> None:
|
||||
"""Handle string start token"""
|
||||
state = self._current_state()
|
||||
if not self._progressed and state.type != _StateEnum.InObjectExpectingKey:
|
||||
self._progressed = True
|
||||
|
||||
if state.type == _StateEnum.Initial:
|
||||
self._state_stack.pop()
|
||||
self._toplevel_value = self._progress_value(JsonTokenType.StringStart, None)
|
||||
|
||||
elif state.type == _StateEnum.InArray:
|
||||
v = self._progress_value(JsonTokenType.StringStart, None)
|
||||
arr = cast(list[JsonValue], state.value)
|
||||
arr.append(v)
|
||||
|
||||
elif state.type == _StateEnum.InObjectExpectingKey:
|
||||
self._state_stack.append(_InStringState())
|
||||
|
||||
elif state.type == _StateEnum.InObjectExpectingValue:
|
||||
key, obj = cast(tuple[str, JsonObject], state.value)
|
||||
sv = self._progress_value(JsonTokenType.StringStart, None)
|
||||
obj[key] = sv
|
||||
|
||||
elif state.type == _StateEnum.InString:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(JsonTokenType.StringStart)} "
|
||||
f"token in the middle of string"
|
||||
)
|
||||
|
||||
def handle_string_middle(self, value: str) -> None:
|
||||
"""Handle string middle token"""
|
||||
state = self._current_state()
|
||||
|
||||
if not self._progressed:
|
||||
if len(self._state_stack) >= 2:
|
||||
prev = self._state_stack[-2]
|
||||
if prev.type != _StateEnum.InObjectExpectingKey:
|
||||
self._progressed = True
|
||||
else:
|
||||
self._progressed = True
|
||||
|
||||
if state.type != _StateEnum.InString:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(JsonTokenType.StringMiddle)} "
|
||||
f"token when not in string"
|
||||
)
|
||||
|
||||
assert isinstance(state.value, str)
|
||||
state.value += value
|
||||
|
||||
parent_state = self._state_stack[-2] if len(self._state_stack) >= 2 else None
|
||||
self._update_string_parent(state.value, parent_state)
|
||||
|
||||
def handle_string_end(self) -> None:
|
||||
"""Handle string end token"""
|
||||
state = self._current_state()
|
||||
|
||||
if state.type != _StateEnum.InString:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(JsonTokenType.StringEnd)} "
|
||||
f"token when not in string"
|
||||
)
|
||||
|
||||
self._state_stack.pop()
|
||||
parent_state = self._state_stack[-1] if self._state_stack else None
|
||||
assert isinstance(state.value, str)
|
||||
self._update_string_parent(state.value, parent_state)
|
||||
|
||||
def handle_array_start(self) -> None:
|
||||
"""Handle array start token"""
|
||||
self._handle_value_token(JsonTokenType.ArrayStart, None)
|
||||
|
||||
def handle_array_end(self) -> None:
|
||||
"""Handle array end token"""
|
||||
state = self._current_state()
|
||||
if state.type != _StateEnum.InArray:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(JsonTokenType.ArrayEnd)} token"
|
||||
)
|
||||
self._state_stack.pop()
|
||||
|
||||
def handle_object_start(self) -> None:
|
||||
"""Handle object start token"""
|
||||
self._handle_value_token(JsonTokenType.ObjectStart, None)
|
||||
|
||||
def handle_object_end(self) -> None:
|
||||
"""Handle object end token"""
|
||||
state = self._current_state()
|
||||
|
||||
if state.type in (
|
||||
_StateEnum.InObjectExpectingKey,
|
||||
_StateEnum.InObjectExpectingValue,
|
||||
):
|
||||
self._state_stack.pop()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(JsonTokenType.ObjectEnd)} token"
|
||||
)
|
||||
|
||||
# Private helper methods
|
||||
|
||||
def _current_state(self) -> _State:
|
||||
"""Get current parser state"""
|
||||
if not self._state_stack:
|
||||
raise ValueError("Unexpected trailing input")
|
||||
return self._state_stack[-1]
|
||||
|
||||
def _handle_value_token(self, token_type: JsonTokenType, value: JsonValue) -> None:
|
||||
"""Handle a complete value token"""
|
||||
state = self._current_state()
|
||||
|
||||
if not self._progressed:
|
||||
self._progressed = True
|
||||
|
||||
if state.type == _StateEnum.Initial:
|
||||
self._state_stack.pop()
|
||||
self._toplevel_value = self._progress_value(token_type, value)
|
||||
|
||||
elif state.type == _StateEnum.InArray:
|
||||
v = self._progress_value(token_type, value)
|
||||
arr = cast(list[JsonValue], state.value)
|
||||
arr.append(v)
|
||||
|
||||
elif state.type == _StateEnum.InObjectExpectingValue:
|
||||
key, obj = cast(tuple[str, JsonObject], state.value)
|
||||
if token_type != JsonTokenType.StringStart:
|
||||
self._state_stack.pop()
|
||||
new_state = _InObjectExpectingKeyState()
|
||||
new_state.value = obj
|
||||
self._state_stack.append(new_state)
|
||||
|
||||
v = self._progress_value(token_type, value)
|
||||
obj[key] = v
|
||||
|
||||
elif state.type == _StateEnum.InString:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(token_type)} "
|
||||
f"token in the middle of string"
|
||||
)
|
||||
|
||||
elif state.type == _StateEnum.InObjectExpectingKey:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(token_type)} "
|
||||
f"token in the middle of object expecting key"
|
||||
)
|
||||
|
||||
def _update_string_parent(self, updated: str, parent_state: _State | None) -> None:
|
||||
"""Update parent container with updated string value"""
|
||||
if parent_state is None:
|
||||
self._toplevel_value = updated
|
||||
|
||||
elif parent_state.type == _StateEnum.InArray:
|
||||
arr = cast(list[JsonValue], parent_state.value)
|
||||
arr[-1] = updated
|
||||
|
||||
elif parent_state.type == _StateEnum.InObjectExpectingValue:
|
||||
key, obj = cast(tuple[str, JsonObject], parent_state.value)
|
||||
obj[key] = updated
|
||||
if self._state_stack and self._state_stack[-1] == parent_state:
|
||||
self._state_stack.pop()
|
||||
new_state = _InObjectExpectingKeyState()
|
||||
new_state.value = obj
|
||||
self._state_stack.append(new_state)
|
||||
|
||||
elif parent_state.type == _StateEnum.InObjectExpectingKey:
|
||||
if self._state_stack and self._state_stack[-1] == parent_state:
|
||||
self._state_stack.pop()
|
||||
obj = cast(JsonObject, parent_state.value)
|
||||
self._state_stack.append(_InObjectExpectingValueState(updated, obj))
|
||||
|
||||
def _progress_value(self, token_type: JsonTokenType, value: JsonValue) -> JsonValue:
|
||||
"""Create initial value for a token and push appropriate state"""
|
||||
if token_type == JsonTokenType.Null:
|
||||
return None
|
||||
|
||||
elif token_type == JsonTokenType.Boolean:
|
||||
return value
|
||||
|
||||
elif token_type == JsonTokenType.Number:
|
||||
return value
|
||||
|
||||
elif token_type == JsonTokenType.StringStart:
|
||||
string_state = _InStringState()
|
||||
self._state_stack.append(string_state)
|
||||
return ""
|
||||
|
||||
elif token_type == JsonTokenType.ArrayStart:
|
||||
array_state = _InArrayState()
|
||||
self._state_stack.append(array_state)
|
||||
return array_state.value
|
||||
|
||||
elif token_type == JsonTokenType.ObjectStart:
|
||||
object_state = _InObjectExpectingKeyState()
|
||||
self._state_stack.append(object_state)
|
||||
return object_state.value
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected token type: {json_token_type_to_string(token_type)}"
|
||||
)
|
||||
514
backend/onyx/utils/jsonriver/tokenize.py
Normal file
514
backend/onyx/utils/jsonriver/tokenize.py
Normal file
@@ -0,0 +1,514 @@
|
||||
"""
|
||||
JSON tokenizer for streaming incremental parsing
|
||||
|
||||
Copyright (c) 2023 Google LLC (original TypeScript implementation)
|
||||
Copyright (c) 2024 jsonriver-python contributors (Python port)
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from enum import IntEnum
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class TokenHandler(Protocol):
|
||||
"""Protocol for handling JSON tokens"""
|
||||
|
||||
def handle_null(self) -> None: ...
|
||||
def handle_boolean(self, value: bool) -> None: ...
|
||||
def handle_number(self, value: float) -> None: ...
|
||||
def handle_string_start(self) -> None: ...
|
||||
def handle_string_middle(self, value: str) -> None: ...
|
||||
def handle_string_end(self) -> None: ...
|
||||
def handle_array_start(self) -> None: ...
|
||||
def handle_array_end(self) -> None: ...
|
||||
def handle_object_start(self) -> None: ...
|
||||
def handle_object_end(self) -> None: ...
|
||||
|
||||
|
||||
class JsonTokenType(IntEnum):
|
||||
"""Types of JSON tokens"""
|
||||
|
||||
Null = 0
|
||||
Boolean = 1
|
||||
Number = 2
|
||||
StringStart = 3
|
||||
StringMiddle = 4
|
||||
StringEnd = 5
|
||||
ArrayStart = 6
|
||||
ArrayEnd = 7
|
||||
ObjectStart = 8
|
||||
ObjectEnd = 9
|
||||
|
||||
|
||||
def json_token_type_to_string(token_type: JsonTokenType) -> str:
|
||||
"""Convert token type to readable string"""
|
||||
names = {
|
||||
JsonTokenType.Null: "null",
|
||||
JsonTokenType.Boolean: "boolean",
|
||||
JsonTokenType.Number: "number",
|
||||
JsonTokenType.StringStart: "string start",
|
||||
JsonTokenType.StringMiddle: "string middle",
|
||||
JsonTokenType.StringEnd: "string end",
|
||||
JsonTokenType.ArrayStart: "array start",
|
||||
JsonTokenType.ArrayEnd: "array end",
|
||||
JsonTokenType.ObjectStart: "object start",
|
||||
JsonTokenType.ObjectEnd: "object end",
|
||||
}
|
||||
return names[token_type]
|
||||
|
||||
|
||||
class _State(IntEnum):
|
||||
"""Internal tokenizer states"""
|
||||
|
||||
ExpectingValue = 0
|
||||
InString = 1
|
||||
StartArray = 2
|
||||
AfterArrayValue = 3
|
||||
StartObject = 4
|
||||
AfterObjectKey = 5
|
||||
AfterObjectValue = 6
|
||||
BeforeObjectKey = 7
|
||||
|
||||
|
||||
# Regex for validating JSON numbers
|
||||
_JSON_NUMBER_PATTERN = re.compile(r"^-?(0|[1-9]\d*)(\.\d+)?([eE][+-]?\d+)?$")
|
||||
|
||||
|
||||
def _parse_json_number(s: str) -> float:
|
||||
"""Parse a JSON number string, validating format"""
|
||||
if not _JSON_NUMBER_PATTERN.match(s):
|
||||
raise ValueError("Invalid number")
|
||||
return float(s)
|
||||
|
||||
|
||||
class _Input:
|
||||
"""
|
||||
Input buffer for chunk-based JSON parsing
|
||||
|
||||
Manages buffering of input chunks and provides methods for
|
||||
consuming and inspecting the buffer.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._buffer = ""
|
||||
self._start_index = 0
|
||||
self.buffer_complete = False
|
||||
|
||||
def feed(self, chunk: str) -> None:
|
||||
"""Add a chunk of data to the buffer"""
|
||||
self._buffer += chunk
|
||||
|
||||
def mark_complete(self) -> None:
|
||||
"""Signal that no more chunks will be fed"""
|
||||
self.buffer_complete = True
|
||||
|
||||
@property
|
||||
def length(self) -> int:
|
||||
"""Number of characters remaining in buffer"""
|
||||
return len(self._buffer) - self._start_index
|
||||
|
||||
def advance(self, length: int) -> None:
|
||||
"""Advance the start position by length characters"""
|
||||
self._start_index += length
|
||||
|
||||
def peek(self, offset: int) -> str | None:
|
||||
"""Peek at character at offset, or None if not available"""
|
||||
idx = self._start_index + offset
|
||||
if idx < len(self._buffer):
|
||||
return self._buffer[idx]
|
||||
return None
|
||||
|
||||
def peek_char_code(self, offset: int) -> int:
|
||||
"""Get character code at offset"""
|
||||
return ord(self._buffer[self._start_index + offset])
|
||||
|
||||
def slice(self, start: int, end: int) -> str:
|
||||
"""Slice buffer from start to end (relative to current position)"""
|
||||
return self._buffer[self._start_index + start : self._start_index + end]
|
||||
|
||||
def commit(self) -> None:
|
||||
"""Commit consumed content, removing it from buffer"""
|
||||
if self._start_index > 0:
|
||||
self._buffer = self._buffer[self._start_index :]
|
||||
self._start_index = 0
|
||||
|
||||
def remaining(self) -> str:
|
||||
"""Get all remaining content in buffer"""
|
||||
return self._buffer[self._start_index :]
|
||||
|
||||
def expect_end_of_content(self) -> None:
|
||||
"""Verify no non-whitespace content remains"""
|
||||
self.commit()
|
||||
self.skip_past_whitespace()
|
||||
if self.length != 0:
|
||||
raise ValueError(f"Unexpected trailing content {self.remaining()!r}")
|
||||
|
||||
def skip_past_whitespace(self) -> None:
|
||||
"""Skip whitespace characters"""
|
||||
i = self._start_index
|
||||
while i < len(self._buffer):
|
||||
c = ord(self._buffer[i])
|
||||
if c in (32, 9, 10, 13): # space, tab, \n, \r
|
||||
i += 1
|
||||
else:
|
||||
break
|
||||
self._start_index = i
|
||||
|
||||
def try_to_take_prefix(self, prefix: str) -> bool:
|
||||
"""Try to consume prefix from buffer, return True if successful"""
|
||||
if self._buffer.startswith(prefix, self._start_index):
|
||||
self._start_index += len(prefix)
|
||||
return True
|
||||
return False
|
||||
|
||||
def try_to_take(self, length: int) -> str | None:
|
||||
"""Try to take length characters, or None if not enough available"""
|
||||
if self.length < length:
|
||||
return None
|
||||
result = self._buffer[self._start_index : self._start_index + length]
|
||||
self._start_index += length
|
||||
return result
|
||||
|
||||
def try_to_take_char_code(self) -> int | None:
|
||||
"""Try to take a single character as char code, or None if buffer empty"""
|
||||
if self.length == 0:
|
||||
return None
|
||||
code = ord(self._buffer[self._start_index])
|
||||
self._start_index += 1
|
||||
return code
|
||||
|
||||
def take_until_quote_or_backslash(self) -> tuple[str, bool]:
|
||||
"""
|
||||
Consume input up to first quote or backslash
|
||||
|
||||
Returns tuple of (consumed_content, pattern_found)
|
||||
"""
|
||||
buf = self._buffer
|
||||
i = self._start_index
|
||||
while i < len(buf):
|
||||
c = ord(buf[i])
|
||||
if c <= 0x1F:
|
||||
raise ValueError("Unescaped control character in string")
|
||||
if c == 34 or c == 92: # " or \
|
||||
result = buf[self._start_index : i]
|
||||
self._start_index = i
|
||||
return (result, True)
|
||||
i += 1
|
||||
|
||||
result = buf[self._start_index :]
|
||||
self._start_index = len(buf)
|
||||
return (result, False)
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""
|
||||
Tokenizer for chunk-based JSON parsing
|
||||
|
||||
Processes chunks fed into its input buffer and calls handler methods
|
||||
as JSON tokens are recognized.
|
||||
"""
|
||||
|
||||
def __init__(self, input: _Input, handler: TokenHandler) -> None:
|
||||
self.input = input
|
||||
self._handler = handler
|
||||
self._stack: list[_State] = [_State.ExpectingValue]
|
||||
self._emitted_tokens = 0
|
||||
|
||||
def is_done(self) -> bool:
|
||||
"""Check if tokenization is complete"""
|
||||
return len(self._stack) == 0 and self.input.length == 0
|
||||
|
||||
def pump(self) -> None:
|
||||
"""Process all available tokens in the buffer"""
|
||||
while True:
|
||||
before = self._emitted_tokens
|
||||
self._tokenize_more()
|
||||
if self._emitted_tokens == before:
|
||||
self.input.commit()
|
||||
return
|
||||
|
||||
def _tokenize_more(self) -> None:
|
||||
"""Process one step of tokenization based on current state"""
|
||||
if not self._stack:
|
||||
return
|
||||
|
||||
state = self._stack[-1]
|
||||
|
||||
if state == _State.ExpectingValue:
|
||||
self._tokenize_value()
|
||||
elif state == _State.InString:
|
||||
self._tokenize_string()
|
||||
elif state == _State.StartArray:
|
||||
self._tokenize_array_start()
|
||||
elif state == _State.AfterArrayValue:
|
||||
self._tokenize_after_array_value()
|
||||
elif state == _State.StartObject:
|
||||
self._tokenize_object_start()
|
||||
elif state == _State.AfterObjectKey:
|
||||
self._tokenize_after_object_key()
|
||||
elif state == _State.AfterObjectValue:
|
||||
self._tokenize_after_object_value()
|
||||
elif state == _State.BeforeObjectKey:
|
||||
self._tokenize_before_object_key()
|
||||
|
||||
def _tokenize_value(self) -> None:
|
||||
"""Tokenize a JSON value"""
|
||||
self.input.skip_past_whitespace()
|
||||
|
||||
if self.input.try_to_take_prefix("null"):
|
||||
self._handler.handle_null()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix("true"):
|
||||
self._handler.handle_boolean(True)
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix("false"):
|
||||
self._handler.handle_boolean(False)
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
if self.input.length > 0:
|
||||
ch = self.input.peek_char_code(0)
|
||||
if (48 <= ch <= 57) or ch == 45: # 0-9 or -
|
||||
# Scan for end of number
|
||||
i = 0
|
||||
while i < self.input.length:
|
||||
c = self.input.peek_char_code(i)
|
||||
if (48 <= c <= 57) or c in (45, 43, 46, 101, 69): # 0-9 - + . e E
|
||||
i += 1
|
||||
else:
|
||||
break
|
||||
|
||||
if i == self.input.length and not self.input.buffer_complete:
|
||||
# Need more input (numbers have no terminator)
|
||||
return
|
||||
|
||||
number_chars = self.input.slice(0, i)
|
||||
self.input.advance(i)
|
||||
number = _parse_json_number(number_chars)
|
||||
self._handler.handle_number(number)
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix('"'):
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.InString)
|
||||
self._handler.handle_string_start()
|
||||
self._emitted_tokens += 1
|
||||
self._tokenize_string()
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix("["):
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.StartArray)
|
||||
self._handler.handle_array_start()
|
||||
self._emitted_tokens += 1
|
||||
self._tokenize_array_start()
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix("{"):
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.StartObject)
|
||||
self._handler.handle_object_start()
|
||||
self._emitted_tokens += 1
|
||||
self._tokenize_object_start()
|
||||
return
|
||||
|
||||
def _tokenize_string(self) -> None:
|
||||
"""Tokenize string content"""
|
||||
while True:
|
||||
chunk, interrupted = self.input.take_until_quote_or_backslash()
|
||||
if chunk:
|
||||
self._handler.handle_string_middle(chunk)
|
||||
self._emitted_tokens += 1
|
||||
elif not interrupted:
|
||||
return
|
||||
|
||||
if interrupted:
|
||||
if self.input.length == 0:
|
||||
return
|
||||
|
||||
next_char = self.input.peek(0)
|
||||
if next_char == '"':
|
||||
self.input.advance(1)
|
||||
self._handler.handle_string_end()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
# Handle escape sequences
|
||||
next_char2 = self.input.peek(1)
|
||||
if next_char2 is None:
|
||||
return
|
||||
|
||||
value: str
|
||||
if next_char2 == "u":
|
||||
# Unicode escape: need 4 hex digits
|
||||
if self.input.length < 6:
|
||||
return
|
||||
|
||||
code = 0
|
||||
for j in range(2, 6):
|
||||
c = self.input.peek_char_code(j)
|
||||
if 48 <= c <= 57: # 0-9
|
||||
digit = c - 48
|
||||
elif 65 <= c <= 70: # A-F
|
||||
digit = c - 55
|
||||
elif 97 <= c <= 102: # a-f
|
||||
digit = c - 87
|
||||
else:
|
||||
raise ValueError("Bad Unicode escape in JSON")
|
||||
code = (code << 4) | digit
|
||||
|
||||
self.input.advance(6)
|
||||
self._handler.handle_string_middle(chr(code))
|
||||
self._emitted_tokens += 1
|
||||
continue
|
||||
|
||||
elif next_char2 == "n":
|
||||
value = "\n"
|
||||
elif next_char2 == "r":
|
||||
value = "\r"
|
||||
elif next_char2 == "t":
|
||||
value = "\t"
|
||||
elif next_char2 == "b":
|
||||
value = "\b"
|
||||
elif next_char2 == "f":
|
||||
value = "\f"
|
||||
elif next_char2 == "\\":
|
||||
value = "\\"
|
||||
elif next_char2 == "/":
|
||||
value = "/"
|
||||
elif next_char2 == '"':
|
||||
value = '"'
|
||||
else:
|
||||
raise ValueError("Bad escape in string")
|
||||
|
||||
self.input.advance(2)
|
||||
self._handler.handle_string_middle(value)
|
||||
self._emitted_tokens += 1
|
||||
|
||||
def _tokenize_array_start(self) -> None:
|
||||
"""Tokenize start of array (check for empty or first element)"""
|
||||
self.input.skip_past_whitespace()
|
||||
if self.input.length == 0:
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix("]"):
|
||||
self._handler.handle_array_end()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.AfterArrayValue)
|
||||
self._stack.append(_State.ExpectingValue)
|
||||
self._tokenize_value()
|
||||
|
||||
def _tokenize_after_array_value(self) -> None:
|
||||
"""Tokenize after an array value (expect , or ])"""
|
||||
self.input.skip_past_whitespace()
|
||||
next_char = self.input.try_to_take_char_code()
|
||||
|
||||
if next_char is None:
|
||||
return
|
||||
elif next_char == 0x5D: # ]
|
||||
self._handler.handle_array_end()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
elif next_char == 0x2C: # ,
|
||||
self._stack.append(_State.ExpectingValue)
|
||||
self._tokenize_value()
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Expected , or ], got {chr(next_char)!r}")
|
||||
|
||||
def _tokenize_object_start(self) -> None:
|
||||
"""Tokenize start of object (check for empty or first key)"""
|
||||
self.input.skip_past_whitespace()
|
||||
next_char = self.input.try_to_take_char_code()
|
||||
|
||||
if next_char is None:
|
||||
return
|
||||
elif next_char == 0x7D: # }
|
||||
self._handler.handle_object_end()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
elif next_char == 0x22: # "
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.AfterObjectKey)
|
||||
self._stack.append(_State.InString)
|
||||
self._handler.handle_string_start()
|
||||
self._emitted_tokens += 1
|
||||
self._tokenize_string()
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Expected start of object key, got {chr(next_char)!r}")
|
||||
|
||||
def _tokenize_after_object_key(self) -> None:
|
||||
"""Tokenize after object key (expect :)"""
|
||||
self.input.skip_past_whitespace()
|
||||
next_char = self.input.try_to_take_char_code()
|
||||
|
||||
if next_char is None:
|
||||
return
|
||||
elif next_char == 0x3A: # :
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.AfterObjectValue)
|
||||
self._stack.append(_State.ExpectingValue)
|
||||
self._tokenize_value()
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Expected colon after object key, got {chr(next_char)!r}")
|
||||
|
||||
def _tokenize_after_object_value(self) -> None:
|
||||
"""Tokenize after object value (expect , or })"""
|
||||
self.input.skip_past_whitespace()
|
||||
next_char = self.input.try_to_take_char_code()
|
||||
|
||||
if next_char is None:
|
||||
return
|
||||
elif next_char == 0x7D: # }
|
||||
self._handler.handle_object_end()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
elif next_char == 0x2C: # ,
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.BeforeObjectKey)
|
||||
self._tokenize_before_object_key()
|
||||
return
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected , or }} after object value, got {chr(next_char)!r}"
|
||||
)
|
||||
|
||||
def _tokenize_before_object_key(self) -> None:
|
||||
"""Tokenize before object key (after comma)"""
|
||||
self.input.skip_past_whitespace()
|
||||
next_char = self.input.try_to_take_char_code()
|
||||
|
||||
if next_char is None:
|
||||
return
|
||||
elif next_char == 0x22: # "
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.AfterObjectKey)
|
||||
self._stack.append(_State.InString)
|
||||
self._handler.handle_string_start()
|
||||
self._emitted_tokens += 1
|
||||
self._tokenize_string()
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Expected start of object key, got {chr(next_char)!r}")
|
||||
@@ -1,30 +1,49 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_SURROGATE_RE = re.compile(r"[\ud800-\udfff]")
|
||||
|
||||
|
||||
def _sanitize_string(value: str) -> str:
|
||||
return value.replace("\x00", "")
|
||||
def sanitize_string(value: str) -> str:
|
||||
"""Strip characters that PostgreSQL text/JSONB columns cannot store.
|
||||
|
||||
Removes:
|
||||
- NUL bytes (\\x00)
|
||||
- UTF-16 surrogates (\\ud800-\\udfff), which are invalid in UTF-8
|
||||
"""
|
||||
sanitized = value.replace("\x00", "")
|
||||
sanitized = _SURROGATE_RE.sub("", sanitized)
|
||||
if value and not sanitized:
|
||||
logger.warning(
|
||||
"sanitize_string: all characters were removed from a non-empty string"
|
||||
)
|
||||
return sanitized
|
||||
|
||||
|
||||
def _sanitize_json_like(value: Any) -> Any:
|
||||
def sanitize_json_like(value: Any) -> Any:
|
||||
"""Recursively sanitize all strings in a JSON-like structure (dict/list/tuple)."""
|
||||
if isinstance(value, str):
|
||||
return _sanitize_string(value)
|
||||
return sanitize_string(value)
|
||||
|
||||
if isinstance(value, list):
|
||||
return [_sanitize_json_like(item) for item in value]
|
||||
return [sanitize_json_like(item) for item in value]
|
||||
|
||||
if isinstance(value, tuple):
|
||||
return tuple(_sanitize_json_like(item) for item in value)
|
||||
return tuple(sanitize_json_like(item) for item in value)
|
||||
|
||||
if isinstance(value, dict):
|
||||
sanitized: dict[Any, Any] = {}
|
||||
for key, nested_value in value.items():
|
||||
cleaned_key = _sanitize_string(key) if isinstance(key, str) else key
|
||||
sanitized[cleaned_key] = _sanitize_json_like(nested_value)
|
||||
cleaned_key = sanitize_string(key) if isinstance(key, str) else key
|
||||
sanitized[cleaned_key] = sanitize_json_like(nested_value)
|
||||
return sanitized
|
||||
|
||||
return value
|
||||
@@ -34,27 +53,27 @@ def _sanitize_expert_info(expert: BasicExpertInfo) -> BasicExpertInfo:
|
||||
return expert.model_copy(
|
||||
update={
|
||||
"display_name": (
|
||||
_sanitize_string(expert.display_name)
|
||||
sanitize_string(expert.display_name)
|
||||
if expert.display_name is not None
|
||||
else None
|
||||
),
|
||||
"first_name": (
|
||||
_sanitize_string(expert.first_name)
|
||||
sanitize_string(expert.first_name)
|
||||
if expert.first_name is not None
|
||||
else None
|
||||
),
|
||||
"middle_initial": (
|
||||
_sanitize_string(expert.middle_initial)
|
||||
sanitize_string(expert.middle_initial)
|
||||
if expert.middle_initial is not None
|
||||
else None
|
||||
),
|
||||
"last_name": (
|
||||
_sanitize_string(expert.last_name)
|
||||
sanitize_string(expert.last_name)
|
||||
if expert.last_name is not None
|
||||
else None
|
||||
),
|
||||
"email": (
|
||||
_sanitize_string(expert.email) if expert.email is not None else None
|
||||
sanitize_string(expert.email) if expert.email is not None else None
|
||||
),
|
||||
}
|
||||
)
|
||||
@@ -63,10 +82,10 @@ def _sanitize_expert_info(expert: BasicExpertInfo) -> BasicExpertInfo:
|
||||
def _sanitize_external_access(external_access: ExternalAccess) -> ExternalAccess:
|
||||
return ExternalAccess(
|
||||
external_user_emails={
|
||||
_sanitize_string(email) for email in external_access.external_user_emails
|
||||
sanitize_string(email) for email in external_access.external_user_emails
|
||||
},
|
||||
external_user_group_ids={
|
||||
_sanitize_string(group_id)
|
||||
sanitize_string(group_id)
|
||||
for group_id in external_access.external_user_group_ids
|
||||
},
|
||||
is_public=external_access.is_public,
|
||||
@@ -76,26 +95,26 @@ def _sanitize_external_access(external_access: ExternalAccess) -> ExternalAccess
|
||||
def sanitize_document_for_postgres(document: Document) -> Document:
|
||||
cleaned_doc = document.model_copy(deep=True)
|
||||
|
||||
cleaned_doc.id = _sanitize_string(cleaned_doc.id)
|
||||
cleaned_doc.semantic_identifier = _sanitize_string(cleaned_doc.semantic_identifier)
|
||||
cleaned_doc.id = sanitize_string(cleaned_doc.id)
|
||||
cleaned_doc.semantic_identifier = sanitize_string(cleaned_doc.semantic_identifier)
|
||||
if cleaned_doc.title is not None:
|
||||
cleaned_doc.title = _sanitize_string(cleaned_doc.title)
|
||||
cleaned_doc.title = sanitize_string(cleaned_doc.title)
|
||||
if cleaned_doc.parent_hierarchy_raw_node_id is not None:
|
||||
cleaned_doc.parent_hierarchy_raw_node_id = _sanitize_string(
|
||||
cleaned_doc.parent_hierarchy_raw_node_id = sanitize_string(
|
||||
cleaned_doc.parent_hierarchy_raw_node_id
|
||||
)
|
||||
|
||||
cleaned_doc.metadata = {
|
||||
_sanitize_string(key): (
|
||||
[_sanitize_string(item) for item in value]
|
||||
sanitize_string(key): (
|
||||
[sanitize_string(item) for item in value]
|
||||
if isinstance(value, list)
|
||||
else _sanitize_string(value)
|
||||
else sanitize_string(value)
|
||||
)
|
||||
for key, value in cleaned_doc.metadata.items()
|
||||
}
|
||||
|
||||
if cleaned_doc.doc_metadata is not None:
|
||||
cleaned_doc.doc_metadata = _sanitize_json_like(cleaned_doc.doc_metadata)
|
||||
cleaned_doc.doc_metadata = sanitize_json_like(cleaned_doc.doc_metadata)
|
||||
|
||||
if cleaned_doc.primary_owners is not None:
|
||||
cleaned_doc.primary_owners = [
|
||||
@@ -113,11 +132,11 @@ def sanitize_document_for_postgres(document: Document) -> Document:
|
||||
|
||||
for section in cleaned_doc.sections:
|
||||
if section.link is not None:
|
||||
section.link = _sanitize_string(section.link)
|
||||
section.link = sanitize_string(section.link)
|
||||
if section.text is not None:
|
||||
section.text = _sanitize_string(section.text)
|
||||
section.text = sanitize_string(section.text)
|
||||
if section.image_file_id is not None:
|
||||
section.image_file_id = _sanitize_string(section.image_file_id)
|
||||
section.image_file_id = sanitize_string(section.image_file_id)
|
||||
|
||||
return cleaned_doc
|
||||
|
||||
@@ -129,12 +148,12 @@ def sanitize_documents_for_postgres(documents: list[Document]) -> list[Document]
|
||||
def sanitize_hierarchy_node_for_postgres(node: HierarchyNode) -> HierarchyNode:
|
||||
cleaned_node = node.model_copy(deep=True)
|
||||
|
||||
cleaned_node.raw_node_id = _sanitize_string(cleaned_node.raw_node_id)
|
||||
cleaned_node.display_name = _sanitize_string(cleaned_node.display_name)
|
||||
cleaned_node.raw_node_id = sanitize_string(cleaned_node.raw_node_id)
|
||||
cleaned_node.display_name = sanitize_string(cleaned_node.display_name)
|
||||
if cleaned_node.raw_parent_id is not None:
|
||||
cleaned_node.raw_parent_id = _sanitize_string(cleaned_node.raw_parent_id)
|
||||
cleaned_node.raw_parent_id = sanitize_string(cleaned_node.raw_parent_id)
|
||||
if cleaned_node.link is not None:
|
||||
cleaned_node.link = _sanitize_string(cleaned_node.link)
|
||||
cleaned_node.link = sanitize_string(cleaned_node.link)
|
||||
|
||||
if cleaned_node.external_access is not None:
|
||||
cleaned_node.external_access = _sanitize_external_access(
|
||||
@@ -24,6 +24,9 @@ class OnyxVersion:
|
||||
def set_ee(self) -> None:
|
||||
self._is_ee = True
|
||||
|
||||
def unset_ee(self) -> None:
|
||||
self._is_ee = False
|
||||
|
||||
def is_ee_version(self) -> bool:
|
||||
return self._is_ee
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ attrs==25.4.0
|
||||
# jsonschema
|
||||
# referencing
|
||||
# zeep
|
||||
authlib==1.6.6
|
||||
authlib==1.6.7
|
||||
# via fastmcp
|
||||
babel==2.17.0
|
||||
# via courlan
|
||||
@@ -109,9 +109,7 @@ brotli==1.2.0
|
||||
bytecode==0.17.0
|
||||
# via ddtrace
|
||||
cachetools==6.2.2
|
||||
# via
|
||||
# google-auth
|
||||
# py-key-value-aio
|
||||
# via py-key-value-aio
|
||||
caio==0.9.25
|
||||
# via aiofile
|
||||
celery==5.5.1
|
||||
@@ -190,6 +188,7 @@ courlan==1.3.2
|
||||
cryptography==46.0.5
|
||||
# via
|
||||
# authlib
|
||||
# google-auth
|
||||
# msal
|
||||
# msoffcrypto-tool
|
||||
# pdfminer-six
|
||||
@@ -230,9 +229,7 @@ distro==1.9.0
|
||||
dnspython==2.8.0
|
||||
# via email-validator
|
||||
docstring-parser==0.17.0
|
||||
# via
|
||||
# cyclopts
|
||||
# google-cloud-aiplatform
|
||||
# via cyclopts
|
||||
docutils==0.22.3
|
||||
# via rich-rst
|
||||
dropbox==12.0.2
|
||||
@@ -297,26 +294,15 @@ gitdb==4.0.12
|
||||
gitpython==3.1.45
|
||||
# via braintrust
|
||||
google-api-core==2.28.1
|
||||
# via
|
||||
# google-api-python-client
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-core
|
||||
# google-cloud-resource-manager
|
||||
# google-cloud-storage
|
||||
# via google-api-python-client
|
||||
google-api-python-client==2.86.0
|
||||
# via onyx
|
||||
google-auth==2.43.0
|
||||
google-auth==2.48.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-api-python-client
|
||||
# google-auth-httplib2
|
||||
# google-auth-oauthlib
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-core
|
||||
# google-cloud-resource-manager
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# kubernetes
|
||||
google-auth-httplib2==0.1.0
|
||||
@@ -325,51 +311,16 @@ google-auth-httplib2==0.1.0
|
||||
# onyx
|
||||
google-auth-oauthlib==1.0.0
|
||||
# via onyx
|
||||
google-cloud-aiplatform==1.121.0
|
||||
# via onyx
|
||||
google-cloud-bigquery==3.38.0
|
||||
# via google-cloud-aiplatform
|
||||
google-cloud-core==2.5.0
|
||||
# via
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
google-cloud-resource-manager==1.15.0
|
||||
# via google-cloud-aiplatform
|
||||
google-cloud-storage==2.19.0
|
||||
# via google-cloud-aiplatform
|
||||
google-crc32c==1.7.1
|
||||
# via
|
||||
# google-cloud-storage
|
||||
# google-resumable-media
|
||||
google-genai==1.52.0
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# onyx
|
||||
google-resumable-media==2.7.2
|
||||
# via
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
# via onyx
|
||||
googleapis-common-protos==1.72.0
|
||||
# via
|
||||
# google-api-core
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# opentelemetry-exporter-otlp-proto-http
|
||||
greenlet==3.2.4
|
||||
# via
|
||||
# playwright
|
||||
# sqlalchemy
|
||||
grpc-google-iam-v1==0.14.3
|
||||
# via google-cloud-resource-manager
|
||||
grpcio==1.76.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpcio-status==1.76.0
|
||||
# via google-api-core
|
||||
h11==0.16.0
|
||||
# via
|
||||
# httpcore
|
||||
@@ -670,8 +621,6 @@ packaging==24.2
|
||||
# dask
|
||||
# distributed
|
||||
# fastmcp
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# huggingface-hub
|
||||
# jira
|
||||
# kombu
|
||||
@@ -721,19 +670,12 @@ propcache==0.4.1
|
||||
# aiohttp
|
||||
# yarl
|
||||
proto-plus==1.26.1
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
# via google-api-core
|
||||
protobuf==6.33.5
|
||||
# via
|
||||
# ddtrace
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# onnxruntime
|
||||
# opentelemetry-proto
|
||||
# proto-plus
|
||||
@@ -771,7 +713,6 @@ pydantic==2.11.7
|
||||
# exa-py
|
||||
# fastapi
|
||||
# fastmcp
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# langchain-core
|
||||
# langfuse
|
||||
@@ -835,7 +776,6 @@ python-dateutil==2.8.2
|
||||
# botocore
|
||||
# celery
|
||||
# dateparser
|
||||
# google-cloud-bigquery
|
||||
# htmldate
|
||||
# hubspot-api-client
|
||||
# kubernetes
|
||||
@@ -927,8 +867,6 @@ requests==2.32.5
|
||||
# dropbox
|
||||
# exa-py
|
||||
# google-api-core
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# hubspot-api-client
|
||||
# huggingface-hub
|
||||
@@ -1002,9 +940,7 @@ sendgrid==6.12.5
|
||||
sentry-sdk==2.14.0
|
||||
# via onyx
|
||||
shapely==2.0.6
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# onyx
|
||||
# via onyx
|
||||
shellingham==1.5.4
|
||||
# via typer
|
||||
simple-salesforce==1.12.6
|
||||
@@ -1118,9 +1054,7 @@ typing-extensions==4.15.0
|
||||
# exa-py
|
||||
# exceptiongroup
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# grpcio
|
||||
# huggingface-hub
|
||||
# jira
|
||||
# langchain-core
|
||||
|
||||
@@ -59,8 +59,6 @@ botocore==1.39.11
|
||||
# s3transfer
|
||||
brotli==1.2.0
|
||||
# via onyx
|
||||
cachetools==6.2.2
|
||||
# via google-auth
|
||||
celery-types==0.19.0
|
||||
# via onyx
|
||||
certifi==2025.11.12
|
||||
@@ -100,7 +98,9 @@ comm==0.2.3
|
||||
contourpy==1.3.3
|
||||
# via matplotlib
|
||||
cryptography==46.0.5
|
||||
# via pyjwt
|
||||
# via
|
||||
# google-auth
|
||||
# pyjwt
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
debugpy==1.8.17
|
||||
@@ -115,8 +115,6 @@ distlib==0.4.0
|
||||
# via virtualenv
|
||||
distro==1.9.0
|
||||
# via openai
|
||||
docstring-parser==0.17.0
|
||||
# via google-cloud-aiplatform
|
||||
durationpy==0.10
|
||||
# via kubernetes
|
||||
execnet==2.1.2
|
||||
@@ -145,65 +143,14 @@ frozenlist==1.8.0
|
||||
# aiosignal
|
||||
fsspec==2025.10.0
|
||||
# via huggingface-hub
|
||||
google-api-core==2.28.1
|
||||
google-auth==2.48.0
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-core
|
||||
# google-cloud-resource-manager
|
||||
# google-cloud-storage
|
||||
google-auth==2.43.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-core
|
||||
# google-cloud-resource-manager
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# kubernetes
|
||||
google-cloud-aiplatform==1.121.0
|
||||
# via onyx
|
||||
google-cloud-bigquery==3.38.0
|
||||
# via google-cloud-aiplatform
|
||||
google-cloud-core==2.5.0
|
||||
# via
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
google-cloud-resource-manager==1.15.0
|
||||
# via google-cloud-aiplatform
|
||||
google-cloud-storage==2.19.0
|
||||
# via google-cloud-aiplatform
|
||||
google-crc32c==1.7.1
|
||||
# via
|
||||
# google-cloud-storage
|
||||
# google-resumable-media
|
||||
google-genai==1.52.0
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# onyx
|
||||
google-resumable-media==2.7.2
|
||||
# via
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
googleapis-common-protos==1.72.0
|
||||
# via
|
||||
# google-api-core
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# via onyx
|
||||
greenlet==3.2.4 ; platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64'
|
||||
# via sqlalchemy
|
||||
grpc-google-iam-v1==0.14.3
|
||||
# via google-cloud-resource-manager
|
||||
grpcio==1.76.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpcio-status==1.76.0
|
||||
# via google-api-core
|
||||
h11==0.16.0
|
||||
# via
|
||||
# httpcore
|
||||
@@ -311,13 +258,12 @@ numpy==2.4.1
|
||||
# contourpy
|
||||
# matplotlib
|
||||
# pandas-stubs
|
||||
# shapely
|
||||
# voyageai
|
||||
oauthlib==3.2.2
|
||||
# via
|
||||
# kubernetes
|
||||
# requests-oauthlib
|
||||
onyx-devtools==0.6.2
|
||||
onyx-devtools==0.6.3
|
||||
# via onyx
|
||||
openai==2.14.0
|
||||
# via
|
||||
@@ -330,8 +276,6 @@ openapi-generator-cli==7.17.0
|
||||
packaging==24.2
|
||||
# via
|
||||
# black
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# hatchling
|
||||
# huggingface-hub
|
||||
# ipykernel
|
||||
@@ -374,20 +318,6 @@ propcache==0.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
proto-plus==1.26.1
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
protobuf==6.33.5
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# proto-plus
|
||||
psutil==7.1.3
|
||||
# via ipykernel
|
||||
ptyprocess==0.7.0 ; sys_platform != 'emscripten' and sys_platform != 'win32'
|
||||
@@ -409,7 +339,6 @@ pydantic==2.11.7
|
||||
# agent-client-protocol
|
||||
# cohere
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# litellm
|
||||
# mcp
|
||||
@@ -450,7 +379,6 @@ python-dateutil==2.8.2
|
||||
# via
|
||||
# aiobotocore
|
||||
# botocore
|
||||
# google-cloud-bigquery
|
||||
# jupyter-client
|
||||
# kubernetes
|
||||
# matplotlib
|
||||
@@ -478,16 +406,13 @@ referencing==0.36.2
|
||||
# jsonschema-specifications
|
||||
regex==2025.11.3
|
||||
# via tiktoken
|
||||
release-tag==0.4.3
|
||||
release-tag==0.5.2
|
||||
# via onyx
|
||||
reorder-python-imports-black==3.14.0
|
||||
# via onyx
|
||||
requests==2.32.5
|
||||
# via
|
||||
# cohere
|
||||
# google-api-core
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# huggingface-hub
|
||||
# kubernetes
|
||||
@@ -510,8 +435,6 @@ s3transfer==0.13.1
|
||||
# via boto3
|
||||
sentry-sdk==2.14.0
|
||||
# via onyx
|
||||
shapely==2.0.6
|
||||
# via google-cloud-aiplatform
|
||||
six==1.17.0
|
||||
# via
|
||||
# kubernetes
|
||||
@@ -602,9 +525,7 @@ typing-extensions==4.15.0
|
||||
# celery-types
|
||||
# cohere
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# grpcio
|
||||
# huggingface-hub
|
||||
# ipython
|
||||
# mcp
|
||||
|
||||
@@ -53,8 +53,6 @@ botocore==1.39.11
|
||||
# s3transfer
|
||||
brotli==1.2.0
|
||||
# via onyx
|
||||
cachetools==6.2.2
|
||||
# via google-auth
|
||||
certifi==2025.11.12
|
||||
# via
|
||||
# httpcore
|
||||
@@ -79,15 +77,15 @@ colorama==0.4.6 ; sys_platform == 'win32'
|
||||
# click
|
||||
# tqdm
|
||||
cryptography==46.0.5
|
||||
# via pyjwt
|
||||
# via
|
||||
# google-auth
|
||||
# pyjwt
|
||||
decorator==5.2.1
|
||||
# via retry
|
||||
discord-py==2.4.0
|
||||
# via onyx
|
||||
distro==1.9.0
|
||||
# via openai
|
||||
docstring-parser==0.17.0
|
||||
# via google-cloud-aiplatform
|
||||
durationpy==0.10
|
||||
# via kubernetes
|
||||
fastapi==0.133.1
|
||||
@@ -104,63 +102,12 @@ frozenlist==1.8.0
|
||||
# aiosignal
|
||||
fsspec==2025.10.0
|
||||
# via huggingface-hub
|
||||
google-api-core==2.28.1
|
||||
google-auth==2.48.0
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-core
|
||||
# google-cloud-resource-manager
|
||||
# google-cloud-storage
|
||||
google-auth==2.43.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-core
|
||||
# google-cloud-resource-manager
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# kubernetes
|
||||
google-cloud-aiplatform==1.121.0
|
||||
# via onyx
|
||||
google-cloud-bigquery==3.38.0
|
||||
# via google-cloud-aiplatform
|
||||
google-cloud-core==2.5.0
|
||||
# via
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
google-cloud-resource-manager==1.15.0
|
||||
# via google-cloud-aiplatform
|
||||
google-cloud-storage==2.19.0
|
||||
# via google-cloud-aiplatform
|
||||
google-crc32c==1.7.1
|
||||
# via
|
||||
# google-cloud-storage
|
||||
# google-resumable-media
|
||||
google-genai==1.52.0
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# onyx
|
||||
google-resumable-media==2.7.2
|
||||
# via
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
googleapis-common-protos==1.72.0
|
||||
# via
|
||||
# google-api-core
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpc-google-iam-v1==0.14.3
|
||||
# via google-cloud-resource-manager
|
||||
grpcio==1.76.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpcio-status==1.76.0
|
||||
# via google-api-core
|
||||
# via onyx
|
||||
h11==0.16.0
|
||||
# via
|
||||
# httpcore
|
||||
@@ -221,9 +168,7 @@ multidict==6.7.0
|
||||
# aiohttp
|
||||
# yarl
|
||||
numpy==2.4.1
|
||||
# via
|
||||
# shapely
|
||||
# voyageai
|
||||
# via voyageai
|
||||
oauthlib==3.2.2
|
||||
# via
|
||||
# kubernetes
|
||||
@@ -233,10 +178,7 @@ openai==2.14.0
|
||||
# litellm
|
||||
# onyx
|
||||
packaging==24.2
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# huggingface-hub
|
||||
# via huggingface-hub
|
||||
parameterized==0.9.0
|
||||
# via cohere
|
||||
posthog==3.7.4
|
||||
@@ -251,20 +193,6 @@ propcache==0.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
proto-plus==1.26.1
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
protobuf==6.33.5
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# proto-plus
|
||||
py==1.11.0
|
||||
# via retry
|
||||
pyasn1==0.6.2
|
||||
@@ -280,7 +208,6 @@ pydantic==2.11.7
|
||||
# agent-client-protocol
|
||||
# cohere
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# litellm
|
||||
# mcp
|
||||
@@ -297,7 +224,6 @@ python-dateutil==2.8.2
|
||||
# via
|
||||
# aiobotocore
|
||||
# botocore
|
||||
# google-cloud-bigquery
|
||||
# kubernetes
|
||||
# posthog
|
||||
python-dotenv==1.1.1
|
||||
@@ -321,9 +247,6 @@ regex==2025.11.3
|
||||
requests==2.32.5
|
||||
# via
|
||||
# cohere
|
||||
# google-api-core
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# huggingface-hub
|
||||
# kubernetes
|
||||
@@ -345,8 +268,6 @@ s3transfer==0.13.1
|
||||
# via boto3
|
||||
sentry-sdk==2.14.0
|
||||
# via onyx
|
||||
shapely==2.0.6
|
||||
# via google-cloud-aiplatform
|
||||
six==1.17.0
|
||||
# via
|
||||
# kubernetes
|
||||
@@ -385,9 +306,7 @@ typing-extensions==4.15.0
|
||||
# anyio
|
||||
# cohere
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# grpcio
|
||||
# huggingface-hub
|
||||
# mcp
|
||||
# openai
|
||||
|
||||
@@ -57,8 +57,6 @@ botocore==1.39.11
|
||||
# s3transfer
|
||||
brotli==1.2.0
|
||||
# via onyx
|
||||
cachetools==6.2.2
|
||||
# via google-auth
|
||||
celery==5.5.1
|
||||
# via sentry-sdk
|
||||
certifi==2025.11.12
|
||||
@@ -95,15 +93,15 @@ colorama==0.4.6 ; sys_platform == 'win32'
|
||||
# click
|
||||
# tqdm
|
||||
cryptography==46.0.5
|
||||
# via pyjwt
|
||||
# via
|
||||
# google-auth
|
||||
# pyjwt
|
||||
decorator==5.2.1
|
||||
# via retry
|
||||
discord-py==2.4.0
|
||||
# via onyx
|
||||
distro==1.9.0
|
||||
# via openai
|
||||
docstring-parser==0.17.0
|
||||
# via google-cloud-aiplatform
|
||||
durationpy==0.10
|
||||
# via kubernetes
|
||||
einops==0.8.1
|
||||
@@ -129,63 +127,12 @@ fsspec==2025.10.0
|
||||
# via
|
||||
# huggingface-hub
|
||||
# torch
|
||||
google-api-core==2.28.1
|
||||
google-auth==2.48.0
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-core
|
||||
# google-cloud-resource-manager
|
||||
# google-cloud-storage
|
||||
google-auth==2.43.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-core
|
||||
# google-cloud-resource-manager
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# kubernetes
|
||||
google-cloud-aiplatform==1.121.0
|
||||
# via onyx
|
||||
google-cloud-bigquery==3.38.0
|
||||
# via google-cloud-aiplatform
|
||||
google-cloud-core==2.5.0
|
||||
# via
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
google-cloud-resource-manager==1.15.0
|
||||
# via google-cloud-aiplatform
|
||||
google-cloud-storage==2.19.0
|
||||
# via google-cloud-aiplatform
|
||||
google-crc32c==1.7.1
|
||||
# via
|
||||
# google-cloud-storage
|
||||
# google-resumable-media
|
||||
google-genai==1.52.0
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# onyx
|
||||
google-resumable-media==2.7.2
|
||||
# via
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
googleapis-common-protos==1.72.0
|
||||
# via
|
||||
# google-api-core
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpc-google-iam-v1==0.14.3
|
||||
# via google-cloud-resource-manager
|
||||
grpcio==1.76.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpcio-status==1.76.0
|
||||
# via google-api-core
|
||||
# via onyx
|
||||
h11==0.16.0
|
||||
# via
|
||||
# httpcore
|
||||
@@ -263,7 +210,6 @@ numpy==2.4.1
|
||||
# onyx
|
||||
# scikit-learn
|
||||
# scipy
|
||||
# shapely
|
||||
# transformers
|
||||
# voyageai
|
||||
nvidia-cublas-cu12==12.8.4.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
@@ -316,8 +262,6 @@ openai==2.14.0
|
||||
packaging==24.2
|
||||
# via
|
||||
# accelerate
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# huggingface-hub
|
||||
# kombu
|
||||
# transformers
|
||||
@@ -337,20 +281,6 @@ propcache==0.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
proto-plus==1.26.1
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
protobuf==6.33.5
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# proto-plus
|
||||
psutil==7.1.3
|
||||
# via accelerate
|
||||
py==1.11.0
|
||||
@@ -368,7 +298,6 @@ pydantic==2.11.7
|
||||
# agent-client-protocol
|
||||
# cohere
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# litellm
|
||||
# mcp
|
||||
@@ -386,7 +315,6 @@ python-dateutil==2.8.2
|
||||
# aiobotocore
|
||||
# botocore
|
||||
# celery
|
||||
# google-cloud-bigquery
|
||||
# kubernetes
|
||||
python-dotenv==1.1.1
|
||||
# via
|
||||
@@ -413,9 +341,6 @@ regex==2025.11.3
|
||||
requests==2.32.5
|
||||
# via
|
||||
# cohere
|
||||
# google-api-core
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# huggingface-hub
|
||||
# kubernetes
|
||||
@@ -452,8 +377,6 @@ sentry-sdk==2.14.0
|
||||
# via onyx
|
||||
setuptools==80.9.0 ; python_full_version >= '3.12'
|
||||
# via torch
|
||||
shapely==2.0.6
|
||||
# via google-cloud-aiplatform
|
||||
six==1.17.0
|
||||
# via
|
||||
# kubernetes
|
||||
@@ -510,9 +433,7 @@ typing-extensions==4.15.0
|
||||
# anyio
|
||||
# cohere
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# grpcio
|
||||
# huggingface-hub
|
||||
# mcp
|
||||
# openai
|
||||
|
||||
171
backend/scripts/debugging/opensearch/opensearch_debug.py
Normal file
171
backend/scripts/debugging/opensearch/opensearch_debug.py
Normal file
@@ -0,0 +1,171 @@
|
||||
#!/usr/bin/env python3
|
||||
"""A utility to interact with OpenSearch.
|
||||
|
||||
Usage:
|
||||
python3 opensearch_debug.py --help
|
||||
python3 opensearch_debug.py list
|
||||
python3 opensearch_debug.py delete <index_name>
|
||||
|
||||
Environment Variables:
|
||||
OPENSEARCH_HOST: OpenSearch host
|
||||
OPENSEARCH_REST_API_PORT: OpenSearch port
|
||||
OPENSEARCH_ADMIN_USERNAME: Admin username
|
||||
OPENSEARCH_ADMIN_PASSWORD: Admin password
|
||||
|
||||
Dependencies:
|
||||
backend/shared_configs/configs.py
|
||||
backend/onyx/document_index/opensearch/client.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.client import OpenSearchIndexClient
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
def list_indices(client: OpenSearchClient) -> None:
|
||||
indices = client.list_indices_with_info()
|
||||
print(f"Found {len(indices)} indices.")
|
||||
print("-" * 80)
|
||||
for index in sorted(indices, key=lambda x: x.name):
|
||||
print(f"Index: {index.name}")
|
||||
print(f"Health: {index.health}")
|
||||
print(f"Status: {index.status}")
|
||||
print(f"Num Primary Shards: {index.num_primary_shards}")
|
||||
print(f"Num Replica Shards: {index.num_replica_shards}")
|
||||
print(f"Docs Count: {index.docs_count}")
|
||||
print(f"Docs Deleted: {index.docs_deleted}")
|
||||
print(f"Created At: {index.created_at}")
|
||||
print(f"Total Size: {index.total_size}")
|
||||
print(f"Primary Shards Size: {index.primary_shards_size}")
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
def delete_index(client: OpenSearchIndexClient) -> None:
|
||||
if not client.index_exists():
|
||||
print(f"Index '{client._index_name}' does not exist.")
|
||||
return
|
||||
|
||||
confirm = input(f"Delete index '{client._index_name}'? (yes/no): ")
|
||||
if confirm.lower() != "yes":
|
||||
print("Aborted.")
|
||||
return
|
||||
|
||||
if client.delete_index():
|
||||
print(f"Deleted index '{client._index_name}'.")
|
||||
else:
|
||||
print(f"Failed to delete index '{client._index_name}' for an unknown reason.")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
def add_standard_arguments(parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
help="OpenSearch host. If not provided, will fall back to OPENSEARCH_HOST, then prompt "
|
||||
"for input.",
|
||||
type=str,
|
||||
default=os.environ.get("OPENSEARCH_HOST", ""),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
help="OpenSearch port. If not provided, will fall back to OPENSEARCH_REST_API_PORT, "
|
||||
"then prompt for input.",
|
||||
type=int,
|
||||
default=int(os.environ.get("OPENSEARCH_REST_API_PORT", 0)),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--username",
|
||||
help="OpenSearch username. If not provided, will fall back to OPENSEARCH_ADMIN_USERNAME, "
|
||||
"then prompt for input.",
|
||||
type=str,
|
||||
default=os.environ.get("OPENSEARCH_ADMIN_USERNAME", ""),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--password",
|
||||
help="OpenSearch password. If not provided, will fall back to OPENSEARCH_ADMIN_PASSWORD, "
|
||||
"then prompt for input.",
|
||||
type=str,
|
||||
default=os.environ.get("OPENSEARCH_ADMIN_PASSWORD", ""),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-ssl", help="Disable SSL.", action="store_true", default=False
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-verify-certs",
|
||||
help="Disable certificate verification (for self-signed certs).",
|
||||
action="store_true",
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-aws-managed-opensearch",
|
||||
help="Whether to use AWS-managed OpenSearch. If not provided, will fall back to checking "
|
||||
"USING_AWS_MANAGED_OPENSEARCH=='true', then default to False.",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=os.environ.get("USING_AWS_MANAGED_OPENSEARCH", "").lower()
|
||||
== "true",
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="A utility to interact with OpenSearch."
|
||||
)
|
||||
subparsers = parser.add_subparsers(
|
||||
dest="command", help="Command to execute.", required=True
|
||||
)
|
||||
|
||||
list_parser = subparsers.add_parser("list", help="List all indices with info.")
|
||||
add_standard_arguments(list_parser)
|
||||
|
||||
delete_parser = subparsers.add_parser("delete", help="Delete an index.")
|
||||
delete_parser.add_argument("index", help="Index name.", type=str)
|
||||
add_standard_arguments(delete_parser)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not (host := args.host or input("Enter the OpenSearch host: ")):
|
||||
print("Error: OpenSearch host is required.")
|
||||
sys.exit(1)
|
||||
if not (port := args.port or int(input("Enter the OpenSearch port: "))):
|
||||
print("Error: OpenSearch port is required.")
|
||||
sys.exit(1)
|
||||
if not (username := args.username or input("Enter the OpenSearch username: ")):
|
||||
print("Error: OpenSearch username is required.")
|
||||
sys.exit(1)
|
||||
if not (password := args.password or input("Enter the OpenSearch password: ")):
|
||||
print("Error: OpenSearch password is required.")
|
||||
sys.exit(1)
|
||||
print("Using AWS-managed OpenSearch: ", args.use_aws_managed_opensearch)
|
||||
print(f"MULTI_TENANT: {MULTI_TENANT}")
|
||||
|
||||
with (
|
||||
OpenSearchIndexClient(
|
||||
index_name=args.index,
|
||||
host=host,
|
||||
port=port,
|
||||
auth=(username, password),
|
||||
use_ssl=not args.no_ssl,
|
||||
verify_certs=not args.no_verify_certs,
|
||||
)
|
||||
if args.command == "delete"
|
||||
else OpenSearchClient(
|
||||
host=host,
|
||||
port=port,
|
||||
auth=(username, password),
|
||||
use_ssl=not args.no_ssl,
|
||||
verify_certs=not args.no_verify_certs,
|
||||
)
|
||||
) as client:
|
||||
if not client.ping():
|
||||
print("Error: Could not connect to OpenSearch.")
|
||||
sys.exit(1)
|
||||
|
||||
if args.command == "list":
|
||||
list_indices(client)
|
||||
elif args.command == "delete":
|
||||
delete_index(client)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
71
backend/tests/README.md
Normal file
71
backend/tests/README.md
Normal file
@@ -0,0 +1,71 @@
|
||||
# Backend Tests
|
||||
|
||||
## Test Types
|
||||
|
||||
There are four test categories, ordered by increasing scope:
|
||||
|
||||
### Unit Tests (`tests/unit/`)
|
||||
|
||||
No external services. Mock all I/O with `unittest.mock`. Use for complex, isolated
|
||||
logic (e.g. citation processing, encryption).
|
||||
|
||||
```bash
|
||||
pytest -xv backend/tests/unit
|
||||
```
|
||||
|
||||
### External Dependency Unit Tests (`tests/external_dependency_unit/`)
|
||||
|
||||
External services (Postgres, Redis, Vespa, OpenAI, etc.) are running, but Onyx
|
||||
application containers are not. Tests call functions directly and can mock selectively.
|
||||
|
||||
Use when you need a real database or real API calls but want control over setup.
|
||||
|
||||
```bash
|
||||
python -m dotenv -f .vscode/.env run -- pytest backend/tests/external_dependency_unit
|
||||
```
|
||||
|
||||
### Integration Tests (`tests/integration/`)
|
||||
|
||||
Full Onyx deployment running. No mocking. Prefer this over other test types when possible.
|
||||
|
||||
```bash
|
||||
python -m dotenv -f .vscode/.env run -- pytest backend/tests/integration
|
||||
```
|
||||
|
||||
### Playwright / E2E Tests (`web/tests/e2e/`)
|
||||
|
||||
Full stack including web server. Use for frontend-backend coordination.
|
||||
|
||||
```bash
|
||||
npx playwright test <TEST_NAME>
|
||||
```
|
||||
|
||||
## Shared Fixtures
|
||||
|
||||
Shared fixtures live in `backend/tests/conftest.py`. Test subdirectories can define
|
||||
their own `conftest.py` for directory-scoped fixtures.
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Use `enable_ee` fixture instead of inlining
|
||||
|
||||
Enables EE mode for a test, with proper teardown and cache clearing.
|
||||
|
||||
```python
|
||||
# Whole file (in a test module, NOT in conftest.py)
|
||||
pytestmark = pytest.mark.usefixtures("enable_ee")
|
||||
|
||||
# Whole directory — add an autouse wrapper to the directory's conftest.py
|
||||
@pytest.fixture(autouse=True)
|
||||
def _enable_ee_for_directory(enable_ee: None) -> None: # noqa: ARG001
|
||||
"""Wraps the shared enable_ee fixture with autouse for this directory."""
|
||||
|
||||
# Single test
|
||||
def test_something(enable_ee: None) -> None: ...
|
||||
```
|
||||
|
||||
**Note:** `pytestmark` in a `conftest.py` does NOT apply markers to tests in that
|
||||
directory — it only affects tests defined in the conftest itself (which is none).
|
||||
Use the autouse fixture wrapper pattern shown above instead.
|
||||
|
||||
Do NOT inline `global_version.set_ee()` — always use the fixture.
|
||||
24
backend/tests/conftest.py
Normal file
24
backend/tests/conftest.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Root conftest — shared fixtures available to all test directories."""
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def enable_ee() -> Generator[None, None, None]:
|
||||
"""Temporarily enable EE mode for a single test.
|
||||
|
||||
Restores the previous EE state and clears the versioned-implementation
|
||||
cache on teardown so state doesn't leak between tests.
|
||||
"""
|
||||
was_ee = global_version.is_ee_version()
|
||||
global_version.set_ee()
|
||||
fetch_versioned_implementation.cache_clear()
|
||||
yield
|
||||
if not was_ee:
|
||||
global_version.unset_ee()
|
||||
fetch_versioned_implementation.cache_clear()
|
||||
@@ -45,7 +45,7 @@ def confluence_connector() -> ConfluenceConnector:
|
||||
def test_confluence_connector_permissions(
|
||||
mock_get_api_key: MagicMock, # noqa: ARG001
|
||||
confluence_connector: ConfluenceConnector,
|
||||
set_ee_on: None, # noqa: ARG001
|
||||
enable_ee: None, # noqa: ARG001
|
||||
) -> None:
|
||||
# Get all doc IDs from the full connector
|
||||
all_full_doc_ids = set()
|
||||
@@ -93,7 +93,7 @@ def test_confluence_connector_permissions(
|
||||
def test_confluence_connector_restriction_handling(
|
||||
mock_get_api_key: MagicMock, # noqa: ARG001
|
||||
mock_db_provider_class: MagicMock,
|
||||
set_ee_on: None, # noqa: ARG001
|
||||
enable_ee: None, # noqa: ARG001
|
||||
) -> None:
|
||||
# Test space key
|
||||
test_space_key = "DailyPermS"
|
||||
|
||||
@@ -4,8 +4,6 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_unstructured_api_key() -> Generator[MagicMock, None, None]:
|
||||
@@ -14,14 +12,3 @@ def mock_get_unstructured_api_key() -> Generator[MagicMock, None, None]:
|
||||
return_value=None,
|
||||
) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def set_ee_on() -> Generator[None, None, None]:
|
||||
"""Need EE to be enabled for these tests to work since
|
||||
perm syncing is a an EE-only feature."""
|
||||
global_version.set_ee()
|
||||
|
||||
yield
|
||||
|
||||
global_version._is_ee = False
|
||||
|
||||
@@ -98,7 +98,7 @@ def _build_connector(
|
||||
|
||||
def test_gdrive_perm_sync_with_real_data(
|
||||
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
|
||||
set_ee_on: None, # noqa: ARG001
|
||||
enable_ee: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""
|
||||
Test gdrive_doc_sync and gdrive_group_sync with real data from the test drive.
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.slack.connector import SlackConnector
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from tests.daily.connectors.utils import load_all_from_connector
|
||||
|
||||
|
||||
@@ -19,16 +17,7 @@ PRIVATE_CHANNEL_USERS = [
|
||||
"test_user_2@onyx-test.com",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def set_ee_on() -> Generator[None, None, None]:
|
||||
"""Need EE to be enabled for these tests to work since
|
||||
perm syncing is a an EE-only feature."""
|
||||
global_version.set_ee()
|
||||
|
||||
yield
|
||||
|
||||
global_version._is_ee = False
|
||||
pytestmark = pytest.mark.usefixtures("enable_ee")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.teams.connector import TeamsConnector
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from tests.daily.connectors.teams.models import TeamsThread
|
||||
from tests.daily.connectors.utils import load_all_from_connector
|
||||
|
||||
@@ -168,18 +166,9 @@ def test_slim_docs_retrieval_from_teams_connector(
|
||||
_assert_is_valid_external_access(external_access=slim_doc.external_access)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=False)
|
||||
def set_ee_on() -> Generator[None, None, None]:
|
||||
"""Need EE to be enabled for perm sync tests to work since
|
||||
perm syncing is an EE-only feature."""
|
||||
global_version.set_ee()
|
||||
yield
|
||||
global_version._is_ee = False
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_perm_sync(
|
||||
teams_connector: TeamsConnector,
|
||||
set_ee_on: None, # noqa: ARG001
|
||||
enable_ee: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Test that load_from_checkpoint_with_perm_sync returns documents with external_access.
|
||||
|
||||
|
||||
@@ -145,6 +145,10 @@ class TestDocprocessingPriorityInDocumentExtraction:
|
||||
@patch("onyx.background.indexing.run_docfetching.get_document_batch_storage")
|
||||
@patch("onyx.background.indexing.run_docfetching.MemoryTracer")
|
||||
@patch("onyx.background.indexing.run_docfetching._get_connector_runner")
|
||||
@patch(
|
||||
"onyx.background.indexing.run_docfetching.strip_null_characters",
|
||||
side_effect=lambda batch: batch,
|
||||
)
|
||||
@patch(
|
||||
"onyx.background.indexing.run_docfetching.get_recent_completed_attempts_for_cc_pair"
|
||||
)
|
||||
@@ -169,6 +173,7 @@ class TestDocprocessingPriorityInDocumentExtraction:
|
||||
mock_save_checkpoint: MagicMock, # noqa: ARG002
|
||||
mock_get_last_successful_attempt_poll_range_end: MagicMock,
|
||||
mock_get_recent_completed_attempts: MagicMock,
|
||||
mock_strip_null_characters: MagicMock, # noqa: ARG002
|
||||
mock_get_connector_runner: MagicMock,
|
||||
mock_memory_tracer_class: MagicMock,
|
||||
mock_get_batch_storage: MagicMock,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -14,13 +15,14 @@ from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.utils import DocumentRow
|
||||
from onyx.db.utils import SortOrder
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
# In order to get these tests to run, use the credentials from Bitwarden.
|
||||
# Search up "ENV vars for local and Github tests", and find the Jira relevant key-value pairs.
|
||||
# Required env vars: JIRA_USER_EMAIL, JIRA_API_TOKEN
|
||||
|
||||
pytestmark = pytest.mark.usefixtures("enable_ee")
|
||||
|
||||
|
||||
class DocExternalAccessSet(BaseModel):
|
||||
"""A version of DocExternalAccess that uses sets for comparison."""
|
||||
@@ -52,9 +54,6 @@ def test_jira_doc_sync(
|
||||
This test uses the AS project which has applicationRole permission,
|
||||
meaning all documents should be marked as public.
|
||||
"""
|
||||
# NOTE: must set EE on or else the connector will skip the perm syncing
|
||||
global_version.set_ee()
|
||||
|
||||
try:
|
||||
# Use AS project specifically for this test
|
||||
connector_config = {
|
||||
@@ -150,9 +149,6 @@ def test_jira_doc_sync_with_specific_permissions(
|
||||
This test uses a project that has specific user permissions to verify
|
||||
that specific users are correctly extracted.
|
||||
"""
|
||||
# NOTE: must set EE on or else the connector will skip the perm syncing
|
||||
global_version.set_ee()
|
||||
|
||||
try:
|
||||
# Use SUP project which has specific user permissions
|
||||
connector_config = {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.external_permissions.jira.group_sync import jira_group_sync
|
||||
@@ -18,6 +19,8 @@ from tests.daily.connectors.confluence.models import ExternalUserGroupSet
|
||||
# Search up "ENV vars for local and Github tests", and find the Jira relevant key-value pairs.
|
||||
# Required env vars: JIRA_USER_EMAIL, JIRA_API_TOKEN
|
||||
|
||||
pytestmark = pytest.mark.usefixtures("enable_ee")
|
||||
|
||||
# Expected groups from the danswerai.atlassian.net Jira instance
|
||||
# Note: These groups are shared with Confluence since they're both Atlassian products
|
||||
# App accounts (bots, integrations) are filtered out
|
||||
|
||||
@@ -158,7 +158,7 @@ class TestLLMConfigurationEndpoint:
|
||||
)
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert exc_info.value.message == error_message
|
||||
assert exc_info.value.detail == error_message
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
@@ -540,7 +540,7 @@ class TestDefaultProviderEndpoint:
|
||||
run_test_default_provider(_=_create_mock_admin())
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert "No LLM Provider setup" in exc_info.value.message
|
||||
assert "No LLM Provider setup" in exc_info.value.detail
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
@@ -585,7 +585,7 @@ class TestDefaultProviderEndpoint:
|
||||
run_test_default_provider(_=_create_mock_admin())
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert exc_info.value.message == error_message
|
||||
assert exc_info.value.detail == error_message
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
|
||||
@@ -111,7 +111,7 @@ class TestLLMProviderChanges:
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert "cannot be changed without changing the API key" in str(
|
||||
exc_info.value.message
|
||||
exc_info.value.detail
|
||||
)
|
||||
finally:
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
@@ -247,7 +247,7 @@ class TestLLMProviderChanges:
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert "cannot be changed without changing the API key" in str(
|
||||
exc_info.value.message
|
||||
exc_info.value.detail
|
||||
)
|
||||
finally:
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
@@ -350,7 +350,7 @@ class TestLLMProviderChanges:
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert "cannot be changed without changing the API key" in str(
|
||||
exc_info.value.message
|
||||
exc_info.value.detail
|
||||
)
|
||||
finally:
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
@@ -386,7 +386,7 @@ class TestLLMProviderChanges:
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert "cannot be changed without changing the API key" in str(
|
||||
exc_info.value.message
|
||||
exc_info.value.detail
|
||||
)
|
||||
finally:
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
@@ -698,6 +698,99 @@ class TestAutoModeMissingFlows:
|
||||
class TestAutoModeTransitionsAndResync:
|
||||
"""Tests for auto/manual transitions, config evolution, and sync idempotency."""
|
||||
|
||||
def test_transition_to_auto_mode_preserves_default(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""When the default provider transitions from manual to auto mode,
|
||||
the global default should be preserved (set to the recommended model).
|
||||
|
||||
Steps:
|
||||
1. Create a manual-mode provider with models, set it as global default.
|
||||
2. Transition to auto mode (model_configurations=[] triggers cascade
|
||||
delete of old ModelConfigurations and their LLMModelFlow rows).
|
||||
3. Verify the provider is still the global default, now using the
|
||||
recommended default model from the GitHub config.
|
||||
"""
|
||||
initial_models = [
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o-mini", is_visible=True),
|
||||
]
|
||||
|
||||
auto_config = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o-mini",
|
||||
additional_models=["gpt-4o"],
|
||||
)
|
||||
|
||||
try:
|
||||
# Step 1: Create manual-mode provider and set as default
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=False,
|
||||
model_configurations=initial_models,
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
default_before = fetch_default_llm_model(db_session)
|
||||
assert default_before is not None
|
||||
assert default_before.name == "gpt-4o"
|
||||
assert default_before.llm_provider_id == provider.id
|
||||
|
||||
# Step 2: Transition to auto mode
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
return_value=auto_config,
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None,
|
||||
api_key_changed=False,
|
||||
is_auto_mode=True,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=False,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Step 3: Default should be preserved on this provider
|
||||
db_session.expire_all()
|
||||
default_after = fetch_default_llm_model(db_session)
|
||||
assert default_after is not None, (
|
||||
"Default model should not be None after transitioning to auto mode — "
|
||||
"the provider was the default before and should remain so"
|
||||
)
|
||||
assert (
|
||||
default_after.llm_provider_id == provider.id
|
||||
), "Default should still belong to the same provider after transition"
|
||||
assert default_after.name == "gpt-4o-mini", (
|
||||
f"Default should be updated to the recommended model 'gpt-4o-mini', "
|
||||
f"got '{default_after.name}'"
|
||||
)
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
def test_auto_to_manual_mode_preserves_models_and_stops_syncing(
|
||||
self,
|
||||
db_session: Session,
|
||||
@@ -1042,14 +1135,195 @@ class TestAutoModeTransitionsAndResync:
|
||||
assert visibility["gpt-4o"] is False, "Removed default should be hidden"
|
||||
assert visibility["gpt-4o-mini"] is True, "New default should be visible"
|
||||
|
||||
# The LLMModelFlow row for gpt-4o still exists (is_default=True),
|
||||
# but the model is hidden. fetch_default_llm_model filters on
|
||||
# is_visible=True, so it should NOT return gpt-4o.
|
||||
# The old default (gpt-4o) is now hidden. sync_auto_mode_models
|
||||
# should update the global default to the new recommended default
|
||||
# (gpt-4o-mini) so that it is not silently lost.
|
||||
db_session.expire_all()
|
||||
default_after = fetch_default_llm_model(db_session)
|
||||
assert (
|
||||
default_after is None or default_after.name != "gpt-4o"
|
||||
), "Hidden model should not be returned as the default"
|
||||
assert default_after is not None, (
|
||||
"Default model should not be None — sync should set the new "
|
||||
"recommended default when the old one is hidden"
|
||||
)
|
||||
assert default_after.name == "gpt-4o-mini", (
|
||||
f"Default should be updated to the new recommended model "
|
||||
f"'gpt-4o-mini', but got '{default_after.name}'"
|
||||
)
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
def test_sync_updates_default_when_recommended_default_changes(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""When the provider owns the CHAT default and a sync arrives with a
|
||||
different recommended default model (both models still in config),
|
||||
the global default should be updated to the new recommendation.
|
||||
|
||||
Steps:
|
||||
1. Create auto-mode provider with config v1: default=gpt-4o.
|
||||
2. Set gpt-4o as the global CHAT default.
|
||||
3. Re-sync with config v2: default=gpt-4o-mini (gpt-4o still present).
|
||||
4. Verify the CHAT default switched to gpt-4o-mini and both models
|
||||
remain visible.
|
||||
"""
|
||||
config_v1 = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o",
|
||||
additional_models=["gpt-4o-mini"],
|
||||
)
|
||||
config_v2 = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o-mini",
|
||||
additional_models=["gpt-4o"],
|
||||
)
|
||||
|
||||
try:
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
return_value=config_v1,
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Set gpt-4o as the global CHAT default
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
default_before = fetch_default_llm_model(db_session)
|
||||
assert default_before is not None
|
||||
assert default_before.name == "gpt-4o"
|
||||
|
||||
# Re-sync with config v2 (recommended default changed)
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
|
||||
changes = sync_auto_mode_models(
|
||||
db_session=db_session,
|
||||
provider=provider,
|
||||
llm_recommendations=config_v2,
|
||||
)
|
||||
assert changes > 0, "Sync should report changes when default switches"
|
||||
|
||||
# Both models should remain visible
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
visibility = {
|
||||
mc.name: mc.is_visible for mc in provider.model_configurations
|
||||
}
|
||||
assert visibility["gpt-4o"] is True
|
||||
assert visibility["gpt-4o-mini"] is True
|
||||
|
||||
# The CHAT default should now be gpt-4o-mini
|
||||
default_after = fetch_default_llm_model(db_session)
|
||||
assert default_after is not None
|
||||
assert (
|
||||
default_after.name == "gpt-4o-mini"
|
||||
), f"Default should be updated to 'gpt-4o-mini', got '{default_after.name}'"
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
def test_sync_idempotent_when_default_already_matches(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""When the provider owns the CHAT default and it already matches the
|
||||
recommended default, re-syncing should report zero changes.
|
||||
|
||||
This is a regression test for the bug where changes was unconditionally
|
||||
incremented even when the default was already correct.
|
||||
"""
|
||||
config = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o",
|
||||
additional_models=["gpt-4o-mini"],
|
||||
)
|
||||
|
||||
try:
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
return_value=config,
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Set gpt-4o (the recommended default) as global CHAT default
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
# First sync to stabilize state
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
sync_auto_mode_models(
|
||||
db_session=db_session,
|
||||
provider=provider,
|
||||
llm_recommendations=config,
|
||||
)
|
||||
|
||||
# Second sync — default already matches, should be a no-op
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
changes = sync_auto_mode_models(
|
||||
db_session=db_session,
|
||||
provider=provider,
|
||||
llm_recommendations=config,
|
||||
)
|
||||
assert changes == 0, (
|
||||
f"Expected 0 changes when default already matches recommended, "
|
||||
f"got {changes}"
|
||||
)
|
||||
|
||||
# Default should still be gpt-4o
|
||||
default_model = fetch_default_llm_model(db_session)
|
||||
assert default_model is not None
|
||||
assert default_model.name == "gpt-4o"
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
|
||||
@@ -0,0 +1,220 @@
|
||||
"""
|
||||
This should act as the main point of reference for testing that default model
|
||||
logic is consisten.
|
||||
|
||||
-
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import update_default_vision_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
|
||||
|
||||
def _create_test_provider(
|
||||
db_session: Session,
|
||||
name: str,
|
||||
models: list[ModelConfigurationUpsertRequest] | None = None,
|
||||
) -> LLMProviderView:
|
||||
"""Helper to create a test LLM provider with multiple models."""
|
||||
if models is None:
|
||||
models = [
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o", is_visible=True, supports_image_input=True
|
||||
),
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True, supports_image_input=False
|
||||
),
|
||||
]
|
||||
return upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
model_configurations=models,
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def _cleanup_provider(db_session: Session, name: str) -> None:
|
||||
"""Helper to clean up a test provider by name."""
|
||||
provider = fetch_existing_llm_provider(name=name, db_session=db_session)
|
||||
if provider:
|
||||
remove_llm_provider(db_session, provider.id)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider_name(db_session: Session) -> Generator[str, None, None]:
|
||||
"""Generate a unique provider name for each test, with automatic cleanup."""
|
||||
name = f"test-provider-{uuid4().hex[:8]}"
|
||||
yield name
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, name)
|
||||
|
||||
|
||||
class TestDefaultModelProtection:
|
||||
"""Tests that the default model cannot be removed or hidden."""
|
||||
|
||||
def test_cannot_remove_default_text_model(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""Removing the default text model from a provider should raise ValueError."""
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
# Try to update the provider without the default model
|
||||
with pytest.raises(ValueError, match="Cannot remove the default model"):
|
||||
upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
),
|
||||
],
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
def test_cannot_hide_default_text_model(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""Setting is_visible=False on the default text model should raise ValueError."""
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
# Try to hide the default model
|
||||
with pytest.raises(ValueError, match="Cannot hide the default model"):
|
||||
upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o", is_visible=False
|
||||
),
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
),
|
||||
],
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
def test_cannot_remove_default_vision_model(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""Removing the default vision model from a provider should raise ValueError."""
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
# Set gpt-4o as both the text and vision default
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
update_default_vision_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
# Try to remove the default vision model
|
||||
with pytest.raises(ValueError, match="Cannot remove the default model"):
|
||||
upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
),
|
||||
],
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
def test_can_remove_non_default_model(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""Removing a non-default model should succeed."""
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
# Remove gpt-4o-mini (not default) — should succeed
|
||||
updated = upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o", is_visible=True, supports_image_input=True
|
||||
),
|
||||
],
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
model_names = {mc.name for mc in updated.model_configurations}
|
||||
assert "gpt-4o" in model_names
|
||||
assert "gpt-4o-mini" not in model_names
|
||||
|
||||
def test_can_hide_non_default_model(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""Hiding a non-default model should succeed."""
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
# Hide gpt-4o-mini (not default) — should succeed
|
||||
updated = upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o", is_visible=True, supports_image_input=True
|
||||
),
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=False
|
||||
),
|
||||
],
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
model_visibility = {
|
||||
mc.name: mc.is_visible for mc in updated.model_configurations
|
||||
}
|
||||
assert model_visibility["gpt-4o"] is True
|
||||
assert model_visibility["gpt-4o-mini"] is False
|
||||
@@ -21,6 +21,8 @@ from onyx.db.oauth_config import get_tools_by_oauth_config
|
||||
from onyx.db.oauth_config import get_user_oauth_token
|
||||
from onyx.db.oauth_config import update_oauth_config
|
||||
from onyx.db.oauth_config import upsert_user_oauth_token
|
||||
from onyx.db.tools import delete_tool__no_commit
|
||||
from onyx.db.tools import update_tool
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
|
||||
|
||||
@@ -312,6 +314,85 @@ class TestOAuthConfigCRUD:
|
||||
# Tool should still exist but oauth_config_id should be NULL
|
||||
assert tool.oauth_config_id is None
|
||||
|
||||
def test_update_tool_cleans_up_orphaned_oauth_config(
|
||||
self, db_session: Session
|
||||
) -> None:
|
||||
"""Test that changing a tool's oauth_config_id deletes the old config if no other tool uses it."""
|
||||
old_config = _create_test_oauth_config(db_session)
|
||||
new_config = _create_test_oauth_config(db_session)
|
||||
tool = _create_test_tool_with_oauth(db_session, old_config)
|
||||
old_config_id = old_config.id
|
||||
|
||||
update_tool(
|
||||
tool_id=tool.id,
|
||||
name=None,
|
||||
description=None,
|
||||
openapi_schema=None,
|
||||
custom_headers=None,
|
||||
user_id=None,
|
||||
db_session=db_session,
|
||||
passthrough_auth=None,
|
||||
oauth_config_id=new_config.id,
|
||||
)
|
||||
|
||||
assert tool.oauth_config_id == new_config.id
|
||||
assert get_oauth_config(old_config_id, db_session) is None
|
||||
|
||||
def test_delete_tool_cleans_up_orphaned_oauth_config(
|
||||
self, db_session: Session
|
||||
) -> None:
|
||||
"""Test that deleting the last tool referencing an OAuthConfig also deletes the config."""
|
||||
config = _create_test_oauth_config(db_session)
|
||||
tool = _create_test_tool_with_oauth(db_session, config)
|
||||
config_id = config.id
|
||||
|
||||
delete_tool__no_commit(tool.id, db_session)
|
||||
db_session.commit()
|
||||
|
||||
assert get_oauth_config(config_id, db_session) is None
|
||||
|
||||
def test_update_tool_preserves_shared_oauth_config(
|
||||
self, db_session: Session
|
||||
) -> None:
|
||||
"""Test that updating one tool's oauth_config_id preserves the config when another tool still uses it."""
|
||||
shared_config = _create_test_oauth_config(db_session)
|
||||
new_config = _create_test_oauth_config(db_session)
|
||||
tool_a = _create_test_tool_with_oauth(db_session, shared_config)
|
||||
tool_b = _create_test_tool_with_oauth(db_session, shared_config)
|
||||
shared_config_id = shared_config.id
|
||||
|
||||
# Move tool_a to a new config; tool_b still references shared_config
|
||||
update_tool(
|
||||
tool_id=tool_a.id,
|
||||
name=None,
|
||||
description=None,
|
||||
openapi_schema=None,
|
||||
custom_headers=None,
|
||||
user_id=None,
|
||||
db_session=db_session,
|
||||
passthrough_auth=None,
|
||||
oauth_config_id=new_config.id,
|
||||
)
|
||||
|
||||
assert tool_a.oauth_config_id == new_config.id
|
||||
assert tool_b.oauth_config_id == shared_config_id
|
||||
assert get_oauth_config(shared_config_id, db_session) is not None
|
||||
|
||||
def test_delete_tool_preserves_shared_oauth_config(
|
||||
self, db_session: Session
|
||||
) -> None:
|
||||
"""Test that deleting one tool preserves the config when another tool still uses it."""
|
||||
shared_config = _create_test_oauth_config(db_session)
|
||||
tool_a = _create_test_tool_with_oauth(db_session, shared_config)
|
||||
tool_b = _create_test_tool_with_oauth(db_session, shared_config)
|
||||
shared_config_id = shared_config.id
|
||||
|
||||
delete_tool__no_commit(tool_a.id, db_session)
|
||||
db_session.commit()
|
||||
|
||||
assert tool_b.oauth_config_id == shared_config_id
|
||||
assert get_oauth_config(shared_config_id, db_session) is not None
|
||||
|
||||
|
||||
class TestOAuthUserTokenCRUD:
|
||||
"""Tests for OAuth user token CRUD operations"""
|
||||
|
||||
@@ -427,7 +427,7 @@ def test_delete_default_llm_provider_rejected(reset: None) -> None: # noqa: ARG
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert delete_response.status_code == 400
|
||||
assert "Cannot delete the default LLM provider" in delete_response.json()["message"]
|
||||
assert "Cannot delete the default LLM provider" in delete_response.json()["detail"]
|
||||
|
||||
# Verify provider still exists
|
||||
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
|
||||
@@ -674,7 +674,7 @@ def test_duplicate_provider_name_rejected(reset: None) -> None: # noqa: ARG001
|
||||
json=base_payload,
|
||||
)
|
||||
assert response.status_code == 409
|
||||
assert "already exists" in response.json()["message"]
|
||||
assert "already exists" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_rename_provider_rejected(reset: None) -> None: # noqa: ARG001
|
||||
@@ -711,7 +711,7 @@ def test_rename_provider_rejected(reset: None) -> None: # noqa: ARG001
|
||||
json=update_payload,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "not currently supported" in response.json()["message"]
|
||||
assert "not currently supported" in response.json()["detail"]
|
||||
|
||||
# Verify no duplicate was created — only the original provider should exist
|
||||
provider = _get_provider_by_id(admin_user, provider_id)
|
||||
|
||||
@@ -69,7 +69,7 @@ def test_unauthorized_persona_access_returns_403(
|
||||
|
||||
# Should return 403 Forbidden
|
||||
assert response.status_code == 403
|
||||
assert "don't have access to this assistant" in response.json()["message"]
|
||||
assert "don't have access to this assistant" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_authorized_persona_access_returns_filtered_providers(
|
||||
@@ -245,4 +245,4 @@ def test_nonexistent_persona_returns_404(
|
||||
|
||||
# Should return 404
|
||||
assert response.status_code == 404
|
||||
assert "Persona not found" in response.json()["message"]
|
||||
assert "Persona not found" in response.json()["detail"]
|
||||
|
||||
8
backend/tests/unit/ee/conftest.py
Normal file
8
backend/tests/unit/ee/conftest.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Auto-enable EE mode for all tests under tests/unit/ee/."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _enable_ee_for_directory(enable_ee: None) -> None: # noqa: ARG001
|
||||
"""Wraps the shared enable_ee fixture with autouse for this directory."""
|
||||
@@ -107,7 +107,7 @@ class TestCreateCheckoutSession:
|
||||
|
||||
assert exc_info.value.status_code == 502
|
||||
assert exc_info.value.error_code is OnyxErrorCode.BAD_GATEWAY
|
||||
assert exc_info.value.message == "Stripe error"
|
||||
assert exc_info.value.detail == "Stripe error"
|
||||
|
||||
|
||||
class TestCreateCustomerPortalSession:
|
||||
@@ -137,7 +137,7 @@ class TestCreateCustomerPortalSession:
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert exc_info.value.error_code is OnyxErrorCode.VALIDATION_ERROR
|
||||
assert exc_info.value.message == "No license found"
|
||||
assert exc_info.value.detail == "No license found"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.billing.api.create_portal_service")
|
||||
@@ -243,7 +243,7 @@ class TestUpdateSeats:
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert exc_info.value.error_code is OnyxErrorCode.VALIDATION_ERROR
|
||||
assert exc_info.value.message == "No license found"
|
||||
assert exc_info.value.detail == "No license found"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.billing.api.get_used_seats")
|
||||
@@ -317,7 +317,7 @@ class TestUpdateSeats:
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert exc_info.value.error_code is OnyxErrorCode.BAD_GATEWAY
|
||||
assert exc_info.value.message == "Cannot reduce below 10 seats"
|
||||
assert exc_info.value.detail == "Cannot reduce below 10 seats"
|
||||
|
||||
|
||||
class TestCircuitBreaker:
|
||||
@@ -346,7 +346,7 @@ class TestCircuitBreaker:
|
||||
|
||||
assert exc_info.value.status_code == 503
|
||||
assert exc_info.value.error_code is OnyxErrorCode.SERVICE_UNAVAILABLE
|
||||
assert "Connect to Stripe" in exc_info.value.message
|
||||
assert "Connect to Stripe" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.billing.api.MULTI_TENANT", False)
|
||||
|
||||
@@ -101,7 +101,7 @@ class TestMakeBillingRequest:
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert exc_info.value.error_code is OnyxErrorCode.BAD_GATEWAY
|
||||
assert "Bad request" in exc_info.value.message
|
||||
assert "Bad request" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.billing.service._get_headers")
|
||||
@@ -152,7 +152,7 @@ class TestMakeBillingRequest:
|
||||
|
||||
assert exc_info.value.status_code == 502
|
||||
assert exc_info.value.error_code is OnyxErrorCode.BAD_GATEWAY
|
||||
assert "Failed to connect" in exc_info.value.message
|
||||
assert "Failed to connect" in exc_info.value.detail
|
||||
|
||||
|
||||
class TestCreateCheckoutSession:
|
||||
|
||||
@@ -72,7 +72,7 @@ class TestGetStripePublishableKey:
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert exc_info.value.error_code is OnyxErrorCode.INTERNAL_ERROR
|
||||
assert exc_info.value.message == "Invalid Stripe publishable key format"
|
||||
assert exc_info.value.detail == "Invalid Stripe publishable key format"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE", None)
|
||||
@@ -97,7 +97,7 @@ class TestGetStripePublishableKey:
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert exc_info.value.error_code is OnyxErrorCode.INTERNAL_ERROR
|
||||
assert exc_info.value.message == "Invalid Stripe publishable key format"
|
||||
assert exc_info.value.detail == "Invalid Stripe publishable key format"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE", None)
|
||||
@@ -118,7 +118,7 @@ class TestGetStripePublishableKey:
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert exc_info.value.error_code is OnyxErrorCode.INTERNAL_ERROR
|
||||
assert exc_info.value.message == "Failed to fetch Stripe publishable key"
|
||||
assert exc_info.value.detail == "Failed to fetch Stripe publishable key"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE", None)
|
||||
@@ -132,7 +132,7 @@ class TestGetStripePublishableKey:
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert exc_info.value.error_code is OnyxErrorCode.INTERNAL_ERROR
|
||||
assert "not configured" in exc_info.value.message
|
||||
assert "not configured" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.chat.llm_loop import _should_keep_bedrock_tool_definitions
|
||||
from onyx.chat.llm_loop import _try_fallback_tool_extraction
|
||||
from onyx.chat.llm_loop import construct_message_history
|
||||
from onyx.chat.models import ChatLoadedFile
|
||||
@@ -14,22 +13,11 @@ from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
|
||||
|
||||
class _StubConfig:
|
||||
def __init__(self, model_provider: str) -> None:
|
||||
self.model_provider = model_provider
|
||||
|
||||
|
||||
class _StubLLM:
|
||||
def __init__(self, model_provider: str) -> None:
|
||||
self.config = _StubConfig(model_provider=model_provider)
|
||||
|
||||
|
||||
def create_message(
|
||||
content: str, message_type: MessageType, token_count: int | None = None
|
||||
) -> ChatMessageSimple:
|
||||
@@ -946,37 +934,6 @@ class TestForgottenFileMetadata:
|
||||
assert "moby_dick.txt" in forgotten.message
|
||||
|
||||
|
||||
class TestBedrockToolConfigGuard:
|
||||
def test_bedrock_with_tool_history_keeps_tool_definitions(self) -> None:
|
||||
llm = _StubLLM(LlmProviderNames.BEDROCK)
|
||||
history = [
|
||||
create_message("Question", MessageType.USER, 5),
|
||||
create_assistant_with_tool_call("tc_1", "search", 5),
|
||||
create_tool_response("tc_1", "Tool output", 5),
|
||||
]
|
||||
|
||||
assert _should_keep_bedrock_tool_definitions(llm, history) is True
|
||||
|
||||
def test_bedrock_without_tool_history_does_not_keep_tool_definitions(self) -> None:
|
||||
llm = _StubLLM(LlmProviderNames.BEDROCK)
|
||||
history = [
|
||||
create_message("Question", MessageType.USER, 5),
|
||||
create_message("Answer", MessageType.ASSISTANT, 5),
|
||||
]
|
||||
|
||||
assert _should_keep_bedrock_tool_definitions(llm, history) is False
|
||||
|
||||
def test_non_bedrock_with_tool_history_does_not_keep_tool_definitions(self) -> None:
|
||||
llm = _StubLLM(LlmProviderNames.OPENAI)
|
||||
history = [
|
||||
create_message("Question", MessageType.USER, 5),
|
||||
create_assistant_with_tool_call("tc_1", "search", 5),
|
||||
create_tool_response("tc_1", "Tool output", 5),
|
||||
]
|
||||
|
||||
assert _should_keep_bedrock_tool_definitions(llm, history) is False
|
||||
|
||||
|
||||
class TestFallbackToolExtraction:
|
||||
def _tool_defs(self) -> list[dict]:
|
||||
return [
|
||||
|
||||
@@ -8,7 +8,6 @@ from onyx.chat.llm_step import _extract_tool_call_kickoffs
|
||||
from onyx.chat.llm_step import _increment_turns
|
||||
from onyx.chat.llm_step import _parse_tool_args_to_dict
|
||||
from onyx.chat.llm_step import _resolve_tool_arguments
|
||||
from onyx.chat.llm_step import _sanitize_llm_output
|
||||
from onyx.chat.llm_step import _XmlToolCallContentFilter
|
||||
from onyx.chat.llm_step import extract_tool_calls_from_response_text
|
||||
from onyx.chat.llm_step import translate_history_to_llm_format
|
||||
@@ -21,48 +20,49 @@ from onyx.llm.models import AssistantMessage
|
||||
from onyx.llm.models import ToolMessage
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.utils.postgres_sanitization import sanitize_string
|
||||
|
||||
|
||||
class TestSanitizeLlmOutput:
|
||||
"""Tests for the _sanitize_llm_output function."""
|
||||
"""Tests for the sanitize_string function."""
|
||||
|
||||
def test_removes_null_bytes(self) -> None:
|
||||
"""Test that NULL bytes are removed from strings."""
|
||||
assert _sanitize_llm_output("hello\x00world") == "helloworld"
|
||||
assert _sanitize_llm_output("\x00start") == "start"
|
||||
assert _sanitize_llm_output("end\x00") == "end"
|
||||
assert _sanitize_llm_output("\x00\x00\x00") == ""
|
||||
assert sanitize_string("hello\x00world") == "helloworld"
|
||||
assert sanitize_string("\x00start") == "start"
|
||||
assert sanitize_string("end\x00") == "end"
|
||||
assert sanitize_string("\x00\x00\x00") == ""
|
||||
|
||||
def test_removes_surrogates(self) -> None:
|
||||
"""Test that UTF-16 surrogates are removed from strings."""
|
||||
# Low surrogate
|
||||
assert _sanitize_llm_output("hello\ud800world") == "helloworld"
|
||||
assert sanitize_string("hello\ud800world") == "helloworld"
|
||||
# High surrogate
|
||||
assert _sanitize_llm_output("hello\udfffworld") == "helloworld"
|
||||
assert sanitize_string("hello\udfffworld") == "helloworld"
|
||||
# Middle of surrogate range
|
||||
assert _sanitize_llm_output("test\uda00value") == "testvalue"
|
||||
assert sanitize_string("test\uda00value") == "testvalue"
|
||||
|
||||
def test_removes_mixed_bad_characters(self) -> None:
|
||||
"""Test removal of both NULL bytes and surrogates together."""
|
||||
assert _sanitize_llm_output("a\x00b\ud800c\udfffd") == "abcd"
|
||||
assert sanitize_string("a\x00b\ud800c\udfffd") == "abcd"
|
||||
|
||||
def test_preserves_valid_unicode(self) -> None:
|
||||
"""Test that valid Unicode characters are preserved."""
|
||||
# Emojis
|
||||
assert _sanitize_llm_output("hello 👋 world") == "hello 👋 world"
|
||||
assert sanitize_string("hello 👋 world") == "hello 👋 world"
|
||||
# Chinese characters
|
||||
assert _sanitize_llm_output("你好世界") == "你好世界"
|
||||
assert sanitize_string("你好世界") == "你好世界"
|
||||
# Mixed scripts
|
||||
assert _sanitize_llm_output("Hello мир 世界") == "Hello мир 世界"
|
||||
assert sanitize_string("Hello мир 世界") == "Hello мир 世界"
|
||||
|
||||
def test_empty_string(self) -> None:
|
||||
"""Test that empty strings are handled correctly."""
|
||||
assert _sanitize_llm_output("") == ""
|
||||
assert sanitize_string("") == ""
|
||||
|
||||
def test_normal_ascii(self) -> None:
|
||||
"""Test that normal ASCII strings pass through unchanged."""
|
||||
assert _sanitize_llm_output("hello world") == "hello world"
|
||||
assert _sanitize_llm_output('{"key": "value"}') == '{"key": "value"}'
|
||||
assert sanitize_string("hello world") == "hello world"
|
||||
assert sanitize_string('{"key": "value"}') == '{"key": "value"}'
|
||||
|
||||
|
||||
class TestParseToolArgsToDict:
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
"""Tests for _extract_referenced_file_descriptors in save_chat.py.
|
||||
"""Tests for save_chat.py.
|
||||
|
||||
Verifies that only code interpreter generated files actually referenced
|
||||
in the assistant's message text are extracted as FileDescriptors for
|
||||
cross-turn persistence.
|
||||
Covers _extract_referenced_file_descriptors and sanitization in save_chat_turn.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from onyx.chat import save_chat
|
||||
from onyx.chat.save_chat import _extract_referenced_file_descriptors
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.tools.models import PythonExecutionFile
|
||||
@@ -29,6 +32,9 @@ def _make_tool_call_info(
|
||||
)
|
||||
|
||||
|
||||
# ---- _extract_referenced_file_descriptors tests ----
|
||||
|
||||
|
||||
def test_returns_empty_when_no_generated_files() -> None:
|
||||
tool_call = _make_tool_call_info(generated_files=None)
|
||||
result = _extract_referenced_file_descriptors([tool_call], "some message")
|
||||
@@ -176,3 +182,34 @@ def test_skips_tool_calls_without_generated_files() -> None:
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["id"] == file_id
|
||||
|
||||
|
||||
# ---- save_chat_turn sanitization test ----
|
||||
|
||||
|
||||
def test_save_chat_turn_sanitizes_message_and_reasoning(
|
||||
monkeypatch: MonkeyPatch,
|
||||
) -> None:
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_tokenizer.encode.return_value = [1, 2, 3]
|
||||
monkeypatch.setattr(save_chat, "get_tokenizer", lambda *_a, **_kw: mock_tokenizer)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.id = 1
|
||||
mock_msg.chat_session_id = "test"
|
||||
mock_msg.files = None
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
save_chat.save_chat_turn(
|
||||
message_text="hello\x00world\ud800",
|
||||
reasoning_tokens="think\x00ing\udfff",
|
||||
tool_calls=[],
|
||||
citation_to_doc={},
|
||||
all_search_docs={},
|
||||
db_session=mock_session,
|
||||
assistant_message=mock_msg,
|
||||
)
|
||||
|
||||
assert mock_msg.message == "helloworld"
|
||||
assert mock_msg.reasoning_tokens == "thinking"
|
||||
@@ -9,6 +9,8 @@ from onyx.connectors.jira.utils import JIRA_SERVER_API_VERSION
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.utils.sensitive import make_mock_sensitive_value
|
||||
|
||||
pytestmark = pytest.mark.usefixtures("enable_ee")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jira_cc_pair(
|
||||
|
||||
@@ -7,6 +7,7 @@ import pytest
|
||||
|
||||
from onyx.db.llm import sync_model_configurations
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.server.manage.llm.models import SyncModelEntry
|
||||
|
||||
|
||||
class TestSyncModelConfigurations:
|
||||
@@ -25,18 +26,18 @@ class TestSyncModelConfigurations:
|
||||
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
|
||||
):
|
||||
models = [
|
||||
{
|
||||
"name": "gpt-4",
|
||||
"display_name": "GPT-4",
|
||||
"max_input_tokens": 128000,
|
||||
"supports_image_input": True,
|
||||
},
|
||||
{
|
||||
"name": "gpt-4o",
|
||||
"display_name": "GPT-4o",
|
||||
"max_input_tokens": 128000,
|
||||
"supports_image_input": True,
|
||||
},
|
||||
SyncModelEntry(
|
||||
name="gpt-4",
|
||||
display_name="GPT-4",
|
||||
max_input_tokens=128000,
|
||||
supports_image_input=True,
|
||||
),
|
||||
SyncModelEntry(
|
||||
name="gpt-4o",
|
||||
display_name="GPT-4o",
|
||||
max_input_tokens=128000,
|
||||
supports_image_input=True,
|
||||
),
|
||||
]
|
||||
|
||||
result = sync_model_configurations(
|
||||
@@ -67,18 +68,18 @@ class TestSyncModelConfigurations:
|
||||
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
|
||||
):
|
||||
models = [
|
||||
{
|
||||
"name": "gpt-4", # Existing - should be skipped
|
||||
"display_name": "GPT-4",
|
||||
"max_input_tokens": 128000,
|
||||
"supports_image_input": True,
|
||||
},
|
||||
{
|
||||
"name": "gpt-4o", # New - should be inserted
|
||||
"display_name": "GPT-4o",
|
||||
"max_input_tokens": 128000,
|
||||
"supports_image_input": True,
|
||||
},
|
||||
SyncModelEntry(
|
||||
name="gpt-4", # Existing - should be skipped
|
||||
display_name="GPT-4",
|
||||
max_input_tokens=128000,
|
||||
supports_image_input=True,
|
||||
),
|
||||
SyncModelEntry(
|
||||
name="gpt-4o", # New - should be inserted
|
||||
display_name="GPT-4o",
|
||||
max_input_tokens=128000,
|
||||
supports_image_input=True,
|
||||
),
|
||||
]
|
||||
|
||||
result = sync_model_configurations(
|
||||
@@ -105,12 +106,12 @@ class TestSyncModelConfigurations:
|
||||
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
|
||||
):
|
||||
models = [
|
||||
{
|
||||
"name": "gpt-4", # Already exists
|
||||
"display_name": "GPT-4",
|
||||
"max_input_tokens": 128000,
|
||||
"supports_image_input": True,
|
||||
},
|
||||
SyncModelEntry(
|
||||
name="gpt-4", # Already exists
|
||||
display_name="GPT-4",
|
||||
max_input_tokens=128000,
|
||||
supports_image_input=True,
|
||||
),
|
||||
]
|
||||
|
||||
result = sync_model_configurations(
|
||||
@@ -131,7 +132,7 @@ class TestSyncModelConfigurations:
|
||||
sync_model_configurations(
|
||||
db_session=mock_session,
|
||||
provider_name="nonexistent",
|
||||
models=[{"name": "model", "display_name": "Model"}],
|
||||
models=[SyncModelEntry(name="model", display_name="Model")],
|
||||
)
|
||||
|
||||
def test_handles_missing_optional_fields(self) -> None:
|
||||
@@ -145,12 +146,12 @@ class TestSyncModelConfigurations:
|
||||
with patch(
|
||||
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
|
||||
):
|
||||
# Model with only required fields
|
||||
# Model with only required fields (max_input_tokens and supports_image_input default)
|
||||
models = [
|
||||
{
|
||||
"name": "model-1",
|
||||
# No display_name, max_input_tokens, or supports_image_input
|
||||
},
|
||||
SyncModelEntry(
|
||||
name="model-1",
|
||||
display_name="Model 1",
|
||||
),
|
||||
]
|
||||
|
||||
result = sync_model_configurations(
|
||||
|
||||
27
backend/tests/unit/onyx/db/test_tools.py
Normal file
27
backend/tests/unit/onyx/db/test_tools.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.db import tools as tools_mod
|
||||
|
||||
|
||||
def test_create_tool_call_no_commit_sanitizes_fields() -> None:
|
||||
mock_session = MagicMock()
|
||||
|
||||
tool_call = tools_mod.create_tool_call_no_commit(
|
||||
chat_session_id=uuid4(),
|
||||
parent_chat_message_id=1,
|
||||
turn_number=0,
|
||||
tool_id=1,
|
||||
tool_call_id="tc-1",
|
||||
tool_call_arguments={"task\x00": "research\ud800 topic"},
|
||||
tool_call_response="report\x00 text\udfff here",
|
||||
tool_call_tokens=10,
|
||||
db_session=mock_session,
|
||||
reasoning_tokens="reason\x00ing\ud800",
|
||||
generated_images=[{"url": "img\x00.png\udfff"}],
|
||||
)
|
||||
|
||||
assert tool_call.tool_call_response == "report text here"
|
||||
assert tool_call.reasoning_tokens == "reasoning"
|
||||
assert tool_call.tool_call_arguments == {"task": "research topic"}
|
||||
assert tool_call.generated_images == [{"url": "img.png"}]
|
||||
@@ -15,12 +15,12 @@ class TestOnyxError:
|
||||
def test_basic_construction(self) -> None:
|
||||
err = OnyxError(OnyxErrorCode.NOT_FOUND, "Session not found")
|
||||
assert err.error_code is OnyxErrorCode.NOT_FOUND
|
||||
assert err.message == "Session not found"
|
||||
assert err.detail == "Session not found"
|
||||
assert err.status_code == 404
|
||||
|
||||
def test_message_defaults_to_code(self) -> None:
|
||||
err = OnyxError(OnyxErrorCode.UNAUTHENTICATED)
|
||||
assert err.message == "UNAUTHENTICATED"
|
||||
assert err.detail == "UNAUTHENTICATED"
|
||||
assert str(err) == "UNAUTHENTICATED"
|
||||
|
||||
def test_status_code_override(self) -> None:
|
||||
@@ -73,18 +73,18 @@ class TestExceptionHandler:
|
||||
assert resp.status_code == 404
|
||||
body = resp.json()
|
||||
assert body["error_code"] == "NOT_FOUND"
|
||||
assert body["message"] == "Thing not found"
|
||||
assert body["detail"] == "Thing not found"
|
||||
|
||||
def test_status_code_override_in_response(self, client: TestClient) -> None:
|
||||
resp = client.get("/boom-override")
|
||||
assert resp.status_code == 503
|
||||
body = resp.json()
|
||||
assert body["error_code"] == "BAD_GATEWAY"
|
||||
assert body["message"] == "upstream 503"
|
||||
assert body["detail"] == "upstream 503"
|
||||
|
||||
def test_default_message(self, client: TestClient) -> None:
|
||||
resp = client.get("/boom-default-msg")
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["error_code"] == "UNAUTHENTICATED"
|
||||
assert body["message"] == "UNAUTHENTICATED"
|
||||
assert body["detail"] == "UNAUTHENTICATED"
|
||||
|
||||
@@ -1214,3 +1214,218 @@ def test_multithreaded_invoke_without_custom_config_skips_env_lock() -> None:
|
||||
|
||||
# The env lock context manager should never have been called
|
||||
mock_env_lock.assert_not_called()
|
||||
|
||||
|
||||
# ---- Tests for Bedrock tool content stripping ----
|
||||
|
||||
|
||||
def test_messages_contain_tool_content_with_tool_role() -> None:
|
||||
from onyx.llm.multi_llm import _messages_contain_tool_content
|
||||
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "I'll search for that."},
|
||||
{"role": "tool", "content": "search results", "tool_call_id": "tc_1"},
|
||||
]
|
||||
assert _messages_contain_tool_content(messages) is True
|
||||
|
||||
|
||||
def test_messages_contain_tool_content_with_tool_calls() -> None:
|
||||
from onyx.llm.multi_llm import _messages_contain_tool_content
|
||||
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "tc_1",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
assert _messages_contain_tool_content(messages) is True
|
||||
|
||||
|
||||
def test_messages_contain_tool_content_without_tools() -> None:
|
||||
from onyx.llm.multi_llm import _messages_contain_tool_content
|
||||
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
assert _messages_contain_tool_content(messages) is False
|
||||
|
||||
|
||||
def test_strip_tool_content_converts_assistant_tool_calls_to_text() -> None:
|
||||
from onyx.llm.multi_llm import _strip_tool_content_from_messages
|
||||
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "Search for cats"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Let me search.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "tc_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search",
|
||||
"arguments": '{"query": "cats"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "Found 3 results about cats.",
|
||||
"tool_call_id": "tc_1",
|
||||
},
|
||||
{"role": "assistant", "content": "Here are the results."},
|
||||
]
|
||||
|
||||
result = _strip_tool_content_from_messages(messages)
|
||||
|
||||
assert len(result) == 4
|
||||
|
||||
# First message unchanged
|
||||
assert result[0] == {"role": "user", "content": "Search for cats"}
|
||||
|
||||
# Assistant with tool calls → plain text
|
||||
assert result[1]["role"] == "assistant"
|
||||
assert "tool_calls" not in result[1]
|
||||
assert "Let me search." in result[1]["content"]
|
||||
assert "[Tool Call]" in result[1]["content"]
|
||||
assert "search" in result[1]["content"]
|
||||
assert "tc_1" in result[1]["content"]
|
||||
|
||||
# Tool response → user message
|
||||
assert result[2]["role"] == "user"
|
||||
assert "[Tool Result]" in result[2]["content"]
|
||||
assert "tc_1" in result[2]["content"]
|
||||
assert "Found 3 results about cats." in result[2]["content"]
|
||||
|
||||
# Final assistant message unchanged
|
||||
assert result[3] == {"role": "assistant", "content": "Here are the results."}
|
||||
|
||||
|
||||
def test_strip_tool_content_handles_assistant_with_no_text_content() -> None:
|
||||
from onyx.llm.multi_llm import _strip_tool_content_from_messages
|
||||
|
||||
messages: list[dict[str, Any]] = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "tc_1",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
result = _strip_tool_content_from_messages(messages)
|
||||
assert result[0]["role"] == "assistant"
|
||||
assert "[Tool Call]" in result[0]["content"]
|
||||
assert "tool_calls" not in result[0]
|
||||
|
||||
|
||||
def test_strip_tool_content_passes_through_non_tool_messages() -> None:
|
||||
from onyx.llm.multi_llm import _strip_tool_content_from_messages
|
||||
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi!"},
|
||||
]
|
||||
|
||||
result = _strip_tool_content_from_messages(messages)
|
||||
assert result == messages
|
||||
|
||||
|
||||
def test_strip_tool_content_handles_list_content_blocks() -> None:
|
||||
from onyx.llm.multi_llm import _strip_tool_content_from_messages
|
||||
|
||||
messages: list[dict[str, Any]] = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "Searching now."}],
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "tc_1",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": [
|
||||
{"type": "text", "text": "result A"},
|
||||
{"type": "text", "text": "result B"},
|
||||
],
|
||||
"tool_call_id": "tc_1",
|
||||
},
|
||||
]
|
||||
|
||||
result = _strip_tool_content_from_messages(messages)
|
||||
|
||||
# Assistant: list content flattened + tool call appended
|
||||
assert result[0]["role"] == "assistant"
|
||||
assert "Searching now." in result[0]["content"]
|
||||
assert "[Tool Call]" in result[0]["content"]
|
||||
assert isinstance(result[0]["content"], str)
|
||||
|
||||
# Tool: list content flattened into user message
|
||||
assert result[1]["role"] == "user"
|
||||
assert "result A" in result[1]["content"]
|
||||
assert "result B" in result[1]["content"]
|
||||
assert isinstance(result[1]["content"], str)
|
||||
|
||||
|
||||
def test_strip_tool_content_merges_consecutive_tool_results() -> None:
|
||||
"""Bedrock requires strict user/assistant alternation. Multiple parallel
|
||||
tool results must be merged into a single user message."""
|
||||
from onyx.llm.multi_llm import _strip_tool_content_from_messages
|
||||
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "weather and news?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "tc_1",
|
||||
"type": "function",
|
||||
"function": {"name": "search_weather", "arguments": "{}"},
|
||||
},
|
||||
{
|
||||
"id": "tc_2",
|
||||
"type": "function",
|
||||
"function": {"name": "search_news", "arguments": "{}"},
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": "sunny 72F", "tool_call_id": "tc_1"},
|
||||
{"role": "tool", "content": "headline news", "tool_call_id": "tc_2"},
|
||||
{"role": "assistant", "content": "Here are the results."},
|
||||
]
|
||||
|
||||
result = _strip_tool_content_from_messages(messages)
|
||||
|
||||
# user, assistant (flattened), user (merged tool results), assistant
|
||||
assert len(result) == 4
|
||||
roles = [m["role"] for m in result]
|
||||
assert roles == ["user", "assistant", "user", "assistant"]
|
||||
|
||||
# Both tool results merged into one user message
|
||||
merged = result[2]["content"]
|
||||
assert "tc_1" in merged
|
||||
assert "sunny 72F" in merged
|
||||
assert "tc_2" in merged
|
||||
assert "headline news" in merged
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user