mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-07 08:35:47 +00:00
Compare commits
68 Commits
nikg/std-e
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
29d9ebf7b3 | ||
|
|
f1df36e306 | ||
|
|
1611604269 | ||
|
|
c2a71091dc | ||
|
|
cc008699e5 | ||
|
|
48802618db | ||
|
|
6917953b86 | ||
|
|
e7cf027f8a | ||
|
|
41fb1480bb | ||
|
|
bdc2bfdcee | ||
|
|
8816d52b27 | ||
|
|
6590f1d7ba | ||
|
|
c527f75557 | ||
|
|
472d1788a7 | ||
|
|
99e95f8205 | ||
|
|
e618bf8385 | ||
|
|
f4dcd130ba | ||
|
|
910718deaa | ||
|
|
1a7ca93b93 | ||
|
|
a615a920cb | ||
|
|
29d8b310b5 | ||
|
|
d1409ccafa | ||
|
|
e41bad9103 | ||
|
|
661dc831dc | ||
|
|
19016dd35a | ||
|
|
127b2dcc80 | ||
|
|
b015a37cea | ||
|
|
b45277a8b0 | ||
|
|
893e8da79a | ||
|
|
a51f0d7cb2 | ||
|
|
c826d0469e | ||
|
|
0f6ae6f69c | ||
|
|
d0836e2603 | ||
|
|
bda03bafca | ||
|
|
376adff94a | ||
|
|
d2d4b89286 | ||
|
|
dde7a18bb7 | ||
|
|
3f004cf02f | ||
|
|
ae893079c3 | ||
|
|
189c07a913 | ||
|
|
2b82743bf5 | ||
|
|
ba2a5a60e1 | ||
|
|
5888f9d69f | ||
|
|
23b3a0a6ae | ||
|
|
eced88fa7a | ||
|
|
f59aaa902d | ||
|
|
57349bdbd1 | ||
|
|
192639a801 | ||
|
|
c10ffbb464 | ||
|
|
091f41fd1f | ||
|
|
45d77be4eb | ||
|
|
413fa85134 | ||
|
|
108cde4f55 | ||
|
|
f88ce32bd4 | ||
|
|
911f3439ea | ||
|
|
b02590d2b2 | ||
|
|
2d75b4b1f8 | ||
|
|
7e3f7d01c2 | ||
|
|
9d6ce26ea3 | ||
|
|
41713d42a2 | ||
|
|
8afc283410 | ||
|
|
b5c873077e | ||
|
|
20a4dd32eb | ||
|
|
fde0d44bc1 | ||
|
|
8fd91b6e83 | ||
|
|
8247fdd45b | ||
|
|
8c5859ba4d | ||
|
|
62ef6f59bb |
186
.cursor/skills/onyx-cli/SKILL.md
Normal file
186
.cursor/skills/onyx-cli/SKILL.md
Normal file
@@ -0,0 +1,186 @@
|
||||
---
|
||||
name: onyx-cli
|
||||
description: Query the Onyx knowledge base using the onyx-cli command. Use when the user wants to search company documents, ask questions about internal knowledge, query connected data sources, or look up information stored in Onyx.
|
||||
---
|
||||
|
||||
# Onyx CLI — Agent Tool
|
||||
|
||||
Onyx is an enterprise search and Gen-AI platform that connects to company documents, apps, and people. The `onyx-cli` CLI provides non-interactive commands to query the Onyx knowledge base and list available agents.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### 1. Check if installed
|
||||
|
||||
```bash
|
||||
which onyx-cli
|
||||
```
|
||||
|
||||
### 2. Install (if needed)
|
||||
|
||||
**Primary — pip:**
|
||||
|
||||
```bash
|
||||
pip install onyx-cli
|
||||
```
|
||||
|
||||
**From source (Go):**
|
||||
|
||||
```bash
|
||||
cd cli && go build -o onyx-cli . && sudo mv onyx-cli /usr/local/bin/
|
||||
```
|
||||
|
||||
### 3. Check if configured
|
||||
|
||||
```bash
|
||||
onyx-cli validate-config
|
||||
```
|
||||
|
||||
This checks the config file exists, API key is present, and tests the server connection via `/api/me`. Exit code 0 on success, non-zero with a descriptive error on failure.
|
||||
|
||||
If unconfigured, you have two options:
|
||||
|
||||
**Option A — Interactive setup (requires user input):**
|
||||
|
||||
```bash
|
||||
onyx-cli configure
|
||||
```
|
||||
|
||||
This prompts for the Onyx server URL and API key, tests the connection, and saves config.
|
||||
|
||||
**Option B — Environment variables (non-interactive, preferred for agents):**
|
||||
|
||||
```bash
|
||||
export ONYX_SERVER_URL="https://your-onyx-server.com" # default: https://cloud.onyx.app
|
||||
export ONYX_API_KEY="your-api-key"
|
||||
```
|
||||
|
||||
Environment variables override the config file. If these are set, no config file is needed.
|
||||
|
||||
| Variable | Required | Description |
|
||||
|----------|----------|-------------|
|
||||
| `ONYX_SERVER_URL` | No | Onyx server base URL (default: `https://cloud.onyx.app`) |
|
||||
| `ONYX_API_KEY` | Yes | API key for authentication |
|
||||
| `ONYX_PERSONA_ID` | No | Default agent/persona ID |
|
||||
|
||||
If neither the config file nor environment variables are set, tell the user that `onyx-cli` needs to be configured and ask them to either:
|
||||
- Run `onyx-cli configure` interactively, or
|
||||
- Set `ONYX_SERVER_URL` and `ONYX_API_KEY` environment variables
|
||||
|
||||
## Commands
|
||||
|
||||
### Validate configuration
|
||||
|
||||
```bash
|
||||
onyx-cli validate-config
|
||||
```
|
||||
|
||||
Checks config file exists, API key is present, and tests the server connection. Use this before `ask` or `agents` to confirm the CLI is properly set up.
|
||||
|
||||
### List available agents
|
||||
|
||||
```bash
|
||||
onyx-cli agents
|
||||
```
|
||||
|
||||
Prints a table of agent IDs, names, and descriptions. Use `--json` for structured output:
|
||||
|
||||
```bash
|
||||
onyx-cli agents --json
|
||||
```
|
||||
|
||||
Use agent IDs with `ask --agent-id` to query a specific agent.
|
||||
|
||||
### Basic query (plain text output)
|
||||
|
||||
```bash
|
||||
onyx-cli ask "What is our company's PTO policy?"
|
||||
```
|
||||
|
||||
Streams the answer as plain text to stdout. Exit code 0 on success, non-zero on error.
|
||||
|
||||
### JSON output (structured events)
|
||||
|
||||
```bash
|
||||
onyx-cli ask --json "What authentication methods do we support?"
|
||||
```
|
||||
|
||||
Outputs JSON-encoded parsed stream events (one object per line). Key event objects include message deltas, stop, errors, search-start, and citation payloads.
|
||||
|
||||
Each line is a JSON object with this envelope:
|
||||
|
||||
```json
|
||||
{"type": "<event_type>", "event": { ... }}
|
||||
```
|
||||
|
||||
| Event Type | Description |
|
||||
|------------|-------------|
|
||||
| `message_delta` | Content token — concatenate all `content` fields for the full answer |
|
||||
| `stop` | Stream complete |
|
||||
| `error` | Error with `error` message field |
|
||||
| `search_tool_start` | Onyx started searching documents |
|
||||
| `citation_info` | Source citation — see shape below |
|
||||
|
||||
`citation_info` event shape:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "citation_info",
|
||||
"event": {
|
||||
"citation_number": 1,
|
||||
"document_id": "abc123def456",
|
||||
"placement": {"turn_index": 0, "tab_index": 0, "sub_turn_index": null}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`placement` is metadata about where in the conversation the citation appeared and can be ignored for most use cases.
|
||||
|
||||
### Specify an agent
|
||||
|
||||
```bash
|
||||
onyx-cli ask --agent-id 5 "Summarize our Q4 roadmap"
|
||||
```
|
||||
|
||||
Uses a specific Onyx agent/persona instead of the default.
|
||||
|
||||
### All flags
|
||||
|
||||
| Flag | Type | Description |
|
||||
|------|------|-------------|
|
||||
| `--agent-id` | int | Agent ID to use (overrides default) |
|
||||
| `--json` | bool | Output raw NDJSON events instead of plain text |
|
||||
|
||||
## Statelessness
|
||||
|
||||
Each `onyx-cli ask` call creates an independent chat session. There is no built-in way to chain context across multiple `ask` invocations — every call starts fresh. If you need multi-turn conversation with memory, use the interactive TUI (`onyx-cli` or `onyx-cli chat`) instead.
|
||||
|
||||
## When to Use
|
||||
|
||||
Use `onyx-cli ask` when:
|
||||
|
||||
- The user asks about company-specific information (policies, docs, processes)
|
||||
- You need to search internal knowledge bases or connected data sources
|
||||
- The user references Onyx, asks you to "search Onyx", or wants to query their documents
|
||||
- You need context from company wikis, Confluence, Google Drive, Slack, or other connected sources
|
||||
|
||||
Do NOT use when:
|
||||
|
||||
- The question is about general programming knowledge (use your own knowledge)
|
||||
- The user is asking about code in the current repository (use grep/read tools)
|
||||
- The user hasn't mentioned Onyx and the question doesn't require internal company data
|
||||
|
||||
## Examples
|
||||
|
||||
```bash
|
||||
# Simple question
|
||||
onyx-cli ask "What are the steps to deploy to production?"
|
||||
|
||||
# Get structured output for parsing
|
||||
onyx-cli ask --json "List all active API integrations"
|
||||
|
||||
# Use a specialized agent
|
||||
onyx-cli ask --agent-id 3 "What were the action items from last week's standup?"
|
||||
|
||||
# Pipe the answer into another command
|
||||
onyx-cli ask "What is the database schema for users?" | head -20
|
||||
```
|
||||
96
.github/workflows/deployment.yml
vendored
96
.github/workflows/deployment.yml
vendored
@@ -182,9 +182,53 @@ jobs:
|
||||
title: "🚨 Version Tag Check Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
|
||||
build-desktop:
|
||||
# Create GitHub release first, before desktop builds start.
|
||||
# This ensures all desktop matrix jobs upload to the same release instead of
|
||||
# racing to create duplicate releases.
|
||||
create-release:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-desktop == 'true'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 10
|
||||
permissions:
|
||||
contents: write
|
||||
outputs:
|
||||
release-id: ${{ steps.create-release.outputs.id }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Determine release tag
|
||||
id: release-tag
|
||||
env:
|
||||
IS_TEST_RUN: ${{ needs.determine-builds.outputs.is-test-run }}
|
||||
SHORT_SHA: ${{ needs.determine-builds.outputs.short-sha }}
|
||||
run: |
|
||||
if [ "${IS_TEST_RUN}" == "true" ]; then
|
||||
echo "tag=v0.0.0-dev+${SHORT_SHA}" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "tag=${GITHUB_REF_NAME}" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: Create GitHub Release
|
||||
id: create-release
|
||||
uses: softprops/action-gh-release@da05d552573ad5aba039eaac05058a918a7bf631 # ratchet:softprops/action-gh-release@v2
|
||||
with:
|
||||
tag_name: ${{ steps.release-tag.outputs.tag }}
|
||||
name: ${{ steps.release-tag.outputs.tag }}
|
||||
body: "See the assets to download this version and install."
|
||||
draft: true
|
||||
prerelease: false
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
build-desktop:
|
||||
needs:
|
||||
- determine-builds
|
||||
- create-release
|
||||
if: needs.determine-builds.outputs.build-desktop == 'true'
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: write
|
||||
@@ -208,12 +252,12 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6.0.2
|
||||
with:
|
||||
# NOTE: persist-credentials is needed for tauri-action to create GitHub releases.
|
||||
# NOTE: persist-credentials is needed for tauri-action to upload assets to GitHub releases.
|
||||
persist-credentials: true # zizmor: ignore[artipacked]
|
||||
|
||||
- name: Configure AWS credentials
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -353,11 +397,9 @@ jobs:
|
||||
APPLE_SIGNING_IDENTITY: ${{ env.CERT_ID }}
|
||||
APPLE_TEAM_ID: ${{ env.APPLE_TEAM_ID }}
|
||||
with:
|
||||
tagName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseBody: "See the assets to download this version and install."
|
||||
releaseDraft: true
|
||||
prerelease: false
|
||||
# Use the release created by the create-release job to avoid race conditions
|
||||
# when multiple matrix jobs try to create/update the same release simultaneously
|
||||
releaseId: ${{ needs.create-release.outputs.release-id }}
|
||||
assetNamePattern: "[name]_[arch][ext]"
|
||||
args: ${{ matrix.args }}
|
||||
|
||||
@@ -384,7 +426,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -458,7 +500,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -527,7 +569,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -597,7 +639,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -679,7 +721,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -756,7 +798,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -823,7 +865,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -896,7 +938,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -964,7 +1006,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1034,7 +1076,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1107,7 +1149,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1176,7 +1218,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1246,7 +1288,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1326,7 +1368,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1400,7 +1442,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1465,7 +1507,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1520,7 +1562,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1580,7 +1622,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1637,7 +1679,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
@@ -15,7 +15,8 @@ permissions:
|
||||
jobs:
|
||||
provider-chat-test:
|
||||
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
|
||||
secrets: inherit
|
||||
secrets:
|
||||
AWS_OIDC_ROLE_ARN: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
|
||||
@@ -6,11 +6,13 @@ on:
|
||||
- main
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
cherry-pick-to-latest-release:
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
outputs:
|
||||
should_cherrypick: ${{ steps.gate.outputs.should_cherrypick }}
|
||||
pr_number: ${{ steps.gate.outputs.pr_number }}
|
||||
|
||||
2
.github/workflows/pr-desktop-build.yml
vendored
2
.github/workflows/pr-desktop-build.yml
vendored
@@ -57,7 +57,7 @@ jobs:
|
||||
cache-dependency-path: ./desktop/package-lock.json
|
||||
|
||||
- name: Setup Rust
|
||||
uses: dtolnay/rust-toolchain@4be9e76fd7c4901c61fb841f559994984270fce7
|
||||
uses: dtolnay/rust-toolchain@efa25f7f19611383d5b0ccf2d1c8914531636bf9
|
||||
with:
|
||||
toolchain: stable
|
||||
targets: ${{ matrix.target }}
|
||||
|
||||
56
.github/workflows/pr-golang-tests.yml
vendored
Normal file
56
.github/workflows/pr-golang-tests.yml
vendored
Normal file
@@ -0,0 +1,56 @@
|
||||
name: Golang Tests
|
||||
concurrency:
|
||||
group: Golang-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- "release/**"
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
permissions: {}
|
||||
|
||||
env:
|
||||
GO_VERSION: "1.26"
|
||||
|
||||
jobs:
|
||||
detect-modules:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
outputs:
|
||||
modules: ${{ steps.set-modules.outputs.modules }}
|
||||
steps:
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8
|
||||
with:
|
||||
persist-credentials: false
|
||||
- id: set-modules
|
||||
run: echo "modules=$(find . -name 'go.mod' -exec dirname {} \; | jq -Rc '[.,inputs]')" >> "$GITHUB_OUTPUT"
|
||||
|
||||
golang:
|
||||
needs: detect-modules
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
strategy:
|
||||
matrix:
|
||||
modules: ${{ fromJSON(needs.detect-modules.outputs.modules) }}
|
||||
steps:
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # zizmor: ignore[cache-poisoning]
|
||||
with:
|
||||
go-version: ${{ env.GO_VERSION }}
|
||||
cache-dependency-path: "**/go.sum"
|
||||
|
||||
- run: go mod tidy
|
||||
working-directory: ${{ matrix.modules }}
|
||||
- run: git diff --exit-code go.mod go.sum
|
||||
working-directory: ${{ matrix.modules }}
|
||||
|
||||
- run: go test ./...
|
||||
working-directory: ${{ matrix.modules }}
|
||||
2
.github/workflows/pr-helm-chart-testing.yml
vendored
2
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -71,7 +71,7 @@ jobs:
|
||||
|
||||
- name: Create kind cluster
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@92086f6be054225fa813e0a4b13787fc9088faab # ratchet:helm/kind-action@v1.13.0
|
||||
uses: helm/kind-action@ef37e7f390d99f746eb8b610417061a60e82a6cc # ratchet:helm/kind-action@v1.14.0
|
||||
|
||||
- name: Pre-install cluster status check
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
|
||||
2
.github/workflows/pr-jest-tests.yml
vendored
2
.github/workflows/pr-jest-tests.yml
vendored
@@ -31,7 +31,7 @@ jobs:
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
cache: "npm" # zizmor: ignore[cache-poisoning] test-only workflow; no deploy artifacts
|
||||
cache-dependency-path: ./web/package-lock.json
|
||||
|
||||
- name: Install node dependencies
|
||||
|
||||
2
.github/workflows/pr-playwright-tests.yml
vendored
2
.github/workflows/pr-playwright-tests.yml
vendored
@@ -461,7 +461,7 @@ jobs:
|
||||
# --- Visual Regression Diff ---
|
||||
- name: Configure AWS credentials
|
||||
if: always()
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
4
.github/workflows/pr-quality-checks.yml
vendored
4
.github/workflows/pr-quality-checks.yml
vendored
@@ -38,9 +38,9 @@ jobs:
|
||||
- name: Install node dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
- uses: j178/prek-action@9d6a3097e0c1865ecce00cfb89fe80f2ee91b547 # ratchet:j178/prek-action@v1
|
||||
- uses: j178/prek-action@0bb87d7f00b0c99306c8bcb8b8beba1eb581c037 # ratchet:j178/prek-action@v1
|
||||
with:
|
||||
prek-version: '0.2.21'
|
||||
prek-version: '0.3.4'
|
||||
extra-args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || github.event_name == 'merge_group' && format('--from-ref {0} --to-ref {1}', github.event.merge_group.base_sha, github.event.merge_group.head_sha) || github.ref_name == 'main' && '--all-files' || '' }}
|
||||
- name: Check Actions
|
||||
uses: giner/check-actions@28d366c7cbbe235f9624a88aa31a628167eee28c # ratchet:giner/check-actions@v1.0.1
|
||||
|
||||
214
.github/workflows/release-cli.yml
vendored
Normal file
214
.github/workflows/release-cli.yml
vendored
Normal file
@@ -0,0 +1,214 @@
|
||||
name: Release CLI
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "cli/v*.*.*"
|
||||
|
||||
jobs:
|
||||
pypi:
|
||||
runs-on: ubuntu-latest
|
||||
environment:
|
||||
name: release-cli
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 10
|
||||
strategy:
|
||||
matrix:
|
||||
os-arch:
|
||||
- { goos: "linux", goarch: "amd64" }
|
||||
- { goos: "linux", goarch: "arm64" }
|
||||
- { goos: "windows", goarch: "amd64" }
|
||||
- { goos: "windows", goarch: "arm64" }
|
||||
- { goos: "darwin", goarch: "amd64" }
|
||||
- { goos: "darwin", goarch: "arm64" }
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
- run: |
|
||||
GOOS="${{ matrix.os-arch.goos }}" \
|
||||
GOARCH="${{ matrix.os-arch.goarch }}" \
|
||||
uv build --wheel
|
||||
working-directory: cli
|
||||
- run: uv publish
|
||||
working-directory: cli
|
||||
|
||||
docker-amd64:
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-cli-amd64
|
||||
- extras=ecr-cache
|
||||
environment: deploy
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 30
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-cli
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./cli
|
||||
file: ./cli/Dockerfile
|
||||
platforms: linux/amd64
|
||||
cache-from: type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: type=inline
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
docker-arm64:
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-cli-arm64
|
||||
- extras=ecr-cache
|
||||
environment: deploy
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 30
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-cli
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./cli
|
||||
file: ./cli/Dockerfile
|
||||
platforms: linux/arm64
|
||||
cache-from: type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: type=inline
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
merge-docker:
|
||||
needs:
|
||||
- docker-amd64
|
||||
- docker-arm64
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-cli-merge
|
||||
environment: deploy
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 10
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-cli
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
AMD64_DIGEST: ${{ needs.docker-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.docker-arm64.outputs.digest }}
|
||||
TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
SANITIZED_TAG="${TAG#cli/}"
|
||||
IMAGES=(
|
||||
"${REGISTRY_IMAGE}@${AMD64_DIGEST}"
|
||||
"${REGISTRY_IMAGE}@${ARM64_DIGEST}"
|
||||
)
|
||||
|
||||
if [[ "$TAG" =~ ^cli/v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||
docker buildx imagetools create \
|
||||
-t "${REGISTRY_IMAGE}:${SANITIZED_TAG}" \
|
||||
-t "${REGISTRY_IMAGE}:latest" \
|
||||
"${IMAGES[@]}"
|
||||
else
|
||||
docker buildx imagetools create \
|
||||
-t "${REGISTRY_IMAGE}:${SANITIZED_TAG}" \
|
||||
"${IMAGES[@]}"
|
||||
fi
|
||||
2
.github/workflows/release-devtools.yml
vendored
2
.github/workflows/release-devtools.yml
vendored
@@ -22,12 +22,10 @@ jobs:
|
||||
- { goos: "windows", goarch: "arm64" }
|
||||
- { goos: "darwin", goarch: "amd64" }
|
||||
- { goos: "darwin", goarch: "arm64" }
|
||||
- { goos: "", goarch: "" }
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
- uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
|
||||
@@ -48,6 +48,10 @@ on:
|
||||
required: false
|
||||
default: true
|
||||
type: boolean
|
||||
secrets:
|
||||
AWS_OIDC_ROLE_ARN:
|
||||
description: "AWS role ARN for OIDC auth"
|
||||
required: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -73,7 +77,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -116,7 +120,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -158,7 +162,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -264,7 +268,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
6
.github/workflows/sandbox-deployment.yml
vendored
6
.github/workflows/sandbox-deployment.yml
vendored
@@ -110,7 +110,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -180,7 +180,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -244,7 +244,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
@@ -119,10 +119,11 @@ repos:
|
||||
]
|
||||
|
||||
- repo: https://github.com/golangci/golangci-lint
|
||||
rev: 9f61b0f53f80672872fced07b6874397c3ed197b # frozen: v2.7.2
|
||||
rev: 5d1e709b7be35cb2025444e19de266b056b7b7ee # frozen: v2.10.1
|
||||
hooks:
|
||||
- id: golangci-lint
|
||||
entry: bash -c "find tools/ -name go.mod -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
|
||||
language_version: "1.26.0"
|
||||
entry: bash -c "find . -name go.mod -not -path './.venv/*' -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
|
||||
@@ -104,6 +104,10 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
|
||||
|
||||
- Always use `@shared_task` rather than `@celery_app`
|
||||
- Put tasks under `background/celery/tasks/` or `ee/background/celery/tasks`
|
||||
- Never enqueue a task without an expiration. Always supply `expires=` when
|
||||
sending tasks, either from the beat schedule or directly from another task. It
|
||||
should never be acceptable to submit code which enqueues tasks without an
|
||||
expiration, as doing so can lead to unbounded task queue growth.
|
||||
|
||||
**Defining APIs**:
|
||||
When creating new FastAPI APIs, do NOT use the `response_model` field. Instead, just type the
|
||||
@@ -540,6 +544,8 @@ To run them:
|
||||
npx playwright test <TEST_NAME>
|
||||
```
|
||||
|
||||
For shared fixtures, best practices, and detailed guidance, see `backend/tests/README.md`.
|
||||
|
||||
## Logs
|
||||
|
||||
When (1) writing integration tests or (2) doing live tests (e.g. curl / playwright) you can get access
|
||||
|
||||
@@ -141,6 +141,7 @@ COPY --chown=onyx:onyx ./scripts/debugging /app/scripts/debugging
|
||||
COPY --chown=onyx:onyx ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
|
||||
COPY --chown=onyx:onyx ./scripts/supervisord_entrypoint.sh /app/scripts/supervisord_entrypoint.sh
|
||||
COPY --chown=onyx:onyx ./scripts/setup_craft_templates.sh /app/scripts/setup_craft_templates.sh
|
||||
COPY --chown=onyx:onyx ./scripts/reencrypt_secrets.py /app/scripts/reencrypt_secrets.py
|
||||
RUN chmod +x /app/scripts/supervisord_entrypoint.sh /app/scripts/setup_craft_templates.sh
|
||||
|
||||
# Run Craft template setup at build time when ENABLE_CRAFT=true
|
||||
|
||||
@@ -246,7 +246,11 @@ async def get_billing_information(
|
||||
)
|
||||
except OnyxError as e:
|
||||
# Open circuit breaker on connection failures (self-hosted only)
|
||||
if e.status_code in (502, 503, 504):
|
||||
if e.status_code in (
|
||||
OnyxErrorCode.BAD_GATEWAY.status_code,
|
||||
OnyxErrorCode.SERVICE_UNAVAILABLE.status_code,
|
||||
OnyxErrorCode.GATEWAY_TIMEOUT.status_code,
|
||||
):
|
||||
_open_billing_circuit()
|
||||
raise
|
||||
|
||||
|
||||
@@ -14,67 +14,91 @@ from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
@lru_cache(maxsize=2)
|
||||
def _get_trimmed_key(key: str) -> bytes:
|
||||
encoded_key = key.encode()
|
||||
key_length = len(encoded_key)
|
||||
if key_length < 16:
|
||||
raise RuntimeError("Invalid ENCRYPTION_KEY_SECRET - too short")
|
||||
elif key_length > 32:
|
||||
key = key[:32]
|
||||
elif key_length not in (16, 24, 32):
|
||||
valid_lengths = [16, 24, 32]
|
||||
key = key[: min(valid_lengths, key=lambda x: abs(x - key_length))]
|
||||
|
||||
return encoded_key
|
||||
# Trim to the largest valid AES key size that fits
|
||||
valid_lengths = [32, 24, 16]
|
||||
for size in valid_lengths:
|
||||
if key_length >= size:
|
||||
return encoded_key[:size]
|
||||
|
||||
raise AssertionError("unreachable")
|
||||
|
||||
|
||||
def _encrypt_string(input_str: str) -> bytes:
|
||||
if not ENCRYPTION_KEY_SECRET:
|
||||
def _encrypt_string(input_str: str, key: str | None = None) -> bytes:
|
||||
effective_key = key if key is not None else ENCRYPTION_KEY_SECRET
|
||||
if not effective_key:
|
||||
return input_str.encode()
|
||||
|
||||
key = _get_trimmed_key(ENCRYPTION_KEY_SECRET)
|
||||
trimmed = _get_trimmed_key(effective_key)
|
||||
iv = urandom(16)
|
||||
padder = padding.PKCS7(algorithms.AES.block_size).padder()
|
||||
padded_data = padder.update(input_str.encode()) + padder.finalize()
|
||||
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||
cipher = Cipher(algorithms.AES(trimmed), modes.CBC(iv), backend=default_backend())
|
||||
encryptor = cipher.encryptor()
|
||||
encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
|
||||
|
||||
return iv + encrypted_data
|
||||
|
||||
|
||||
def _decrypt_bytes(input_bytes: bytes) -> str:
|
||||
if not ENCRYPTION_KEY_SECRET:
|
||||
def _decrypt_bytes(input_bytes: bytes, key: str | None = None) -> str:
|
||||
effective_key = key if key is not None else ENCRYPTION_KEY_SECRET
|
||||
if not effective_key:
|
||||
return input_bytes.decode()
|
||||
|
||||
key = _get_trimmed_key(ENCRYPTION_KEY_SECRET)
|
||||
iv = input_bytes[:16]
|
||||
encrypted_data = input_bytes[16:]
|
||||
trimmed = _get_trimmed_key(effective_key)
|
||||
try:
|
||||
iv = input_bytes[:16]
|
||||
encrypted_data = input_bytes[16:]
|
||||
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||
decryptor = cipher.decryptor()
|
||||
decrypted_padded_data = decryptor.update(encrypted_data) + decryptor.finalize()
|
||||
cipher = Cipher(
|
||||
algorithms.AES(trimmed), modes.CBC(iv), backend=default_backend()
|
||||
)
|
||||
decryptor = cipher.decryptor()
|
||||
decrypted_padded_data = decryptor.update(encrypted_data) + decryptor.finalize()
|
||||
|
||||
unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
|
||||
decrypted_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
|
||||
unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
|
||||
decrypted_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
|
||||
|
||||
return decrypted_data.decode()
|
||||
return decrypted_data.decode()
|
||||
except (ValueError, UnicodeDecodeError):
|
||||
if key is not None:
|
||||
# Explicit key was provided — don't fall back silently
|
||||
raise
|
||||
# Read path: attempt raw UTF-8 decode as a fallback for legacy data.
|
||||
# Does NOT handle data encrypted with a different key — that
|
||||
# ciphertext is not valid UTF-8 and will raise below.
|
||||
logger.warning(
|
||||
"AES decryption failed — falling back to raw decode. "
|
||||
"Run the re-encrypt secrets script to rotate to the current key."
|
||||
)
|
||||
try:
|
||||
return input_bytes.decode()
|
||||
except UnicodeDecodeError:
|
||||
raise ValueError(
|
||||
"Data is not valid UTF-8 — likely encrypted with a different key. "
|
||||
"Run the re-encrypt secrets script to rotate to the current key."
|
||||
) from None
|
||||
|
||||
|
||||
def encrypt_string_to_bytes(input_str: str) -> bytes:
|
||||
def encrypt_string_to_bytes(input_str: str, key: str | None = None) -> bytes:
|
||||
versioned_encryption_fn = fetch_versioned_implementation(
|
||||
"onyx.utils.encryption", "_encrypt_string"
|
||||
)
|
||||
return versioned_encryption_fn(input_str)
|
||||
return versioned_encryption_fn(input_str, key=key)
|
||||
|
||||
|
||||
def decrypt_bytes_to_string(input_bytes: bytes) -> str:
|
||||
def decrypt_bytes_to_string(input_bytes: bytes, key: str | None = None) -> str:
|
||||
versioned_decryption_fn = fetch_versioned_implementation(
|
||||
"onyx.utils.encryption", "_decrypt_bytes"
|
||||
)
|
||||
return versioned_decryption_fn(input_bytes)
|
||||
return versioned_decryption_fn(input_bytes, key=key)
|
||||
|
||||
|
||||
def test_encryption() -> None:
|
||||
|
||||
@@ -58,8 +58,6 @@ from onyx.file_store.document_batch_storage import DocumentBatchStorage
|
||||
from onyx.file_store.document_batch_storage import get_document_batch_storage
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
|
||||
from onyx.indexing.postgres_sanitization import sanitize_document_for_postgres
|
||||
from onyx.indexing.postgres_sanitization import sanitize_hierarchy_nodes_for_postgres
|
||||
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
|
||||
from onyx.redis.redis_hierarchy import ensure_source_node_exists
|
||||
from onyx.redis.redis_hierarchy import get_node_id_from_raw_id
|
||||
@@ -71,6 +69,8 @@ from onyx.server.features.build.indexing.persistent_document_writer import (
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.middleware import make_randomized_onyx_request_id
|
||||
from onyx.utils.postgres_sanitization import sanitize_document_for_postgres
|
||||
from onyx.utils.postgres_sanitization import sanitize_hierarchy_nodes_for_postgres
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import INDEX_ATTEMPT_INFO_CONTEXTVAR
|
||||
|
||||
@@ -36,7 +36,6 @@ from onyx.db.memory import add_memory
|
||||
from onyx.db.memory import update_memory_at_index
|
||||
from onyx.db.memory import UserMemoryContext
|
||||
from onyx.db.models import Persona
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
@@ -84,28 +83,6 @@ def _looks_like_xml_tool_call_payload(text: str | None) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def _should_keep_bedrock_tool_definitions(
|
||||
llm: object, simple_chat_history: list[ChatMessageSimple]
|
||||
) -> bool:
|
||||
"""Bedrock requires tool config when history includes toolUse/toolResult blocks."""
|
||||
model_provider = getattr(getattr(llm, "config", None), "model_provider", None)
|
||||
if model_provider not in {
|
||||
LlmProviderNames.BEDROCK,
|
||||
LlmProviderNames.BEDROCK_CONVERSE,
|
||||
}:
|
||||
return False
|
||||
|
||||
return any(
|
||||
(
|
||||
msg.message_type == MessageType.ASSISTANT
|
||||
and msg.tool_calls
|
||||
and len(msg.tool_calls) > 0
|
||||
)
|
||||
or msg.message_type == MessageType.TOOL_CALL_RESPONSE
|
||||
for msg in simple_chat_history
|
||||
)
|
||||
|
||||
|
||||
def _try_fallback_tool_extraction(
|
||||
llm_step_result: LlmStepResult,
|
||||
tool_choice: ToolChoiceOptions,
|
||||
@@ -686,12 +663,7 @@ def run_llm_loop(
|
||||
elif out_of_cycles or ran_image_gen:
|
||||
# Last cycle, no tools allowed, just answer!
|
||||
tool_choice = ToolChoiceOptions.NONE
|
||||
# Bedrock requires tool config in requests that include toolUse/toolResult history.
|
||||
final_tools = (
|
||||
tools
|
||||
if _should_keep_bedrock_tool_definitions(llm, simple_chat_history)
|
||||
else []
|
||||
)
|
||||
final_tools = []
|
||||
else:
|
||||
tool_choice = ToolChoiceOptions.AUTO
|
||||
final_tools = tools
|
||||
|
||||
@@ -55,6 +55,7 @@ from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tracing.framework.create import generation_span
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_string
|
||||
from onyx.utils.text_processing import find_all_json_objects
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -166,15 +167,6 @@ def _find_function_calls_open_marker(text_lower: str) -> int:
|
||||
search_from = idx + 1
|
||||
|
||||
|
||||
def _sanitize_llm_output(value: str) -> str:
|
||||
"""Remove characters that PostgreSQL's text/JSONB types cannot store.
|
||||
|
||||
- NULL bytes (\x00): Not allowed in PostgreSQL text types
|
||||
- UTF-16 surrogates (\ud800-\udfff): Invalid in UTF-8 encoding
|
||||
"""
|
||||
return "".join(c for c in value if c != "\x00" and not ("\ud800" <= c <= "\udfff"))
|
||||
|
||||
|
||||
def _try_parse_json_string(value: Any) -> Any:
|
||||
"""Attempt to parse a JSON string value into its Python equivalent.
|
||||
|
||||
@@ -222,9 +214,7 @@ def _parse_tool_args_to_dict(raw_args: Any) -> dict[str, Any]:
|
||||
if isinstance(raw_args, dict):
|
||||
# Parse any string values that look like JSON arrays/objects
|
||||
return {
|
||||
k: _try_parse_json_string(
|
||||
_sanitize_llm_output(v) if isinstance(v, str) else v
|
||||
)
|
||||
k: _try_parse_json_string(sanitize_string(v) if isinstance(v, str) else v)
|
||||
for k, v in raw_args.items()
|
||||
}
|
||||
|
||||
@@ -232,7 +222,7 @@ def _parse_tool_args_to_dict(raw_args: Any) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
# Sanitize before parsing to remove NULL bytes and surrogates
|
||||
raw_args = _sanitize_llm_output(raw_args)
|
||||
raw_args = sanitize_string(raw_args)
|
||||
|
||||
try:
|
||||
parsed1: Any = json.loads(raw_args)
|
||||
@@ -545,12 +535,12 @@ def _extract_xml_attribute(attrs: str, attr_name: str) -> str | None:
|
||||
)
|
||||
if not attr_match:
|
||||
return None
|
||||
return _sanitize_llm_output(unescape(attr_match.group(2).strip()))
|
||||
return sanitize_string(unescape(attr_match.group(2).strip()))
|
||||
|
||||
|
||||
def _parse_xml_parameter_value(raw_value: str, string_attr: str | None) -> Any:
|
||||
"""Parse a parameter value from XML-style tool call payloads."""
|
||||
value = _sanitize_llm_output(unescape(raw_value).strip())
|
||||
value = sanitize_string(unescape(raw_value).strip())
|
||||
|
||||
if string_attr and string_attr.lower() == "true":
|
||||
return value
|
||||
@@ -569,6 +559,7 @@ def _resolve_tool_arguments(obj: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""
|
||||
arguments = obj.get("arguments", obj.get("parameters", {}))
|
||||
if isinstance(arguments, str):
|
||||
arguments = sanitize_string(arguments)
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
|
||||
@@ -19,6 +19,7 @@ from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_string
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -201,8 +202,13 @@ def save_chat_turn(
|
||||
pre_answer_processing_time: Duration of processing before answer starts (in seconds)
|
||||
"""
|
||||
# 1. Update ChatMessage with message content, reasoning tokens, and token count
|
||||
assistant_message.message = message_text
|
||||
assistant_message.reasoning_tokens = reasoning_tokens
|
||||
sanitized_message_text = (
|
||||
sanitize_string(message_text) if message_text else message_text
|
||||
)
|
||||
assistant_message.message = sanitized_message_text
|
||||
assistant_message.reasoning_tokens = (
|
||||
sanitize_string(reasoning_tokens) if reasoning_tokens else reasoning_tokens
|
||||
)
|
||||
assistant_message.is_clarification = is_clarification
|
||||
|
||||
# Use pre-answer processing time (captured when MESSAGE_START was emitted)
|
||||
@@ -212,8 +218,10 @@ def save_chat_turn(
|
||||
# Calculate token count using default tokenizer, when storing, this should not use the LLM
|
||||
# specific one so we use a system default tokenizer here.
|
||||
default_tokenizer = get_tokenizer(None, None)
|
||||
if message_text:
|
||||
assistant_message.token_count = len(default_tokenizer.encode(message_text))
|
||||
if sanitized_message_text:
|
||||
assistant_message.token_count = len(
|
||||
default_tokenizer.encode(sanitized_message_text)
|
||||
)
|
||||
else:
|
||||
assistant_message.token_count = 0
|
||||
|
||||
@@ -328,8 +336,10 @@ def save_chat_turn(
|
||||
# 8. Attach code interpreter generated files that the assistant actually
|
||||
# referenced in its response, so they are available via load_all_chat_files
|
||||
# on subsequent turns. Files not mentioned are intermediate artifacts.
|
||||
if message_text:
|
||||
referenced = _extract_referenced_file_descriptors(tool_calls, message_text)
|
||||
if sanitized_message_text:
|
||||
referenced = _extract_referenced_file_descriptors(
|
||||
tool_calls, sanitized_message_text
|
||||
)
|
||||
if referenced:
|
||||
existing_files = assistant_message.files or []
|
||||
assistant_message.files = existing_files + referenced
|
||||
|
||||
@@ -38,6 +38,7 @@ from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.override_models import PromptOverride
|
||||
from onyx.server.query_and_chat.models import ChatMessageDetail
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_string
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -675,58 +676,43 @@ def set_as_latest_chat_message(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _sanitize_for_postgres(value: str) -> str:
|
||||
"""Remove NUL (0x00) characters from strings as PostgreSQL doesn't allow them."""
|
||||
sanitized = value.replace("\x00", "")
|
||||
if value and not sanitized:
|
||||
logger.warning("Sanitization removed all characters from string")
|
||||
return sanitized
|
||||
|
||||
|
||||
def _sanitize_list_for_postgres(values: list[str]) -> list[str]:
|
||||
"""Remove NUL (0x00) characters from all strings in a list."""
|
||||
return [_sanitize_for_postgres(v) for v in values]
|
||||
|
||||
|
||||
def create_db_search_doc(
|
||||
server_search_doc: ServerSearchDoc,
|
||||
db_session: Session,
|
||||
commit: bool = True,
|
||||
) -> DBSearchDoc:
|
||||
# Sanitize string fields to remove NUL characters (PostgreSQL doesn't allow them)
|
||||
db_search_doc = DBSearchDoc(
|
||||
document_id=_sanitize_for_postgres(server_search_doc.document_id),
|
||||
document_id=sanitize_string(server_search_doc.document_id),
|
||||
chunk_ind=server_search_doc.chunk_ind,
|
||||
semantic_id=_sanitize_for_postgres(server_search_doc.semantic_identifier),
|
||||
semantic_id=sanitize_string(server_search_doc.semantic_identifier),
|
||||
link=(
|
||||
_sanitize_for_postgres(server_search_doc.link)
|
||||
sanitize_string(server_search_doc.link)
|
||||
if server_search_doc.link is not None
|
||||
else None
|
||||
),
|
||||
blurb=_sanitize_for_postgres(server_search_doc.blurb),
|
||||
blurb=sanitize_string(server_search_doc.blurb),
|
||||
source_type=server_search_doc.source_type,
|
||||
boost=server_search_doc.boost,
|
||||
hidden=server_search_doc.hidden,
|
||||
doc_metadata=server_search_doc.metadata,
|
||||
is_relevant=server_search_doc.is_relevant,
|
||||
relevance_explanation=(
|
||||
_sanitize_for_postgres(server_search_doc.relevance_explanation)
|
||||
sanitize_string(server_search_doc.relevance_explanation)
|
||||
if server_search_doc.relevance_explanation is not None
|
||||
else None
|
||||
),
|
||||
# For docs further down that aren't reranked, we can't use the retrieval score
|
||||
score=server_search_doc.score or 0.0,
|
||||
match_highlights=_sanitize_list_for_postgres(
|
||||
server_search_doc.match_highlights
|
||||
),
|
||||
match_highlights=[
|
||||
sanitize_string(h) for h in server_search_doc.match_highlights
|
||||
],
|
||||
updated_at=server_search_doc.updated_at,
|
||||
primary_owners=(
|
||||
_sanitize_list_for_postgres(server_search_doc.primary_owners)
|
||||
[sanitize_string(o) for o in server_search_doc.primary_owners]
|
||||
if server_search_doc.primary_owners is not None
|
||||
else None
|
||||
),
|
||||
secondary_owners=(
|
||||
_sanitize_list_for_postgres(server_search_doc.secondary_owners)
|
||||
[sanitize_string(o) for o in server_search_doc.secondary_owners]
|
||||
if server_search_doc.secondary_owners is not None
|
||||
else None
|
||||
),
|
||||
|
||||
@@ -25,8 +25,11 @@ from onyx.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def update_group_llm_provider_relationships__no_commit(
|
||||
llm_provider_id: int,
|
||||
@@ -812,6 +815,43 @@ def sync_auto_mode_models(
|
||||
changes += 1
|
||||
|
||||
db_session.commit()
|
||||
|
||||
# Update the default if this provider currently holds the global CHAT default
|
||||
recommended_default = llm_recommendations.get_default_model(provider.provider)
|
||||
if recommended_default:
|
||||
current_default_name = db_session.scalar(
|
||||
select(ModelConfiguration.name)
|
||||
.join(
|
||||
LLMModelFlow,
|
||||
LLMModelFlow.model_configuration_id == ModelConfiguration.id,
|
||||
)
|
||||
.where(
|
||||
ModelConfiguration.llm_provider_id == provider.id,
|
||||
LLMModelFlow.llm_model_flow_type == LLMModelFlowType.CHAT,
|
||||
LLMModelFlow.is_default == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
|
||||
if (
|
||||
current_default_name is not None
|
||||
and current_default_name != recommended_default.name
|
||||
):
|
||||
try:
|
||||
_update_default_model(
|
||||
db_session=db_session,
|
||||
provider_id=provider.id,
|
||||
model=recommended_default.name,
|
||||
flow_type=LLMModelFlowType.CHAT,
|
||||
)
|
||||
changes += 1
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"Recommended default model '%s' not found "
|
||||
"for provider_id=%s; skipping default update.",
|
||||
recommended_default.name,
|
||||
provider.id,
|
||||
)
|
||||
|
||||
return changes
|
||||
|
||||
|
||||
|
||||
161
backend/onyx/db/rotate_encryption_key.py
Normal file
161
backend/onyx/db/rotate_encryption_key.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Rotate encryption key for all encrypted columns.
|
||||
|
||||
Dynamically discovers all columns using EncryptedString / EncryptedJson,
|
||||
decrypts each value with the old key, and re-encrypts with the current
|
||||
ENCRYPTION_KEY_SECRET.
|
||||
|
||||
The operation is idempotent: rows already encrypted with the current key
|
||||
are skipped. Commits are made in batches so a crash mid-rotation can be
|
||||
safely resumed by re-running.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import LargeBinary
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import ENCRYPTION_KEY_SECRET
|
||||
from onyx.db.models import Base
|
||||
from onyx.db.models import EncryptedJson
|
||||
from onyx.db.models import EncryptedString
|
||||
from onyx.utils.encryption import decrypt_bytes_to_string
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_BATCH_SIZE = 500
|
||||
|
||||
|
||||
def _can_decrypt_with_current_key(data: bytes) -> bool:
|
||||
"""Check if data is already encrypted with the current key.
|
||||
|
||||
Passes the key explicitly so the fallback-to-raw-decode path in
|
||||
_decrypt_bytes is NOT triggered — a clean success/failure signal.
|
||||
"""
|
||||
try:
|
||||
decrypt_bytes_to_string(data, key=ENCRYPTION_KEY_SECRET)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _discover_encrypted_columns() -> list[tuple[type, str, list[str], bool]]:
|
||||
"""Walk all ORM models and find columns using EncryptedString/EncryptedJson.
|
||||
|
||||
Returns list of (ModelClass, column_attr_name, [pk_attr_names], is_json).
|
||||
"""
|
||||
results: list[tuple[type, str, list[str], bool]] = []
|
||||
|
||||
for mapper in Base.registry.mappers:
|
||||
model_cls = mapper.class_
|
||||
pk_names = [col.key for col in mapper.primary_key]
|
||||
|
||||
for prop in mapper.column_attrs:
|
||||
for col in prop.columns:
|
||||
if isinstance(col.type, EncryptedJson):
|
||||
results.append((model_cls, prop.key, pk_names, True))
|
||||
elif isinstance(col.type, EncryptedString):
|
||||
results.append((model_cls, prop.key, pk_names, False))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def rotate_encryption_key(
|
||||
db_session: Session,
|
||||
old_key: str | None,
|
||||
dry_run: bool = False,
|
||||
) -> dict[str, int]:
|
||||
"""Decrypt all encrypted columns with old_key and re-encrypt with the current key.
|
||||
|
||||
Args:
|
||||
db_session: Active database session.
|
||||
old_key: The previous encryption key. Pass None or "" if values were
|
||||
not previously encrypted with a key.
|
||||
dry_run: If True, count rows that need rotation without modifying data.
|
||||
|
||||
Returns:
|
||||
Dict of "table.column" -> number of rows re-encrypted (or would be).
|
||||
|
||||
Commits every _BATCH_SIZE rows so that locks are held briefly and progress
|
||||
is preserved on crash. Already-rotated rows are detected and skipped,
|
||||
making the operation safe to re-run.
|
||||
"""
|
||||
if not global_version.is_ee_version():
|
||||
raise RuntimeError("EE mode is not enabled — rotation requires EE encryption.")
|
||||
|
||||
if not ENCRYPTION_KEY_SECRET:
|
||||
raise RuntimeError(
|
||||
"ENCRYPTION_KEY_SECRET is not set — cannot rotate. "
|
||||
"Set the target encryption key in the environment before running."
|
||||
)
|
||||
|
||||
encrypted_columns = _discover_encrypted_columns()
|
||||
totals: dict[str, int] = {}
|
||||
|
||||
for model_cls, col_name, pk_names, is_json in encrypted_columns:
|
||||
table_name: str = model_cls.__tablename__ # type: ignore[attr-defined]
|
||||
col_attr = getattr(model_cls, col_name)
|
||||
pk_attrs = [getattr(model_cls, pk) for pk in pk_names]
|
||||
|
||||
# Read raw bytes directly, bypassing the TypeDecorator
|
||||
raw_col = col_attr.property.columns[0]
|
||||
|
||||
stmt = select(*pk_attrs, raw_col.cast(LargeBinary)).where(col_attr.is_not(None))
|
||||
rows = db_session.execute(stmt).all()
|
||||
|
||||
reencrypted = 0
|
||||
batch_pending = 0
|
||||
for row in rows:
|
||||
raw_bytes: bytes | None = row[-1]
|
||||
if raw_bytes is None:
|
||||
continue
|
||||
|
||||
if _can_decrypt_with_current_key(raw_bytes):
|
||||
continue
|
||||
|
||||
try:
|
||||
if not old_key:
|
||||
decrypted_str = raw_bytes.decode("utf-8")
|
||||
else:
|
||||
decrypted_str = decrypt_bytes_to_string(raw_bytes, key=old_key)
|
||||
|
||||
# For EncryptedJson, parse back to dict so the TypeDecorator
|
||||
# can json.dumps() it cleanly (avoids double-encoding).
|
||||
value: Any = json.loads(decrypted_str) if is_json else decrypted_str
|
||||
except (ValueError, UnicodeDecodeError) as e:
|
||||
pk_vals = [row[i] for i in range(len(pk_names))]
|
||||
logger.warning(
|
||||
f"Could not decrypt/parse {table_name}.{col_name} "
|
||||
f"row {pk_vals} — skipping: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not dry_run:
|
||||
pk_filters = [pk_attr == row[i] for i, pk_attr in enumerate(pk_attrs)]
|
||||
update_stmt = (
|
||||
update(model_cls).where(*pk_filters).values({col_name: value})
|
||||
)
|
||||
db_session.execute(update_stmt)
|
||||
batch_pending += 1
|
||||
|
||||
if batch_pending >= _BATCH_SIZE:
|
||||
db_session.commit()
|
||||
batch_pending = 0
|
||||
reencrypted += 1
|
||||
|
||||
# Flush remaining rows in this column
|
||||
if batch_pending > 0:
|
||||
db_session.commit()
|
||||
|
||||
if reencrypted > 0:
|
||||
totals[f"{table_name}.{col_name}"] = reencrypted
|
||||
logger.info(
|
||||
f"{'[DRY RUN] Would re-encrypt' if dry_run else 'Re-encrypted'} "
|
||||
f"{reencrypted} value(s) in {table_name}.{col_name}"
|
||||
)
|
||||
|
||||
return totals
|
||||
@@ -13,12 +13,15 @@ from onyx.db.constants import UNSET
|
||||
from onyx.db.constants import UnsetType
|
||||
from onyx.db.enums import MCPServerStatus
|
||||
from onyx.db.models import MCPServer
|
||||
from onyx.db.models import OAuthConfig
|
||||
from onyx.db.models import Tool
|
||||
from onyx.db.models import ToolCall
|
||||
from onyx.server.features.tool.models import Header
|
||||
from onyx.tools.built_in_tools import BUILT_IN_TOOL_TYPES
|
||||
from onyx.utils.headers import HeaderItemDict
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_json_like
|
||||
from onyx.utils.postgres_sanitization import sanitize_string
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
@@ -159,10 +162,26 @@ def update_tool(
|
||||
]
|
||||
if passthrough_auth is not None:
|
||||
tool.passthrough_auth = passthrough_auth
|
||||
old_oauth_config_id = tool.oauth_config_id
|
||||
if not isinstance(oauth_config_id, UnsetType):
|
||||
tool.oauth_config_id = oauth_config_id
|
||||
db_session.commit()
|
||||
db_session.flush()
|
||||
|
||||
# Clean up orphaned OAuthConfig if the oauth_config_id was changed
|
||||
if (
|
||||
old_oauth_config_id is not None
|
||||
and not isinstance(oauth_config_id, UnsetType)
|
||||
and old_oauth_config_id != oauth_config_id
|
||||
):
|
||||
other_tools = db_session.scalars(
|
||||
select(Tool).where(Tool.oauth_config_id == old_oauth_config_id)
|
||||
).all()
|
||||
if not other_tools:
|
||||
oauth_config = db_session.get(OAuthConfig, old_oauth_config_id)
|
||||
if oauth_config:
|
||||
db_session.delete(oauth_config)
|
||||
|
||||
db_session.commit()
|
||||
return tool
|
||||
|
||||
|
||||
@@ -171,8 +190,21 @@ def delete_tool__no_commit(tool_id: int, db_session: Session) -> None:
|
||||
if tool is None:
|
||||
raise ValueError(f"Tool with ID {tool_id} does not exist")
|
||||
|
||||
oauth_config_id = tool.oauth_config_id
|
||||
|
||||
db_session.delete(tool)
|
||||
db_session.flush() # Don't commit yet, let caller decide when to commit
|
||||
db_session.flush()
|
||||
|
||||
# Clean up orphaned OAuthConfig if no other tools reference it
|
||||
if oauth_config_id is not None:
|
||||
other_tools = db_session.scalars(
|
||||
select(Tool).where(Tool.oauth_config_id == oauth_config_id)
|
||||
).all()
|
||||
if not other_tools:
|
||||
oauth_config = db_session.get(OAuthConfig, oauth_config_id)
|
||||
if oauth_config:
|
||||
db_session.delete(oauth_config)
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def get_builtin_tool(
|
||||
@@ -256,11 +288,13 @@ def create_tool_call_no_commit(
|
||||
tab_index=tab_index,
|
||||
tool_id=tool_id,
|
||||
tool_call_id=tool_call_id,
|
||||
reasoning_tokens=reasoning_tokens,
|
||||
tool_call_arguments=tool_call_arguments,
|
||||
tool_call_response=tool_call_response,
|
||||
reasoning_tokens=(
|
||||
sanitize_string(reasoning_tokens) if reasoning_tokens else reasoning_tokens
|
||||
),
|
||||
tool_call_arguments=sanitize_json_like(tool_call_arguments),
|
||||
tool_call_response=sanitize_json_like(tool_call_response),
|
||||
tool_call_tokens=tool_call_tokens,
|
||||
generated_images=generated_images,
|
||||
generated_images=sanitize_json_like(generated_images),
|
||||
)
|
||||
|
||||
db_session.add(tool_call)
|
||||
|
||||
103
backend/onyx/document_index/FILTER_SEMANTICS.md
Normal file
103
backend/onyx/document_index/FILTER_SEMANTICS.md
Normal file
@@ -0,0 +1,103 @@
|
||||
# Vector DB Filter Semantics
|
||||
|
||||
How `IndexFilters` fields combine into the final query filter. Applies to both Vespa and OpenSearch.
|
||||
|
||||
## Filter categories
|
||||
|
||||
| Category | Fields | Join logic |
|
||||
|---|---|---|
|
||||
| **Visibility** | `hidden` | Always applied (unless `include_hidden`) |
|
||||
| **Tenant** | `tenant_id` | AND (multi-tenant only) |
|
||||
| **ACL** | `access_control_list` | OR within, AND with rest |
|
||||
| **Narrowing** | `source_type`, `tags`, `time_cutoff` | Each OR within, AND with rest |
|
||||
| **Knowledge scope** | `document_set`, `user_file_ids`, `attached_document_ids`, `hierarchy_node_ids` | OR within group, AND with rest |
|
||||
| **Additive scope** | `project_id`, `persona_id` | OR'd into knowledge scope **only when** a knowledge scope filter already exists |
|
||||
|
||||
## How filters combine
|
||||
|
||||
All categories are AND'd together. Within the knowledge scope category, individual filters are OR'd.
|
||||
|
||||
```
|
||||
NOT hidden
|
||||
AND tenant = T -- if multi-tenant
|
||||
AND (acl contains A1 OR acl contains A2)
|
||||
AND (source_type = S1 OR ...) -- if set
|
||||
AND (tag = T1 OR ...) -- if set
|
||||
AND <knowledge scope> -- see below
|
||||
AND time >= cutoff -- if set
|
||||
```
|
||||
|
||||
## Knowledge scope rules
|
||||
|
||||
The knowledge scope filter controls **what knowledge an assistant can access**.
|
||||
|
||||
### No explicit knowledge attached
|
||||
|
||||
When `document_set`, `user_file_ids`, `attached_document_ids`, and `hierarchy_node_ids` are all empty/None:
|
||||
|
||||
- **No knowledge scope filter is applied.** The assistant can see everything (subject to ACL).
|
||||
- `project_id` and `persona_id` are ignored — they never restrict on their own.
|
||||
|
||||
### One explicit knowledge type
|
||||
|
||||
```
|
||||
-- Only document sets
|
||||
AND (document_sets contains "Engineering" OR document_sets contains "Legal")
|
||||
|
||||
-- Only user files
|
||||
AND (document_id = "uuid-1" OR document_id = "uuid-2")
|
||||
```
|
||||
|
||||
### Multiple explicit knowledge types (OR'd)
|
||||
|
||||
```
|
||||
-- Document sets + user files
|
||||
AND (
|
||||
document_sets contains "Engineering"
|
||||
OR document_id = "uuid-1"
|
||||
)
|
||||
```
|
||||
|
||||
### Explicit knowledge + overflowing user files
|
||||
|
||||
When an explicit knowledge restriction is in effect **and** `project_id` or `persona_id` is set (user files overflowed the LLM context window), the additive scopes widen the filter:
|
||||
|
||||
```
|
||||
-- Document sets + persona user files overflowed
|
||||
AND (
|
||||
document_sets contains "Engineering"
|
||||
OR personas contains 42
|
||||
)
|
||||
|
||||
-- User files + project files overflowed
|
||||
AND (
|
||||
document_id = "uuid-1"
|
||||
OR user_project contains 7
|
||||
)
|
||||
```
|
||||
|
||||
### Only project_id or persona_id (no explicit knowledge)
|
||||
|
||||
No knowledge scope filter. The assistant searches everything.
|
||||
|
||||
```
|
||||
-- Just ACL, no restriction
|
||||
NOT hidden
|
||||
AND (acl contains ...)
|
||||
```
|
||||
|
||||
## Field reference
|
||||
|
||||
| Filter field | Vespa field | Vespa type | Purpose |
|
||||
|---|---|---|---|
|
||||
| `document_set` | `document_sets` | `weightedset<string>` | Connector doc sets attached to assistant |
|
||||
| `user_file_ids` | `document_id` | `string` | User files uploaded to assistant |
|
||||
| `attached_document_ids` | `document_id` | `string` | Documents explicitly attached (OpenSearch only) |
|
||||
| `hierarchy_node_ids` | `ancestor_hierarchy_node_ids` | `array<int>` | Folder/space nodes (OpenSearch only) |
|
||||
| `project_id` | `user_project` | `array<int>` | Project tag for overflowing user files |
|
||||
| `persona_id` | `personas` | `array<int>` | Persona tag for overflowing user files |
|
||||
| `access_control_list` | `access_control_list` | `weightedset<string>` | ACL entries for the requesting user |
|
||||
| `source_type` | `source_type` | `string` | Connector source type (e.g. `web`, `jira`) |
|
||||
| `tags` | `metadata_list` | `array<string>` | Document metadata tags |
|
||||
| `time_cutoff` | `doc_updated_at` | `long` | Minimum document update timestamp |
|
||||
| `tenant_id` | `tenant_id` | `string` | Tenant isolation (multi-tenant) |
|
||||
@@ -61,6 +61,25 @@ class SearchHit(BaseModel, Generic[SchemaDocumentModel]):
|
||||
explanation: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class IndexInfo(BaseModel):
|
||||
"""
|
||||
Represents information about an OpenSearch index.
|
||||
"""
|
||||
|
||||
model_config = {"frozen": True}
|
||||
|
||||
name: str
|
||||
health: str
|
||||
status: str
|
||||
num_primary_shards: str
|
||||
num_replica_shards: str
|
||||
docs_count: str
|
||||
docs_deleted: str
|
||||
created_at: str
|
||||
total_size: str
|
||||
primary_shards_size: str
|
||||
|
||||
|
||||
def get_new_body_without_vectors(body: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Recursively replaces vectors in the body with their length.
|
||||
|
||||
@@ -159,8 +178,8 @@ class OpenSearchClient(AbstractContextManager):
|
||||
Raises:
|
||||
Exception: There was an error creating the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
|
||||
if not result.get("acknowledged", False):
|
||||
response = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
|
||||
if not response.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to create search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
@@ -173,8 +192,8 @@ class OpenSearchClient(AbstractContextManager):
|
||||
Raises:
|
||||
Exception: There was an error deleting the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.delete(id=pipeline_id)
|
||||
if not result.get("acknowledged", False):
|
||||
response = self._client.search_pipeline.delete(id=pipeline_id)
|
||||
if not response.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to delete search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
@@ -198,6 +217,34 @@ class OpenSearchClient(AbstractContextManager):
|
||||
logger.error(f"Failed to put cluster settings: {response}.")
|
||||
return False
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def list_indices_with_info(self) -> list[IndexInfo]:
|
||||
"""
|
||||
Lists the indices in the OpenSearch cluster with information about each
|
||||
index.
|
||||
|
||||
Returns:
|
||||
A list of IndexInfo objects for each index.
|
||||
"""
|
||||
response = self._client.cat.indices(format="json")
|
||||
indices: list[IndexInfo] = []
|
||||
for raw_index_info in response:
|
||||
indices.append(
|
||||
IndexInfo(
|
||||
name=raw_index_info.get("index", ""),
|
||||
health=raw_index_info.get("health", ""),
|
||||
status=raw_index_info.get("status", ""),
|
||||
num_primary_shards=raw_index_info.get("pri", ""),
|
||||
num_replica_shards=raw_index_info.get("rep", ""),
|
||||
docs_count=raw_index_info.get("docs.count", ""),
|
||||
docs_deleted=raw_index_info.get("docs.deleted", ""),
|
||||
created_at=raw_index_info.get("creation.date.string", ""),
|
||||
total_size=raw_index_info.get("store.size", ""),
|
||||
primary_shards_size=raw_index_info.get("pri.store.size", ""),
|
||||
)
|
||||
)
|
||||
return indices
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def ping(self) -> bool:
|
||||
"""Pings the OpenSearch cluster.
|
||||
|
||||
@@ -739,7 +739,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
The number of chunks successfully deleted.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Deleting document {document_id} from index {self._index_name}."
|
||||
f"[OpenSearchDocumentIndex] Deleting document {document_id} from index "
|
||||
f"{self._index_name}."
|
||||
)
|
||||
query_body = DocumentQuery.delete_from_document_id_query(
|
||||
document_id=document_id,
|
||||
@@ -775,7 +776,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
specified documents.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Updating {len(update_requests)} chunks for index {self._index_name}."
|
||||
f"[OpenSearchDocumentIndex] Updating {len(update_requests)} chunks for index "
|
||||
f"{self._index_name}."
|
||||
)
|
||||
for update_request in update_requests:
|
||||
properties_to_update: dict[str, Any] = dict()
|
||||
@@ -831,9 +833,11 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
# here.
|
||||
# TODO(andrei): Fix the aforementioned race condition.
|
||||
raise ChunkCountNotFoundError(
|
||||
f"Tried to update document {doc_id} but its chunk count is not known. Older versions of the "
|
||||
"application used to permit this but is not a supported state for a document when using OpenSearch. "
|
||||
"The document was likely just added to the indexing pipeline and the chunk count will be updated shortly."
|
||||
f"Tried to update document {doc_id} but its chunk count is not known. "
|
||||
"Older versions of the application used to permit this but is not a "
|
||||
"supported state for a document when using OpenSearch. The document was "
|
||||
"likely just added to the indexing pipeline and the chunk count will be "
|
||||
"updated shortly."
|
||||
)
|
||||
if doc_chunk_count == 0:
|
||||
raise ValueError(
|
||||
@@ -865,7 +869,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
chunk IDs vs querying for matching document chunks.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Retrieving {len(chunk_requests)} chunks for index {self._index_name}."
|
||||
f"[OpenSearchDocumentIndex] Retrieving {len(chunk_requests)} chunks for index "
|
||||
f"{self._index_name}."
|
||||
)
|
||||
results: list[InferenceChunk] = []
|
||||
for chunk_request in chunk_requests:
|
||||
@@ -912,7 +917,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Hybrid retrieving {num_to_retrieve} chunks for index {self._index_name}."
|
||||
f"[OpenSearchDocumentIndex] Hybrid retrieving {num_to_retrieve} chunks for index "
|
||||
f"{self._index_name}."
|
||||
)
|
||||
# TODO(andrei): This could be better, the caller should just make this
|
||||
# decision when passing in the query param. See the above comment in the
|
||||
@@ -932,8 +938,10 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
index_filters=filters,
|
||||
include_hidden=False,
|
||||
)
|
||||
# NOTE: Using z-score normalization here because it's better for hybrid search from a theoretical standpoint.
|
||||
# Empirically on a small dataset of up to 10K docs, it's not very different. Likely more impactful at scale.
|
||||
# NOTE: Using z-score normalization here because it's better for hybrid
|
||||
# search from a theoretical standpoint. Empirically on a small dataset
|
||||
# of up to 10K docs, it's not very different. Likely more impactful at
|
||||
# scale.
|
||||
# https://opensearch.org/blog/introducing-the-z-score-normalization-technique-for-hybrid-search/
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
|
||||
body=query_body,
|
||||
@@ -960,7 +968,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
dirty: bool | None = None, # noqa: ARG002
|
||||
) -> list[InferenceChunk]:
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Randomly retrieving {num_to_retrieve} chunks for index {self._index_name}."
|
||||
f"[OpenSearchDocumentIndex] Randomly retrieving {num_to_retrieve} chunks for index "
|
||||
f"{self._index_name}."
|
||||
)
|
||||
query_body = DocumentQuery.get_random_search_query(
|
||||
tenant_state=self._tenant_state,
|
||||
@@ -990,7 +999,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
complete.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Indexing {len(chunks)} raw chunks for index {self._index_name}."
|
||||
f"[OpenSearchDocumentIndex] Indexing {len(chunks)} raw chunks for index "
|
||||
f"{self._index_name}."
|
||||
)
|
||||
# Do not raise if the document already exists, just update. This is
|
||||
# because the document may already have been indexed during the
|
||||
|
||||
@@ -243,7 +243,8 @@ class DocumentChunk(BaseModel):
|
||||
return value
|
||||
if not isinstance(value, int):
|
||||
raise ValueError(
|
||||
f"Bug: Expected an int for the last_updated property from OpenSearch, got {type(value)} instead."
|
||||
f"Bug: Expected an int for the last_updated property from OpenSearch, got "
|
||||
f"{type(value)} instead."
|
||||
)
|
||||
return datetime.fromtimestamp(value, tz=timezone.utc)
|
||||
|
||||
@@ -284,19 +285,22 @@ class DocumentChunk(BaseModel):
|
||||
elif isinstance(value, TenantState):
|
||||
if MULTI_TENANT != value.multitenant:
|
||||
raise ValueError(
|
||||
f"Bug: An existing TenantState object was supplied to the DocumentChunk model but its multi-tenant mode "
|
||||
f"({value.multitenant}) does not match the program's current global tenancy state."
|
||||
f"Bug: An existing TenantState object was supplied to the DocumentChunk model "
|
||||
f"but its multi-tenant mode ({value.multitenant}) does not match the program's "
|
||||
"current global tenancy state."
|
||||
)
|
||||
return value
|
||||
elif not isinstance(value, str):
|
||||
raise ValueError(
|
||||
f"Bug: Expected a str for the tenant_id property from OpenSearch, got {type(value)} instead."
|
||||
f"Bug: Expected a str for the tenant_id property from OpenSearch, got "
|
||||
f"{type(value)} instead."
|
||||
)
|
||||
else:
|
||||
if not MULTI_TENANT:
|
||||
raise ValueError(
|
||||
"Bug: Got a non-null str for the tenant_id property from OpenSearch but multi-tenant mode is not enabled. "
|
||||
"This is unexpected because in single-tenant mode we don't expect to see a tenant_id."
|
||||
"Bug: Got a non-null str for the tenant_id property from OpenSearch but "
|
||||
"multi-tenant mode is not enabled. This is unexpected because in single-tenant "
|
||||
"mode we don't expect to see a tenant_id."
|
||||
)
|
||||
return TenantState(tenant_id=value, multitenant=MULTI_TENANT)
|
||||
|
||||
@@ -352,8 +356,10 @@ class DocumentSchema:
|
||||
"properties": {
|
||||
TITLE_FIELD_NAME: {
|
||||
"type": "text",
|
||||
# Language analyzer (e.g. english) stems at index and search time for variant matching.
|
||||
# Configure via OPENSEARCH_TEXT_ANALYZER. Existing indices need reindexing after a change.
|
||||
# Language analyzer (e.g. english) stems at index and search
|
||||
# time for variant matching. Configure via
|
||||
# OPENSEARCH_TEXT_ANALYZER. Existing indices need reindexing
|
||||
# after a change.
|
||||
"analyzer": OPENSEARCH_TEXT_ANALYZER,
|
||||
"fields": {
|
||||
# Subfield accessed as title.keyword. Not indexed for
|
||||
|
||||
@@ -698,41 +698,6 @@ class DocumentQuery:
|
||||
"""
|
||||
return {"terms": {ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME: node_ids}}
|
||||
|
||||
def _get_assistant_knowledge_filter(
|
||||
attached_doc_ids: list[str] | None,
|
||||
node_ids: list[int] | None,
|
||||
file_ids: list[UUID] | None,
|
||||
document_sets: list[str] | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Combined filter for assistant knowledge.
|
||||
|
||||
When an assistant has attached knowledge, search should be scoped to:
|
||||
- Documents explicitly attached (by document ID), OR
|
||||
- Documents under attached hierarchy nodes (by ancestor node IDs), OR
|
||||
- User-uploaded files attached to the assistant, OR
|
||||
- Documents in the assistant's document sets (if any)
|
||||
"""
|
||||
knowledge_filter: dict[str, Any] = {
|
||||
"bool": {"should": [], "minimum_should_match": 1}
|
||||
}
|
||||
if attached_doc_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_attached_document_id_filter(attached_doc_ids)
|
||||
)
|
||||
if node_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_hierarchy_node_filter(node_ids)
|
||||
)
|
||||
if file_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_file_id_filter(file_ids)
|
||||
)
|
||||
if document_sets:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_document_set_filter(document_sets)
|
||||
)
|
||||
return knowledge_filter
|
||||
|
||||
filter_clauses: list[dict[str, Any]] = []
|
||||
|
||||
if not include_hidden:
|
||||
@@ -758,41 +723,53 @@ class DocumentQuery:
|
||||
# document's metadata list.
|
||||
filter_clauses.append(_get_tag_filter(tags))
|
||||
|
||||
# Check if this is an assistant knowledge search (has any assistant-scoped knowledge)
|
||||
has_assistant_knowledge = (
|
||||
# Knowledge scope: explicit knowledge attachments restrict what
|
||||
# an assistant can see. When none are set the assistant
|
||||
# searches everything.
|
||||
#
|
||||
# project_id / persona_id are additive: they make overflowing
|
||||
# user files findable but must NOT trigger the restriction on
|
||||
# their own (an agent with no explicit knowledge should search
|
||||
# everything).
|
||||
has_knowledge_scope = (
|
||||
attached_document_ids
|
||||
or hierarchy_node_ids
|
||||
or user_file_ids
|
||||
or document_sets
|
||||
)
|
||||
|
||||
if has_assistant_knowledge:
|
||||
# If assistant has attached knowledge, scope search to that knowledge.
|
||||
# Document sets are included in the OR filter so directly attached
|
||||
# docs are always findable even if not in the document sets.
|
||||
filter_clauses.append(
|
||||
_get_assistant_knowledge_filter(
|
||||
attached_document_ids,
|
||||
hierarchy_node_ids,
|
||||
user_file_ids,
|
||||
document_sets,
|
||||
if has_knowledge_scope:
|
||||
knowledge_filter: dict[str, Any] = {
|
||||
"bool": {"should": [], "minimum_should_match": 1}
|
||||
}
|
||||
if attached_document_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_attached_document_id_filter(attached_document_ids)
|
||||
)
|
||||
)
|
||||
elif user_file_ids:
|
||||
# Fallback for non-assistant user file searches (e.g., project searches)
|
||||
# If at least one user file ID is provided, the caller will only
|
||||
# retrieve documents where the document ID is in this input list of
|
||||
# file IDs.
|
||||
filter_clauses.append(_get_user_file_id_filter(user_file_ids))
|
||||
|
||||
if project_id is not None:
|
||||
# If a project ID is provided, the caller will only retrieve
|
||||
# documents where the project ID provided here is present in the
|
||||
# document's user projects list.
|
||||
filter_clauses.append(_get_user_project_filter(project_id))
|
||||
|
||||
if persona_id is not None:
|
||||
filter_clauses.append(_get_persona_filter(persona_id))
|
||||
if hierarchy_node_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_hierarchy_node_filter(hierarchy_node_ids)
|
||||
)
|
||||
if user_file_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_file_id_filter(user_file_ids)
|
||||
)
|
||||
if document_sets:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_document_set_filter(document_sets)
|
||||
)
|
||||
# Additive: widen scope to also cover overflowing user
|
||||
# files, but only when an explicit restriction is already
|
||||
# in effect.
|
||||
if project_id is not None:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_project_filter(project_id)
|
||||
)
|
||||
if persona_id is not None:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_persona_filter(persona_id)
|
||||
)
|
||||
filter_clauses.append(knowledge_filter)
|
||||
|
||||
if time_cutoff is not None:
|
||||
# If a time cutoff is provided, the caller will only retrieve
|
||||
|
||||
@@ -23,11 +23,8 @@ from shared_configs.configs import MULTI_TENANT
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def build_tenant_id_filter(tenant_id: str, include_trailing_and: bool = False) -> str:
|
||||
filter_str = f'({TENANT_ID} contains "{tenant_id}")'
|
||||
if include_trailing_and:
|
||||
filter_str += " and "
|
||||
return filter_str
|
||||
def build_tenant_id_filter(tenant_id: str) -> str:
|
||||
return f'({TENANT_ID} contains "{tenant_id}")'
|
||||
|
||||
|
||||
def build_vespa_filters(
|
||||
@@ -37,30 +34,22 @@ def build_vespa_filters(
|
||||
remove_trailing_and: bool = False, # Set to True when using as a complete Vespa query
|
||||
) -> str:
|
||||
def _build_or_filters(key: str, vals: list[str] | None) -> str:
|
||||
"""For string-based 'contains' filters, e.g. WSET fields or array<string> fields."""
|
||||
"""For string-based 'contains' filters, e.g. WSET fields or array<string> fields.
|
||||
Returns a bare clause like '(key contains "v1" or key contains "v2")' or ""."""
|
||||
if not key or not vals:
|
||||
return ""
|
||||
eq_elems = [f'{key} contains "{val}"' for val in vals if val]
|
||||
if not eq_elems:
|
||||
return ""
|
||||
or_clause = " or ".join(eq_elems)
|
||||
return f"({or_clause}) and "
|
||||
return f"({' or '.join(eq_elems)})"
|
||||
|
||||
def _build_int_or_filters(key: str, vals: list[int] | None) -> str:
|
||||
"""
|
||||
For an integer field filter.
|
||||
If vals is not None, we want *only* docs whose key matches one of vals.
|
||||
"""
|
||||
# If `vals` is None => skip the filter entirely
|
||||
"""For an integer field filter.
|
||||
Returns a bare clause or ""."""
|
||||
if vals is None or not vals:
|
||||
return ""
|
||||
|
||||
# Otherwise build the OR filter
|
||||
eq_elems = [f"{key} = {val}" for val in vals]
|
||||
or_clause = " or ".join(eq_elems)
|
||||
result = f"({or_clause}) and "
|
||||
|
||||
return result
|
||||
return f"({' or '.join(eq_elems)})"
|
||||
|
||||
def _build_kg_filter(
|
||||
kg_entities: list[str] | None,
|
||||
@@ -73,16 +62,12 @@ def build_vespa_filters(
|
||||
combined_filter_parts = []
|
||||
|
||||
def _build_kge(entity: str) -> str:
|
||||
# TYPE-SUBTYPE::ID -> "TYPE-SUBTYPE::ID"
|
||||
# TYPE-SUBTYPE::* -> ({prefix: true}"TYPE-SUBTYPE")
|
||||
# TYPE::* -> ({prefix: true}"TYPE")
|
||||
GENERAL = "::*"
|
||||
if entity.endswith(GENERAL):
|
||||
return f'({{prefix: true}}"{entity.split(GENERAL, 1)[0]}")'
|
||||
else:
|
||||
return f'"{entity}"'
|
||||
|
||||
# OR the entities (give new design)
|
||||
if kg_entities:
|
||||
filter_parts = []
|
||||
for kg_entity in kg_entities:
|
||||
@@ -104,8 +89,7 @@ def build_vespa_filters(
|
||||
|
||||
# TODO: remove kg terms entirely from prompts and codebase
|
||||
|
||||
# AND the combined filter parts
|
||||
return f"({' and '.join(combined_filter_parts)}) and "
|
||||
return f"({' and '.join(combined_filter_parts)})"
|
||||
|
||||
def _build_kg_source_filters(
|
||||
kg_sources: list[str] | None,
|
||||
@@ -114,16 +98,14 @@ def build_vespa_filters(
|
||||
return ""
|
||||
|
||||
source_phrases = [f'{DOCUMENT_ID} contains "{source}"' for source in kg_sources]
|
||||
|
||||
return f"({' or '.join(source_phrases)}) and "
|
||||
return f"({' or '.join(source_phrases)})"
|
||||
|
||||
def _build_kg_chunk_id_zero_only_filter(
|
||||
kg_chunk_id_zero_only: bool,
|
||||
) -> str:
|
||||
if not kg_chunk_id_zero_only:
|
||||
return ""
|
||||
|
||||
return "(chunk_id = 0 ) and "
|
||||
return "(chunk_id = 0)"
|
||||
|
||||
def _build_time_filter(
|
||||
cutoff: datetime | None,
|
||||
@@ -135,8 +117,8 @@ def build_vespa_filters(
|
||||
cutoff_secs = int(cutoff.timestamp())
|
||||
|
||||
if include_untimed:
|
||||
return f"!({DOC_UPDATED_AT} < {cutoff_secs}) and "
|
||||
return f"({DOC_UPDATED_AT} >= {cutoff_secs}) and "
|
||||
return f"!({DOC_UPDATED_AT} < {cutoff_secs})"
|
||||
return f"({DOC_UPDATED_AT} >= {cutoff_secs})"
|
||||
|
||||
def _build_user_project_filter(
|
||||
project_id: int | None,
|
||||
@@ -147,8 +129,7 @@ def build_vespa_filters(
|
||||
pid = int(project_id)
|
||||
except Exception:
|
||||
return ""
|
||||
# Vespa YQL 'contains' expects a string literal; quote the integer
|
||||
return f'({USER_PROJECT} contains "{pid}") and '
|
||||
return f'({USER_PROJECT} contains "{pid}")'
|
||||
|
||||
def _build_persona_filter(
|
||||
persona_id: int | None,
|
||||
@@ -160,73 +141,94 @@ def build_vespa_filters(
|
||||
except Exception:
|
||||
logger.warning(f"Invalid persona ID: {persona_id}")
|
||||
return ""
|
||||
return f'({PERSONAS} contains "{pid}") and '
|
||||
return f'({PERSONAS} contains "{pid}")'
|
||||
|
||||
# Start building the filter string
|
||||
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
|
||||
def _append(parts: list[str], clause: str) -> None:
|
||||
if clause:
|
||||
parts.append(clause)
|
||||
|
||||
# Collect all top-level filter clauses, then join with " and " at the end.
|
||||
filter_parts: list[str] = []
|
||||
|
||||
if not include_hidden:
|
||||
filter_parts.append(f"!({HIDDEN}=true)")
|
||||
|
||||
# TODO: add error condition if MULTI_TENANT and no tenant_id filter is set
|
||||
# If running in multi-tenant mode
|
||||
if filters.tenant_id and MULTI_TENANT:
|
||||
filter_str += build_tenant_id_filter(
|
||||
filters.tenant_id, include_trailing_and=True
|
||||
)
|
||||
filter_parts.append(build_tenant_id_filter(filters.tenant_id))
|
||||
|
||||
# ACL filters
|
||||
if filters.access_control_list is not None:
|
||||
filter_str += _build_or_filters(
|
||||
ACCESS_CONTROL_LIST, filters.access_control_list
|
||||
_append(
|
||||
filter_parts,
|
||||
_build_or_filters(ACCESS_CONTROL_LIST, filters.access_control_list),
|
||||
)
|
||||
|
||||
# Source type filters
|
||||
source_strs = (
|
||||
[s.value for s in filters.source_type] if filters.source_type else None
|
||||
)
|
||||
filter_str += _build_or_filters(SOURCE_TYPE, source_strs)
|
||||
_append(filter_parts, _build_or_filters(SOURCE_TYPE, source_strs))
|
||||
|
||||
# Tag filters
|
||||
tag_attributes = None
|
||||
if filters.tags:
|
||||
# build e.g. "tag_key|tag_value"
|
||||
tag_attributes = [
|
||||
f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}" for tag in filters.tags
|
||||
]
|
||||
filter_str += _build_or_filters(METADATA_LIST, tag_attributes)
|
||||
_append(filter_parts, _build_or_filters(METADATA_LIST, tag_attributes))
|
||||
|
||||
# Document sets
|
||||
filter_str += _build_or_filters(DOCUMENT_SETS, filters.document_set)
|
||||
# Knowledge scope: explicit knowledge attachments (document_sets,
|
||||
# user_file_ids) restrict what an assistant can see. When none are
|
||||
# set, the assistant can see everything.
|
||||
#
|
||||
# project_id / persona_id are additive: they make overflowing user
|
||||
# files findable in Vespa but must NOT trigger the restriction on
|
||||
# their own (an agent with no explicit knowledge should search
|
||||
# everything).
|
||||
knowledge_scope_parts: list[str] = []
|
||||
|
||||
_append(
|
||||
knowledge_scope_parts, _build_or_filters(DOCUMENT_SETS, filters.document_set)
|
||||
)
|
||||
|
||||
# Convert UUIDs to strings for user_file_ids
|
||||
user_file_ids_str = (
|
||||
[str(uuid) for uuid in filters.user_file_ids] if filters.user_file_ids else None
|
||||
)
|
||||
filter_str += _build_or_filters(DOCUMENT_ID, user_file_ids_str)
|
||||
_append(knowledge_scope_parts, _build_or_filters(DOCUMENT_ID, user_file_ids_str))
|
||||
|
||||
# User project filter (array<int> attribute membership)
|
||||
filter_str += _build_user_project_filter(filters.project_id)
|
||||
# Only include project/persona scopes when an explicit knowledge
|
||||
# restriction is already in effect — they widen the scope to also
|
||||
# cover overflowing user files but never restrict on their own.
|
||||
if knowledge_scope_parts:
|
||||
_append(knowledge_scope_parts, _build_user_project_filter(filters.project_id))
|
||||
_append(knowledge_scope_parts, _build_persona_filter(filters.persona_id))
|
||||
|
||||
# Persona filter (array<int> attribute membership)
|
||||
filter_str += _build_persona_filter(filters.persona_id)
|
||||
if len(knowledge_scope_parts) > 1:
|
||||
filter_parts.append("(" + " or ".join(knowledge_scope_parts) + ")")
|
||||
elif len(knowledge_scope_parts) == 1:
|
||||
filter_parts.append(knowledge_scope_parts[0])
|
||||
|
||||
# Time filter
|
||||
filter_str += _build_time_filter(filters.time_cutoff)
|
||||
_append(filter_parts, _build_time_filter(filters.time_cutoff))
|
||||
|
||||
# # Knowledge Graph Filters
|
||||
# filter_str += _build_kg_filter(
|
||||
# _append(filter_parts, _build_kg_filter(
|
||||
# kg_entities=filters.kg_entities,
|
||||
# kg_relationships=filters.kg_relationships,
|
||||
# kg_terms=filters.kg_terms,
|
||||
# )
|
||||
# ))
|
||||
|
||||
# filter_str += _build_kg_source_filters(filters.kg_sources)
|
||||
# _append(filter_parts, _build_kg_source_filters(filters.kg_sources))
|
||||
|
||||
# filter_str += _build_kg_chunk_id_zero_only_filter(
|
||||
# _append(filter_parts, _build_kg_chunk_id_zero_only_filter(
|
||||
# filters.kg_chunk_id_zero_only or False
|
||||
# )
|
||||
# ))
|
||||
|
||||
# Trim trailing " and "
|
||||
if remove_trailing_and and filter_str.endswith(" and "):
|
||||
filter_str = filter_str[:-5]
|
||||
filter_str = " and ".join(filter_parts)
|
||||
|
||||
if filter_str and not remove_trailing_and:
|
||||
filter_str += " and "
|
||||
|
||||
return filter_str
|
||||
|
||||
|
||||
@@ -48,10 +48,11 @@ class OnyxError(Exception):
|
||||
*,
|
||||
status_code_override: int | None = None,
|
||||
) -> None:
|
||||
resolved_message = message or error_code.code
|
||||
super().__init__(resolved_message)
|
||||
self.error_code = error_code
|
||||
self.message = message or error_code.code
|
||||
self.message = resolved_message
|
||||
self._status_code_override = status_code_override
|
||||
super().__init__(self.message)
|
||||
|
||||
@property
|
||||
def status_code(self) -> int:
|
||||
|
||||
@@ -49,7 +49,6 @@ from onyx.indexing.embedder import IndexingEmbedder
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import IndexingBatchAdapter
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
from onyx.indexing.postgres_sanitization import sanitize_documents_for_postgres
|
||||
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
|
||||
from onyx.llm.factory import get_default_llm_with_vision
|
||||
from onyx.llm.factory import get_llm_for_contextual_rag
|
||||
@@ -65,6 +64,7 @@ from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT1
|
||||
from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT2
|
||||
from onyx.prompts.contextual_retrieval import DOCUMENT_SUMMARY_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_documents_for_postgres
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ class LlmProviderNames(str, Enum):
|
||||
OPENROUTER = "openrouter"
|
||||
AZURE = "azure"
|
||||
OLLAMA_CHAT = "ollama_chat"
|
||||
LM_STUDIO = "lm_studio"
|
||||
MISTRAL = "mistral"
|
||||
LITELLM_PROXY = "litellm_proxy"
|
||||
|
||||
@@ -41,6 +42,7 @@ WELL_KNOWN_PROVIDER_NAMES = [
|
||||
LlmProviderNames.OPENROUTER,
|
||||
LlmProviderNames.AZURE,
|
||||
LlmProviderNames.OLLAMA_CHAT,
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
]
|
||||
|
||||
|
||||
@@ -56,6 +58,7 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
|
||||
LlmProviderNames.AZURE: "Azure",
|
||||
"ollama": "Ollama",
|
||||
LlmProviderNames.OLLAMA_CHAT: "Ollama",
|
||||
LlmProviderNames.LM_STUDIO: "LM Studio",
|
||||
"groq": "Groq",
|
||||
"anyscale": "Anyscale",
|
||||
"deepseek": "DeepSeek",
|
||||
@@ -103,6 +106,7 @@ AGGREGATOR_PROVIDERS: set[str] = {
|
||||
LlmProviderNames.BEDROCK_CONVERSE,
|
||||
LlmProviderNames.OPENROUTER,
|
||||
LlmProviderNames.OLLAMA_CHAT,
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
LlmProviderNames.VERTEX_AI,
|
||||
LlmProviderNames.AZURE,
|
||||
}
|
||||
|
||||
@@ -20,7 +20,9 @@ from onyx.llm.multi_llm import LitellmLLM
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.utils import get_max_input_tokens_from_llm_provider
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.llm.well_known_providers.constants import OLLAMA_API_KEY_CONFIG_KEY
|
||||
from onyx.llm.well_known_providers.constants import (
|
||||
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING,
|
||||
)
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.utils.headers import build_llm_extra_headers
|
||||
@@ -32,14 +34,18 @@ logger = setup_logger()
|
||||
def _build_provider_extra_headers(
|
||||
provider: str, custom_config: dict[str, str] | None
|
||||
) -> dict[str, str]:
|
||||
if provider == LlmProviderNames.OLLAMA_CHAT and custom_config:
|
||||
raw_api_key = custom_config.get(OLLAMA_API_KEY_CONFIG_KEY)
|
||||
api_key = raw_api_key.strip() if raw_api_key else None
|
||||
if provider in PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING and custom_config:
|
||||
raw = custom_config.get(PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING[provider])
|
||||
api_key = raw.strip() if raw else None
|
||||
if not api_key:
|
||||
return {}
|
||||
if not api_key.lower().startswith("bearer "):
|
||||
api_key = f"Bearer {api_key}"
|
||||
return {"Authorization": api_key}
|
||||
return {
|
||||
"Authorization": (
|
||||
api_key
|
||||
if api_key.lower().startswith("bearer ")
|
||||
else f"Bearer {api_key}"
|
||||
)
|
||||
}
|
||||
|
||||
# Passing these will put Onyx on the OpenRouter leaderboard
|
||||
elif provider == LlmProviderNames.OPENROUTER:
|
||||
|
||||
@@ -1512,6 +1512,10 @@
|
||||
"display_name": "Claude Opus 4.5",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-opus-4-6": {
|
||||
"display_name": "Claude Opus 4.6",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-opus-4-5-20251101": {
|
||||
"display_name": "Claude Opus 4.5",
|
||||
"model_vendor": "anthropic",
|
||||
@@ -1526,6 +1530,10 @@
|
||||
"display_name": "Claude Sonnet 4.5",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-sonnet-4-6": {
|
||||
"display_name": "Claude Sonnet 4.6",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-sonnet-4-5-20250929": {
|
||||
"display_name": "Claude Sonnet 4.5",
|
||||
"model_vendor": "anthropic",
|
||||
@@ -2516,6 +2524,10 @@
|
||||
"model_vendor": "openai",
|
||||
"model_version": "2025-10-06"
|
||||
},
|
||||
"gpt-5.4": {
|
||||
"display_name": "GPT-5.4",
|
||||
"model_vendor": "openai"
|
||||
},
|
||||
"gpt-5.2-pro-2025-12-11": {
|
||||
"display_name": "GPT-5.2 Pro",
|
||||
"model_vendor": "openai",
|
||||
|
||||
@@ -42,6 +42,7 @@ from onyx.llm.well_known_providers.constants import AWS_SECRET_ACCESS_KEY_KWARG
|
||||
from onyx.llm.well_known_providers.constants import (
|
||||
AWS_SECRET_ACCESS_KEY_KWARG_ENV_VAR_FORMAT,
|
||||
)
|
||||
from onyx.llm.well_known_providers.constants import LM_STUDIO_API_KEY_CONFIG_KEY
|
||||
from onyx.llm.well_known_providers.constants import OLLAMA_API_KEY_CONFIG_KEY
|
||||
from onyx.llm.well_known_providers.constants import VERTEX_CREDENTIALS_FILE_KWARG
|
||||
from onyx.llm.well_known_providers.constants import (
|
||||
@@ -92,6 +93,98 @@ def _prompt_to_dicts(prompt: LanguageModelInput) -> list[dict[str, Any]]:
|
||||
return [prompt.model_dump(exclude_none=True)]
|
||||
|
||||
|
||||
def _normalize_content(raw: Any) -> str:
|
||||
"""Normalize a message content field to a plain string.
|
||||
|
||||
Content can be a string, None, or a list of content-block dicts
|
||||
(e.g. [{"type": "text", "text": "..."}]).
|
||||
"""
|
||||
if raw is None:
|
||||
return ""
|
||||
if isinstance(raw, str):
|
||||
return raw
|
||||
if isinstance(raw, list):
|
||||
return "\n".join(
|
||||
block.get("text", "") if isinstance(block, dict) else str(block)
|
||||
for block in raw
|
||||
)
|
||||
return str(raw)
|
||||
|
||||
|
||||
def _strip_tool_content_from_messages(
|
||||
messages: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert tool-related messages to plain text.
|
||||
|
||||
Bedrock's Converse API requires toolConfig when messages contain
|
||||
toolUse/toolResult content blocks. When no tools are provided for the
|
||||
current request, we must convert any tool-related history into plain text
|
||||
to avoid the "toolConfig field must be defined" error.
|
||||
|
||||
This is the same approach used by _OllamaHistoryMessageFormatter.
|
||||
"""
|
||||
result: list[dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
role = msg.get("role")
|
||||
tool_calls = msg.get("tool_calls")
|
||||
|
||||
if role == "assistant" and tool_calls:
|
||||
# Convert structured tool calls to text representation
|
||||
tool_call_lines = []
|
||||
for tc in tool_calls:
|
||||
func = tc.get("function", {})
|
||||
name = func.get("name", "unknown")
|
||||
args = func.get("arguments", "{}")
|
||||
tc_id = tc.get("id", "")
|
||||
tool_call_lines.append(
|
||||
f"[Tool Call] name={name} id={tc_id} args={args}"
|
||||
)
|
||||
|
||||
existing_content = _normalize_content(msg.get("content"))
|
||||
parts = (
|
||||
[existing_content] + tool_call_lines
|
||||
if existing_content
|
||||
else tool_call_lines
|
||||
)
|
||||
new_msg = {
|
||||
"role": "assistant",
|
||||
"content": "\n".join(parts),
|
||||
}
|
||||
result.append(new_msg)
|
||||
|
||||
elif role == "tool":
|
||||
# Convert tool response to user message with text content
|
||||
tool_call_id = msg.get("tool_call_id", "")
|
||||
content = _normalize_content(msg.get("content"))
|
||||
tool_result_text = f"[Tool Result] id={tool_call_id}\n{content}"
|
||||
# Merge into previous user message if it is also a converted
|
||||
# tool result to avoid consecutive user messages (Bedrock requires
|
||||
# strict user/assistant alternation).
|
||||
if (
|
||||
result
|
||||
and result[-1]["role"] == "user"
|
||||
and "[Tool Result]" in result[-1].get("content", "")
|
||||
):
|
||||
result[-1]["content"] += "\n\n" + tool_result_text
|
||||
else:
|
||||
result.append({"role": "user", "content": tool_result_text})
|
||||
|
||||
else:
|
||||
result.append(msg)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _messages_contain_tool_content(messages: list[dict[str, Any]]) -> bool:
|
||||
"""Check if any messages contain tool-related content blocks."""
|
||||
for msg in messages:
|
||||
if msg.get("role") == "tool":
|
||||
return True
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_vertex_model_rejecting_output_config(model_name: str) -> bool:
|
||||
normalized_model_name = model_name.lower()
|
||||
return any(
|
||||
@@ -157,6 +250,9 @@ class LitellmLLM(LLM):
|
||||
elif model_provider == LlmProviderNames.OLLAMA_CHAT:
|
||||
if k == OLLAMA_API_KEY_CONFIG_KEY:
|
||||
model_kwargs["api_key"] = v
|
||||
elif model_provider == LlmProviderNames.LM_STUDIO:
|
||||
if k == LM_STUDIO_API_KEY_CONFIG_KEY:
|
||||
model_kwargs["api_key"] = v
|
||||
elif model_provider == LlmProviderNames.BEDROCK:
|
||||
if k == AWS_REGION_NAME_KWARG:
|
||||
model_kwargs[k] = v
|
||||
@@ -173,6 +269,19 @@ class LitellmLLM(LLM):
|
||||
elif k == AWS_SECRET_ACCESS_KEY_KWARG_ENV_VAR_FORMAT:
|
||||
model_kwargs[AWS_SECRET_ACCESS_KEY_KWARG] = v
|
||||
|
||||
# LM Studio: LiteLLM defaults to "fake-api-key" when no key is provided,
|
||||
# which LM Studio rejects. Ensure we always pass an explicit key (or empty
|
||||
# string) to prevent LiteLLM from injecting its fake default.
|
||||
if model_provider == LlmProviderNames.LM_STUDIO:
|
||||
model_kwargs.setdefault("api_key", "")
|
||||
|
||||
# Users provide the server root (e.g. http://localhost:1234) but LiteLLM
|
||||
# needs /v1 for OpenAI-compatible calls.
|
||||
if self._api_base is not None:
|
||||
base = self._api_base.rstrip("/")
|
||||
self._api_base = base if base.endswith("/v1") else f"{base}/v1"
|
||||
model_kwargs["api_base"] = self._api_base
|
||||
|
||||
# Default vertex_location to "global" if not provided for Vertex AI
|
||||
# Latest gemini models are only available through the global region
|
||||
if (
|
||||
@@ -404,13 +513,30 @@ class LitellmLLM(LLM):
|
||||
else nullcontext()
|
||||
)
|
||||
with env_ctx:
|
||||
messages = _prompt_to_dicts(prompt)
|
||||
|
||||
# Bedrock's Converse API requires toolConfig when messages
|
||||
# contain toolUse/toolResult content blocks. When no tools are
|
||||
# provided for this request but the history contains tool
|
||||
# content from previous turns, strip it to plain text.
|
||||
is_bedrock = self._model_provider in {
|
||||
LlmProviderNames.BEDROCK,
|
||||
LlmProviderNames.BEDROCK_CONVERSE,
|
||||
}
|
||||
if (
|
||||
is_bedrock
|
||||
and not tools
|
||||
and _messages_contain_tool_content(messages)
|
||||
):
|
||||
messages = _strip_tool_content_from_messages(messages)
|
||||
|
||||
response = litellm.completion(
|
||||
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
|
||||
model=model,
|
||||
base_url=self._api_base or None,
|
||||
api_version=self._api_version or None,
|
||||
custom_llm_provider=self._custom_llm_provider or None,
|
||||
messages=_prompt_to_dicts(prompt),
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=stream,
|
||||
|
||||
@@ -322,7 +322,7 @@ def test_llm(llm: LLM) -> str | None:
|
||||
error_msg = None
|
||||
for _ in range(2):
|
||||
try:
|
||||
llm.invoke(UserMessage(content="Do not respond"))
|
||||
llm.invoke(UserMessage(content="Do not respond"), max_tokens=50)
|
||||
return None
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
|
||||
@@ -1,52 +1,27 @@
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
|
||||
OPENAI_PROVIDER_NAME = "openai"
|
||||
# Curated list of OpenAI models to show by default in the UI
|
||||
OPENAI_VISIBLE_MODEL_NAMES = {
|
||||
"gpt-5",
|
||||
"gpt-5-mini",
|
||||
"o1",
|
||||
"o3-mini",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
}
|
||||
|
||||
BEDROCK_PROVIDER_NAME = "bedrock"
|
||||
BEDROCK_DEFAULT_MODEL = "anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
|
||||
|
||||
def _fallback_bedrock_regions() -> list[str]:
|
||||
# Fall back to a conservative set of well-known Bedrock regions if boto3 data isn't available.
|
||||
return [
|
||||
"us-east-1",
|
||||
"us-east-2",
|
||||
"us-gov-east-1",
|
||||
"us-gov-west-1",
|
||||
"us-west-2",
|
||||
"ap-northeast-1",
|
||||
"ap-south-1",
|
||||
"ap-southeast-1",
|
||||
"ap-southeast-2",
|
||||
"ap-east-1",
|
||||
"ca-central-1",
|
||||
"eu-central-1",
|
||||
"eu-west-2",
|
||||
]
|
||||
|
||||
|
||||
OLLAMA_PROVIDER_NAME = "ollama_chat"
|
||||
OLLAMA_API_KEY_CONFIG_KEY = "OLLAMA_API_KEY"
|
||||
|
||||
LM_STUDIO_PROVIDER_NAME = "lm_studio"
|
||||
LM_STUDIO_API_KEY_CONFIG_KEY = "LM_STUDIO_API_KEY"
|
||||
|
||||
# Providers that use optional Bearer auth from custom_config
|
||||
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING: dict[str, str] = {
|
||||
LlmProviderNames.OLLAMA_CHAT: OLLAMA_API_KEY_CONFIG_KEY,
|
||||
LlmProviderNames.LM_STUDIO: LM_STUDIO_API_KEY_CONFIG_KEY,
|
||||
}
|
||||
|
||||
# OpenRouter
|
||||
OPENROUTER_PROVIDER_NAME = "openrouter"
|
||||
|
||||
ANTHROPIC_PROVIDER_NAME = "anthropic"
|
||||
|
||||
# Curated list of Anthropic models to show by default in the UI
|
||||
ANTHROPIC_VISIBLE_MODEL_NAMES = {
|
||||
"claude-opus-4-5",
|
||||
"claude-sonnet-4-5",
|
||||
"claude-haiku-4-5",
|
||||
}
|
||||
|
||||
AZURE_PROVIDER_NAME = "azure"
|
||||
|
||||
|
||||
@@ -54,13 +29,6 @@ VERTEXAI_PROVIDER_NAME = "vertex_ai"
|
||||
VERTEX_CREDENTIALS_FILE_KWARG = "vertex_credentials"
|
||||
VERTEX_CREDENTIALS_FILE_KWARG_ENV_VAR_FORMAT = "CREDENTIALS_FILE"
|
||||
VERTEX_LOCATION_KWARG = "vertex_location"
|
||||
VERTEXAI_DEFAULT_MODEL = "gemini-2.5-flash"
|
||||
# Curated list of Vertex AI models to show by default in the UI
|
||||
VERTEXAI_VISIBLE_MODEL_NAMES = {
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-flash-lite",
|
||||
"gemini-2.5-pro",
|
||||
}
|
||||
|
||||
AWS_REGION_NAME_KWARG = "aws_region_name"
|
||||
AWS_REGION_NAME_KWARG_ENV_VAR_FORMAT = "AWS_REGION_NAME"
|
||||
|
||||
@@ -15,6 +15,7 @@ from onyx.llm.well_known_providers.auto_update_service import (
|
||||
from onyx.llm.well_known_providers.constants import ANTHROPIC_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import AZURE_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import BEDROCK_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import LM_STUDIO_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OLLAMA_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENAI_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENROUTER_PROVIDER_NAME
|
||||
@@ -44,6 +45,7 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
|
||||
ANTHROPIC_PROVIDER_NAME: get_anthropic_model_names(),
|
||||
VERTEXAI_PROVIDER_NAME: get_vertexai_model_names(),
|
||||
OLLAMA_PROVIDER_NAME: [], # Dynamic - fetched from Ollama API
|
||||
LM_STUDIO_PROVIDER_NAME: [], # Dynamic - fetched from LM Studio API
|
||||
OPENROUTER_PROVIDER_NAME: [], # Dynamic - fetched from OpenRouter API
|
||||
}
|
||||
|
||||
@@ -323,6 +325,7 @@ def get_provider_display_name(provider_name: str) -> str:
|
||||
_ONYX_PROVIDER_DISPLAY_NAMES: dict[str, str] = {
|
||||
OPENAI_PROVIDER_NAME: "ChatGPT (OpenAI)",
|
||||
OLLAMA_PROVIDER_NAME: "Ollama",
|
||||
LM_STUDIO_PROVIDER_NAME: "LM Studio",
|
||||
ANTHROPIC_PROVIDER_NAME: "Claude (Anthropic)",
|
||||
AZURE_PROVIDER_NAME: "Azure OpenAI",
|
||||
BEDROCK_PROVIDER_NAME: "Amazon Bedrock",
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"version": "1.1",
|
||||
"updated_at": "2026-02-05T00:00:00Z",
|
||||
"updated_at": "2026-03-05T00:00:00Z",
|
||||
"providers": {
|
||||
"openai": {
|
||||
"default_model": { "name": "gpt-5.2" },
|
||||
"default_model": { "name": "gpt-5.4" },
|
||||
"additional_visible_models": [
|
||||
{ "name": "gpt-5-mini" },
|
||||
{ "name": "gpt-4.1" }
|
||||
{ "name": "gpt-5.4" },
|
||||
{ "name": "gpt-5.2" }
|
||||
]
|
||||
},
|
||||
"anthropic": {
|
||||
@@ -16,6 +16,10 @@
|
||||
"name": "claude-opus-4-6",
|
||||
"display_name": "Claude Opus 4.6"
|
||||
},
|
||||
{
|
||||
"name": "claude-sonnet-4-6",
|
||||
"display_name": "Claude Sonnet 4.6"
|
||||
},
|
||||
{
|
||||
"name": "claude-opus-4-5",
|
||||
"display_name": "Claude Opus 4.5"
|
||||
|
||||
@@ -961,9 +961,9 @@
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@hono/node-server": {
|
||||
"version": "1.19.9",
|
||||
"resolved": "https://registry.npmjs.org/@hono/node-server/-/node-server-1.19.9.tgz",
|
||||
"integrity": "sha512-vHL6w3ecZsky+8P5MD+eFfaGTyCeOHUIFYMGpQGbrBTSmNNoxv0if69rEZ5giu36weC5saFuznL411gRX7bJDw==",
|
||||
"version": "1.19.10",
|
||||
"resolved": "https://registry.npmjs.org/@hono/node-server/-/node-server-1.19.10.tgz",
|
||||
"integrity": "sha512-hZ7nOssGqRgyV3FVVQdfi+U4q02uB23bpnYpdvNXkYTRRyWx84b7yf1ans+dnJ/7h41sGL3CeQTfO+ZGxuO+Iw==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=18.14.1"
|
||||
@@ -1573,27 +1573,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@isaacs/balanced-match": {
|
||||
"version": "4.0.1",
|
||||
"resolved": "https://registry.npmjs.org/@isaacs/balanced-match/-/balanced-match-4.0.1.tgz",
|
||||
"integrity": "sha512-yzMTt9lEb8Gv7zRioUilSglI0c0smZ9k5D65677DLWLtWJaXIS3CqcGyUFByYKlnUj6TkjLVs54fBl6+TiGQDQ==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": "20 || >=22"
|
||||
}
|
||||
},
|
||||
"node_modules/@isaacs/brace-expansion": {
|
||||
"version": "5.0.1",
|
||||
"resolved": "https://registry.npmjs.org/@isaacs/brace-expansion/-/brace-expansion-5.0.1.tgz",
|
||||
"integrity": "sha512-WMz71T1JS624nWj2n2fnYAuPovhv7EUhk69R6i9dsVyzxt5eM3bjwvgk9L+APE1TRscGysAVMANkB0jh0LQZrQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@isaacs/balanced-match": "^4.0.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": "20 || >=22"
|
||||
}
|
||||
},
|
||||
"node_modules/@jridgewell/gen-mapping": {
|
||||
"version": "0.3.13",
|
||||
"resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.13.tgz",
|
||||
@@ -1680,9 +1659,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@modelcontextprotocol/sdk/node_modules/ajv": {
|
||||
"version": "8.17.1",
|
||||
"resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz",
|
||||
"integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==",
|
||||
"version": "8.18.0",
|
||||
"resolved": "https://registry.npmjs.org/ajv/-/ajv-8.18.0.tgz",
|
||||
"integrity": "sha512-PlXPeEWMXMZ7sPYOHqmDyCJzcfNrUr3fGNKtezX14ykXOEIvyK81d+qydx89KY5O71FKMPaQ2vBfBFI5NHR63A==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"fast-deep-equal": "^3.1.3",
|
||||
@@ -3855,6 +3834,27 @@
|
||||
"path-browserify": "^1.0.1"
|
||||
}
|
||||
},
|
||||
"node_modules/@ts-morph/common/node_modules/balanced-match": {
|
||||
"version": "4.0.4",
|
||||
"resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-4.0.4.tgz",
|
||||
"integrity": "sha512-BLrgEcRTwX2o6gGxGOCNyMvGSp35YofuYzw9h1IMTRmKqttAZZVU67bdb9Pr2vUHA8+j3i2tJfjO6C6+4myGTA==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": "18 || 20 || >=22"
|
||||
}
|
||||
},
|
||||
"node_modules/@ts-morph/common/node_modules/brace-expansion": {
|
||||
"version": "5.0.3",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-5.0.3.tgz",
|
||||
"integrity": "sha512-fy6KJm2RawA5RcHkLa1z/ScpBeA762UF9KmZQxwIbDtRJrgLzM10depAiEQ+CXYcoiqW1/m96OAAoke2nE9EeA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"balanced-match": "^4.0.2"
|
||||
},
|
||||
"engines": {
|
||||
"node": "18 || 20 || >=22"
|
||||
}
|
||||
},
|
||||
"node_modules/@ts-morph/common/node_modules/fast-glob": {
|
||||
"version": "3.3.3",
|
||||
"resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.3.3.tgz",
|
||||
@@ -3884,15 +3884,15 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@ts-morph/common/node_modules/minimatch": {
|
||||
"version": "10.1.1",
|
||||
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-10.1.1.tgz",
|
||||
"integrity": "sha512-enIvLvRAFZYXJzkCYG5RKmPfrFArdLv+R+lbQ53BmIMLIry74bjKzX6iHAm8WYamJkhSSEabrWN5D97XnKObjQ==",
|
||||
"version": "10.2.4",
|
||||
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-10.2.4.tgz",
|
||||
"integrity": "sha512-oRjTw/97aTBN0RHbYCdtF1MQfvusSIBQM0IZEgzl6426+8jSC0nF1a/GmnVLpfB9yyr6g6FTqWqiZVbxrtaCIg==",
|
||||
"license": "BlueOak-1.0.0",
|
||||
"dependencies": {
|
||||
"@isaacs/brace-expansion": "^5.0.0"
|
||||
"brace-expansion": "^5.0.2"
|
||||
},
|
||||
"engines": {
|
||||
"node": "20 || >=22"
|
||||
"node": "18 || 20 || >=22"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/isaacs"
|
||||
@@ -4234,13 +4234,13 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@typescript-eslint/typescript-estree/node_modules/minimatch": {
|
||||
"version": "9.0.5",
|
||||
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz",
|
||||
"integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==",
|
||||
"version": "9.0.9",
|
||||
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.9.tgz",
|
||||
"integrity": "sha512-OBwBN9AL4dqmETlpS2zasx+vTeWclWzkblfZk7KTA5j3jeOONz/tRCnZomUyvNg83wL5Zv9Ss6HMJXAgL8R2Yg==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"dependencies": {
|
||||
"brace-expansion": "^2.0.1"
|
||||
"brace-expansion": "^2.0.2"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=16 || 14 >=14.17"
|
||||
@@ -4619,9 +4619,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/ajv": {
|
||||
"version": "6.12.6",
|
||||
"resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz",
|
||||
"integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==",
|
||||
"version": "6.14.0",
|
||||
"resolved": "https://registry.npmjs.org/ajv/-/ajv-6.14.0.tgz",
|
||||
"integrity": "sha512-IWrosm/yrn43eiKqkfkHis7QioDleaXQHdDVPKg0FSwwd/DuvyX79TZnFOnYpB7dcsFAMmtFztZuXPDvSePkFw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
@@ -4653,9 +4653,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/ajv-formats/node_modules/ajv": {
|
||||
"version": "8.17.1",
|
||||
"resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz",
|
||||
"integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==",
|
||||
"version": "8.18.0",
|
||||
"resolved": "https://registry.npmjs.org/ajv/-/ajv-8.18.0.tgz",
|
||||
"integrity": "sha512-PlXPeEWMXMZ7sPYOHqmDyCJzcfNrUr3fGNKtezX14ykXOEIvyK81d+qydx89KY5O71FKMPaQ2vBfBFI5NHR63A==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"fast-deep-equal": "^3.1.3",
|
||||
@@ -6758,12 +6758,12 @@
|
||||
}
|
||||
},
|
||||
"node_modules/express-rate-limit": {
|
||||
"version": "8.2.1",
|
||||
"resolved": "https://registry.npmjs.org/express-rate-limit/-/express-rate-limit-8.2.1.tgz",
|
||||
"integrity": "sha512-PCZEIEIxqwhzw4KF0n7QF4QqruVTcF73O5kFKUnGOyjbCCgizBBiFaYpd/fnBLUMPw/BWw9OsiN7GgrNYr7j6g==",
|
||||
"version": "8.3.0",
|
||||
"resolved": "https://registry.npmjs.org/express-rate-limit/-/express-rate-limit-8.3.0.tgz",
|
||||
"integrity": "sha512-KJzBawY6fB9FiZGdE/0aftepZ91YlaGIrV8vgblRM3J8X+dHx/aiowJWwkx6LIGyuqGiANsjSwwrbb8mifOJ4Q==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"ip-address": "10.0.1"
|
||||
"ip-address": "10.1.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 16"
|
||||
@@ -7556,9 +7556,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/ip-address": {
|
||||
"version": "10.0.1",
|
||||
"resolved": "https://registry.npmjs.org/ip-address/-/ip-address-10.0.1.tgz",
|
||||
"integrity": "sha512-NWv9YLW4PoW2B7xtzaS3NCot75m6nK7Icdv0o3lfMceJVRfSoQwqD4wEH5rLwoKJwUiZ/rfpiVBhnaF0FK4HoA==",
|
||||
"version": "10.1.0",
|
||||
"resolved": "https://registry.npmjs.org/ip-address/-/ip-address-10.1.0.tgz",
|
||||
"integrity": "sha512-XXADHxXmvT9+CRxhXg56LJovE+bmWnEWB78LB83VZTprKTmaC5QfruXocxzTZ2Kl0DNwKuBdlIhjL8LeY8Sf8Q==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 12"
|
||||
@@ -8831,9 +8831,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/minimatch": {
|
||||
"version": "3.1.2",
|
||||
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz",
|
||||
"integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==",
|
||||
"version": "3.1.5",
|
||||
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.5.tgz",
|
||||
"integrity": "sha512-VgjWUsnnT6n+NUk6eZq77zeFdpW2LWDzP6zFGrCbHXiYNul5Dzqk2HHQ5uFH2DNW5Xbp8+jVzaeNt94ssEEl4w==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"dependencies": {
|
||||
@@ -9699,9 +9699,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/qs": {
|
||||
"version": "6.14.1",
|
||||
"resolved": "https://registry.npmjs.org/qs/-/qs-6.14.1.tgz",
|
||||
"integrity": "sha512-4EK3+xJl8Ts67nLYNwqw/dsFVnCf+qR7RgXSK9jEEm9unao3njwMDdmsdvoKBKHzxd7tCYz5e5M+SnMjdtXGQQ==",
|
||||
"version": "6.14.2",
|
||||
"resolved": "https://registry.npmjs.org/qs/-/qs-6.14.2.tgz",
|
||||
"integrity": "sha512-V/yCWTTF7VJ9hIh18Ugr2zhJMP01MY7c5kh4J870L7imm6/DIzBsNLTXzMwUA3yZ5b/KBqLx8Kp3uRvd7xSe3Q==",
|
||||
"license": "BSD-3-Clause",
|
||||
"dependencies": {
|
||||
"side-channel": "^1.1.0"
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import cast
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
@@ -27,8 +28,6 @@ from onyx.db.feedback import update_document_boost_for_user
|
||||
from onyx.db.feedback import update_document_hidden_for_user
|
||||
from onyx.db.index_attempt import cancel_indexing_attempts_for_ccpair
|
||||
from onyx.db.models import User
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
@@ -125,11 +124,11 @@ def validate_existing_genai_api_key(
|
||||
try:
|
||||
llm = get_default_llm(timeout=10)
|
||||
except ValueError:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "LLM not setup")
|
||||
raise HTTPException(status_code=404, detail="LLM not setup")
|
||||
|
||||
error = test_llm(llm)
|
||||
if error:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, error)
|
||||
raise HTTPException(status_code=400, detail=error)
|
||||
|
||||
# Mark check as successful
|
||||
curr_time = datetime.now(tz=timezone.utc)
|
||||
@@ -160,7 +159,10 @@ def create_deletion_attempt_for_connector_id(
|
||||
f"'{credential_id}' does not exist. Has it already been deleted?"
|
||||
)
|
||||
logger.error(error)
|
||||
raise OnyxError(OnyxErrorCode.CONNECTOR_NOT_FOUND, error)
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=error,
|
||||
)
|
||||
|
||||
# Cancel any scheduled indexing attempts
|
||||
cancel_indexing_attempts_for_ccpair(
|
||||
@@ -176,9 +178,9 @@ def create_deletion_attempt_for_connector_id(
|
||||
# connector_credential_pair=cc_pair, db_session=db_session
|
||||
# )
|
||||
# if deletion_attempt_disallowed_reason:
|
||||
# raise OnyxError(
|
||||
# OnyxErrorCode.VALIDATION_ERROR,
|
||||
# deletion_attempt_disallowed_reason,
|
||||
# raise HTTPException(
|
||||
# status_code=400,
|
||||
# detail=deletion_attempt_disallowed_reason,
|
||||
# )
|
||||
|
||||
# mark as deleting
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
@@ -22,8 +24,6 @@ from onyx.db.discord_bot import update_discord_channel_config
|
||||
from onyx.db.discord_bot import update_guild_config
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.manage.discord_bot.models import DiscordBotConfigCreateRequest
|
||||
from onyx.server.manage.discord_bot.models import DiscordBotConfigResponse
|
||||
from onyx.server.manage.discord_bot.models import DiscordChannelConfigResponse
|
||||
@@ -47,14 +47,14 @@ def _check_bot_config_api_access() -> None:
|
||||
- When DISCORD_BOT_TOKEN env var is set (managed via env)
|
||||
"""
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.UNAUTHORIZED,
|
||||
"Discord bot configuration is managed by Onyx on Cloud.",
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Discord bot configuration is managed by Onyx on Cloud.",
|
||||
)
|
||||
if DISCORD_BOT_TOKEN:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.UNAUTHORIZED,
|
||||
"Discord bot is configured via environment variables. API access disabled.",
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Discord bot is configured via environment variables. API access disabled.",
|
||||
)
|
||||
|
||||
|
||||
@@ -92,9 +92,9 @@ def create_bot_request(
|
||||
bot_token=request.bot_token,
|
||||
)
|
||||
except ValueError:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CONFLICT,
|
||||
"Discord bot config already exists. Delete it first to create a new one.",
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Discord bot config already exists. Delete it first to create a new one.",
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
@@ -117,7 +117,7 @@ def delete_bot_config_endpoint(
|
||||
"""
|
||||
deleted = delete_discord_bot_config(db_session)
|
||||
if not deleted:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Bot config not found")
|
||||
raise HTTPException(status_code=404, detail="Bot config not found")
|
||||
|
||||
# Also delete the service API key used by the Discord bot
|
||||
delete_discord_service_api_key(db_session)
|
||||
@@ -144,7 +144,7 @@ def delete_service_api_key_endpoint(
|
||||
"""
|
||||
deleted = delete_discord_service_api_key(db_session)
|
||||
if not deleted:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Service API key not found")
|
||||
raise HTTPException(status_code=404, detail="Service API key not found")
|
||||
db_session.commit()
|
||||
return {"deleted": True}
|
||||
|
||||
@@ -189,7 +189,7 @@ def get_guild_config(
|
||||
"""Get specific guild config."""
|
||||
config = get_guild_config_by_internal_id(db_session, internal_id=config_id)
|
||||
if not config:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Guild config not found")
|
||||
raise HTTPException(status_code=404, detail="Guild config not found")
|
||||
return DiscordGuildConfigResponse.model_validate(config)
|
||||
|
||||
|
||||
@@ -203,7 +203,7 @@ def update_guild_request(
|
||||
"""Update guild config."""
|
||||
config = get_guild_config_by_internal_id(db_session, internal_id=config_id)
|
||||
if not config:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Guild config not found")
|
||||
raise HTTPException(status_code=404, detail="Guild config not found")
|
||||
|
||||
config = update_guild_config(
|
||||
db_session,
|
||||
@@ -228,7 +228,7 @@ def delete_guild_request(
|
||||
"""
|
||||
deleted = delete_guild_config(db_session, config_id)
|
||||
if not deleted:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Guild config not found")
|
||||
raise HTTPException(status_code=404, detail="Guild config not found")
|
||||
|
||||
# On Cloud, delete service API key when all guilds are removed
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
@@ -254,9 +254,9 @@ def list_channel_configs(
|
||||
"""List whitelisted channels for a guild."""
|
||||
guild_config = get_guild_config_by_internal_id(db_session, internal_id=config_id)
|
||||
if not guild_config:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Guild config not found")
|
||||
raise HTTPException(status_code=404, detail="Guild config not found")
|
||||
if not guild_config.guild_id:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Guild not yet registered")
|
||||
raise HTTPException(status_code=400, detail="Guild not yet registered")
|
||||
|
||||
configs = get_channel_configs(db_session, config_id)
|
||||
return [DiscordChannelConfigResponse.model_validate(c) for c in configs]
|
||||
@@ -278,7 +278,7 @@ def update_channel_request(
|
||||
db_session, guild_config_id, channel_config_id
|
||||
)
|
||||
if not config:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Channel config not found")
|
||||
raise HTTPException(status_code=404, detail="Channel config not found")
|
||||
|
||||
config = update_discord_channel_config(
|
||||
db_session,
|
||||
|
||||
@@ -3,6 +3,7 @@ import re
|
||||
|
||||
import requests
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
|
||||
from onyx import __version__
|
||||
from onyx.auth.users import anonymous_user_enabled
|
||||
@@ -15,8 +16,6 @@ from onyx.configs.constants import DEV_VERSION_PATTERN
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.configs.constants import STABLE_VERSION_PATTERN
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.manage.models import AllVersions
|
||||
from onyx.server.manage.models import AuthTypeResponse
|
||||
from onyx.server.manage.models import ContainerVersions
|
||||
@@ -105,14 +104,14 @@ def get_versions() -> AllVersions:
|
||||
|
||||
# Ensure we have at least one tag of each type
|
||||
if not dev_tags:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"No valid dev versions found matching pattern v(number).(number).(number)-beta.(number)",
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="No valid dev versions found matching pattern v(number).(number).(number)-beta.(number)",
|
||||
)
|
||||
if not stable_tags:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"No valid stable versions found matching pattern v(number).(number).(number)",
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="No valid stable versions found matching pattern v(number).(number).(number)",
|
||||
)
|
||||
|
||||
# Sort common tags and get the latest one
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
@@ -14,8 +15,6 @@ from onyx.db.llm import remove_llm_provider__no_commit
|
||||
from onyx.db.models import LLMProvider as LLMProviderModel
|
||||
from onyx.db.models import ModelConfiguration
|
||||
from onyx.db.models import User
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.image_gen.exceptions import ImageProviderCredentialsError
|
||||
from onyx.image_gen.factory import get_image_generation_provider
|
||||
from onyx.image_gen.factory import validate_credentials
|
||||
@@ -75,9 +74,9 @@ def _build_llm_provider_request(
|
||||
# Clone mode: Only use API key from source provider
|
||||
source_provider = db_session.get(LLMProviderModel, source_llm_provider_id)
|
||||
if not source_provider:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"Source LLM provider with id {source_llm_provider_id} not found",
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Source LLM provider with id {source_llm_provider_id} not found",
|
||||
)
|
||||
|
||||
_validate_llm_provider_change(
|
||||
@@ -111,9 +110,9 @@ def _build_llm_provider_request(
|
||||
)
|
||||
|
||||
if not provider:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No provider or source llm provided",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No provider or source llm provided",
|
||||
)
|
||||
|
||||
credentials = ImageGenerationProviderCredentials(
|
||||
@@ -125,9 +124,9 @@ def _build_llm_provider_request(
|
||||
)
|
||||
|
||||
if not validate_credentials(provider, credentials):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Incorrect credentials for {provider}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Incorrect credentials for {provider}",
|
||||
)
|
||||
|
||||
return LLMProviderUpsertRequest(
|
||||
@@ -216,9 +215,9 @@ def test_image_generation(
|
||||
LLMProviderModel, test_request.source_llm_provider_id
|
||||
)
|
||||
if not source_provider:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"Source LLM provider with id {test_request.source_llm_provider_id} not found",
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Source LLM provider with id {test_request.source_llm_provider_id} not found",
|
||||
)
|
||||
|
||||
_validate_llm_provider_change(
|
||||
@@ -237,9 +236,9 @@ def test_image_generation(
|
||||
provider = source_provider.provider
|
||||
|
||||
if provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No provider or source llm provided",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No provider or source llm provided",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -258,14 +257,14 @@ def test_image_generation(
|
||||
),
|
||||
)
|
||||
except ValueError:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"Invalid image generation provider: {provider}",
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Invalid image generation provider: {provider}",
|
||||
)
|
||||
except ImageProviderCredentialsError:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.UNAUTHENTICATED,
|
||||
"Invalid image generation credentials",
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid image generation credentials",
|
||||
)
|
||||
|
||||
quality = _get_test_quality_for_model(test_request.model_name)
|
||||
@@ -277,15 +276,15 @@ def test_image_generation(
|
||||
n=1,
|
||||
quality=quality,
|
||||
)
|
||||
except OnyxError:
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
# Log only exception type to avoid exposing sensitive data
|
||||
# (LiteLLM errors may contain URLs with API keys or auth tokens)
|
||||
logger.warning(f"Image generation test failed: {type(e).__name__}")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Image generation test failed: {type(e).__name__}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Image generation test failed: {type(e).__name__}",
|
||||
)
|
||||
|
||||
|
||||
@@ -310,9 +309,9 @@ def create_config(
|
||||
db_session, config_create.image_provider_id
|
||||
)
|
||||
if existing_config:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"ImageGenerationConfig with image_provider_id '{config_create.image_provider_id}' already exists",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"ImageGenerationConfig with image_provider_id '{config_create.image_provider_id}' already exists",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -346,10 +345,10 @@ def create_config(
|
||||
db_session.commit()
|
||||
db_session.refresh(config)
|
||||
return ImageGenerationConfigView.from_model(config)
|
||||
except OnyxError:
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.get("/config")
|
||||
@@ -374,9 +373,9 @@ def get_config_credentials(
|
||||
"""
|
||||
config = get_image_generation_config(db_session, image_provider_id)
|
||||
if not config:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
|
||||
)
|
||||
|
||||
return ImageGenerationCredentials.from_model(config)
|
||||
@@ -402,9 +401,9 @@ def update_config(
|
||||
# 1. Get existing config
|
||||
existing_config = get_image_generation_config(db_session, image_provider_id)
|
||||
if not existing_config:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
|
||||
)
|
||||
|
||||
old_llm_provider_id = existing_config.model_configuration.llm_provider_id
|
||||
@@ -473,10 +472,10 @@ def update_config(
|
||||
db_session.refresh(existing_config)
|
||||
return ImageGenerationConfigView.from_model(existing_config)
|
||||
|
||||
except OnyxError:
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.delete("/config/{image_provider_id}")
|
||||
@@ -490,9 +489,9 @@ def delete_config(
|
||||
# Get the config first to find the associated LLM provider
|
||||
existing_config = get_image_generation_config(db_session, image_provider_id)
|
||||
if not existing_config:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
|
||||
)
|
||||
|
||||
llm_provider_id = existing_config.model_configuration.llm_provider_id
|
||||
@@ -504,10 +503,10 @@ def delete_config(
|
||||
remove_llm_provider__no_commit(db_session, llm_provider_id)
|
||||
|
||||
db_session.commit()
|
||||
except OnyxError:
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.post("/config/{image_provider_id}/default")
|
||||
@@ -520,7 +519,7 @@ def set_config_as_default(
|
||||
try:
|
||||
set_default_image_generation_config(db_session, image_provider_id)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.delete("/config/{image_provider_id}/default")
|
||||
@@ -533,4 +532,4 @@ def unset_config_as_default(
|
||||
try:
|
||||
unset_default_image_generation_config(db_session, image_provider_id)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
@@ -48,6 +48,7 @@ from onyx.llm.utils import test_llm
|
||||
from onyx.llm.well_known_providers.auto_update_service import (
|
||||
fetch_llm_recommendations_from_github,
|
||||
)
|
||||
from onyx.llm.well_known_providers.constants import LM_STUDIO_API_KEY_CONFIG_KEY
|
||||
from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
fetch_available_well_known_llms,
|
||||
)
|
||||
@@ -62,6 +63,8 @@ from onyx.server.manage.llm.models import LLMProviderDescriptor
|
||||
from onyx.server.manage.llm.models import LLMProviderResponse
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import LMStudioFinalModelResponse
|
||||
from onyx.server.manage.llm.models import LMStudioModelsRequest
|
||||
from onyx.server.manage.llm.models import OllamaFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OllamaModelDetails
|
||||
from onyx.server.manage.llm.models import OllamaModelsRequest
|
||||
@@ -73,6 +76,7 @@ from onyx.server.manage.llm.models import VisionProviderResponse
|
||||
from onyx.server.manage.llm.utils import generate_bedrock_display_name
|
||||
from onyx.server.manage.llm.utils import generate_ollama_display_name
|
||||
from onyx.server.manage.llm.utils import infer_vision_support
|
||||
from onyx.server.manage.llm.utils import is_reasoning_model
|
||||
from onyx.server.manage.llm.utils import is_valid_bedrock_model
|
||||
from onyx.server.manage.llm.utils import ModelMetadata
|
||||
from onyx.server.manage.llm.utils import strip_openrouter_vendor_prefix
|
||||
@@ -441,6 +445,17 @@ def put_llm_provider(
|
||||
not existing_provider or not existing_provider.is_auto_mode
|
||||
)
|
||||
|
||||
# Before the upsert, check if this provider currently owns the global
|
||||
# CHAT default. The upsert may cascade-delete model_configurations
|
||||
# (and their flow mappings), so we need to remember this beforehand.
|
||||
was_default_provider = False
|
||||
if existing_provider and transitioning_to_auto_mode:
|
||||
current_default = fetch_default_llm_model(db_session)
|
||||
was_default_provider = (
|
||||
current_default is not None
|
||||
and current_default.llm_provider_id == existing_provider.id
|
||||
)
|
||||
|
||||
try:
|
||||
result = upsert_llm_provider(
|
||||
llm_provider_upsert_request=llm_provider_upsert_request,
|
||||
@@ -463,6 +478,20 @@ def put_llm_provider(
|
||||
updated_provider,
|
||||
config,
|
||||
)
|
||||
|
||||
# If this provider was the default before the transition,
|
||||
# restore the default using the recommended model.
|
||||
if was_default_provider:
|
||||
recommended = config.get_default_model(
|
||||
llm_provider_upsert_request.provider
|
||||
)
|
||||
if recommended:
|
||||
update_default_provider(
|
||||
provider_id=updated_provider.id,
|
||||
model_name=recommended.name,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Refresh result with synced models
|
||||
result = LLMProviderView.from_model(updated_provider)
|
||||
|
||||
@@ -1217,3 +1246,117 @@ def get_openrouter_available_models(
|
||||
logger.warning(f"Failed to sync OpenRouter models to DB: {e}")
|
||||
|
||||
return sorted_results
|
||||
|
||||
|
||||
@admin_router.post("/lm-studio/available-models")
|
||||
def get_lm_studio_available_models(
|
||||
request: LMStudioModelsRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LMStudioFinalModelResponse]:
|
||||
"""Fetch available models from an LM Studio server.
|
||||
|
||||
Uses the LM Studio-native /api/v1/models endpoint which exposes
|
||||
rich metadata including capabilities (vision, reasoning),
|
||||
display names, and context lengths.
|
||||
"""
|
||||
cleaned_api_base = request.api_base.strip().rstrip("/")
|
||||
# Strip /v1 suffix that users may copy from OpenAI-compatible tool configs;
|
||||
# the native metadata endpoint lives at /api/v1/models, not /v1/api/v1/models.
|
||||
cleaned_api_base = cleaned_api_base.removesuffix("/v1")
|
||||
if not cleaned_api_base:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"API base URL is required to fetch LM Studio models.",
|
||||
)
|
||||
|
||||
# If provider_name is given and the api_key hasn't been changed by the user,
|
||||
# fall back to the stored API key from the database (the form value is masked).
|
||||
api_key = request.api_key
|
||||
if request.provider_name and not request.api_key_changed:
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=request.provider_name, db_session=db_session
|
||||
)
|
||||
if existing_provider and existing_provider.custom_config:
|
||||
api_key = existing_provider.custom_config.get(LM_STUDIO_API_KEY_CONFIG_KEY)
|
||||
|
||||
url = f"{cleaned_api_base}/api/v1/models"
|
||||
headers: dict[str, str] = {}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
try:
|
||||
response = httpx.get(url, headers=headers, timeout=10.0)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
except Exception as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch LM Studio models: {e}",
|
||||
)
|
||||
|
||||
models = response_json.get("models", [])
|
||||
if not isinstance(models, list) or len(models) == 0:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No models found from your LM Studio server.",
|
||||
)
|
||||
|
||||
results: list[LMStudioFinalModelResponse] = []
|
||||
for item in models:
|
||||
# Filter to LLM-type models only (skip embeddings, etc.)
|
||||
if item.get("type") != "llm":
|
||||
continue
|
||||
|
||||
model_key = item.get("key")
|
||||
if not model_key:
|
||||
continue
|
||||
|
||||
display_name = item.get("display_name") or model_key
|
||||
max_context_length = item.get("max_context_length")
|
||||
capabilities = item.get("capabilities") or {}
|
||||
|
||||
results.append(
|
||||
LMStudioFinalModelResponse(
|
||||
name=model_key,
|
||||
display_name=display_name,
|
||||
max_input_tokens=max_context_length,
|
||||
supports_image_input=capabilities.get("vision", False),
|
||||
supports_reasoning=capabilities.get("reasoning", False)
|
||||
or is_reasoning_model(model_key, display_name),
|
||||
)
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No compatible models found from LM Studio server.",
|
||||
)
|
||||
|
||||
sorted_results = sorted(results, key=lambda m: m.name.lower())
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
try:
|
||||
models_to_sync = [
|
||||
{
|
||||
"name": r.name,
|
||||
"display_name": r.display_name,
|
||||
"max_input_tokens": r.max_input_tokens,
|
||||
"supports_image_input": r.supports_image_input,
|
||||
}
|
||||
for r in sorted_results
|
||||
]
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=models_to_sync,
|
||||
)
|
||||
if new_count > 0:
|
||||
logger.info(
|
||||
f"Added {new_count} new LM Studio models to provider '{request.provider_name}'"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync LM Studio models to DB: {e}")
|
||||
|
||||
return sorted_results
|
||||
|
||||
@@ -371,6 +371,22 @@ class OpenRouterFinalModelResponse(BaseModel):
|
||||
supports_image_input: bool
|
||||
|
||||
|
||||
# LM Studio dynamic models fetch
|
||||
class LMStudioModelsRequest(BaseModel):
|
||||
api_base: str
|
||||
api_key: str | None = None
|
||||
api_key_changed: bool = False
|
||||
provider_name: str | None = None # Optional: to save models to existing provider
|
||||
|
||||
|
||||
class LMStudioFinalModelResponse(BaseModel):
|
||||
name: str # Model ID from LM Studio (e.g., "lmstudio-community/Meta-Llama-3-8B")
|
||||
display_name: str # Human-readable name
|
||||
max_input_tokens: int | None # From LM Studio API or None if unavailable
|
||||
supports_image_input: bool
|
||||
supports_reasoning: bool
|
||||
|
||||
|
||||
class DefaultModel(BaseModel):
|
||||
provider_id: int
|
||||
model_name: str
|
||||
|
||||
@@ -12,6 +12,7 @@ from typing import TypedDict
|
||||
|
||||
from onyx.llm.constants import BEDROCK_MODEL_NAME_MAPPINGS
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.constants import MODEL_PREFIX_TO_VENDOR
|
||||
from onyx.llm.constants import OLLAMA_MODEL_NAME_MAPPINGS
|
||||
from onyx.llm.constants import OLLAMA_MODEL_TO_VENDOR
|
||||
from onyx.llm.constants import PROVIDER_DISPLAY_NAMES
|
||||
@@ -23,6 +24,7 @@ DYNAMIC_LLM_PROVIDERS = frozenset(
|
||||
LlmProviderNames.OPENROUTER,
|
||||
LlmProviderNames.BEDROCK,
|
||||
LlmProviderNames.OLLAMA_CHAT,
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -348,4 +350,19 @@ def extract_vendor_from_model_name(model_name: str, provider: str) -> str | None
|
||||
# Fallback: capitalize the base name as vendor
|
||||
return base_name.split("-")[0].title()
|
||||
|
||||
elif provider == LlmProviderNames.LM_STUDIO:
|
||||
# LM Studio model IDs can be paths like "publisher/model-name"
|
||||
# or simple names. Use MODEL_PREFIX_TO_VENDOR for matching.
|
||||
|
||||
model_lower = model_name.lower()
|
||||
# Check for slash-separated vendor prefix first
|
||||
if "/" in model_lower:
|
||||
vendor_key = model_lower.split("/")[0]
|
||||
return PROVIDER_DISPLAY_NAMES.get(vendor_key, vendor_key.title())
|
||||
# Fallback to model prefix matching
|
||||
for prefix, vendor in MODEL_PREFIX_TO_VENDOR.items():
|
||||
if model_lower.startswith(prefix):
|
||||
return PROVIDER_DISPLAY_NAMES.get(vendor, vendor.title())
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
@@ -25,8 +27,6 @@ from onyx.db.search_settings import update_current_search_settings
|
||||
from onyx.db.search_settings import update_search_settings_status
|
||||
from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_processing.unstructured import delete_unstructured_api_key
|
||||
from onyx.file_processing.unstructured import get_unstructured_api_key
|
||||
from onyx.file_processing.unstructured import update_unstructured_api_key
|
||||
@@ -58,9 +58,9 @@ def set_new_search_settings(
|
||||
|
||||
# Disallow contextual RAG for cloud deployments.
|
||||
if MULTI_TENANT and search_settings_new.enable_contextual_rag:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Contextual RAG disabled in Onyx Cloud",
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Contextual RAG disabled in Onyx Cloud",
|
||||
)
|
||||
|
||||
# Validate cloud provider exists or create new LiteLLM provider.
|
||||
@@ -70,9 +70,9 @@ def set_new_search_settings(
|
||||
)
|
||||
|
||||
if cloud_provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"No embedding provider exists for cloud embedding type {search_settings_new.provider_type}",
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"No embedding provider exists for cloud embedding type {search_settings_new.provider_type}",
|
||||
)
|
||||
|
||||
validate_contextual_rag_model(
|
||||
@@ -188,7 +188,7 @@ def delete_search_settings_endpoint(
|
||||
search_settings_id=deletion_request.search_settings_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/get-current-search-settings")
|
||||
@@ -238,9 +238,9 @@ def update_saved_search_settings(
|
||||
) -> None:
|
||||
# Disallow contextual RAG for cloud deployments
|
||||
if MULTI_TENANT and search_settings.enable_contextual_rag:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Contextual RAG disabled in Onyx Cloud",
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Contextual RAG disabled in Onyx Cloud",
|
||||
)
|
||||
|
||||
validate_contextual_rag_model(
|
||||
@@ -294,7 +294,7 @@ def validate_contextual_rag_model(
|
||||
model_name=model_name,
|
||||
db_session=db_session,
|
||||
):
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, error_msg)
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=error_msg)
|
||||
|
||||
|
||||
def _validate_contextual_rag_model(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
@@ -20,8 +21,6 @@ from onyx.db.slack_channel_config import fetch_slack_channel_configs
|
||||
from onyx.db.slack_channel_config import insert_slack_channel_config
|
||||
from onyx.db.slack_channel_config import remove_slack_channel_config
|
||||
from onyx.db.slack_channel_config import update_slack_channel_config
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.onyxbot.slack.config import validate_channel_name
|
||||
from onyx.server.manage.models import SlackBot
|
||||
from onyx.server.manage.models import SlackBotCreationRequest
|
||||
@@ -64,7 +63,10 @@ def _form_channel_config(
|
||||
current_slack_bot_id=slack_channel_config_creation_request.slack_bot_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
if respond_tag_only and respond_member_group_list:
|
||||
raise ValueError(
|
||||
@@ -121,7 +123,10 @@ def create_slack_channel_config(
|
||||
)
|
||||
|
||||
if channel_config["channel_name"] is None:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Channel name is required")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Channel name is required",
|
||||
)
|
||||
|
||||
persona_id = None
|
||||
if slack_channel_config_creation_request.persona_id is not None:
|
||||
@@ -166,7 +171,10 @@ def patch_slack_channel_config(
|
||||
db_session=db_session, slack_channel_config_id=slack_channel_config_id
|
||||
)
|
||||
if existing_slack_channel_config is None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Slack channel config not found")
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Slack channel config not found",
|
||||
)
|
||||
|
||||
existing_persona_id = existing_slack_channel_config.persona_id
|
||||
if existing_persona_id is not None:
|
||||
|
||||
@@ -13,6 +13,7 @@ from email_validator import validate_email
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Body
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi import Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
@@ -72,8 +73,6 @@ from onyx.db.users import get_page_of_filtered_users
|
||||
from onyx.db.users import get_total_filtered_users_count
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.db.users import validate_user_role_update
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.redis.redis_pool import get_raw_redis_client
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
@@ -125,7 +124,7 @@ def set_user_role(
|
||||
email=user_role_update_request.user_email, db_session=db_session
|
||||
)
|
||||
if not user_to_update:
|
||||
raise OnyxError(OnyxErrorCode.USER_NOT_FOUND, "User not found")
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
current_role = user_to_update.role
|
||||
requested_role = user_role_update_request.new_role
|
||||
@@ -140,9 +139,9 @@ def set_user_role(
|
||||
)
|
||||
|
||||
if user_to_update.id == current_user.id:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"An admin cannot demote themselves from admin role!",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="An admin cannot demote themselves from admin role!",
|
||||
)
|
||||
|
||||
if requested_role == UserRole.CURATOR:
|
||||
@@ -387,9 +386,9 @@ def bulk_invite_users(
|
||||
new_invited_emails.append(email_info.normalized)
|
||||
|
||||
except (EmailUndeliverableError, EmailNotValidError) as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Invalid email address: {email} - {str(e)}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid email address: {email} - {str(e)}",
|
||||
)
|
||||
|
||||
# Count only new users (not already invited or existing) that need seats
|
||||
@@ -406,9 +405,9 @@ def bulk_invite_users(
|
||||
if MULTI_TENANT and is_tenant_on_trial_fn(tenant_id):
|
||||
current_invited = len(already_invited)
|
||||
if current_invited + len(emails_needing_seats) > NUM_FREE_TRIAL_USER_INVITES:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.UNAUTHORIZED,
|
||||
"You have hit your invite limit. "
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="You have hit your invite limit. "
|
||||
"Please upgrade for unlimited invites.",
|
||||
)
|
||||
|
||||
@@ -503,16 +502,14 @@ def deactivate_user_api(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
if current_user.email == user_email.user_email:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR, "You cannot deactivate yourself"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="You cannot deactivate yourself")
|
||||
|
||||
user_to_deactivate = get_user_by_email(
|
||||
email=user_email.user_email, db_session=db_session
|
||||
)
|
||||
|
||||
if not user_to_deactivate:
|
||||
raise OnyxError(OnyxErrorCode.USER_NOT_FOUND, "User not found")
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
if user_to_deactivate.is_active is False:
|
||||
logger.warning("{} is already deactivated".format(user_to_deactivate.email))
|
||||
@@ -537,15 +534,14 @@ async def delete_user(
|
||||
email=user_email.user_email, db_session=db_session
|
||||
)
|
||||
if not user_to_delete:
|
||||
raise OnyxError(OnyxErrorCode.USER_NOT_FOUND, "User not found")
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
if user_to_delete.is_active is True:
|
||||
logger.warning(
|
||||
"{} must be deactivated before deleting".format(user_to_delete.email)
|
||||
)
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"User must be deactivated before deleting",
|
||||
raise HTTPException(
|
||||
status_code=400, detail="User must be deactivated before deleting"
|
||||
)
|
||||
|
||||
# Detach the user from the current session
|
||||
@@ -569,7 +565,7 @@ async def delete_user(
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
logger.error(f"Error deleting user {user_to_delete.email}: {str(e)}")
|
||||
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, "Error deleting user")
|
||||
raise HTTPException(status_code=500, detail="Error deleting user")
|
||||
|
||||
|
||||
@router.patch("/manage/admin/activate-user", tags=PUBLIC_API_TAGS)
|
||||
@@ -582,7 +578,7 @@ def activate_user_api(
|
||||
email=user_email.user_email, db_session=db_session
|
||||
)
|
||||
if not user_to_activate:
|
||||
raise OnyxError(OnyxErrorCode.USER_NOT_FOUND, "User not found")
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
if user_to_activate.is_active is True:
|
||||
logger.warning("{} is already activated".format(user_to_activate.email))
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
from onyx.configs.constants import SLACK_USER_TOKEN_PREFIX
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
SLACK_API_URL = "https://slack.com/api/auth.test"
|
||||
SLACK_CONNECTIONS_OPEN_URL = "https://slack.com/api/apps.connections.open"
|
||||
@@ -13,15 +12,15 @@ def validate_bot_token(bot_token: str) -> bool:
|
||||
response = requests.post(SLACK_API_URL, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR, "Error communicating with Slack API."
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Error communicating with Slack API."
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
if not data.get("ok", False):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Invalid bot token: {data.get('error', 'Unknown error')}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid bot token: {data.get('error', 'Unknown error')}",
|
||||
)
|
||||
|
||||
return True
|
||||
@@ -32,15 +31,15 @@ def validate_app_token(app_token: str) -> bool:
|
||||
response = requests.post(SLACK_CONNECTIONS_OPEN_URL, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR, "Error communicating with Slack API."
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Error communicating with Slack API."
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
if not data.get("ok", False):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Invalid app token: {data.get('error', 'Unknown error')}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid app token: {data.get('error', 'Unknown error')}",
|
||||
)
|
||||
|
||||
return True
|
||||
@@ -55,16 +54,16 @@ def validate_user_token(user_token: str | None) -> None:
|
||||
Returns:
|
||||
None is valid and will return successfully.
|
||||
Raises:
|
||||
OnyxError: If the token is invalid or missing required fields
|
||||
HTTPException: If the token is invalid or missing required fields
|
||||
"""
|
||||
if not user_token:
|
||||
# user_token is optional, so None or empty string is valid
|
||||
return
|
||||
|
||||
if not user_token.startswith(SLACK_USER_TOKEN_PREFIX):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Invalid user token format. User OAuth tokens must start with '{SLACK_USER_TOKEN_PREFIX}'",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid user token format. User OAuth tokens must start with '{SLACK_USER_TOKEN_PREFIX}'",
|
||||
)
|
||||
|
||||
# Test the token with Slack API to ensure it's valid
|
||||
@@ -72,13 +71,13 @@ def validate_user_token(user_token: str | None) -> None:
|
||||
response = requests.post(SLACK_API_URL, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR, "Error communicating with Slack API."
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Error communicating with Slack API."
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
if not data.get("ok", False):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Invalid user token: {data.get('error', 'Unknown error')}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid user token: {data.get('error', 'Unknown error')}",
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Response
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -25,8 +26,6 @@ from onyx.db.web_search import set_active_web_content_provider
|
||||
from onyx.db.web_search import set_active_web_search_provider
|
||||
from onyx.db.web_search import upsert_web_content_provider
|
||||
from onyx.db.web_search import upsert_web_search_provider
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.manage.web_search.models import WebContentProviderTestRequest
|
||||
from onyx.server.manage.web_search.models import WebContentProviderUpsertRequest
|
||||
from onyx.server.manage.web_search.models import WebContentProviderView
|
||||
@@ -87,9 +86,9 @@ def upsert_search_provider_endpoint(
|
||||
and request.id is not None
|
||||
and existing_by_name.id != request.id
|
||||
):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"A search provider named '{request.name}' already exists.",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"A search provider named '{request.name}' already exists.",
|
||||
)
|
||||
|
||||
provider = upsert_web_search_provider(
|
||||
@@ -194,16 +193,16 @@ def test_search_provider(
|
||||
request.provider_type, db_session
|
||||
)
|
||||
if existing_provider is None or not existing_provider.api_key:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No stored API key found for this provider type.",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No stored API key found for this provider type.",
|
||||
)
|
||||
api_key = existing_provider.api_key.get_value(apply_mask=False)
|
||||
|
||||
if requires_key and not api_key:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"API key is required. Either provide api_key or set use_stored_key to true.",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="API key is required. Either provide api_key or set use_stored_key to true.",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -213,21 +212,20 @@ def test_search_provider(
|
||||
config=request.config or {},
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(exc)) from exc
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
if provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Unable to build provider configuration.",
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Unable to build provider configuration."
|
||||
)
|
||||
|
||||
# Run the API client's test_connection method to ensure the connection is valid.
|
||||
try:
|
||||
return provider.test_connection()
|
||||
except OnyxError:
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e)) from e
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
|
||||
|
||||
@admin_router.get("/content-providers", response_model=list[WebContentProviderView])
|
||||
@@ -261,9 +259,9 @@ def upsert_content_provider_endpoint(
|
||||
and request.id is not None
|
||||
and existing_by_name.id != request.id
|
||||
):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"A content provider named '{request.name}' already exists.",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"A content provider named '{request.name}' already exists.",
|
||||
)
|
||||
|
||||
provider = upsert_web_content_provider(
|
||||
@@ -381,9 +379,9 @@ def test_content_provider(
|
||||
request.provider_type, db_session
|
||||
)
|
||||
if existing_provider is None or not existing_provider.api_key:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No stored API key found for this provider type.",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No stored API key found for this provider type.",
|
||||
)
|
||||
if MULTI_TENANT:
|
||||
stored_base_url = (
|
||||
@@ -391,17 +389,17 @@ def test_content_provider(
|
||||
)
|
||||
request_base_url = request.config.base_url
|
||||
if request_base_url != stored_base_url:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Base URL cannot differ from stored provider when using stored API key",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Base URL cannot differ from stored provider when using stored API key",
|
||||
)
|
||||
|
||||
api_key = existing_provider.api_key.get_value(apply_mask=False)
|
||||
|
||||
if not api_key:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"API key is required. Either provide api_key or set use_stored_key to true.",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="API key is required. Either provide api_key or set use_stored_key to true.",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -411,12 +409,11 @@ def test_content_provider(
|
||||
config=request.config,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(exc)) from exc
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
if provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Unable to build provider configuration.",
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Unable to build provider configuration."
|
||||
)
|
||||
|
||||
# Actually test the API key by making a real content fetch call
|
||||
@@ -428,11 +425,11 @@ def test_content_provider(
|
||||
if not test_results or not any(
|
||||
result.scrape_successful for result in test_results
|
||||
):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"API key validation failed: content fetch returned no results.",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="API key validation failed: content fetch returned no results.",
|
||||
)
|
||||
except OnyxError:
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
@@ -441,13 +438,13 @@ def test_content_provider(
|
||||
or "key" in error_msg.lower()
|
||||
or "auth" in error_msg.lower()
|
||||
):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Invalid API key: {error_msg}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid API key: {error_msg}",
|
||||
) from e
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"API key validation failed: {error_msg}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"API key validation failed: {error_msg}",
|
||||
) from e
|
||||
|
||||
logger.info(
|
||||
|
||||
@@ -111,19 +111,26 @@ def _normalize_text_with_mapping(text: str) -> tuple[str, list[int]]:
|
||||
# Step 1: NFC normalization with position mapping
|
||||
nfc_text = unicodedata.normalize("NFC", text)
|
||||
|
||||
# Build mapping from NFC positions to original start positions
|
||||
# Map NFD positions → original positions.
|
||||
# NFD only decomposes, so each original char produces 1+ NFD chars.
|
||||
nfd_to_orig: list[int] = []
|
||||
for orig_idx, orig_char in enumerate(original_text):
|
||||
nfd_of_char = unicodedata.normalize("NFD", orig_char)
|
||||
for _ in nfd_of_char:
|
||||
nfd_to_orig.append(orig_idx)
|
||||
|
||||
# Map NFC positions → NFD positions.
|
||||
# Each NFC char, when decomposed, tells us exactly how many NFD
|
||||
# chars it was composed from.
|
||||
nfc_to_orig: list[int] = []
|
||||
orig_idx = 0
|
||||
nfd_idx = 0
|
||||
for nfc_char in nfc_text:
|
||||
nfc_to_orig.append(orig_idx)
|
||||
# Find how many original chars contributed to this NFC char
|
||||
for length in range(1, len(original_text) - orig_idx + 1):
|
||||
substr = original_text[orig_idx : orig_idx + length]
|
||||
if unicodedata.normalize("NFC", substr) == nfc_char:
|
||||
orig_idx += length
|
||||
break
|
||||
if nfd_idx < len(nfd_to_orig):
|
||||
nfc_to_orig.append(nfd_to_orig[nfd_idx])
|
||||
else:
|
||||
orig_idx += 1 # Fallback
|
||||
nfc_to_orig.append(len(original_text) - 1)
|
||||
nfd_of_nfc = unicodedata.normalize("NFD", nfc_char)
|
||||
nfd_idx += len(nfd_of_nfc)
|
||||
|
||||
# Work with NFC text from here
|
||||
text = nfc_text
|
||||
|
||||
@@ -11,16 +11,20 @@ logger = setup_logger()
|
||||
|
||||
|
||||
# IMPORTANT DO NOT DELETE, THIS IS USED BY fetch_versioned_implementation
|
||||
def _encrypt_string(input_str: str) -> bytes:
|
||||
def _encrypt_string(input_str: str, key: str | None = None) -> bytes: # noqa: ARG001
|
||||
if ENCRYPTION_KEY_SECRET:
|
||||
logger.warning("MIT version of Onyx does not support encryption of secrets.")
|
||||
elif key is not None:
|
||||
logger.debug("MIT encrypt called with explicit key — key ignored.")
|
||||
return input_str.encode()
|
||||
|
||||
|
||||
# IMPORTANT DO NOT DELETE, THIS IS USED BY fetch_versioned_implementation
|
||||
def _decrypt_bytes(input_bytes: bytes) -> str:
|
||||
# No need to double warn. If you wish to learn more about encryption features
|
||||
# refer to the Onyx EE code
|
||||
def _decrypt_bytes(input_bytes: bytes, key: str | None = None) -> str: # noqa: ARG001
|
||||
if ENCRYPTION_KEY_SECRET:
|
||||
logger.warning("MIT version of Onyx does not support decryption of secrets.")
|
||||
elif key is not None:
|
||||
logger.debug("MIT decrypt called with explicit key — key ignored.")
|
||||
return input_bytes.decode()
|
||||
|
||||
|
||||
@@ -86,15 +90,15 @@ def _mask_list(items: list[Any]) -> list[Any]:
|
||||
return masked
|
||||
|
||||
|
||||
def encrypt_string_to_bytes(intput_str: str) -> bytes:
|
||||
def encrypt_string_to_bytes(intput_str: str, key: str | None = None) -> bytes:
|
||||
versioned_encryption_fn = fetch_versioned_implementation(
|
||||
"onyx.utils.encryption", "_encrypt_string"
|
||||
)
|
||||
return versioned_encryption_fn(intput_str)
|
||||
return versioned_encryption_fn(intput_str, key=key)
|
||||
|
||||
|
||||
def decrypt_bytes_to_string(intput_bytes: bytes) -> str:
|
||||
def decrypt_bytes_to_string(intput_bytes: bytes, key: str | None = None) -> str:
|
||||
versioned_decryption_fn = fetch_versioned_implementation(
|
||||
"onyx.utils.encryption", "_decrypt_bytes"
|
||||
)
|
||||
return versioned_decryption_fn(intput_bytes)
|
||||
return versioned_decryption_fn(intput_bytes, key=key)
|
||||
|
||||
17
backend/onyx/utils/jsonriver/__init__.py
Normal file
17
backend/onyx/utils/jsonriver/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
jsonriver - A streaming JSON parser for Python
|
||||
|
||||
Parse JSON incrementally as it streams in, e.g. from a network request or a language model.
|
||||
Gives you a sequence of increasingly complete values.
|
||||
|
||||
Copyright (c) 2023 Google LLC (original TypeScript implementation)
|
||||
Copyright (c) 2024 jsonriver-python contributors (Python port)
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
"""
|
||||
|
||||
from .parse import _Parser as Parser
|
||||
from .parse import JsonObject
|
||||
from .parse import JsonValue
|
||||
|
||||
__all__ = ["Parser", "JsonValue", "JsonObject"]
|
||||
__version__ = "0.0.1"
|
||||
427
backend/onyx/utils/jsonriver/parse.py
Normal file
427
backend/onyx/utils/jsonriver/parse.py
Normal file
@@ -0,0 +1,427 @@
|
||||
"""
|
||||
JSON parser for streaming incremental parsing
|
||||
|
||||
Copyright (c) 2023 Google LLC (original TypeScript implementation)
|
||||
Copyright (c) 2024 jsonriver-python contributors (Python port)
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from enum import IntEnum
|
||||
from typing import cast
|
||||
from typing import Union
|
||||
|
||||
from .tokenize import _Input
|
||||
from .tokenize import json_token_type_to_string
|
||||
from .tokenize import JsonTokenType
|
||||
from .tokenize import Tokenizer
|
||||
|
||||
|
||||
# Type definitions for JSON values
|
||||
JsonValue = Union[None, bool, float, str, list["JsonValue"], dict[str, "JsonValue"]]
|
||||
JsonObject = dict[str, JsonValue]
|
||||
|
||||
|
||||
class _StateEnum(IntEnum):
|
||||
"""Parser state machine states"""
|
||||
|
||||
Initial = 0
|
||||
InString = 1
|
||||
InArray = 2
|
||||
InObjectExpectingKey = 3
|
||||
InObjectExpectingValue = 4
|
||||
|
||||
|
||||
class _State:
|
||||
"""Base class for parser states"""
|
||||
|
||||
type: _StateEnum
|
||||
value: JsonValue | tuple[str, JsonObject] | None
|
||||
|
||||
|
||||
class _InitialState(_State):
|
||||
"""Initial state before any parsing"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.type = _StateEnum.Initial
|
||||
self.value = None
|
||||
|
||||
|
||||
class _InStringState(_State):
|
||||
"""State while parsing a string"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.type = _StateEnum.InString
|
||||
self.value = ""
|
||||
|
||||
|
||||
class _InArrayState(_State):
|
||||
"""State while parsing an array"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.type = _StateEnum.InArray
|
||||
self.value: list[JsonValue] = []
|
||||
|
||||
|
||||
class _InObjectExpectingKeyState(_State):
|
||||
"""State while parsing an object, expecting a key"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.type = _StateEnum.InObjectExpectingKey
|
||||
self.value: JsonObject = {}
|
||||
|
||||
|
||||
class _InObjectExpectingValueState(_State):
|
||||
"""State while parsing an object, expecting a value"""
|
||||
|
||||
def __init__(self, key: str, obj: JsonObject) -> None:
|
||||
self.type = _StateEnum.InObjectExpectingValue
|
||||
self.value = (key, obj)
|
||||
|
||||
|
||||
# Sentinel value to distinguish "not set" from "set to None/null"
|
||||
class _Unset:
|
||||
pass
|
||||
|
||||
|
||||
_UNSET = _Unset()
|
||||
|
||||
|
||||
class _Parser:
|
||||
"""
|
||||
Incremental JSON parser
|
||||
|
||||
Feed chunks of JSON text via feed() and get back progressively
|
||||
more complete JSON values.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._state_stack: list[_State] = [_InitialState()]
|
||||
self._toplevel_value: JsonValue | _Unset = _UNSET
|
||||
self._input = _Input()
|
||||
self.tokenizer = Tokenizer(self._input, self)
|
||||
self._finished = False
|
||||
self._progressed = False
|
||||
self._prev_snapshot: JsonValue | _Unset = _UNSET
|
||||
|
||||
def feed(self, chunk: str) -> list[JsonValue]:
|
||||
"""
|
||||
Feed a chunk of JSON text and return deltas from the previous state.
|
||||
|
||||
Each element in the returned list represents what changed since the
|
||||
last yielded value. For dicts, only changed/new keys are included,
|
||||
with string values containing only the newly appended characters.
|
||||
"""
|
||||
if self._finished:
|
||||
return []
|
||||
|
||||
self._input.feed(chunk)
|
||||
return self._collect_deltas()
|
||||
|
||||
@staticmethod
|
||||
def _compute_delta(prev: JsonValue | None, current: JsonValue) -> JsonValue | None:
|
||||
if prev is None:
|
||||
return current
|
||||
|
||||
if isinstance(current, dict) and isinstance(prev, dict):
|
||||
result: JsonObject = {}
|
||||
for key in current:
|
||||
cur_val = current[key]
|
||||
prev_val = prev.get(key)
|
||||
if key not in prev:
|
||||
result[key] = cur_val
|
||||
elif isinstance(cur_val, str) and isinstance(prev_val, str):
|
||||
if cur_val != prev_val:
|
||||
result[key] = cur_val[len(prev_val) :]
|
||||
elif isinstance(cur_val, list) and isinstance(prev_val, list):
|
||||
if cur_val != prev_val:
|
||||
new_items = cur_val[len(prev_val) :]
|
||||
# check if the last existing element was updated
|
||||
if (
|
||||
prev_val
|
||||
and len(cur_val) >= len(prev_val)
|
||||
and cur_val[len(prev_val) - 1] != prev_val[-1]
|
||||
):
|
||||
result[key] = [cur_val[len(prev_val) - 1]] + new_items
|
||||
elif new_items:
|
||||
result[key] = new_items
|
||||
elif cur_val != prev_val:
|
||||
result[key] = cur_val
|
||||
return result if result else None
|
||||
|
||||
if isinstance(current, str) and isinstance(prev, str):
|
||||
delta = current[len(prev) :]
|
||||
return delta if delta else None
|
||||
|
||||
if isinstance(current, list) and isinstance(prev, list):
|
||||
if current != prev:
|
||||
new_items = current[len(prev) :]
|
||||
if (
|
||||
prev
|
||||
and len(current) >= len(prev)
|
||||
and current[len(prev) - 1] != prev[-1]
|
||||
):
|
||||
return [current[len(prev) - 1]] + new_items
|
||||
return new_items if new_items else None
|
||||
return None
|
||||
|
||||
if current != prev:
|
||||
return current
|
||||
return None
|
||||
|
||||
def finish(self) -> list[JsonValue]:
|
||||
"""Signal that no more chunks will be fed. Validates trailing content.
|
||||
|
||||
Returns any final deltas produced by flushing pending tokens (e.g.
|
||||
numbers, which have no terminator and wait for more input).
|
||||
"""
|
||||
self._input.mark_complete()
|
||||
# Pump once more so the tokenizer can emit tokens that were waiting
|
||||
# for more input (e.g. numbers need buffer_complete to finalize).
|
||||
results = self._collect_deltas()
|
||||
self._input.expect_end_of_content()
|
||||
return results
|
||||
|
||||
def _collect_deltas(self) -> list[JsonValue]:
|
||||
"""Run one pump cycle and return any deltas produced."""
|
||||
results: list[JsonValue] = []
|
||||
while True:
|
||||
self._progressed = False
|
||||
self.tokenizer.pump()
|
||||
|
||||
if self._progressed:
|
||||
if self._toplevel_value is _UNSET:
|
||||
raise RuntimeError(
|
||||
"Internal error: toplevel_value should not be unset "
|
||||
"after progressing"
|
||||
)
|
||||
current = copy.deepcopy(cast(JsonValue, self._toplevel_value))
|
||||
if isinstance(self._prev_snapshot, _Unset):
|
||||
results.append(current)
|
||||
else:
|
||||
delta = self._compute_delta(self._prev_snapshot, current)
|
||||
if delta is not None:
|
||||
results.append(delta)
|
||||
self._prev_snapshot = current
|
||||
else:
|
||||
if not self._state_stack:
|
||||
self._finished = True
|
||||
break
|
||||
return results
|
||||
|
||||
# TokenHandler protocol implementation
|
||||
|
||||
def handle_null(self) -> None:
|
||||
"""Handle null token"""
|
||||
self._handle_value_token(JsonTokenType.Null, None)
|
||||
|
||||
def handle_boolean(self, value: bool) -> None:
|
||||
"""Handle boolean token"""
|
||||
self._handle_value_token(JsonTokenType.Boolean, value)
|
||||
|
||||
def handle_number(self, value: float) -> None:
|
||||
"""Handle number token"""
|
||||
self._handle_value_token(JsonTokenType.Number, value)
|
||||
|
||||
def handle_string_start(self) -> None:
|
||||
"""Handle string start token"""
|
||||
state = self._current_state()
|
||||
if not self._progressed and state.type != _StateEnum.InObjectExpectingKey:
|
||||
self._progressed = True
|
||||
|
||||
if state.type == _StateEnum.Initial:
|
||||
self._state_stack.pop()
|
||||
self._toplevel_value = self._progress_value(JsonTokenType.StringStart, None)
|
||||
|
||||
elif state.type == _StateEnum.InArray:
|
||||
v = self._progress_value(JsonTokenType.StringStart, None)
|
||||
arr = cast(list[JsonValue], state.value)
|
||||
arr.append(v)
|
||||
|
||||
elif state.type == _StateEnum.InObjectExpectingKey:
|
||||
self._state_stack.append(_InStringState())
|
||||
|
||||
elif state.type == _StateEnum.InObjectExpectingValue:
|
||||
key, obj = cast(tuple[str, JsonObject], state.value)
|
||||
sv = self._progress_value(JsonTokenType.StringStart, None)
|
||||
obj[key] = sv
|
||||
|
||||
elif state.type == _StateEnum.InString:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(JsonTokenType.StringStart)} "
|
||||
f"token in the middle of string"
|
||||
)
|
||||
|
||||
def handle_string_middle(self, value: str) -> None:
|
||||
"""Handle string middle token"""
|
||||
state = self._current_state()
|
||||
|
||||
if not self._progressed:
|
||||
if len(self._state_stack) >= 2:
|
||||
prev = self._state_stack[-2]
|
||||
if prev.type != _StateEnum.InObjectExpectingKey:
|
||||
self._progressed = True
|
||||
else:
|
||||
self._progressed = True
|
||||
|
||||
if state.type != _StateEnum.InString:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(JsonTokenType.StringMiddle)} "
|
||||
f"token when not in string"
|
||||
)
|
||||
|
||||
assert isinstance(state.value, str)
|
||||
state.value += value
|
||||
|
||||
parent_state = self._state_stack[-2] if len(self._state_stack) >= 2 else None
|
||||
self._update_string_parent(state.value, parent_state)
|
||||
|
||||
def handle_string_end(self) -> None:
|
||||
"""Handle string end token"""
|
||||
state = self._current_state()
|
||||
|
||||
if state.type != _StateEnum.InString:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(JsonTokenType.StringEnd)} "
|
||||
f"token when not in string"
|
||||
)
|
||||
|
||||
self._state_stack.pop()
|
||||
parent_state = self._state_stack[-1] if self._state_stack else None
|
||||
assert isinstance(state.value, str)
|
||||
self._update_string_parent(state.value, parent_state)
|
||||
|
||||
def handle_array_start(self) -> None:
|
||||
"""Handle array start token"""
|
||||
self._handle_value_token(JsonTokenType.ArrayStart, None)
|
||||
|
||||
def handle_array_end(self) -> None:
|
||||
"""Handle array end token"""
|
||||
state = self._current_state()
|
||||
if state.type != _StateEnum.InArray:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(JsonTokenType.ArrayEnd)} token"
|
||||
)
|
||||
self._state_stack.pop()
|
||||
|
||||
def handle_object_start(self) -> None:
|
||||
"""Handle object start token"""
|
||||
self._handle_value_token(JsonTokenType.ObjectStart, None)
|
||||
|
||||
def handle_object_end(self) -> None:
|
||||
"""Handle object end token"""
|
||||
state = self._current_state()
|
||||
|
||||
if state.type in (
|
||||
_StateEnum.InObjectExpectingKey,
|
||||
_StateEnum.InObjectExpectingValue,
|
||||
):
|
||||
self._state_stack.pop()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(JsonTokenType.ObjectEnd)} token"
|
||||
)
|
||||
|
||||
# Private helper methods
|
||||
|
||||
def _current_state(self) -> _State:
|
||||
"""Get current parser state"""
|
||||
if not self._state_stack:
|
||||
raise ValueError("Unexpected trailing input")
|
||||
return self._state_stack[-1]
|
||||
|
||||
def _handle_value_token(self, token_type: JsonTokenType, value: JsonValue) -> None:
|
||||
"""Handle a complete value token"""
|
||||
state = self._current_state()
|
||||
|
||||
if not self._progressed:
|
||||
self._progressed = True
|
||||
|
||||
if state.type == _StateEnum.Initial:
|
||||
self._state_stack.pop()
|
||||
self._toplevel_value = self._progress_value(token_type, value)
|
||||
|
||||
elif state.type == _StateEnum.InArray:
|
||||
v = self._progress_value(token_type, value)
|
||||
arr = cast(list[JsonValue], state.value)
|
||||
arr.append(v)
|
||||
|
||||
elif state.type == _StateEnum.InObjectExpectingValue:
|
||||
key, obj = cast(tuple[str, JsonObject], state.value)
|
||||
if token_type != JsonTokenType.StringStart:
|
||||
self._state_stack.pop()
|
||||
new_state = _InObjectExpectingKeyState()
|
||||
new_state.value = obj
|
||||
self._state_stack.append(new_state)
|
||||
|
||||
v = self._progress_value(token_type, value)
|
||||
obj[key] = v
|
||||
|
||||
elif state.type == _StateEnum.InString:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(token_type)} "
|
||||
f"token in the middle of string"
|
||||
)
|
||||
|
||||
elif state.type == _StateEnum.InObjectExpectingKey:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(token_type)} "
|
||||
f"token in the middle of object expecting key"
|
||||
)
|
||||
|
||||
def _update_string_parent(self, updated: str, parent_state: _State | None) -> None:
|
||||
"""Update parent container with updated string value"""
|
||||
if parent_state is None:
|
||||
self._toplevel_value = updated
|
||||
|
||||
elif parent_state.type == _StateEnum.InArray:
|
||||
arr = cast(list[JsonValue], parent_state.value)
|
||||
arr[-1] = updated
|
||||
|
||||
elif parent_state.type == _StateEnum.InObjectExpectingValue:
|
||||
key, obj = cast(tuple[str, JsonObject], parent_state.value)
|
||||
obj[key] = updated
|
||||
if self._state_stack and self._state_stack[-1] == parent_state:
|
||||
self._state_stack.pop()
|
||||
new_state = _InObjectExpectingKeyState()
|
||||
new_state.value = obj
|
||||
self._state_stack.append(new_state)
|
||||
|
||||
elif parent_state.type == _StateEnum.InObjectExpectingKey:
|
||||
if self._state_stack and self._state_stack[-1] == parent_state:
|
||||
self._state_stack.pop()
|
||||
obj = cast(JsonObject, parent_state.value)
|
||||
self._state_stack.append(_InObjectExpectingValueState(updated, obj))
|
||||
|
||||
def _progress_value(self, token_type: JsonTokenType, value: JsonValue) -> JsonValue:
|
||||
"""Create initial value for a token and push appropriate state"""
|
||||
if token_type == JsonTokenType.Null:
|
||||
return None
|
||||
|
||||
elif token_type == JsonTokenType.Boolean:
|
||||
return value
|
||||
|
||||
elif token_type == JsonTokenType.Number:
|
||||
return value
|
||||
|
||||
elif token_type == JsonTokenType.StringStart:
|
||||
string_state = _InStringState()
|
||||
self._state_stack.append(string_state)
|
||||
return ""
|
||||
|
||||
elif token_type == JsonTokenType.ArrayStart:
|
||||
array_state = _InArrayState()
|
||||
self._state_stack.append(array_state)
|
||||
return array_state.value
|
||||
|
||||
elif token_type == JsonTokenType.ObjectStart:
|
||||
object_state = _InObjectExpectingKeyState()
|
||||
self._state_stack.append(object_state)
|
||||
return object_state.value
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected token type: {json_token_type_to_string(token_type)}"
|
||||
)
|
||||
514
backend/onyx/utils/jsonriver/tokenize.py
Normal file
514
backend/onyx/utils/jsonriver/tokenize.py
Normal file
@@ -0,0 +1,514 @@
|
||||
"""
|
||||
JSON tokenizer for streaming incremental parsing
|
||||
|
||||
Copyright (c) 2023 Google LLC (original TypeScript implementation)
|
||||
Copyright (c) 2024 jsonriver-python contributors (Python port)
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from enum import IntEnum
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class TokenHandler(Protocol):
|
||||
"""Protocol for handling JSON tokens"""
|
||||
|
||||
def handle_null(self) -> None: ...
|
||||
def handle_boolean(self, value: bool) -> None: ...
|
||||
def handle_number(self, value: float) -> None: ...
|
||||
def handle_string_start(self) -> None: ...
|
||||
def handle_string_middle(self, value: str) -> None: ...
|
||||
def handle_string_end(self) -> None: ...
|
||||
def handle_array_start(self) -> None: ...
|
||||
def handle_array_end(self) -> None: ...
|
||||
def handle_object_start(self) -> None: ...
|
||||
def handle_object_end(self) -> None: ...
|
||||
|
||||
|
||||
class JsonTokenType(IntEnum):
|
||||
"""Types of JSON tokens"""
|
||||
|
||||
Null = 0
|
||||
Boolean = 1
|
||||
Number = 2
|
||||
StringStart = 3
|
||||
StringMiddle = 4
|
||||
StringEnd = 5
|
||||
ArrayStart = 6
|
||||
ArrayEnd = 7
|
||||
ObjectStart = 8
|
||||
ObjectEnd = 9
|
||||
|
||||
|
||||
def json_token_type_to_string(token_type: JsonTokenType) -> str:
|
||||
"""Convert token type to readable string"""
|
||||
names = {
|
||||
JsonTokenType.Null: "null",
|
||||
JsonTokenType.Boolean: "boolean",
|
||||
JsonTokenType.Number: "number",
|
||||
JsonTokenType.StringStart: "string start",
|
||||
JsonTokenType.StringMiddle: "string middle",
|
||||
JsonTokenType.StringEnd: "string end",
|
||||
JsonTokenType.ArrayStart: "array start",
|
||||
JsonTokenType.ArrayEnd: "array end",
|
||||
JsonTokenType.ObjectStart: "object start",
|
||||
JsonTokenType.ObjectEnd: "object end",
|
||||
}
|
||||
return names[token_type]
|
||||
|
||||
|
||||
class _State(IntEnum):
|
||||
"""Internal tokenizer states"""
|
||||
|
||||
ExpectingValue = 0
|
||||
InString = 1
|
||||
StartArray = 2
|
||||
AfterArrayValue = 3
|
||||
StartObject = 4
|
||||
AfterObjectKey = 5
|
||||
AfterObjectValue = 6
|
||||
BeforeObjectKey = 7
|
||||
|
||||
|
||||
# Regex for validating JSON numbers
|
||||
_JSON_NUMBER_PATTERN = re.compile(r"^-?(0|[1-9]\d*)(\.\d+)?([eE][+-]?\d+)?$")
|
||||
|
||||
|
||||
def _parse_json_number(s: str) -> float:
|
||||
"""Parse a JSON number string, validating format"""
|
||||
if not _JSON_NUMBER_PATTERN.match(s):
|
||||
raise ValueError("Invalid number")
|
||||
return float(s)
|
||||
|
||||
|
||||
class _Input:
|
||||
"""
|
||||
Input buffer for chunk-based JSON parsing
|
||||
|
||||
Manages buffering of input chunks and provides methods for
|
||||
consuming and inspecting the buffer.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._buffer = ""
|
||||
self._start_index = 0
|
||||
self.buffer_complete = False
|
||||
|
||||
def feed(self, chunk: str) -> None:
|
||||
"""Add a chunk of data to the buffer"""
|
||||
self._buffer += chunk
|
||||
|
||||
def mark_complete(self) -> None:
|
||||
"""Signal that no more chunks will be fed"""
|
||||
self.buffer_complete = True
|
||||
|
||||
@property
|
||||
def length(self) -> int:
|
||||
"""Number of characters remaining in buffer"""
|
||||
return len(self._buffer) - self._start_index
|
||||
|
||||
def advance(self, length: int) -> None:
|
||||
"""Advance the start position by length characters"""
|
||||
self._start_index += length
|
||||
|
||||
def peek(self, offset: int) -> str | None:
|
||||
"""Peek at character at offset, or None if not available"""
|
||||
idx = self._start_index + offset
|
||||
if idx < len(self._buffer):
|
||||
return self._buffer[idx]
|
||||
return None
|
||||
|
||||
def peek_char_code(self, offset: int) -> int:
|
||||
"""Get character code at offset"""
|
||||
return ord(self._buffer[self._start_index + offset])
|
||||
|
||||
def slice(self, start: int, end: int) -> str:
|
||||
"""Slice buffer from start to end (relative to current position)"""
|
||||
return self._buffer[self._start_index + start : self._start_index + end]
|
||||
|
||||
def commit(self) -> None:
|
||||
"""Commit consumed content, removing it from buffer"""
|
||||
if self._start_index > 0:
|
||||
self._buffer = self._buffer[self._start_index :]
|
||||
self._start_index = 0
|
||||
|
||||
def remaining(self) -> str:
|
||||
"""Get all remaining content in buffer"""
|
||||
return self._buffer[self._start_index :]
|
||||
|
||||
def expect_end_of_content(self) -> None:
|
||||
"""Verify no non-whitespace content remains"""
|
||||
self.commit()
|
||||
self.skip_past_whitespace()
|
||||
if self.length != 0:
|
||||
raise ValueError(f"Unexpected trailing content {self.remaining()!r}")
|
||||
|
||||
def skip_past_whitespace(self) -> None:
|
||||
"""Skip whitespace characters"""
|
||||
i = self._start_index
|
||||
while i < len(self._buffer):
|
||||
c = ord(self._buffer[i])
|
||||
if c in (32, 9, 10, 13): # space, tab, \n, \r
|
||||
i += 1
|
||||
else:
|
||||
break
|
||||
self._start_index = i
|
||||
|
||||
def try_to_take_prefix(self, prefix: str) -> bool:
|
||||
"""Try to consume prefix from buffer, return True if successful"""
|
||||
if self._buffer.startswith(prefix, self._start_index):
|
||||
self._start_index += len(prefix)
|
||||
return True
|
||||
return False
|
||||
|
||||
def try_to_take(self, length: int) -> str | None:
|
||||
"""Try to take length characters, or None if not enough available"""
|
||||
if self.length < length:
|
||||
return None
|
||||
result = self._buffer[self._start_index : self._start_index + length]
|
||||
self._start_index += length
|
||||
return result
|
||||
|
||||
def try_to_take_char_code(self) -> int | None:
|
||||
"""Try to take a single character as char code, or None if buffer empty"""
|
||||
if self.length == 0:
|
||||
return None
|
||||
code = ord(self._buffer[self._start_index])
|
||||
self._start_index += 1
|
||||
return code
|
||||
|
||||
def take_until_quote_or_backslash(self) -> tuple[str, bool]:
|
||||
"""
|
||||
Consume input up to first quote or backslash
|
||||
|
||||
Returns tuple of (consumed_content, pattern_found)
|
||||
"""
|
||||
buf = self._buffer
|
||||
i = self._start_index
|
||||
while i < len(buf):
|
||||
c = ord(buf[i])
|
||||
if c <= 0x1F:
|
||||
raise ValueError("Unescaped control character in string")
|
||||
if c == 34 or c == 92: # " or \
|
||||
result = buf[self._start_index : i]
|
||||
self._start_index = i
|
||||
return (result, True)
|
||||
i += 1
|
||||
|
||||
result = buf[self._start_index :]
|
||||
self._start_index = len(buf)
|
||||
return (result, False)
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""
|
||||
Tokenizer for chunk-based JSON parsing
|
||||
|
||||
Processes chunks fed into its input buffer and calls handler methods
|
||||
as JSON tokens are recognized.
|
||||
"""
|
||||
|
||||
def __init__(self, input: _Input, handler: TokenHandler) -> None:
|
||||
self.input = input
|
||||
self._handler = handler
|
||||
self._stack: list[_State] = [_State.ExpectingValue]
|
||||
self._emitted_tokens = 0
|
||||
|
||||
def is_done(self) -> bool:
|
||||
"""Check if tokenization is complete"""
|
||||
return len(self._stack) == 0 and self.input.length == 0
|
||||
|
||||
def pump(self) -> None:
|
||||
"""Process all available tokens in the buffer"""
|
||||
while True:
|
||||
before = self._emitted_tokens
|
||||
self._tokenize_more()
|
||||
if self._emitted_tokens == before:
|
||||
self.input.commit()
|
||||
return
|
||||
|
||||
def _tokenize_more(self) -> None:
|
||||
"""Process one step of tokenization based on current state"""
|
||||
if not self._stack:
|
||||
return
|
||||
|
||||
state = self._stack[-1]
|
||||
|
||||
if state == _State.ExpectingValue:
|
||||
self._tokenize_value()
|
||||
elif state == _State.InString:
|
||||
self._tokenize_string()
|
||||
elif state == _State.StartArray:
|
||||
self._tokenize_array_start()
|
||||
elif state == _State.AfterArrayValue:
|
||||
self._tokenize_after_array_value()
|
||||
elif state == _State.StartObject:
|
||||
self._tokenize_object_start()
|
||||
elif state == _State.AfterObjectKey:
|
||||
self._tokenize_after_object_key()
|
||||
elif state == _State.AfterObjectValue:
|
||||
self._tokenize_after_object_value()
|
||||
elif state == _State.BeforeObjectKey:
|
||||
self._tokenize_before_object_key()
|
||||
|
||||
def _tokenize_value(self) -> None:
|
||||
"""Tokenize a JSON value"""
|
||||
self.input.skip_past_whitespace()
|
||||
|
||||
if self.input.try_to_take_prefix("null"):
|
||||
self._handler.handle_null()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix("true"):
|
||||
self._handler.handle_boolean(True)
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix("false"):
|
||||
self._handler.handle_boolean(False)
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
if self.input.length > 0:
|
||||
ch = self.input.peek_char_code(0)
|
||||
if (48 <= ch <= 57) or ch == 45: # 0-9 or -
|
||||
# Scan for end of number
|
||||
i = 0
|
||||
while i < self.input.length:
|
||||
c = self.input.peek_char_code(i)
|
||||
if (48 <= c <= 57) or c in (45, 43, 46, 101, 69): # 0-9 - + . e E
|
||||
i += 1
|
||||
else:
|
||||
break
|
||||
|
||||
if i == self.input.length and not self.input.buffer_complete:
|
||||
# Need more input (numbers have no terminator)
|
||||
return
|
||||
|
||||
number_chars = self.input.slice(0, i)
|
||||
self.input.advance(i)
|
||||
number = _parse_json_number(number_chars)
|
||||
self._handler.handle_number(number)
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix('"'):
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.InString)
|
||||
self._handler.handle_string_start()
|
||||
self._emitted_tokens += 1
|
||||
self._tokenize_string()
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix("["):
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.StartArray)
|
||||
self._handler.handle_array_start()
|
||||
self._emitted_tokens += 1
|
||||
self._tokenize_array_start()
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix("{"):
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.StartObject)
|
||||
self._handler.handle_object_start()
|
||||
self._emitted_tokens += 1
|
||||
self._tokenize_object_start()
|
||||
return
|
||||
|
||||
def _tokenize_string(self) -> None:
|
||||
"""Tokenize string content"""
|
||||
while True:
|
||||
chunk, interrupted = self.input.take_until_quote_or_backslash()
|
||||
if chunk:
|
||||
self._handler.handle_string_middle(chunk)
|
||||
self._emitted_tokens += 1
|
||||
elif not interrupted:
|
||||
return
|
||||
|
||||
if interrupted:
|
||||
if self.input.length == 0:
|
||||
return
|
||||
|
||||
next_char = self.input.peek(0)
|
||||
if next_char == '"':
|
||||
self.input.advance(1)
|
||||
self._handler.handle_string_end()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
# Handle escape sequences
|
||||
next_char2 = self.input.peek(1)
|
||||
if next_char2 is None:
|
||||
return
|
||||
|
||||
value: str
|
||||
if next_char2 == "u":
|
||||
# Unicode escape: need 4 hex digits
|
||||
if self.input.length < 6:
|
||||
return
|
||||
|
||||
code = 0
|
||||
for j in range(2, 6):
|
||||
c = self.input.peek_char_code(j)
|
||||
if 48 <= c <= 57: # 0-9
|
||||
digit = c - 48
|
||||
elif 65 <= c <= 70: # A-F
|
||||
digit = c - 55
|
||||
elif 97 <= c <= 102: # a-f
|
||||
digit = c - 87
|
||||
else:
|
||||
raise ValueError("Bad Unicode escape in JSON")
|
||||
code = (code << 4) | digit
|
||||
|
||||
self.input.advance(6)
|
||||
self._handler.handle_string_middle(chr(code))
|
||||
self._emitted_tokens += 1
|
||||
continue
|
||||
|
||||
elif next_char2 == "n":
|
||||
value = "\n"
|
||||
elif next_char2 == "r":
|
||||
value = "\r"
|
||||
elif next_char2 == "t":
|
||||
value = "\t"
|
||||
elif next_char2 == "b":
|
||||
value = "\b"
|
||||
elif next_char2 == "f":
|
||||
value = "\f"
|
||||
elif next_char2 == "\\":
|
||||
value = "\\"
|
||||
elif next_char2 == "/":
|
||||
value = "/"
|
||||
elif next_char2 == '"':
|
||||
value = '"'
|
||||
else:
|
||||
raise ValueError("Bad escape in string")
|
||||
|
||||
self.input.advance(2)
|
||||
self._handler.handle_string_middle(value)
|
||||
self._emitted_tokens += 1
|
||||
|
||||
def _tokenize_array_start(self) -> None:
|
||||
"""Tokenize start of array (check for empty or first element)"""
|
||||
self.input.skip_past_whitespace()
|
||||
if self.input.length == 0:
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix("]"):
|
||||
self._handler.handle_array_end()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.AfterArrayValue)
|
||||
self._stack.append(_State.ExpectingValue)
|
||||
self._tokenize_value()
|
||||
|
||||
def _tokenize_after_array_value(self) -> None:
|
||||
"""Tokenize after an array value (expect , or ])"""
|
||||
self.input.skip_past_whitespace()
|
||||
next_char = self.input.try_to_take_char_code()
|
||||
|
||||
if next_char is None:
|
||||
return
|
||||
elif next_char == 0x5D: # ]
|
||||
self._handler.handle_array_end()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
elif next_char == 0x2C: # ,
|
||||
self._stack.append(_State.ExpectingValue)
|
||||
self._tokenize_value()
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Expected , or ], got {chr(next_char)!r}")
|
||||
|
||||
def _tokenize_object_start(self) -> None:
|
||||
"""Tokenize start of object (check for empty or first key)"""
|
||||
self.input.skip_past_whitespace()
|
||||
next_char = self.input.try_to_take_char_code()
|
||||
|
||||
if next_char is None:
|
||||
return
|
||||
elif next_char == 0x7D: # }
|
||||
self._handler.handle_object_end()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
elif next_char == 0x22: # "
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.AfterObjectKey)
|
||||
self._stack.append(_State.InString)
|
||||
self._handler.handle_string_start()
|
||||
self._emitted_tokens += 1
|
||||
self._tokenize_string()
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Expected start of object key, got {chr(next_char)!r}")
|
||||
|
||||
def _tokenize_after_object_key(self) -> None:
|
||||
"""Tokenize after object key (expect :)"""
|
||||
self.input.skip_past_whitespace()
|
||||
next_char = self.input.try_to_take_char_code()
|
||||
|
||||
if next_char is None:
|
||||
return
|
||||
elif next_char == 0x3A: # :
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.AfterObjectValue)
|
||||
self._stack.append(_State.ExpectingValue)
|
||||
self._tokenize_value()
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Expected colon after object key, got {chr(next_char)!r}")
|
||||
|
||||
def _tokenize_after_object_value(self) -> None:
|
||||
"""Tokenize after object value (expect , or })"""
|
||||
self.input.skip_past_whitespace()
|
||||
next_char = self.input.try_to_take_char_code()
|
||||
|
||||
if next_char is None:
|
||||
return
|
||||
elif next_char == 0x7D: # }
|
||||
self._handler.handle_object_end()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
elif next_char == 0x2C: # ,
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.BeforeObjectKey)
|
||||
self._tokenize_before_object_key()
|
||||
return
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected , or }} after object value, got {chr(next_char)!r}"
|
||||
)
|
||||
|
||||
def _tokenize_before_object_key(self) -> None:
|
||||
"""Tokenize before object key (after comma)"""
|
||||
self.input.skip_past_whitespace()
|
||||
next_char = self.input.try_to_take_char_code()
|
||||
|
||||
if next_char is None:
|
||||
return
|
||||
elif next_char == 0x22: # "
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.AfterObjectKey)
|
||||
self._stack.append(_State.InString)
|
||||
self._handler.handle_string_start()
|
||||
self._emitted_tokens += 1
|
||||
self._tokenize_string()
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Expected start of object key, got {chr(next_char)!r}")
|
||||
@@ -1,30 +1,49 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_SURROGATE_RE = re.compile(r"[\ud800-\udfff]")
|
||||
|
||||
|
||||
def _sanitize_string(value: str) -> str:
|
||||
return value.replace("\x00", "")
|
||||
def sanitize_string(value: str) -> str:
|
||||
"""Strip characters that PostgreSQL text/JSONB columns cannot store.
|
||||
|
||||
Removes:
|
||||
- NUL bytes (\\x00)
|
||||
- UTF-16 surrogates (\\ud800-\\udfff), which are invalid in UTF-8
|
||||
"""
|
||||
sanitized = value.replace("\x00", "")
|
||||
sanitized = _SURROGATE_RE.sub("", sanitized)
|
||||
if value and not sanitized:
|
||||
logger.warning(
|
||||
"sanitize_string: all characters were removed from a non-empty string"
|
||||
)
|
||||
return sanitized
|
||||
|
||||
|
||||
def _sanitize_json_like(value: Any) -> Any:
|
||||
def sanitize_json_like(value: Any) -> Any:
|
||||
"""Recursively sanitize all strings in a JSON-like structure (dict/list/tuple)."""
|
||||
if isinstance(value, str):
|
||||
return _sanitize_string(value)
|
||||
return sanitize_string(value)
|
||||
|
||||
if isinstance(value, list):
|
||||
return [_sanitize_json_like(item) for item in value]
|
||||
return [sanitize_json_like(item) for item in value]
|
||||
|
||||
if isinstance(value, tuple):
|
||||
return tuple(_sanitize_json_like(item) for item in value)
|
||||
return tuple(sanitize_json_like(item) for item in value)
|
||||
|
||||
if isinstance(value, dict):
|
||||
sanitized: dict[Any, Any] = {}
|
||||
for key, nested_value in value.items():
|
||||
cleaned_key = _sanitize_string(key) if isinstance(key, str) else key
|
||||
sanitized[cleaned_key] = _sanitize_json_like(nested_value)
|
||||
cleaned_key = sanitize_string(key) if isinstance(key, str) else key
|
||||
sanitized[cleaned_key] = sanitize_json_like(nested_value)
|
||||
return sanitized
|
||||
|
||||
return value
|
||||
@@ -34,27 +53,27 @@ def _sanitize_expert_info(expert: BasicExpertInfo) -> BasicExpertInfo:
|
||||
return expert.model_copy(
|
||||
update={
|
||||
"display_name": (
|
||||
_sanitize_string(expert.display_name)
|
||||
sanitize_string(expert.display_name)
|
||||
if expert.display_name is not None
|
||||
else None
|
||||
),
|
||||
"first_name": (
|
||||
_sanitize_string(expert.first_name)
|
||||
sanitize_string(expert.first_name)
|
||||
if expert.first_name is not None
|
||||
else None
|
||||
),
|
||||
"middle_initial": (
|
||||
_sanitize_string(expert.middle_initial)
|
||||
sanitize_string(expert.middle_initial)
|
||||
if expert.middle_initial is not None
|
||||
else None
|
||||
),
|
||||
"last_name": (
|
||||
_sanitize_string(expert.last_name)
|
||||
sanitize_string(expert.last_name)
|
||||
if expert.last_name is not None
|
||||
else None
|
||||
),
|
||||
"email": (
|
||||
_sanitize_string(expert.email) if expert.email is not None else None
|
||||
sanitize_string(expert.email) if expert.email is not None else None
|
||||
),
|
||||
}
|
||||
)
|
||||
@@ -63,10 +82,10 @@ def _sanitize_expert_info(expert: BasicExpertInfo) -> BasicExpertInfo:
|
||||
def _sanitize_external_access(external_access: ExternalAccess) -> ExternalAccess:
|
||||
return ExternalAccess(
|
||||
external_user_emails={
|
||||
_sanitize_string(email) for email in external_access.external_user_emails
|
||||
sanitize_string(email) for email in external_access.external_user_emails
|
||||
},
|
||||
external_user_group_ids={
|
||||
_sanitize_string(group_id)
|
||||
sanitize_string(group_id)
|
||||
for group_id in external_access.external_user_group_ids
|
||||
},
|
||||
is_public=external_access.is_public,
|
||||
@@ -76,26 +95,26 @@ def _sanitize_external_access(external_access: ExternalAccess) -> ExternalAccess
|
||||
def sanitize_document_for_postgres(document: Document) -> Document:
|
||||
cleaned_doc = document.model_copy(deep=True)
|
||||
|
||||
cleaned_doc.id = _sanitize_string(cleaned_doc.id)
|
||||
cleaned_doc.semantic_identifier = _sanitize_string(cleaned_doc.semantic_identifier)
|
||||
cleaned_doc.id = sanitize_string(cleaned_doc.id)
|
||||
cleaned_doc.semantic_identifier = sanitize_string(cleaned_doc.semantic_identifier)
|
||||
if cleaned_doc.title is not None:
|
||||
cleaned_doc.title = _sanitize_string(cleaned_doc.title)
|
||||
cleaned_doc.title = sanitize_string(cleaned_doc.title)
|
||||
if cleaned_doc.parent_hierarchy_raw_node_id is not None:
|
||||
cleaned_doc.parent_hierarchy_raw_node_id = _sanitize_string(
|
||||
cleaned_doc.parent_hierarchy_raw_node_id = sanitize_string(
|
||||
cleaned_doc.parent_hierarchy_raw_node_id
|
||||
)
|
||||
|
||||
cleaned_doc.metadata = {
|
||||
_sanitize_string(key): (
|
||||
[_sanitize_string(item) for item in value]
|
||||
sanitize_string(key): (
|
||||
[sanitize_string(item) for item in value]
|
||||
if isinstance(value, list)
|
||||
else _sanitize_string(value)
|
||||
else sanitize_string(value)
|
||||
)
|
||||
for key, value in cleaned_doc.metadata.items()
|
||||
}
|
||||
|
||||
if cleaned_doc.doc_metadata is not None:
|
||||
cleaned_doc.doc_metadata = _sanitize_json_like(cleaned_doc.doc_metadata)
|
||||
cleaned_doc.doc_metadata = sanitize_json_like(cleaned_doc.doc_metadata)
|
||||
|
||||
if cleaned_doc.primary_owners is not None:
|
||||
cleaned_doc.primary_owners = [
|
||||
@@ -113,11 +132,11 @@ def sanitize_document_for_postgres(document: Document) -> Document:
|
||||
|
||||
for section in cleaned_doc.sections:
|
||||
if section.link is not None:
|
||||
section.link = _sanitize_string(section.link)
|
||||
section.link = sanitize_string(section.link)
|
||||
if section.text is not None:
|
||||
section.text = _sanitize_string(section.text)
|
||||
section.text = sanitize_string(section.text)
|
||||
if section.image_file_id is not None:
|
||||
section.image_file_id = _sanitize_string(section.image_file_id)
|
||||
section.image_file_id = sanitize_string(section.image_file_id)
|
||||
|
||||
return cleaned_doc
|
||||
|
||||
@@ -129,12 +148,12 @@ def sanitize_documents_for_postgres(documents: list[Document]) -> list[Document]
|
||||
def sanitize_hierarchy_node_for_postgres(node: HierarchyNode) -> HierarchyNode:
|
||||
cleaned_node = node.model_copy(deep=True)
|
||||
|
||||
cleaned_node.raw_node_id = _sanitize_string(cleaned_node.raw_node_id)
|
||||
cleaned_node.display_name = _sanitize_string(cleaned_node.display_name)
|
||||
cleaned_node.raw_node_id = sanitize_string(cleaned_node.raw_node_id)
|
||||
cleaned_node.display_name = sanitize_string(cleaned_node.display_name)
|
||||
if cleaned_node.raw_parent_id is not None:
|
||||
cleaned_node.raw_parent_id = _sanitize_string(cleaned_node.raw_parent_id)
|
||||
cleaned_node.raw_parent_id = sanitize_string(cleaned_node.raw_parent_id)
|
||||
if cleaned_node.link is not None:
|
||||
cleaned_node.link = _sanitize_string(cleaned_node.link)
|
||||
cleaned_node.link = sanitize_string(cleaned_node.link)
|
||||
|
||||
if cleaned_node.external_access is not None:
|
||||
cleaned_node.external_access = _sanitize_external_access(
|
||||
@@ -24,6 +24,9 @@ class OnyxVersion:
|
||||
def set_ee(self) -> None:
|
||||
self._is_ee = True
|
||||
|
||||
def unset_ee(self) -> None:
|
||||
self._is_ee = False
|
||||
|
||||
def is_ee_version(self) -> bool:
|
||||
return self._is_ee
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ attrs==25.4.0
|
||||
# jsonschema
|
||||
# referencing
|
||||
# zeep
|
||||
authlib==1.6.6
|
||||
authlib==1.6.7
|
||||
# via fastmcp
|
||||
babel==2.17.0
|
||||
# via courlan
|
||||
@@ -109,9 +109,7 @@ brotli==1.2.0
|
||||
bytecode==0.17.0
|
||||
# via ddtrace
|
||||
cachetools==6.2.2
|
||||
# via
|
||||
# google-auth
|
||||
# py-key-value-aio
|
||||
# via py-key-value-aio
|
||||
caio==0.9.25
|
||||
# via aiofile
|
||||
celery==5.5.1
|
||||
@@ -190,6 +188,7 @@ courlan==1.3.2
|
||||
cryptography==46.0.5
|
||||
# via
|
||||
# authlib
|
||||
# google-auth
|
||||
# msal
|
||||
# msoffcrypto-tool
|
||||
# pdfminer-six
|
||||
@@ -230,9 +229,7 @@ distro==1.9.0
|
||||
dnspython==2.8.0
|
||||
# via email-validator
|
||||
docstring-parser==0.17.0
|
||||
# via
|
||||
# cyclopts
|
||||
# google-cloud-aiplatform
|
||||
# via cyclopts
|
||||
docutils==0.22.3
|
||||
# via rich-rst
|
||||
dropbox==12.0.2
|
||||
@@ -297,26 +294,15 @@ gitdb==4.0.12
|
||||
gitpython==3.1.45
|
||||
# via braintrust
|
||||
google-api-core==2.28.1
|
||||
# via
|
||||
# google-api-python-client
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-core
|
||||
# google-cloud-resource-manager
|
||||
# google-cloud-storage
|
||||
# via google-api-python-client
|
||||
google-api-python-client==2.86.0
|
||||
# via onyx
|
||||
google-auth==2.43.0
|
||||
google-auth==2.48.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-api-python-client
|
||||
# google-auth-httplib2
|
||||
# google-auth-oauthlib
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-core
|
||||
# google-cloud-resource-manager
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# kubernetes
|
||||
google-auth-httplib2==0.1.0
|
||||
@@ -325,51 +311,16 @@ google-auth-httplib2==0.1.0
|
||||
# onyx
|
||||
google-auth-oauthlib==1.0.0
|
||||
# via onyx
|
||||
google-cloud-aiplatform==1.121.0
|
||||
# via onyx
|
||||
google-cloud-bigquery==3.38.0
|
||||
# via google-cloud-aiplatform
|
||||
google-cloud-core==2.5.0
|
||||
# via
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
google-cloud-resource-manager==1.15.0
|
||||
# via google-cloud-aiplatform
|
||||
google-cloud-storage==2.19.0
|
||||
# via google-cloud-aiplatform
|
||||
google-crc32c==1.7.1
|
||||
# via
|
||||
# google-cloud-storage
|
||||
# google-resumable-media
|
||||
google-genai==1.52.0
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# onyx
|
||||
google-resumable-media==2.7.2
|
||||
# via
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
# via onyx
|
||||
googleapis-common-protos==1.72.0
|
||||
# via
|
||||
# google-api-core
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# opentelemetry-exporter-otlp-proto-http
|
||||
greenlet==3.2.4
|
||||
# via
|
||||
# playwright
|
||||
# sqlalchemy
|
||||
grpc-google-iam-v1==0.14.3
|
||||
# via google-cloud-resource-manager
|
||||
grpcio==1.76.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpcio-status==1.76.0
|
||||
# via google-api-core
|
||||
h11==0.16.0
|
||||
# via
|
||||
# httpcore
|
||||
@@ -670,8 +621,6 @@ packaging==24.2
|
||||
# dask
|
||||
# distributed
|
||||
# fastmcp
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# huggingface-hub
|
||||
# jira
|
||||
# kombu
|
||||
@@ -721,19 +670,12 @@ propcache==0.4.1
|
||||
# aiohttp
|
||||
# yarl
|
||||
proto-plus==1.26.1
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
# via google-api-core
|
||||
protobuf==6.33.5
|
||||
# via
|
||||
# ddtrace
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# onnxruntime
|
||||
# opentelemetry-proto
|
||||
# proto-plus
|
||||
@@ -771,7 +713,6 @@ pydantic==2.11.7
|
||||
# exa-py
|
||||
# fastapi
|
||||
# fastmcp
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# langchain-core
|
||||
# langfuse
|
||||
@@ -835,7 +776,6 @@ python-dateutil==2.8.2
|
||||
# botocore
|
||||
# celery
|
||||
# dateparser
|
||||
# google-cloud-bigquery
|
||||
# htmldate
|
||||
# hubspot-api-client
|
||||
# kubernetes
|
||||
@@ -927,8 +867,6 @@ requests==2.32.5
|
||||
# dropbox
|
||||
# exa-py
|
||||
# google-api-core
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# hubspot-api-client
|
||||
# huggingface-hub
|
||||
@@ -1002,9 +940,7 @@ sendgrid==6.12.5
|
||||
sentry-sdk==2.14.0
|
||||
# via onyx
|
||||
shapely==2.0.6
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# onyx
|
||||
# via onyx
|
||||
shellingham==1.5.4
|
||||
# via typer
|
||||
simple-salesforce==1.12.6
|
||||
@@ -1118,9 +1054,7 @@ typing-extensions==4.15.0
|
||||
# exa-py
|
||||
# exceptiongroup
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# grpcio
|
||||
# huggingface-hub
|
||||
# jira
|
||||
# langchain-core
|
||||
|
||||
@@ -59,8 +59,6 @@ botocore==1.39.11
|
||||
# s3transfer
|
||||
brotli==1.2.0
|
||||
# via onyx
|
||||
cachetools==6.2.2
|
||||
# via google-auth
|
||||
celery-types==0.19.0
|
||||
# via onyx
|
||||
certifi==2025.11.12
|
||||
@@ -100,7 +98,9 @@ comm==0.2.3
|
||||
contourpy==1.3.3
|
||||
# via matplotlib
|
||||
cryptography==46.0.5
|
||||
# via pyjwt
|
||||
# via
|
||||
# google-auth
|
||||
# pyjwt
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
debugpy==1.8.17
|
||||
@@ -115,8 +115,6 @@ distlib==0.4.0
|
||||
# via virtualenv
|
||||
distro==1.9.0
|
||||
# via openai
|
||||
docstring-parser==0.17.0
|
||||
# via google-cloud-aiplatform
|
||||
durationpy==0.10
|
||||
# via kubernetes
|
||||
execnet==2.1.2
|
||||
@@ -145,65 +143,14 @@ frozenlist==1.8.0
|
||||
# aiosignal
|
||||
fsspec==2025.10.0
|
||||
# via huggingface-hub
|
||||
google-api-core==2.28.1
|
||||
google-auth==2.48.0
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-core
|
||||
# google-cloud-resource-manager
|
||||
# google-cloud-storage
|
||||
google-auth==2.43.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-core
|
||||
# google-cloud-resource-manager
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# kubernetes
|
||||
google-cloud-aiplatform==1.121.0
|
||||
# via onyx
|
||||
google-cloud-bigquery==3.38.0
|
||||
# via google-cloud-aiplatform
|
||||
google-cloud-core==2.5.0
|
||||
# via
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
google-cloud-resource-manager==1.15.0
|
||||
# via google-cloud-aiplatform
|
||||
google-cloud-storage==2.19.0
|
||||
# via google-cloud-aiplatform
|
||||
google-crc32c==1.7.1
|
||||
# via
|
||||
# google-cloud-storage
|
||||
# google-resumable-media
|
||||
google-genai==1.52.0
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# onyx
|
||||
google-resumable-media==2.7.2
|
||||
# via
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
googleapis-common-protos==1.72.0
|
||||
# via
|
||||
# google-api-core
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# via onyx
|
||||
greenlet==3.2.4 ; platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64'
|
||||
# via sqlalchemy
|
||||
grpc-google-iam-v1==0.14.3
|
||||
# via google-cloud-resource-manager
|
||||
grpcio==1.76.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpcio-status==1.76.0
|
||||
# via google-api-core
|
||||
h11==0.16.0
|
||||
# via
|
||||
# httpcore
|
||||
@@ -311,13 +258,12 @@ numpy==2.4.1
|
||||
# contourpy
|
||||
# matplotlib
|
||||
# pandas-stubs
|
||||
# shapely
|
||||
# voyageai
|
||||
oauthlib==3.2.2
|
||||
# via
|
||||
# kubernetes
|
||||
# requests-oauthlib
|
||||
onyx-devtools==0.6.2
|
||||
onyx-devtools==0.6.3
|
||||
# via onyx
|
||||
openai==2.14.0
|
||||
# via
|
||||
@@ -330,8 +276,6 @@ openapi-generator-cli==7.17.0
|
||||
packaging==24.2
|
||||
# via
|
||||
# black
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# hatchling
|
||||
# huggingface-hub
|
||||
# ipykernel
|
||||
@@ -374,20 +318,6 @@ propcache==0.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
proto-plus==1.26.1
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
protobuf==6.33.5
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# proto-plus
|
||||
psutil==7.1.3
|
||||
# via ipykernel
|
||||
ptyprocess==0.7.0 ; sys_platform != 'emscripten' and sys_platform != 'win32'
|
||||
@@ -409,7 +339,6 @@ pydantic==2.11.7
|
||||
# agent-client-protocol
|
||||
# cohere
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# litellm
|
||||
# mcp
|
||||
@@ -450,7 +379,6 @@ python-dateutil==2.8.2
|
||||
# via
|
||||
# aiobotocore
|
||||
# botocore
|
||||
# google-cloud-bigquery
|
||||
# jupyter-client
|
||||
# kubernetes
|
||||
# matplotlib
|
||||
@@ -485,9 +413,6 @@ reorder-python-imports-black==3.14.0
|
||||
requests==2.32.5
|
||||
# via
|
||||
# cohere
|
||||
# google-api-core
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# huggingface-hub
|
||||
# kubernetes
|
||||
@@ -510,8 +435,6 @@ s3transfer==0.13.1
|
||||
# via boto3
|
||||
sentry-sdk==2.14.0
|
||||
# via onyx
|
||||
shapely==2.0.6
|
||||
# via google-cloud-aiplatform
|
||||
six==1.17.0
|
||||
# via
|
||||
# kubernetes
|
||||
@@ -602,9 +525,7 @@ typing-extensions==4.15.0
|
||||
# celery-types
|
||||
# cohere
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# grpcio
|
||||
# huggingface-hub
|
||||
# ipython
|
||||
# mcp
|
||||
|
||||
@@ -53,8 +53,6 @@ botocore==1.39.11
|
||||
# s3transfer
|
||||
brotli==1.2.0
|
||||
# via onyx
|
||||
cachetools==6.2.2
|
||||
# via google-auth
|
||||
certifi==2025.11.12
|
||||
# via
|
||||
# httpcore
|
||||
@@ -79,15 +77,15 @@ colorama==0.4.6 ; sys_platform == 'win32'
|
||||
# click
|
||||
# tqdm
|
||||
cryptography==46.0.5
|
||||
# via pyjwt
|
||||
# via
|
||||
# google-auth
|
||||
# pyjwt
|
||||
decorator==5.2.1
|
||||
# via retry
|
||||
discord-py==2.4.0
|
||||
# via onyx
|
||||
distro==1.9.0
|
||||
# via openai
|
||||
docstring-parser==0.17.0
|
||||
# via google-cloud-aiplatform
|
||||
durationpy==0.10
|
||||
# via kubernetes
|
||||
fastapi==0.133.1
|
||||
@@ -104,63 +102,12 @@ frozenlist==1.8.0
|
||||
# aiosignal
|
||||
fsspec==2025.10.0
|
||||
# via huggingface-hub
|
||||
google-api-core==2.28.1
|
||||
google-auth==2.48.0
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-core
|
||||
# google-cloud-resource-manager
|
||||
# google-cloud-storage
|
||||
google-auth==2.43.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-core
|
||||
# google-cloud-resource-manager
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# kubernetes
|
||||
google-cloud-aiplatform==1.121.0
|
||||
# via onyx
|
||||
google-cloud-bigquery==3.38.0
|
||||
# via google-cloud-aiplatform
|
||||
google-cloud-core==2.5.0
|
||||
# via
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
google-cloud-resource-manager==1.15.0
|
||||
# via google-cloud-aiplatform
|
||||
google-cloud-storage==2.19.0
|
||||
# via google-cloud-aiplatform
|
||||
google-crc32c==1.7.1
|
||||
# via
|
||||
# google-cloud-storage
|
||||
# google-resumable-media
|
||||
google-genai==1.52.0
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# onyx
|
||||
google-resumable-media==2.7.2
|
||||
# via
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
googleapis-common-protos==1.72.0
|
||||
# via
|
||||
# google-api-core
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpc-google-iam-v1==0.14.3
|
||||
# via google-cloud-resource-manager
|
||||
grpcio==1.76.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpcio-status==1.76.0
|
||||
# via google-api-core
|
||||
# via onyx
|
||||
h11==0.16.0
|
||||
# via
|
||||
# httpcore
|
||||
@@ -221,9 +168,7 @@ multidict==6.7.0
|
||||
# aiohttp
|
||||
# yarl
|
||||
numpy==2.4.1
|
||||
# via
|
||||
# shapely
|
||||
# voyageai
|
||||
# via voyageai
|
||||
oauthlib==3.2.2
|
||||
# via
|
||||
# kubernetes
|
||||
@@ -233,10 +178,7 @@ openai==2.14.0
|
||||
# litellm
|
||||
# onyx
|
||||
packaging==24.2
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# huggingface-hub
|
||||
# via huggingface-hub
|
||||
parameterized==0.9.0
|
||||
# via cohere
|
||||
posthog==3.7.4
|
||||
@@ -251,20 +193,6 @@ propcache==0.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
proto-plus==1.26.1
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
protobuf==6.33.5
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# proto-plus
|
||||
py==1.11.0
|
||||
# via retry
|
||||
pyasn1==0.6.2
|
||||
@@ -280,7 +208,6 @@ pydantic==2.11.7
|
||||
# agent-client-protocol
|
||||
# cohere
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# litellm
|
||||
# mcp
|
||||
@@ -297,7 +224,6 @@ python-dateutil==2.8.2
|
||||
# via
|
||||
# aiobotocore
|
||||
# botocore
|
||||
# google-cloud-bigquery
|
||||
# kubernetes
|
||||
# posthog
|
||||
python-dotenv==1.1.1
|
||||
@@ -321,9 +247,6 @@ regex==2025.11.3
|
||||
requests==2.32.5
|
||||
# via
|
||||
# cohere
|
||||
# google-api-core
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# huggingface-hub
|
||||
# kubernetes
|
||||
@@ -345,8 +268,6 @@ s3transfer==0.13.1
|
||||
# via boto3
|
||||
sentry-sdk==2.14.0
|
||||
# via onyx
|
||||
shapely==2.0.6
|
||||
# via google-cloud-aiplatform
|
||||
six==1.17.0
|
||||
# via
|
||||
# kubernetes
|
||||
@@ -385,9 +306,7 @@ typing-extensions==4.15.0
|
||||
# anyio
|
||||
# cohere
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# grpcio
|
||||
# huggingface-hub
|
||||
# mcp
|
||||
# openai
|
||||
|
||||
@@ -57,8 +57,6 @@ botocore==1.39.11
|
||||
# s3transfer
|
||||
brotli==1.2.0
|
||||
# via onyx
|
||||
cachetools==6.2.2
|
||||
# via google-auth
|
||||
celery==5.5.1
|
||||
# via sentry-sdk
|
||||
certifi==2025.11.12
|
||||
@@ -95,15 +93,15 @@ colorama==0.4.6 ; sys_platform == 'win32'
|
||||
# click
|
||||
# tqdm
|
||||
cryptography==46.0.5
|
||||
# via pyjwt
|
||||
# via
|
||||
# google-auth
|
||||
# pyjwt
|
||||
decorator==5.2.1
|
||||
# via retry
|
||||
discord-py==2.4.0
|
||||
# via onyx
|
||||
distro==1.9.0
|
||||
# via openai
|
||||
docstring-parser==0.17.0
|
||||
# via google-cloud-aiplatform
|
||||
durationpy==0.10
|
||||
# via kubernetes
|
||||
einops==0.8.1
|
||||
@@ -129,63 +127,12 @@ fsspec==2025.10.0
|
||||
# via
|
||||
# huggingface-hub
|
||||
# torch
|
||||
google-api-core==2.28.1
|
||||
google-auth==2.48.0
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-core
|
||||
# google-cloud-resource-manager
|
||||
# google-cloud-storage
|
||||
google-auth==2.43.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-core
|
||||
# google-cloud-resource-manager
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# kubernetes
|
||||
google-cloud-aiplatform==1.121.0
|
||||
# via onyx
|
||||
google-cloud-bigquery==3.38.0
|
||||
# via google-cloud-aiplatform
|
||||
google-cloud-core==2.5.0
|
||||
# via
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
google-cloud-resource-manager==1.15.0
|
||||
# via google-cloud-aiplatform
|
||||
google-cloud-storage==2.19.0
|
||||
# via google-cloud-aiplatform
|
||||
google-crc32c==1.7.1
|
||||
# via
|
||||
# google-cloud-storage
|
||||
# google-resumable-media
|
||||
google-genai==1.52.0
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# onyx
|
||||
google-resumable-media==2.7.2
|
||||
# via
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
googleapis-common-protos==1.72.0
|
||||
# via
|
||||
# google-api-core
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpc-google-iam-v1==0.14.3
|
||||
# via google-cloud-resource-manager
|
||||
grpcio==1.76.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpcio-status==1.76.0
|
||||
# via google-api-core
|
||||
# via onyx
|
||||
h11==0.16.0
|
||||
# via
|
||||
# httpcore
|
||||
@@ -263,7 +210,6 @@ numpy==2.4.1
|
||||
# onyx
|
||||
# scikit-learn
|
||||
# scipy
|
||||
# shapely
|
||||
# transformers
|
||||
# voyageai
|
||||
nvidia-cublas-cu12==12.8.4.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
@@ -316,8 +262,6 @@ openai==2.14.0
|
||||
packaging==24.2
|
||||
# via
|
||||
# accelerate
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# huggingface-hub
|
||||
# kombu
|
||||
# transformers
|
||||
@@ -337,20 +281,6 @@ propcache==0.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
proto-plus==1.26.1
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
protobuf==6.33.5
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# proto-plus
|
||||
psutil==7.1.3
|
||||
# via accelerate
|
||||
py==1.11.0
|
||||
@@ -368,7 +298,6 @@ pydantic==2.11.7
|
||||
# agent-client-protocol
|
||||
# cohere
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# litellm
|
||||
# mcp
|
||||
@@ -386,7 +315,6 @@ python-dateutil==2.8.2
|
||||
# aiobotocore
|
||||
# botocore
|
||||
# celery
|
||||
# google-cloud-bigquery
|
||||
# kubernetes
|
||||
python-dotenv==1.1.1
|
||||
# via
|
||||
@@ -413,9 +341,6 @@ regex==2025.11.3
|
||||
requests==2.32.5
|
||||
# via
|
||||
# cohere
|
||||
# google-api-core
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# huggingface-hub
|
||||
# kubernetes
|
||||
@@ -452,8 +377,6 @@ sentry-sdk==2.14.0
|
||||
# via onyx
|
||||
setuptools==80.9.0 ; python_full_version >= '3.12'
|
||||
# via torch
|
||||
shapely==2.0.6
|
||||
# via google-cloud-aiplatform
|
||||
six==1.17.0
|
||||
# via
|
||||
# kubernetes
|
||||
@@ -510,9 +433,7 @@ typing-extensions==4.15.0
|
||||
# anyio
|
||||
# cohere
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# grpcio
|
||||
# huggingface-hub
|
||||
# mcp
|
||||
# openai
|
||||
|
||||
171
backend/scripts/debugging/opensearch/opensearch_debug.py
Normal file
171
backend/scripts/debugging/opensearch/opensearch_debug.py
Normal file
@@ -0,0 +1,171 @@
|
||||
#!/usr/bin/env python3
|
||||
"""A utility to interact with OpenSearch.
|
||||
|
||||
Usage:
|
||||
python3 opensearch_debug.py --help
|
||||
python3 opensearch_debug.py list
|
||||
python3 opensearch_debug.py delete <index_name>
|
||||
|
||||
Environment Variables:
|
||||
OPENSEARCH_HOST: OpenSearch host
|
||||
OPENSEARCH_REST_API_PORT: OpenSearch port
|
||||
OPENSEARCH_ADMIN_USERNAME: Admin username
|
||||
OPENSEARCH_ADMIN_PASSWORD: Admin password
|
||||
|
||||
Dependencies:
|
||||
backend/shared_configs/configs.py
|
||||
backend/onyx/document_index/opensearch/client.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.client import OpenSearchIndexClient
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
def list_indices(client: OpenSearchClient) -> None:
|
||||
indices = client.list_indices_with_info()
|
||||
print(f"Found {len(indices)} indices.")
|
||||
print("-" * 80)
|
||||
for index in sorted(indices, key=lambda x: x.name):
|
||||
print(f"Index: {index.name}")
|
||||
print(f"Health: {index.health}")
|
||||
print(f"Status: {index.status}")
|
||||
print(f"Num Primary Shards: {index.num_primary_shards}")
|
||||
print(f"Num Replica Shards: {index.num_replica_shards}")
|
||||
print(f"Docs Count: {index.docs_count}")
|
||||
print(f"Docs Deleted: {index.docs_deleted}")
|
||||
print(f"Created At: {index.created_at}")
|
||||
print(f"Total Size: {index.total_size}")
|
||||
print(f"Primary Shards Size: {index.primary_shards_size}")
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
def delete_index(client: OpenSearchIndexClient) -> None:
|
||||
if not client.index_exists():
|
||||
print(f"Index '{client._index_name}' does not exist.")
|
||||
return
|
||||
|
||||
confirm = input(f"Delete index '{client._index_name}'? (yes/no): ")
|
||||
if confirm.lower() != "yes":
|
||||
print("Aborted.")
|
||||
return
|
||||
|
||||
if client.delete_index():
|
||||
print(f"Deleted index '{client._index_name}'.")
|
||||
else:
|
||||
print(f"Failed to delete index '{client._index_name}' for an unknown reason.")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
def add_standard_arguments(parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
help="OpenSearch host. If not provided, will fall back to OPENSEARCH_HOST, then prompt "
|
||||
"for input.",
|
||||
type=str,
|
||||
default=os.environ.get("OPENSEARCH_HOST", ""),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
help="OpenSearch port. If not provided, will fall back to OPENSEARCH_REST_API_PORT, "
|
||||
"then prompt for input.",
|
||||
type=int,
|
||||
default=int(os.environ.get("OPENSEARCH_REST_API_PORT", 0)),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--username",
|
||||
help="OpenSearch username. If not provided, will fall back to OPENSEARCH_ADMIN_USERNAME, "
|
||||
"then prompt for input.",
|
||||
type=str,
|
||||
default=os.environ.get("OPENSEARCH_ADMIN_USERNAME", ""),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--password",
|
||||
help="OpenSearch password. If not provided, will fall back to OPENSEARCH_ADMIN_PASSWORD, "
|
||||
"then prompt for input.",
|
||||
type=str,
|
||||
default=os.environ.get("OPENSEARCH_ADMIN_PASSWORD", ""),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-ssl", help="Disable SSL.", action="store_true", default=False
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-verify-certs",
|
||||
help="Disable certificate verification (for self-signed certs).",
|
||||
action="store_true",
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-aws-managed-opensearch",
|
||||
help="Whether to use AWS-managed OpenSearch. If not provided, will fall back to checking "
|
||||
"USING_AWS_MANAGED_OPENSEARCH=='true', then default to False.",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=os.environ.get("USING_AWS_MANAGED_OPENSEARCH", "").lower()
|
||||
== "true",
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="A utility to interact with OpenSearch."
|
||||
)
|
||||
subparsers = parser.add_subparsers(
|
||||
dest="command", help="Command to execute.", required=True
|
||||
)
|
||||
|
||||
list_parser = subparsers.add_parser("list", help="List all indices with info.")
|
||||
add_standard_arguments(list_parser)
|
||||
|
||||
delete_parser = subparsers.add_parser("delete", help="Delete an index.")
|
||||
delete_parser.add_argument("index", help="Index name.", type=str)
|
||||
add_standard_arguments(delete_parser)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not (host := args.host or input("Enter the OpenSearch host: ")):
|
||||
print("Error: OpenSearch host is required.")
|
||||
sys.exit(1)
|
||||
if not (port := args.port or int(input("Enter the OpenSearch port: "))):
|
||||
print("Error: OpenSearch port is required.")
|
||||
sys.exit(1)
|
||||
if not (username := args.username or input("Enter the OpenSearch username: ")):
|
||||
print("Error: OpenSearch username is required.")
|
||||
sys.exit(1)
|
||||
if not (password := args.password or input("Enter the OpenSearch password: ")):
|
||||
print("Error: OpenSearch password is required.")
|
||||
sys.exit(1)
|
||||
print("Using AWS-managed OpenSearch: ", args.use_aws_managed_opensearch)
|
||||
print(f"MULTI_TENANT: {MULTI_TENANT}")
|
||||
|
||||
with (
|
||||
OpenSearchIndexClient(
|
||||
index_name=args.index,
|
||||
host=host,
|
||||
port=port,
|
||||
auth=(username, password),
|
||||
use_ssl=not args.no_ssl,
|
||||
verify_certs=not args.no_verify_certs,
|
||||
)
|
||||
if args.command == "delete"
|
||||
else OpenSearchClient(
|
||||
host=host,
|
||||
port=port,
|
||||
auth=(username, password),
|
||||
use_ssl=not args.no_ssl,
|
||||
verify_certs=not args.no_verify_certs,
|
||||
)
|
||||
) as client:
|
||||
if not client.ping():
|
||||
print("Error: Could not connect to OpenSearch.")
|
||||
sys.exit(1)
|
||||
|
||||
if args.command == "list":
|
||||
list_indices(client)
|
||||
elif args.command == "delete":
|
||||
delete_index(client)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
107
backend/scripts/reencrypt_secrets.py
Normal file
107
backend/scripts/reencrypt_secrets.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""Re-encrypt secrets under the current ENCRYPTION_KEY_SECRET.
|
||||
|
||||
Decrypts all encrypted columns using the old key (or raw decode if the old key
|
||||
is empty), then re-encrypts them with the current ENCRYPTION_KEY_SECRET.
|
||||
|
||||
Usage (docker):
|
||||
docker exec -it onyx-api_server-1 \
|
||||
python -m scripts.reencrypt_secrets --old-key "previous-key"
|
||||
|
||||
Usage (kubernetes):
|
||||
kubectl exec -it <pod> -- \
|
||||
python -m scripts.reencrypt_secrets --old-key "previous-key"
|
||||
|
||||
Omit --old-key (or pass "") if secrets were not previously encrypted.
|
||||
|
||||
For multi-tenant deployments, pass --tenant-id to target a specific tenant,
|
||||
or --all-tenants to iterate every tenant.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
from onyx.db.rotate_encryption_key import rotate_encryption_key # noqa: E402
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant # noqa: E402
|
||||
from onyx.db.engine.sql_engine import SqlEngine # noqa: E402
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids # noqa: E402
|
||||
from onyx.utils.variable_functionality import global_version # noqa: E402
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA # noqa: E402
|
||||
|
||||
|
||||
def _run_for_tenant(tenant_id: str, old_key: str | None, dry_run: bool = False) -> None:
|
||||
print(f"Re-encrypting secrets for tenant: {tenant_id}")
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
results = rotate_encryption_key(db_session, old_key=old_key, dry_run=dry_run)
|
||||
|
||||
if results:
|
||||
for col, count in results.items():
|
||||
print(
|
||||
f" {col}: {count} row(s) {'would be ' if dry_run else ''}re-encrypted"
|
||||
)
|
||||
else:
|
||||
print("No rows needed re-encryption.")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Re-encrypt secrets under the current encryption key."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--old-key",
|
||||
default=None,
|
||||
help="Previous encryption key. Omit or pass empty string if not applicable.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Show what would be re-encrypted without making changes.",
|
||||
)
|
||||
|
||||
tenant_group = parser.add_mutually_exclusive_group()
|
||||
tenant_group.add_argument(
|
||||
"--tenant-id",
|
||||
default=None,
|
||||
help="Target a specific tenant schema.",
|
||||
)
|
||||
tenant_group.add_argument(
|
||||
"--all-tenants",
|
||||
action="store_true",
|
||||
help="Iterate all tenants.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
old_key = args.old_key if args.old_key else None
|
||||
|
||||
global_version.set_ee()
|
||||
SqlEngine.init_engine(pool_size=5, max_overflow=2)
|
||||
|
||||
if args.dry_run:
|
||||
print("DRY RUN — no changes will be made")
|
||||
|
||||
if args.all_tenants:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
print(f"Found {len(tenant_ids)} tenant(s)")
|
||||
failed_tenants: list[str] = []
|
||||
for tid in tenant_ids:
|
||||
try:
|
||||
_run_for_tenant(tid, old_key, dry_run=args.dry_run)
|
||||
except Exception as e:
|
||||
print(f" ERROR for tenant {tid}: {e}")
|
||||
failed_tenants.append(tid)
|
||||
if failed_tenants:
|
||||
print(f"FAILED tenants ({len(failed_tenants)}): {failed_tenants}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
tenant_id = args.tenant_id or POSTGRES_DEFAULT_SCHEMA
|
||||
_run_for_tenant(tenant_id, old_key, dry_run=args.dry_run)
|
||||
|
||||
print("Done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
71
backend/tests/README.md
Normal file
71
backend/tests/README.md
Normal file
@@ -0,0 +1,71 @@
|
||||
# Backend Tests
|
||||
|
||||
## Test Types
|
||||
|
||||
There are four test categories, ordered by increasing scope:
|
||||
|
||||
### Unit Tests (`tests/unit/`)
|
||||
|
||||
No external services. Mock all I/O with `unittest.mock`. Use for complex, isolated
|
||||
logic (e.g. citation processing, encryption).
|
||||
|
||||
```bash
|
||||
pytest -xv backend/tests/unit
|
||||
```
|
||||
|
||||
### External Dependency Unit Tests (`tests/external_dependency_unit/`)
|
||||
|
||||
External services (Postgres, Redis, Vespa, OpenAI, etc.) are running, but Onyx
|
||||
application containers are not. Tests call functions directly and can mock selectively.
|
||||
|
||||
Use when you need a real database or real API calls but want control over setup.
|
||||
|
||||
```bash
|
||||
python -m dotenv -f .vscode/.env run -- pytest backend/tests/external_dependency_unit
|
||||
```
|
||||
|
||||
### Integration Tests (`tests/integration/`)
|
||||
|
||||
Full Onyx deployment running. No mocking. Prefer this over other test types when possible.
|
||||
|
||||
```bash
|
||||
python -m dotenv -f .vscode/.env run -- pytest backend/tests/integration
|
||||
```
|
||||
|
||||
### Playwright / E2E Tests (`web/tests/e2e/`)
|
||||
|
||||
Full stack including web server. Use for frontend-backend coordination.
|
||||
|
||||
```bash
|
||||
npx playwright test <TEST_NAME>
|
||||
```
|
||||
|
||||
## Shared Fixtures
|
||||
|
||||
Shared fixtures live in `backend/tests/conftest.py`. Test subdirectories can define
|
||||
their own `conftest.py` for directory-scoped fixtures.
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Use `enable_ee` fixture instead of inlining
|
||||
|
||||
Enables EE mode for a test, with proper teardown and cache clearing.
|
||||
|
||||
```python
|
||||
# Whole file (in a test module, NOT in conftest.py)
|
||||
pytestmark = pytest.mark.usefixtures("enable_ee")
|
||||
|
||||
# Whole directory — add an autouse wrapper to the directory's conftest.py
|
||||
@pytest.fixture(autouse=True)
|
||||
def _enable_ee_for_directory(enable_ee: None) -> None: # noqa: ARG001
|
||||
"""Wraps the shared enable_ee fixture with autouse for this directory."""
|
||||
|
||||
# Single test
|
||||
def test_something(enable_ee: None) -> None: ...
|
||||
```
|
||||
|
||||
**Note:** `pytestmark` in a `conftest.py` does NOT apply markers to tests in that
|
||||
directory — it only affects tests defined in the conftest itself (which is none).
|
||||
Use the autouse fixture wrapper pattern shown above instead.
|
||||
|
||||
Do NOT inline `global_version.set_ee()` — always use the fixture.
|
||||
24
backend/tests/conftest.py
Normal file
24
backend/tests/conftest.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Root conftest — shared fixtures available to all test directories."""
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def enable_ee() -> Generator[None, None, None]:
|
||||
"""Temporarily enable EE mode for a single test.
|
||||
|
||||
Restores the previous EE state and clears the versioned-implementation
|
||||
cache on teardown so state doesn't leak between tests.
|
||||
"""
|
||||
was_ee = global_version.is_ee_version()
|
||||
global_version.set_ee()
|
||||
fetch_versioned_implementation.cache_clear()
|
||||
yield
|
||||
if not was_ee:
|
||||
global_version.unset_ee()
|
||||
fetch_versioned_implementation.cache_clear()
|
||||
@@ -45,7 +45,7 @@ def confluence_connector() -> ConfluenceConnector:
|
||||
def test_confluence_connector_permissions(
|
||||
mock_get_api_key: MagicMock, # noqa: ARG001
|
||||
confluence_connector: ConfluenceConnector,
|
||||
set_ee_on: None, # noqa: ARG001
|
||||
enable_ee: None, # noqa: ARG001
|
||||
) -> None:
|
||||
# Get all doc IDs from the full connector
|
||||
all_full_doc_ids = set()
|
||||
@@ -93,7 +93,7 @@ def test_confluence_connector_permissions(
|
||||
def test_confluence_connector_restriction_handling(
|
||||
mock_get_api_key: MagicMock, # noqa: ARG001
|
||||
mock_db_provider_class: MagicMock,
|
||||
set_ee_on: None, # noqa: ARG001
|
||||
enable_ee: None, # noqa: ARG001
|
||||
) -> None:
|
||||
# Test space key
|
||||
test_space_key = "DailyPermS"
|
||||
|
||||
@@ -4,8 +4,6 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_unstructured_api_key() -> Generator[MagicMock, None, None]:
|
||||
@@ -14,14 +12,3 @@ def mock_get_unstructured_api_key() -> Generator[MagicMock, None, None]:
|
||||
return_value=None,
|
||||
) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def set_ee_on() -> Generator[None, None, None]:
|
||||
"""Need EE to be enabled for these tests to work since
|
||||
perm syncing is a an EE-only feature."""
|
||||
global_version.set_ee()
|
||||
|
||||
yield
|
||||
|
||||
global_version._is_ee = False
|
||||
|
||||
@@ -98,7 +98,7 @@ def _build_connector(
|
||||
|
||||
def test_gdrive_perm_sync_with_real_data(
|
||||
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
|
||||
set_ee_on: None, # noqa: ARG001
|
||||
enable_ee: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""
|
||||
Test gdrive_doc_sync and gdrive_group_sync with real data from the test drive.
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.slack.connector import SlackConnector
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from tests.daily.connectors.utils import load_all_from_connector
|
||||
|
||||
|
||||
@@ -19,16 +17,7 @@ PRIVATE_CHANNEL_USERS = [
|
||||
"test_user_2@onyx-test.com",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def set_ee_on() -> Generator[None, None, None]:
|
||||
"""Need EE to be enabled for these tests to work since
|
||||
perm syncing is a an EE-only feature."""
|
||||
global_version.set_ee()
|
||||
|
||||
yield
|
||||
|
||||
global_version._is_ee = False
|
||||
pytestmark = pytest.mark.usefixtures("enable_ee")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.teams.connector import TeamsConnector
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from tests.daily.connectors.teams.models import TeamsThread
|
||||
from tests.daily.connectors.utils import load_all_from_connector
|
||||
|
||||
@@ -168,18 +166,9 @@ def test_slim_docs_retrieval_from_teams_connector(
|
||||
_assert_is_valid_external_access(external_access=slim_doc.external_access)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=False)
|
||||
def set_ee_on() -> Generator[None, None, None]:
|
||||
"""Need EE to be enabled for perm sync tests to work since
|
||||
perm syncing is an EE-only feature."""
|
||||
global_version.set_ee()
|
||||
yield
|
||||
global_version._is_ee = False
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_perm_sync(
|
||||
teams_connector: TeamsConnector,
|
||||
set_ee_on: None, # noqa: ARG001
|
||||
enable_ee: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Test that load_from_checkpoint_with_perm_sync returns documents with external_access.
|
||||
|
||||
|
||||
@@ -145,6 +145,10 @@ class TestDocprocessingPriorityInDocumentExtraction:
|
||||
@patch("onyx.background.indexing.run_docfetching.get_document_batch_storage")
|
||||
@patch("onyx.background.indexing.run_docfetching.MemoryTracer")
|
||||
@patch("onyx.background.indexing.run_docfetching._get_connector_runner")
|
||||
@patch(
|
||||
"onyx.background.indexing.run_docfetching.strip_null_characters",
|
||||
side_effect=lambda batch: batch,
|
||||
)
|
||||
@patch(
|
||||
"onyx.background.indexing.run_docfetching.get_recent_completed_attempts_for_cc_pair"
|
||||
)
|
||||
@@ -169,6 +173,7 @@ class TestDocprocessingPriorityInDocumentExtraction:
|
||||
mock_save_checkpoint: MagicMock, # noqa: ARG002
|
||||
mock_get_last_successful_attempt_poll_range_end: MagicMock,
|
||||
mock_get_recent_completed_attempts: MagicMock,
|
||||
mock_strip_null_characters: MagicMock, # noqa: ARG002
|
||||
mock_get_connector_runner: MagicMock,
|
||||
mock_memory_tracer_class: MagicMock,
|
||||
mock_get_batch_storage: MagicMock,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -14,13 +15,14 @@ from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.utils import DocumentRow
|
||||
from onyx.db.utils import SortOrder
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
# In order to get these tests to run, use the credentials from Bitwarden.
|
||||
# Search up "ENV vars for local and Github tests", and find the Jira relevant key-value pairs.
|
||||
# Required env vars: JIRA_USER_EMAIL, JIRA_API_TOKEN
|
||||
|
||||
pytestmark = pytest.mark.usefixtures("enable_ee")
|
||||
|
||||
|
||||
class DocExternalAccessSet(BaseModel):
|
||||
"""A version of DocExternalAccess that uses sets for comparison."""
|
||||
@@ -52,9 +54,6 @@ def test_jira_doc_sync(
|
||||
This test uses the AS project which has applicationRole permission,
|
||||
meaning all documents should be marked as public.
|
||||
"""
|
||||
# NOTE: must set EE on or else the connector will skip the perm syncing
|
||||
global_version.set_ee()
|
||||
|
||||
try:
|
||||
# Use AS project specifically for this test
|
||||
connector_config = {
|
||||
@@ -150,9 +149,6 @@ def test_jira_doc_sync_with_specific_permissions(
|
||||
This test uses a project that has specific user permissions to verify
|
||||
that specific users are correctly extracted.
|
||||
"""
|
||||
# NOTE: must set EE on or else the connector will skip the perm syncing
|
||||
global_version.set_ee()
|
||||
|
||||
try:
|
||||
# Use SUP project which has specific user permissions
|
||||
connector_config = {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.external_permissions.jira.group_sync import jira_group_sync
|
||||
@@ -18,6 +19,8 @@ from tests.daily.connectors.confluence.models import ExternalUserGroupSet
|
||||
# Search up "ENV vars for local and Github tests", and find the Jira relevant key-value pairs.
|
||||
# Required env vars: JIRA_USER_EMAIL, JIRA_API_TOKEN
|
||||
|
||||
pytestmark = pytest.mark.usefixtures("enable_ee")
|
||||
|
||||
# Expected groups from the danswerai.atlassian.net Jira instance
|
||||
# Note: These groups are shared with Confluence since they're both Atlassian products
|
||||
# App accounts (bots, integrations) are filtered out
|
||||
|
||||
@@ -0,0 +1,305 @@
|
||||
"""Tests for rotate_encryption_key against real Postgres.
|
||||
|
||||
Uses real ORM models (Credential, InternetSearchProvider) and the actual
|
||||
Postgres database. Discovery is mocked in rotation tests to scope mutations
|
||||
to only the test rows — the real _discover_encrypted_columns walk is tested
|
||||
separately in TestDiscoverEncryptedColumns.
|
||||
|
||||
Requires a running Postgres instance. Run with::
|
||||
|
||||
python -m dotenv -f .vscode/.env run -- pytest tests/external_dependency_unit/db/test_rotate_encryption_key.py
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import LargeBinary
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.utils.encryption import _decrypt_bytes
|
||||
from ee.onyx.utils.encryption import _encrypt_string
|
||||
from ee.onyx.utils.encryption import _get_trimmed_key
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import EncryptedJson
|
||||
from onyx.db.models import EncryptedString
|
||||
from onyx.db.models import InternetSearchProvider
|
||||
from onyx.db.rotate_encryption_key import _discover_encrypted_columns
|
||||
from onyx.db.rotate_encryption_key import rotate_encryption_key
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
EE_MODULE = "ee.onyx.utils.encryption"
|
||||
ROTATE_MODULE = "onyx.db.rotate_encryption_key"
|
||||
|
||||
OLD_KEY = "o" * 16
|
||||
NEW_KEY = "n" * 16
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _enable_ee() -> Generator[None, None, None]:
|
||||
prev = global_version._is_ee
|
||||
global_version.set_ee()
|
||||
fetch_versioned_implementation.cache_clear()
|
||||
yield
|
||||
global_version._is_ee = prev
|
||||
fetch_versioned_implementation.cache_clear()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_key_cache() -> None:
|
||||
_get_trimmed_key.cache_clear()
|
||||
|
||||
|
||||
def _raw_credential_bytes(db_session: Session, credential_id: int) -> bytes | None:
|
||||
"""Read raw bytes from credential_json, bypassing the TypeDecorator."""
|
||||
col = Credential.__table__.c.credential_json
|
||||
stmt = select(col.cast(LargeBinary)).where(
|
||||
Credential.__table__.c.id == credential_id
|
||||
)
|
||||
return db_session.execute(stmt).scalar()
|
||||
|
||||
|
||||
def _raw_isp_bytes(db_session: Session, isp_id: int) -> bytes | None:
|
||||
"""Read raw bytes from InternetSearchProvider.api_key."""
|
||||
col = InternetSearchProvider.__table__.c.api_key
|
||||
stmt = select(col.cast(LargeBinary)).where(
|
||||
InternetSearchProvider.__table__.c.id == isp_id
|
||||
)
|
||||
return db_session.execute(stmt).scalar()
|
||||
|
||||
|
||||
class TestDiscoverEncryptedColumns:
|
||||
"""Verify _discover_encrypted_columns finds real production models."""
|
||||
|
||||
def test_discovers_credential_json(self) -> None:
|
||||
results = _discover_encrypted_columns()
|
||||
found = {
|
||||
(model_cls.__tablename__, col_name, is_json) # type: ignore[attr-defined]
|
||||
for model_cls, col_name, _, is_json in results
|
||||
}
|
||||
assert ("credential", "credential_json", True) in found
|
||||
|
||||
def test_discovers_internet_search_provider_api_key(self) -> None:
|
||||
results = _discover_encrypted_columns()
|
||||
found = {
|
||||
(model_cls.__tablename__, col_name, is_json) # type: ignore[attr-defined]
|
||||
for model_cls, col_name, _, is_json in results
|
||||
}
|
||||
assert ("internet_search_provider", "api_key", False) in found
|
||||
|
||||
def test_all_encrypted_string_columns_are_not_json(self) -> None:
|
||||
results = _discover_encrypted_columns()
|
||||
for model_cls, col_name, _, is_json in results:
|
||||
col = getattr(model_cls, col_name).property.columns[0]
|
||||
if isinstance(col.type, EncryptedString):
|
||||
assert not is_json, (
|
||||
f"{model_cls.__tablename__}.{col_name} is EncryptedString " # type: ignore[attr-defined]
|
||||
f"but is_json={is_json}"
|
||||
)
|
||||
|
||||
def test_all_encrypted_json_columns_are_json(self) -> None:
|
||||
results = _discover_encrypted_columns()
|
||||
for model_cls, col_name, _, is_json in results:
|
||||
col = getattr(model_cls, col_name).property.columns[0]
|
||||
if isinstance(col.type, EncryptedJson):
|
||||
assert is_json, (
|
||||
f"{model_cls.__tablename__}.{col_name} is EncryptedJson " # type: ignore[attr-defined]
|
||||
f"but is_json={is_json}"
|
||||
)
|
||||
|
||||
|
||||
class TestRotateCredential:
|
||||
"""Test rotation against the real Credential table (EncryptedJson).
|
||||
|
||||
Discovery is scoped to only the Credential model to avoid mutating
|
||||
other tables in the test database.
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _limit_discovery(self) -> Generator[None, None, None]:
|
||||
with patch(
|
||||
f"{ROTATE_MODULE}._discover_encrypted_columns",
|
||||
return_value=[(Credential, "credential_json", ["id"], True)],
|
||||
):
|
||||
yield
|
||||
|
||||
@pytest.fixture()
|
||||
def credential_id(
|
||||
self, db_session: Session, tenant_context: None # noqa: ARG002
|
||||
) -> Generator[int, None, None]:
|
||||
"""Insert a Credential row with raw encrypted bytes, clean up after."""
|
||||
config = {"api_key": "sk-test-1234", "endpoint": "https://example.com"}
|
||||
encrypted = _encrypt_string(json.dumps(config), key=OLD_KEY)
|
||||
|
||||
result = db_session.execute(
|
||||
text(
|
||||
"INSERT INTO credential "
|
||||
"(source, credential_json, admin_public, curator_public) "
|
||||
"VALUES (:source, :cred_json, true, false) "
|
||||
"RETURNING id"
|
||||
),
|
||||
{"source": DocumentSource.INGESTION_API.value, "cred_json": encrypted},
|
||||
)
|
||||
cred_id = result.scalar_one()
|
||||
db_session.commit()
|
||||
|
||||
yield cred_id
|
||||
|
||||
db_session.execute(
|
||||
text("DELETE FROM credential WHERE id = :id"), {"id": cred_id}
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
def test_rotates_credential_json(
|
||||
self, db_session: Session, credential_id: int
|
||||
) -> None:
|
||||
with (
|
||||
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
):
|
||||
totals = rotate_encryption_key(db_session, old_key=OLD_KEY)
|
||||
|
||||
assert totals.get("credential.credential_json", 0) >= 1
|
||||
|
||||
raw = _raw_credential_bytes(db_session, credential_id)
|
||||
assert raw is not None
|
||||
decrypted = json.loads(_decrypt_bytes(raw, key=NEW_KEY))
|
||||
assert decrypted["api_key"] == "sk-test-1234"
|
||||
assert decrypted["endpoint"] == "https://example.com"
|
||||
|
||||
def test_skips_already_rotated(
|
||||
self, db_session: Session, credential_id: int
|
||||
) -> None:
|
||||
with (
|
||||
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
):
|
||||
rotate_encryption_key(db_session, old_key=OLD_KEY)
|
||||
_ = rotate_encryption_key(db_session, old_key=OLD_KEY)
|
||||
|
||||
raw = _raw_credential_bytes(db_session, credential_id)
|
||||
assert raw is not None
|
||||
decrypted = json.loads(_decrypt_bytes(raw, key=NEW_KEY))
|
||||
assert decrypted["api_key"] == "sk-test-1234"
|
||||
|
||||
def test_dry_run_does_not_modify(
|
||||
self, db_session: Session, credential_id: int
|
||||
) -> None:
|
||||
original = _raw_credential_bytes(db_session, credential_id)
|
||||
|
||||
with (
|
||||
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
):
|
||||
totals = rotate_encryption_key(db_session, old_key=OLD_KEY, dry_run=True)
|
||||
|
||||
assert totals.get("credential.credential_json", 0) >= 1
|
||||
|
||||
raw_after = _raw_credential_bytes(db_session, credential_id)
|
||||
assert raw_after == original
|
||||
|
||||
|
||||
class TestRotateInternetSearchProvider:
|
||||
"""Test rotation against the real InternetSearchProvider table (EncryptedString).
|
||||
|
||||
Discovery is scoped to only the InternetSearchProvider model to avoid
|
||||
mutating other tables in the test database.
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _limit_discovery(self) -> Generator[None, None, None]:
|
||||
with patch(
|
||||
f"{ROTATE_MODULE}._discover_encrypted_columns",
|
||||
return_value=[
|
||||
(InternetSearchProvider, "api_key", ["id"], False),
|
||||
],
|
||||
):
|
||||
yield
|
||||
|
||||
@pytest.fixture()
|
||||
def isp_id(
|
||||
self, db_session: Session, tenant_context: None # noqa: ARG002
|
||||
) -> Generator[int, None, None]:
|
||||
"""Insert an InternetSearchProvider row with raw encrypted bytes."""
|
||||
encrypted = _encrypt_string("sk-secret-api-key", key=OLD_KEY)
|
||||
|
||||
result = db_session.execute(
|
||||
text(
|
||||
"INSERT INTO internet_search_provider "
|
||||
"(name, provider_type, api_key, is_active) "
|
||||
"VALUES (:name, :ptype, :api_key, false) "
|
||||
"RETURNING id"
|
||||
),
|
||||
{
|
||||
"name": f"test-rotation-{id(self)}",
|
||||
"ptype": "test",
|
||||
"api_key": encrypted,
|
||||
},
|
||||
)
|
||||
isp_id = result.scalar_one()
|
||||
db_session.commit()
|
||||
|
||||
yield isp_id
|
||||
|
||||
db_session.execute(
|
||||
text("DELETE FROM internet_search_provider WHERE id = :id"),
|
||||
{"id": isp_id},
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
def test_rotates_api_key(self, db_session: Session, isp_id: int) -> None:
|
||||
with (
|
||||
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
):
|
||||
totals = rotate_encryption_key(db_session, old_key=OLD_KEY)
|
||||
|
||||
assert totals.get("internet_search_provider.api_key", 0) >= 1
|
||||
|
||||
raw = _raw_isp_bytes(db_session, isp_id)
|
||||
assert raw is not None
|
||||
assert _decrypt_bytes(raw, key=NEW_KEY) == "sk-secret-api-key"
|
||||
|
||||
def test_rotates_from_unencrypted(
|
||||
self, db_session: Session, tenant_context: None # noqa: ARG002
|
||||
) -> None:
|
||||
"""Test rotating data that was stored without any encryption key."""
|
||||
result = db_session.execute(
|
||||
text(
|
||||
"INSERT INTO internet_search_provider "
|
||||
"(name, provider_type, api_key, is_active) "
|
||||
"VALUES (:name, :ptype, :api_key, false) "
|
||||
"RETURNING id"
|
||||
),
|
||||
{
|
||||
"name": f"test-raw-{id(self)}",
|
||||
"ptype": "test",
|
||||
"api_key": b"raw-api-key",
|
||||
},
|
||||
)
|
||||
isp_id = result.scalar_one()
|
||||
db_session.commit()
|
||||
|
||||
try:
|
||||
with (
|
||||
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
):
|
||||
totals = rotate_encryption_key(db_session, old_key=None)
|
||||
|
||||
assert totals.get("internet_search_provider.api_key", 0) >= 1
|
||||
|
||||
raw = _raw_isp_bytes(db_session, isp_id)
|
||||
assert raw is not None
|
||||
assert _decrypt_bytes(raw, key=NEW_KEY) == "raw-api-key"
|
||||
finally:
|
||||
db_session.execute(
|
||||
text("DELETE FROM internet_search_provider WHERE id = :id"),
|
||||
{"id": isp_id},
|
||||
)
|
||||
db_session.commit()
|
||||
@@ -698,6 +698,99 @@ class TestAutoModeMissingFlows:
|
||||
class TestAutoModeTransitionsAndResync:
|
||||
"""Tests for auto/manual transitions, config evolution, and sync idempotency."""
|
||||
|
||||
def test_transition_to_auto_mode_preserves_default(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""When the default provider transitions from manual to auto mode,
|
||||
the global default should be preserved (set to the recommended model).
|
||||
|
||||
Steps:
|
||||
1. Create a manual-mode provider with models, set it as global default.
|
||||
2. Transition to auto mode (model_configurations=[] triggers cascade
|
||||
delete of old ModelConfigurations and their LLMModelFlow rows).
|
||||
3. Verify the provider is still the global default, now using the
|
||||
recommended default model from the GitHub config.
|
||||
"""
|
||||
initial_models = [
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o-mini", is_visible=True),
|
||||
]
|
||||
|
||||
auto_config = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o-mini",
|
||||
additional_models=["gpt-4o"],
|
||||
)
|
||||
|
||||
try:
|
||||
# Step 1: Create manual-mode provider and set as default
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=False,
|
||||
model_configurations=initial_models,
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
default_before = fetch_default_llm_model(db_session)
|
||||
assert default_before is not None
|
||||
assert default_before.name == "gpt-4o"
|
||||
assert default_before.llm_provider_id == provider.id
|
||||
|
||||
# Step 2: Transition to auto mode
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
return_value=auto_config,
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None,
|
||||
api_key_changed=False,
|
||||
is_auto_mode=True,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=False,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Step 3: Default should be preserved on this provider
|
||||
db_session.expire_all()
|
||||
default_after = fetch_default_llm_model(db_session)
|
||||
assert default_after is not None, (
|
||||
"Default model should not be None after transitioning to auto mode — "
|
||||
"the provider was the default before and should remain so"
|
||||
)
|
||||
assert (
|
||||
default_after.llm_provider_id == provider.id
|
||||
), "Default should still belong to the same provider after transition"
|
||||
assert default_after.name == "gpt-4o-mini", (
|
||||
f"Default should be updated to the recommended model 'gpt-4o-mini', "
|
||||
f"got '{default_after.name}'"
|
||||
)
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
def test_auto_to_manual_mode_preserves_models_and_stops_syncing(
|
||||
self,
|
||||
db_session: Session,
|
||||
@@ -1042,14 +1135,19 @@ class TestAutoModeTransitionsAndResync:
|
||||
assert visibility["gpt-4o"] is False, "Removed default should be hidden"
|
||||
assert visibility["gpt-4o-mini"] is True, "New default should be visible"
|
||||
|
||||
# The LLMModelFlow row for gpt-4o still exists (is_default=True),
|
||||
# but the model is hidden. fetch_default_llm_model filters on
|
||||
# is_visible=True, so it should NOT return gpt-4o.
|
||||
# The old default (gpt-4o) is now hidden. sync_auto_mode_models
|
||||
# should update the global default to the new recommended default
|
||||
# (gpt-4o-mini) so that it is not silently lost.
|
||||
db_session.expire_all()
|
||||
default_after = fetch_default_llm_model(db_session)
|
||||
assert (
|
||||
default_after is None or default_after.name != "gpt-4o"
|
||||
), "Hidden model should not be returned as the default"
|
||||
assert default_after is not None, (
|
||||
"Default model should not be None — sync should set the new "
|
||||
"recommended default when the old one is hidden"
|
||||
)
|
||||
assert default_after.name == "gpt-4o-mini", (
|
||||
f"Default should be updated to the new recommended model "
|
||||
f"'gpt-4o-mini', but got '{default_after.name}'"
|
||||
)
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
|
||||
@@ -21,6 +21,8 @@ from onyx.db.oauth_config import get_tools_by_oauth_config
|
||||
from onyx.db.oauth_config import get_user_oauth_token
|
||||
from onyx.db.oauth_config import update_oauth_config
|
||||
from onyx.db.oauth_config import upsert_user_oauth_token
|
||||
from onyx.db.tools import delete_tool__no_commit
|
||||
from onyx.db.tools import update_tool
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
|
||||
|
||||
@@ -312,6 +314,85 @@ class TestOAuthConfigCRUD:
|
||||
# Tool should still exist but oauth_config_id should be NULL
|
||||
assert tool.oauth_config_id is None
|
||||
|
||||
def test_update_tool_cleans_up_orphaned_oauth_config(
|
||||
self, db_session: Session
|
||||
) -> None:
|
||||
"""Test that changing a tool's oauth_config_id deletes the old config if no other tool uses it."""
|
||||
old_config = _create_test_oauth_config(db_session)
|
||||
new_config = _create_test_oauth_config(db_session)
|
||||
tool = _create_test_tool_with_oauth(db_session, old_config)
|
||||
old_config_id = old_config.id
|
||||
|
||||
update_tool(
|
||||
tool_id=tool.id,
|
||||
name=None,
|
||||
description=None,
|
||||
openapi_schema=None,
|
||||
custom_headers=None,
|
||||
user_id=None,
|
||||
db_session=db_session,
|
||||
passthrough_auth=None,
|
||||
oauth_config_id=new_config.id,
|
||||
)
|
||||
|
||||
assert tool.oauth_config_id == new_config.id
|
||||
assert get_oauth_config(old_config_id, db_session) is None
|
||||
|
||||
def test_delete_tool_cleans_up_orphaned_oauth_config(
|
||||
self, db_session: Session
|
||||
) -> None:
|
||||
"""Test that deleting the last tool referencing an OAuthConfig also deletes the config."""
|
||||
config = _create_test_oauth_config(db_session)
|
||||
tool = _create_test_tool_with_oauth(db_session, config)
|
||||
config_id = config.id
|
||||
|
||||
delete_tool__no_commit(tool.id, db_session)
|
||||
db_session.commit()
|
||||
|
||||
assert get_oauth_config(config_id, db_session) is None
|
||||
|
||||
def test_update_tool_preserves_shared_oauth_config(
|
||||
self, db_session: Session
|
||||
) -> None:
|
||||
"""Test that updating one tool's oauth_config_id preserves the config when another tool still uses it."""
|
||||
shared_config = _create_test_oauth_config(db_session)
|
||||
new_config = _create_test_oauth_config(db_session)
|
||||
tool_a = _create_test_tool_with_oauth(db_session, shared_config)
|
||||
tool_b = _create_test_tool_with_oauth(db_session, shared_config)
|
||||
shared_config_id = shared_config.id
|
||||
|
||||
# Move tool_a to a new config; tool_b still references shared_config
|
||||
update_tool(
|
||||
tool_id=tool_a.id,
|
||||
name=None,
|
||||
description=None,
|
||||
openapi_schema=None,
|
||||
custom_headers=None,
|
||||
user_id=None,
|
||||
db_session=db_session,
|
||||
passthrough_auth=None,
|
||||
oauth_config_id=new_config.id,
|
||||
)
|
||||
|
||||
assert tool_a.oauth_config_id == new_config.id
|
||||
assert tool_b.oauth_config_id == shared_config_id
|
||||
assert get_oauth_config(shared_config_id, db_session) is not None
|
||||
|
||||
def test_delete_tool_preserves_shared_oauth_config(
|
||||
self, db_session: Session
|
||||
) -> None:
|
||||
"""Test that deleting one tool preserves the config when another tool still uses it."""
|
||||
shared_config = _create_test_oauth_config(db_session)
|
||||
tool_a = _create_test_tool_with_oauth(db_session, shared_config)
|
||||
tool_b = _create_test_tool_with_oauth(db_session, shared_config)
|
||||
shared_config_id = shared_config.id
|
||||
|
||||
delete_tool__no_commit(tool_a.id, db_session)
|
||||
db_session.commit()
|
||||
|
||||
assert tool_b.oauth_config_id == shared_config_id
|
||||
assert get_oauth_config(shared_config_id, db_session) is not None
|
||||
|
||||
|
||||
class TestOAuthUserTokenCRUD:
|
||||
"""Tests for OAuth user token CRUD operations"""
|
||||
|
||||
8
backend/tests/unit/ee/conftest.py
Normal file
8
backend/tests/unit/ee/conftest.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Auto-enable EE mode for all tests under tests/unit/ee/."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _enable_ee_for_directory(enable_ee: None) -> None: # noqa: ARG001
|
||||
"""Wraps the shared enable_ee fixture with autouse for this directory."""
|
||||
165
backend/tests/unit/ee/onyx/utils/test_encryption.py
Normal file
165
backend/tests/unit/ee/onyx/utils/test_encryption.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Tests for EE AES-CBC encryption/decryption with explicit key support.
|
||||
|
||||
With EE mode enabled (via conftest), fetch_versioned_implementation resolves
|
||||
to the EE implementations, so no patching of the MIT layer is needed.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from ee.onyx.utils.encryption import _decrypt_bytes
|
||||
from ee.onyx.utils.encryption import _encrypt_string
|
||||
from ee.onyx.utils.encryption import _get_trimmed_key
|
||||
from ee.onyx.utils.encryption import decrypt_bytes_to_string
|
||||
from ee.onyx.utils.encryption import encrypt_string_to_bytes
|
||||
|
||||
EE_MODULE = "ee.onyx.utils.encryption"
|
||||
|
||||
# Keys must be exactly 16, 24, or 32 bytes for AES
|
||||
KEY_16 = "a" * 16
|
||||
KEY_16_ALT = "b" * 16
|
||||
KEY_24 = "d" * 24
|
||||
KEY_32 = "c" * 32
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_key_cache() -> None:
|
||||
_get_trimmed_key.cache_clear()
|
||||
|
||||
|
||||
class TestEncryptDecryptRoundTrip:
|
||||
def test_roundtrip_with_env_key(self) -> None:
|
||||
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", KEY_16):
|
||||
encrypted = _encrypt_string("hello world")
|
||||
assert encrypted != b"hello world"
|
||||
assert _decrypt_bytes(encrypted) == "hello world"
|
||||
|
||||
def test_roundtrip_with_explicit_key(self) -> None:
|
||||
encrypted = _encrypt_string("secret data", key=KEY_32)
|
||||
assert encrypted != b"secret data"
|
||||
assert _decrypt_bytes(encrypted, key=KEY_32) == "secret data"
|
||||
|
||||
def test_roundtrip_no_key(self) -> None:
|
||||
"""Without any key, data is raw-encoded (no encryption)."""
|
||||
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", ""):
|
||||
encrypted = _encrypt_string("plain text")
|
||||
assert encrypted == b"plain text"
|
||||
assert _decrypt_bytes(encrypted) == "plain text"
|
||||
|
||||
def test_explicit_key_overrides_env(self) -> None:
|
||||
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", KEY_16):
|
||||
encrypted = _encrypt_string("data", key=KEY_16_ALT)
|
||||
with pytest.raises(ValueError):
|
||||
_decrypt_bytes(encrypted, key=KEY_16)
|
||||
assert _decrypt_bytes(encrypted, key=KEY_16_ALT) == "data"
|
||||
|
||||
def test_different_encryptions_produce_different_bytes(self) -> None:
|
||||
"""Each encryption uses a random IV, so results differ."""
|
||||
a = _encrypt_string("same", key=KEY_16)
|
||||
b = _encrypt_string("same", key=KEY_16)
|
||||
assert a != b
|
||||
|
||||
def test_roundtrip_empty_string(self) -> None:
|
||||
encrypted = _encrypt_string("", key=KEY_16)
|
||||
assert encrypted != b""
|
||||
assert _decrypt_bytes(encrypted, key=KEY_16) == ""
|
||||
|
||||
def test_roundtrip_unicode(self) -> None:
|
||||
text = "日本語テスト 🔐 émojis"
|
||||
encrypted = _encrypt_string(text, key=KEY_16)
|
||||
assert _decrypt_bytes(encrypted, key=KEY_16) == text
|
||||
|
||||
|
||||
class TestDecryptFallbackBehavior:
|
||||
def test_wrong_env_key_falls_back_to_raw_decode(self) -> None:
|
||||
"""Default key path: AES fails on non-AES data → fallback to raw decode."""
|
||||
raw = "readable text".encode()
|
||||
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", KEY_16):
|
||||
assert _decrypt_bytes(raw) == "readable text"
|
||||
|
||||
def test_explicit_wrong_key_raises(self) -> None:
|
||||
"""Explicit key path: AES fails → raises, no fallback."""
|
||||
encrypted = _encrypt_string("secret", key=KEY_16)
|
||||
with pytest.raises(ValueError):
|
||||
_decrypt_bytes(encrypted, key=KEY_16_ALT)
|
||||
|
||||
def test_explicit_none_key_with_no_env(self) -> None:
|
||||
"""key=None with empty env → raw decode."""
|
||||
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", ""):
|
||||
assert _decrypt_bytes(b"hello", key=None) == "hello"
|
||||
|
||||
def test_explicit_empty_string_key(self) -> None:
|
||||
"""key='' means no encryption."""
|
||||
encrypted = _encrypt_string("test", key="")
|
||||
assert encrypted == b"test"
|
||||
assert _decrypt_bytes(encrypted, key="") == "test"
|
||||
|
||||
|
||||
class TestKeyValidation:
|
||||
def test_key_too_short_raises(self) -> None:
|
||||
with pytest.raises(RuntimeError, match="too short"):
|
||||
_encrypt_string("data", key="short")
|
||||
|
||||
def test_16_byte_key(self) -> None:
|
||||
encrypted = _encrypt_string("data", key=KEY_16)
|
||||
assert _decrypt_bytes(encrypted, key=KEY_16) == "data"
|
||||
|
||||
def test_24_byte_key(self) -> None:
|
||||
encrypted = _encrypt_string("data", key=KEY_24)
|
||||
assert _decrypt_bytes(encrypted, key=KEY_24) == "data"
|
||||
|
||||
def test_32_byte_key(self) -> None:
|
||||
encrypted = _encrypt_string("data", key=KEY_32)
|
||||
assert _decrypt_bytes(encrypted, key=KEY_32) == "data"
|
||||
|
||||
def test_long_key_truncated_to_32(self) -> None:
|
||||
"""Keys longer than 32 bytes are truncated to 32."""
|
||||
long_key = "e" * 64
|
||||
encrypted = _encrypt_string("data", key=long_key)
|
||||
assert _decrypt_bytes(encrypted, key=long_key) == "data"
|
||||
|
||||
def test_20_byte_key_trimmed_to_16(self) -> None:
|
||||
"""A 20-byte key is trimmed to the largest valid AES size that fits (16)."""
|
||||
key_20 = "f" * 20
|
||||
encrypted = _encrypt_string("data", key=key_20)
|
||||
assert _decrypt_bytes(encrypted, key=key_20) == "data"
|
||||
|
||||
# Verify it was trimmed to 16 by checking that the first 16 bytes
|
||||
# of the key can also decrypt it
|
||||
key_16_same_prefix = "f" * 16
|
||||
assert _decrypt_bytes(encrypted, key=key_16_same_prefix) == "data"
|
||||
|
||||
def test_25_byte_key_trimmed_to_24(self) -> None:
|
||||
"""A 25-byte key is trimmed to the largest valid AES size that fits (24)."""
|
||||
key_25 = "g" * 25
|
||||
encrypted = _encrypt_string("data", key=key_25)
|
||||
assert _decrypt_bytes(encrypted, key=key_25) == "data"
|
||||
|
||||
key_24_same_prefix = "g" * 24
|
||||
assert _decrypt_bytes(encrypted, key=key_24_same_prefix) == "data"
|
||||
|
||||
def test_30_byte_key_trimmed_to_24(self) -> None:
|
||||
"""A 30-byte key is trimmed to the largest valid AES size that fits (24)."""
|
||||
key_30 = "h" * 30
|
||||
encrypted = _encrypt_string("data", key=key_30)
|
||||
assert _decrypt_bytes(encrypted, key=key_30) == "data"
|
||||
|
||||
key_24_same_prefix = "h" * 24
|
||||
assert _decrypt_bytes(encrypted, key=key_24_same_prefix) == "data"
|
||||
|
||||
|
||||
class TestWrapperFunctions:
|
||||
"""Test encrypt_string_to_bytes / decrypt_bytes_to_string pass key through.
|
||||
|
||||
With EE mode enabled, the wrappers resolve to EE implementations automatically.
|
||||
"""
|
||||
|
||||
def test_wrapper_passes_key(self) -> None:
|
||||
encrypted = encrypt_string_to_bytes("payload", key=KEY_16)
|
||||
assert decrypt_bytes_to_string(encrypted, key=KEY_16) == "payload"
|
||||
|
||||
def test_wrapper_no_key_uses_env(self) -> None:
|
||||
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", KEY_32):
|
||||
encrypted = encrypt_string_to_bytes("payload")
|
||||
assert decrypt_bytes_to_string(encrypted) == "payload"
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.chat.llm_loop import _should_keep_bedrock_tool_definitions
|
||||
from onyx.chat.llm_loop import _try_fallback_tool_extraction
|
||||
from onyx.chat.llm_loop import construct_message_history
|
||||
from onyx.chat.models import ChatLoadedFile
|
||||
@@ -14,22 +13,11 @@ from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
|
||||
|
||||
class _StubConfig:
|
||||
def __init__(self, model_provider: str) -> None:
|
||||
self.model_provider = model_provider
|
||||
|
||||
|
||||
class _StubLLM:
|
||||
def __init__(self, model_provider: str) -> None:
|
||||
self.config = _StubConfig(model_provider=model_provider)
|
||||
|
||||
|
||||
def create_message(
|
||||
content: str, message_type: MessageType, token_count: int | None = None
|
||||
) -> ChatMessageSimple:
|
||||
@@ -946,37 +934,6 @@ class TestForgottenFileMetadata:
|
||||
assert "moby_dick.txt" in forgotten.message
|
||||
|
||||
|
||||
class TestBedrockToolConfigGuard:
|
||||
def test_bedrock_with_tool_history_keeps_tool_definitions(self) -> None:
|
||||
llm = _StubLLM(LlmProviderNames.BEDROCK)
|
||||
history = [
|
||||
create_message("Question", MessageType.USER, 5),
|
||||
create_assistant_with_tool_call("tc_1", "search", 5),
|
||||
create_tool_response("tc_1", "Tool output", 5),
|
||||
]
|
||||
|
||||
assert _should_keep_bedrock_tool_definitions(llm, history) is True
|
||||
|
||||
def test_bedrock_without_tool_history_does_not_keep_tool_definitions(self) -> None:
|
||||
llm = _StubLLM(LlmProviderNames.BEDROCK)
|
||||
history = [
|
||||
create_message("Question", MessageType.USER, 5),
|
||||
create_message("Answer", MessageType.ASSISTANT, 5),
|
||||
]
|
||||
|
||||
assert _should_keep_bedrock_tool_definitions(llm, history) is False
|
||||
|
||||
def test_non_bedrock_with_tool_history_does_not_keep_tool_definitions(self) -> None:
|
||||
llm = _StubLLM(LlmProviderNames.OPENAI)
|
||||
history = [
|
||||
create_message("Question", MessageType.USER, 5),
|
||||
create_assistant_with_tool_call("tc_1", "search", 5),
|
||||
create_tool_response("tc_1", "Tool output", 5),
|
||||
]
|
||||
|
||||
assert _should_keep_bedrock_tool_definitions(llm, history) is False
|
||||
|
||||
|
||||
class TestFallbackToolExtraction:
|
||||
def _tool_defs(self) -> list[dict]:
|
||||
return [
|
||||
|
||||
@@ -8,7 +8,6 @@ from onyx.chat.llm_step import _extract_tool_call_kickoffs
|
||||
from onyx.chat.llm_step import _increment_turns
|
||||
from onyx.chat.llm_step import _parse_tool_args_to_dict
|
||||
from onyx.chat.llm_step import _resolve_tool_arguments
|
||||
from onyx.chat.llm_step import _sanitize_llm_output
|
||||
from onyx.chat.llm_step import _XmlToolCallContentFilter
|
||||
from onyx.chat.llm_step import extract_tool_calls_from_response_text
|
||||
from onyx.chat.llm_step import translate_history_to_llm_format
|
||||
@@ -21,48 +20,49 @@ from onyx.llm.models import AssistantMessage
|
||||
from onyx.llm.models import ToolMessage
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.utils.postgres_sanitization import sanitize_string
|
||||
|
||||
|
||||
class TestSanitizeLlmOutput:
|
||||
"""Tests for the _sanitize_llm_output function."""
|
||||
"""Tests for the sanitize_string function."""
|
||||
|
||||
def test_removes_null_bytes(self) -> None:
|
||||
"""Test that NULL bytes are removed from strings."""
|
||||
assert _sanitize_llm_output("hello\x00world") == "helloworld"
|
||||
assert _sanitize_llm_output("\x00start") == "start"
|
||||
assert _sanitize_llm_output("end\x00") == "end"
|
||||
assert _sanitize_llm_output("\x00\x00\x00") == ""
|
||||
assert sanitize_string("hello\x00world") == "helloworld"
|
||||
assert sanitize_string("\x00start") == "start"
|
||||
assert sanitize_string("end\x00") == "end"
|
||||
assert sanitize_string("\x00\x00\x00") == ""
|
||||
|
||||
def test_removes_surrogates(self) -> None:
|
||||
"""Test that UTF-16 surrogates are removed from strings."""
|
||||
# Low surrogate
|
||||
assert _sanitize_llm_output("hello\ud800world") == "helloworld"
|
||||
assert sanitize_string("hello\ud800world") == "helloworld"
|
||||
# High surrogate
|
||||
assert _sanitize_llm_output("hello\udfffworld") == "helloworld"
|
||||
assert sanitize_string("hello\udfffworld") == "helloworld"
|
||||
# Middle of surrogate range
|
||||
assert _sanitize_llm_output("test\uda00value") == "testvalue"
|
||||
assert sanitize_string("test\uda00value") == "testvalue"
|
||||
|
||||
def test_removes_mixed_bad_characters(self) -> None:
|
||||
"""Test removal of both NULL bytes and surrogates together."""
|
||||
assert _sanitize_llm_output("a\x00b\ud800c\udfffd") == "abcd"
|
||||
assert sanitize_string("a\x00b\ud800c\udfffd") == "abcd"
|
||||
|
||||
def test_preserves_valid_unicode(self) -> None:
|
||||
"""Test that valid Unicode characters are preserved."""
|
||||
# Emojis
|
||||
assert _sanitize_llm_output("hello 👋 world") == "hello 👋 world"
|
||||
assert sanitize_string("hello 👋 world") == "hello 👋 world"
|
||||
# Chinese characters
|
||||
assert _sanitize_llm_output("你好世界") == "你好世界"
|
||||
assert sanitize_string("你好世界") == "你好世界"
|
||||
# Mixed scripts
|
||||
assert _sanitize_llm_output("Hello мир 世界") == "Hello мир 世界"
|
||||
assert sanitize_string("Hello мир 世界") == "Hello мир 世界"
|
||||
|
||||
def test_empty_string(self) -> None:
|
||||
"""Test that empty strings are handled correctly."""
|
||||
assert _sanitize_llm_output("") == ""
|
||||
assert sanitize_string("") == ""
|
||||
|
||||
def test_normal_ascii(self) -> None:
|
||||
"""Test that normal ASCII strings pass through unchanged."""
|
||||
assert _sanitize_llm_output("hello world") == "hello world"
|
||||
assert _sanitize_llm_output('{"key": "value"}') == '{"key": "value"}'
|
||||
assert sanitize_string("hello world") == "hello world"
|
||||
assert sanitize_string('{"key": "value"}') == '{"key": "value"}'
|
||||
|
||||
|
||||
class TestParseToolArgsToDict:
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
"""Tests for _extract_referenced_file_descriptors in save_chat.py.
|
||||
"""Tests for save_chat.py.
|
||||
|
||||
Verifies that only code interpreter generated files actually referenced
|
||||
in the assistant's message text are extracted as FileDescriptors for
|
||||
cross-turn persistence.
|
||||
Covers _extract_referenced_file_descriptors and sanitization in save_chat_turn.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from onyx.chat import save_chat
|
||||
from onyx.chat.save_chat import _extract_referenced_file_descriptors
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.tools.models import PythonExecutionFile
|
||||
@@ -29,6 +32,9 @@ def _make_tool_call_info(
|
||||
)
|
||||
|
||||
|
||||
# ---- _extract_referenced_file_descriptors tests ----
|
||||
|
||||
|
||||
def test_returns_empty_when_no_generated_files() -> None:
|
||||
tool_call = _make_tool_call_info(generated_files=None)
|
||||
result = _extract_referenced_file_descriptors([tool_call], "some message")
|
||||
@@ -176,3 +182,34 @@ def test_skips_tool_calls_without_generated_files() -> None:
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["id"] == file_id
|
||||
|
||||
|
||||
# ---- save_chat_turn sanitization test ----
|
||||
|
||||
|
||||
def test_save_chat_turn_sanitizes_message_and_reasoning(
|
||||
monkeypatch: MonkeyPatch,
|
||||
) -> None:
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_tokenizer.encode.return_value = [1, 2, 3]
|
||||
monkeypatch.setattr(save_chat, "get_tokenizer", lambda *_a, **_kw: mock_tokenizer)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.id = 1
|
||||
mock_msg.chat_session_id = "test"
|
||||
mock_msg.files = None
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
save_chat.save_chat_turn(
|
||||
message_text="hello\x00world\ud800",
|
||||
reasoning_tokens="think\x00ing\udfff",
|
||||
tool_calls=[],
|
||||
citation_to_doc={},
|
||||
all_search_docs={},
|
||||
db_session=mock_session,
|
||||
assistant_message=mock_msg,
|
||||
)
|
||||
|
||||
assert mock_msg.message == "helloworld"
|
||||
assert mock_msg.reasoning_tokens == "thinking"
|
||||
@@ -9,6 +9,8 @@ from onyx.connectors.jira.utils import JIRA_SERVER_API_VERSION
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.utils.sensitive import make_mock_sensitive_value
|
||||
|
||||
pytestmark = pytest.mark.usefixtures("enable_ee")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jira_cc_pair(
|
||||
|
||||
27
backend/tests/unit/onyx/db/test_tools.py
Normal file
27
backend/tests/unit/onyx/db/test_tools.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.db import tools as tools_mod
|
||||
|
||||
|
||||
def test_create_tool_call_no_commit_sanitizes_fields() -> None:
|
||||
mock_session = MagicMock()
|
||||
|
||||
tool_call = tools_mod.create_tool_call_no_commit(
|
||||
chat_session_id=uuid4(),
|
||||
parent_chat_message_id=1,
|
||||
turn_number=0,
|
||||
tool_id=1,
|
||||
tool_call_id="tc-1",
|
||||
tool_call_arguments={"task\x00": "research\ud800 topic"},
|
||||
tool_call_response="report\x00 text\udfff here",
|
||||
tool_call_tokens=10,
|
||||
db_session=mock_session,
|
||||
reasoning_tokens="reason\x00ing\ud800",
|
||||
generated_images=[{"url": "img\x00.png\udfff"}],
|
||||
)
|
||||
|
||||
assert tool_call.tool_call_response == "report text here"
|
||||
assert tool_call.reasoning_tokens == "reasoning"
|
||||
assert tool_call.tool_call_arguments == {"task": "research topic"}
|
||||
assert tool_call.generated_images == [{"url": "img.png"}]
|
||||
@@ -1214,3 +1214,218 @@ def test_multithreaded_invoke_without_custom_config_skips_env_lock() -> None:
|
||||
|
||||
# The env lock context manager should never have been called
|
||||
mock_env_lock.assert_not_called()
|
||||
|
||||
|
||||
# ---- Tests for Bedrock tool content stripping ----
|
||||
|
||||
|
||||
def test_messages_contain_tool_content_with_tool_role() -> None:
|
||||
from onyx.llm.multi_llm import _messages_contain_tool_content
|
||||
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "I'll search for that."},
|
||||
{"role": "tool", "content": "search results", "tool_call_id": "tc_1"},
|
||||
]
|
||||
assert _messages_contain_tool_content(messages) is True
|
||||
|
||||
|
||||
def test_messages_contain_tool_content_with_tool_calls() -> None:
|
||||
from onyx.llm.multi_llm import _messages_contain_tool_content
|
||||
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "tc_1",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
assert _messages_contain_tool_content(messages) is True
|
||||
|
||||
|
||||
def test_messages_contain_tool_content_without_tools() -> None:
|
||||
from onyx.llm.multi_llm import _messages_contain_tool_content
|
||||
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
assert _messages_contain_tool_content(messages) is False
|
||||
|
||||
|
||||
def test_strip_tool_content_converts_assistant_tool_calls_to_text() -> None:
|
||||
from onyx.llm.multi_llm import _strip_tool_content_from_messages
|
||||
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "Search for cats"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Let me search.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "tc_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search",
|
||||
"arguments": '{"query": "cats"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "Found 3 results about cats.",
|
||||
"tool_call_id": "tc_1",
|
||||
},
|
||||
{"role": "assistant", "content": "Here are the results."},
|
||||
]
|
||||
|
||||
result = _strip_tool_content_from_messages(messages)
|
||||
|
||||
assert len(result) == 4
|
||||
|
||||
# First message unchanged
|
||||
assert result[0] == {"role": "user", "content": "Search for cats"}
|
||||
|
||||
# Assistant with tool calls → plain text
|
||||
assert result[1]["role"] == "assistant"
|
||||
assert "tool_calls" not in result[1]
|
||||
assert "Let me search." in result[1]["content"]
|
||||
assert "[Tool Call]" in result[1]["content"]
|
||||
assert "search" in result[1]["content"]
|
||||
assert "tc_1" in result[1]["content"]
|
||||
|
||||
# Tool response → user message
|
||||
assert result[2]["role"] == "user"
|
||||
assert "[Tool Result]" in result[2]["content"]
|
||||
assert "tc_1" in result[2]["content"]
|
||||
assert "Found 3 results about cats." in result[2]["content"]
|
||||
|
||||
# Final assistant message unchanged
|
||||
assert result[3] == {"role": "assistant", "content": "Here are the results."}
|
||||
|
||||
|
||||
def test_strip_tool_content_handles_assistant_with_no_text_content() -> None:
|
||||
from onyx.llm.multi_llm import _strip_tool_content_from_messages
|
||||
|
||||
messages: list[dict[str, Any]] = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "tc_1",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
result = _strip_tool_content_from_messages(messages)
|
||||
assert result[0]["role"] == "assistant"
|
||||
assert "[Tool Call]" in result[0]["content"]
|
||||
assert "tool_calls" not in result[0]
|
||||
|
||||
|
||||
def test_strip_tool_content_passes_through_non_tool_messages() -> None:
|
||||
from onyx.llm.multi_llm import _strip_tool_content_from_messages
|
||||
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi!"},
|
||||
]
|
||||
|
||||
result = _strip_tool_content_from_messages(messages)
|
||||
assert result == messages
|
||||
|
||||
|
||||
def test_strip_tool_content_handles_list_content_blocks() -> None:
|
||||
from onyx.llm.multi_llm import _strip_tool_content_from_messages
|
||||
|
||||
messages: list[dict[str, Any]] = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "Searching now."}],
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "tc_1",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": [
|
||||
{"type": "text", "text": "result A"},
|
||||
{"type": "text", "text": "result B"},
|
||||
],
|
||||
"tool_call_id": "tc_1",
|
||||
},
|
||||
]
|
||||
|
||||
result = _strip_tool_content_from_messages(messages)
|
||||
|
||||
# Assistant: list content flattened + tool call appended
|
||||
assert result[0]["role"] == "assistant"
|
||||
assert "Searching now." in result[0]["content"]
|
||||
assert "[Tool Call]" in result[0]["content"]
|
||||
assert isinstance(result[0]["content"], str)
|
||||
|
||||
# Tool: list content flattened into user message
|
||||
assert result[1]["role"] == "user"
|
||||
assert "result A" in result[1]["content"]
|
||||
assert "result B" in result[1]["content"]
|
||||
assert isinstance(result[1]["content"], str)
|
||||
|
||||
|
||||
def test_strip_tool_content_merges_consecutive_tool_results() -> None:
|
||||
"""Bedrock requires strict user/assistant alternation. Multiple parallel
|
||||
tool results must be merged into a single user message."""
|
||||
from onyx.llm.multi_llm import _strip_tool_content_from_messages
|
||||
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "weather and news?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "tc_1",
|
||||
"type": "function",
|
||||
"function": {"name": "search_weather", "arguments": "{}"},
|
||||
},
|
||||
{
|
||||
"id": "tc_2",
|
||||
"type": "function",
|
||||
"function": {"name": "search_news", "arguments": "{}"},
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": "sunny 72F", "tool_call_id": "tc_1"},
|
||||
{"role": "tool", "content": "headline news", "tool_call_id": "tc_2"},
|
||||
{"role": "assistant", "content": "Here are the results."},
|
||||
]
|
||||
|
||||
result = _strip_tool_content_from_messages(messages)
|
||||
|
||||
# user, assistant (flattened), user (merged tool results), assistant
|
||||
assert len(result) == 4
|
||||
roles = [m["role"] for m in result]
|
||||
assert roles == ["user", "assistant", "user", "assistant"]
|
||||
|
||||
# Both tool results merged into one user message
|
||||
merged = result[2]["content"]
|
||||
assert "tc_1" in merged
|
||||
assert "sunny 72F" in merged
|
||||
assert "tc_2" in merged
|
||||
assert "headline news" in merged
|
||||
|
||||
@@ -10,6 +10,8 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.server.manage.llm.models import LMStudioFinalModelResponse
|
||||
from onyx.server.manage.llm.models import LMStudioModelsRequest
|
||||
from onyx.server.manage.llm.models import OllamaFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OllamaModelsRequest
|
||||
from onyx.server.manage.llm.models import OpenRouterFinalModelResponse
|
||||
@@ -317,3 +319,298 @@ class TestGetOpenRouterAvailableModels:
|
||||
# No DB operations should happen
|
||||
mock_session.execute.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
|
||||
class TestGetLMStudioAvailableModels:
|
||||
"""Tests for the LM Studio model fetch endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_lm_studio_response(self) -> dict:
|
||||
"""Mock response from LM Studio /api/v1/models endpoint."""
|
||||
return {
|
||||
"models": [
|
||||
{
|
||||
"key": "lmstudio-community/Meta-Llama-3-8B",
|
||||
"type": "llm",
|
||||
"display_name": "Meta Llama 3 8B",
|
||||
"max_context_length": 8192,
|
||||
"capabilities": {"vision": False},
|
||||
},
|
||||
{
|
||||
"key": "lmstudio-community/Qwen2.5-VL-7B",
|
||||
"type": "llm",
|
||||
"display_name": "Qwen 2.5 VL 7B",
|
||||
"max_context_length": 32768,
|
||||
"capabilities": {"vision": True},
|
||||
},
|
||||
{
|
||||
"key": "text-embedding-nomic-embed-text-v1.5",
|
||||
"type": "embedding",
|
||||
"display_name": "Nomic Embed Text v1.5",
|
||||
"max_context_length": 2048,
|
||||
"capabilities": {},
|
||||
},
|
||||
{
|
||||
"key": "lmstudio-community/DeepSeek-R1-8B",
|
||||
"type": "llm",
|
||||
"display_name": "DeepSeek R1 8B",
|
||||
"max_context_length": 65536,
|
||||
"capabilities": {"vision": False},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
def test_returns_model_list(self, mock_lm_studio_response: dict) -> None:
|
||||
"""Test that endpoint returns properly formatted LLM-only model list."""
|
||||
from onyx.server.manage.llm.api import get_lm_studio_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx") as mock_httpx:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_lm_studio_response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_httpx.get.return_value = mock_response
|
||||
|
||||
request = LMStudioModelsRequest(api_base="http://localhost:1234")
|
||||
results = get_lm_studio_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
# Only LLM-type models should be returned (embedding filtered out)
|
||||
assert len(results) == 3
|
||||
assert all(isinstance(r, LMStudioFinalModelResponse) for r in results)
|
||||
names = [r.name for r in results]
|
||||
assert "text-embedding-nomic-embed-text-v1.5" not in names
|
||||
# Results should be alphabetically sorted by model name
|
||||
assert names == sorted(names, key=str.lower)
|
||||
|
||||
def test_infers_vision_support(self, mock_lm_studio_response: dict) -> None:
|
||||
"""Test that vision support is correctly read from capabilities."""
|
||||
from onyx.server.manage.llm.api import get_lm_studio_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx") as mock_httpx:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_lm_studio_response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_httpx.get.return_value = mock_response
|
||||
|
||||
request = LMStudioModelsRequest(api_base="http://localhost:1234")
|
||||
results = get_lm_studio_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
qwen = next(r for r in results if "Qwen" in r.display_name)
|
||||
llama = next(r for r in results if "Llama" in r.display_name)
|
||||
|
||||
assert qwen.supports_image_input is True
|
||||
assert llama.supports_image_input is False
|
||||
|
||||
def test_infers_reasoning_from_model_name(self) -> None:
|
||||
"""Test that reasoning is inferred from model name when not in capabilities."""
|
||||
from onyx.server.manage.llm.api import get_lm_studio_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
response = {
|
||||
"models": [
|
||||
{
|
||||
"key": "lmstudio-community/DeepSeek-R1-8B",
|
||||
"type": "llm",
|
||||
"display_name": "DeepSeek R1 8B",
|
||||
"max_context_length": 65536,
|
||||
"capabilities": {},
|
||||
},
|
||||
{
|
||||
"key": "lmstudio-community/Meta-Llama-3-8B",
|
||||
"type": "llm",
|
||||
"display_name": "Meta Llama 3 8B",
|
||||
"max_context_length": 8192,
|
||||
"capabilities": {},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx") as mock_httpx:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_httpx.get.return_value = mock_response
|
||||
|
||||
request = LMStudioModelsRequest(api_base="http://localhost:1234")
|
||||
results = get_lm_studio_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
deepseek = next(r for r in results if "DeepSeek" in r.display_name)
|
||||
llama = next(r for r in results if "Llama" in r.display_name)
|
||||
|
||||
assert deepseek.supports_reasoning is True
|
||||
assert llama.supports_reasoning is False
|
||||
|
||||
def test_uses_display_name_from_api(self, mock_lm_studio_response: dict) -> None:
|
||||
"""Test that display_name from the API is used directly."""
|
||||
from onyx.server.manage.llm.api import get_lm_studio_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx") as mock_httpx:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_lm_studio_response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_httpx.get.return_value = mock_response
|
||||
|
||||
request = LMStudioModelsRequest(api_base="http://localhost:1234")
|
||||
results = get_lm_studio_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
llama = next(r for r in results if "Llama" in r.name)
|
||||
assert llama.display_name == "Meta Llama 3 8B"
|
||||
assert llama.max_input_tokens == 8192
|
||||
|
||||
def test_strips_trailing_v1_from_api_base(self) -> None:
|
||||
"""Test that /v1 suffix is stripped before building the native API URL."""
|
||||
from onyx.server.manage.llm.api import get_lm_studio_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
response = {
|
||||
"models": [
|
||||
{
|
||||
"key": "test-model",
|
||||
"type": "llm",
|
||||
"display_name": "Test",
|
||||
"max_context_length": 4096,
|
||||
"capabilities": {},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx") as mock_httpx:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_httpx.get.return_value = mock_response
|
||||
|
||||
request = LMStudioModelsRequest(api_base="http://localhost:1234/v1")
|
||||
get_lm_studio_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
# Should hit /api/v1/models, not /v1/api/v1/models
|
||||
mock_httpx.get.assert_called_once()
|
||||
called_url = mock_httpx.get.call_args[0][0]
|
||||
assert called_url == "http://localhost:1234/api/v1/models"
|
||||
|
||||
def test_falls_back_to_stored_api_key(self) -> None:
|
||||
"""Test that stored API key is used when api_key_changed is False."""
|
||||
from onyx.server.manage.llm.api import get_lm_studio_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.custom_config = {"LM_STUDIO_API_KEY": "stored-secret"}
|
||||
|
||||
response = {
|
||||
"models": [
|
||||
{
|
||||
"key": "test-model",
|
||||
"type": "llm",
|
||||
"display_name": "Test",
|
||||
"max_context_length": 4096,
|
||||
"capabilities": {},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
with (
|
||||
patch("onyx.server.manage.llm.api.httpx") as mock_httpx,
|
||||
patch(
|
||||
"onyx.server.manage.llm.api.fetch_existing_llm_provider",
|
||||
return_value=mock_provider,
|
||||
),
|
||||
):
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_httpx.get.return_value = mock_response
|
||||
|
||||
request = LMStudioModelsRequest(
|
||||
api_base="http://localhost:1234",
|
||||
api_key="masked-value",
|
||||
api_key_changed=False,
|
||||
provider_name="my-lm-studio",
|
||||
)
|
||||
get_lm_studio_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
headers = mock_httpx.get.call_args[1]["headers"]
|
||||
assert headers["Authorization"] == "Bearer stored-secret"
|
||||
|
||||
def test_uses_submitted_api_key_when_changed(self) -> None:
|
||||
"""Test that submitted API key is used when api_key_changed is True."""
|
||||
from onyx.server.manage.llm.api import get_lm_studio_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
response = {
|
||||
"models": [
|
||||
{
|
||||
"key": "test-model",
|
||||
"type": "llm",
|
||||
"display_name": "Test",
|
||||
"max_context_length": 4096,
|
||||
"capabilities": {},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx") as mock_httpx:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_httpx.get.return_value = mock_response
|
||||
|
||||
request = LMStudioModelsRequest(
|
||||
api_base="http://localhost:1234",
|
||||
api_key="new-secret",
|
||||
api_key_changed=True,
|
||||
provider_name="my-lm-studio",
|
||||
)
|
||||
get_lm_studio_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
headers = mock_httpx.get.call_args[1]["headers"]
|
||||
assert headers["Authorization"] == "Bearer new-secret"
|
||||
|
||||
def test_raises_on_empty_models(self) -> None:
|
||||
"""Test that an error is raised when no models are returned."""
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.manage.llm.api import get_lm_studio_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx") as mock_httpx:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"models": []}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_httpx.get.return_value = mock_response
|
||||
|
||||
request = LMStudioModelsRequest(api_base="http://localhost:1234")
|
||||
with pytest.raises(OnyxError):
|
||||
get_lm_studio_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
def test_raises_on_only_non_llm_models(self) -> None:
|
||||
"""Test that an error is raised when all models are non-LLM type."""
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.manage.llm.api import get_lm_studio_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
response = {
|
||||
"models": [
|
||||
{
|
||||
"key": "embedding-model",
|
||||
"type": "embedding",
|
||||
"display_name": "Embedding",
|
||||
"max_context_length": 2048,
|
||||
"capabilities": {},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx") as mock_httpx:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_httpx.get.return_value = mock_response
|
||||
|
||||
request = LMStudioModelsRequest(api_base="http://localhost:1234")
|
||||
with pytest.raises(OnyxError):
|
||||
get_lm_studio_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import unicodedata # used to verify NFC expansion test preconditions
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
@@ -158,3 +159,47 @@ def test_snippet_finding(test_data: TestSchema) -> None:
|
||||
f"end_idx mismatch: expected {test_data.expected_result.expected_end_idx}, "
|
||||
f"got {result.end_idx}"
|
||||
)
|
||||
|
||||
|
||||
# Characters confirmed to expand from 1 → 2 codepoints under NFC
|
||||
NFC_EXPANDING_CHARS = [
|
||||
("\u0958", "Devanagari letter qa"),
|
||||
("\u0959", "Devanagari letter khha"),
|
||||
("\u095a", "Devanagari letter ghha"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"char,description",
|
||||
NFC_EXPANDING_CHARS,
|
||||
)
|
||||
def test_nfc_expanding_char_snippet_match(char: str, description: str) -> None:
|
||||
"""Snippet matching should produce valid indices for content
|
||||
containing characters that expand under NFC normalization."""
|
||||
nfc = unicodedata.normalize("NFC", char)
|
||||
if len(nfc) <= 1:
|
||||
pytest.skip(f"{description} does not expand under NFC on this platform")
|
||||
|
||||
content = f"before {char} after"
|
||||
snippet = f"{char} after"
|
||||
|
||||
result = find_snippet_in_content(content, snippet)
|
||||
|
||||
assert result.snippet_located, f"[{description}] Snippet should be found in content"
|
||||
assert (
|
||||
0 <= result.start_idx < len(content)
|
||||
), f"[{description}] start_idx {result.start_idx} out of bounds"
|
||||
assert (
|
||||
0 <= result.end_idx < len(content)
|
||||
), f"[{description}] end_idx {result.end_idx} out of bounds"
|
||||
assert (
|
||||
result.start_idx <= result.end_idx
|
||||
), f"[{description}] start_idx {result.start_idx} > end_idx {result.end_idx}"
|
||||
|
||||
matched = content[result.start_idx : result.end_idx + 1]
|
||||
matched_nfc = unicodedata.normalize("NFC", matched)
|
||||
snippet_nfc = unicodedata.normalize("NFC", snippet)
|
||||
assert snippet_nfc in matched_nfc or matched_nfc in snippet_nfc, (
|
||||
f"[{description}] Matched span '{matched}' does not overlap "
|
||||
f"with expected snippet '{snippet}'"
|
||||
)
|
||||
394
backend/tests/unit/onyx/utils/test_json_river.py
Normal file
394
backend/tests/unit/onyx/utils/test_json_river.py
Normal file
@@ -0,0 +1,394 @@
|
||||
"""Tests for the jsonriver incremental JSON parser."""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.utils.jsonriver import JsonValue
|
||||
from onyx.utils.jsonriver import Parser
|
||||
|
||||
|
||||
def _all_deltas(chunks: list[str]) -> list[JsonValue]:
|
||||
"""Feed chunks one at a time and collect all emitted deltas."""
|
||||
parser = Parser()
|
||||
deltas: list[JsonValue] = []
|
||||
for chunk in chunks:
|
||||
deltas.extend(parser.feed(chunk))
|
||||
deltas.extend(parser.finish())
|
||||
return deltas
|
||||
|
||||
|
||||
class TestParseComplete:
|
||||
"""Parsing complete JSON in a single chunk."""
|
||||
|
||||
def test_simple_object(self) -> None:
|
||||
deltas = _all_deltas(['{"a": 1}'])
|
||||
assert any(r == {"a": 1.0} or r == {"a": 1} for r in deltas)
|
||||
|
||||
def test_simple_array(self) -> None:
|
||||
deltas = _all_deltas(["[1, 2, 3]"])
|
||||
assert any(isinstance(r, list) for r in deltas)
|
||||
|
||||
def test_simple_string(self) -> None:
|
||||
deltas = _all_deltas(['"hello"'])
|
||||
assert "hello" in deltas or any("hello" in str(r) for r in deltas)
|
||||
|
||||
def test_null(self) -> None:
|
||||
deltas = _all_deltas(["null"])
|
||||
assert None in deltas
|
||||
|
||||
def test_boolean_true(self) -> None:
|
||||
deltas = _all_deltas(["true"])
|
||||
assert True in deltas
|
||||
|
||||
def test_boolean_false(self) -> None:
|
||||
deltas = _all_deltas(["false"])
|
||||
assert any(r is False for r in deltas)
|
||||
|
||||
def test_number(self) -> None:
|
||||
deltas = _all_deltas(["42"])
|
||||
assert 42.0 in deltas
|
||||
|
||||
def test_negative_number(self) -> None:
|
||||
deltas = _all_deltas(["-3.14"])
|
||||
assert any(abs(r - (-3.14)) < 1e-10 for r in deltas if isinstance(r, float))
|
||||
|
||||
def test_empty_object(self) -> None:
|
||||
deltas = _all_deltas(["{}"])
|
||||
assert {} in deltas
|
||||
|
||||
def test_empty_array(self) -> None:
|
||||
deltas = _all_deltas(["[]"])
|
||||
assert [] in deltas
|
||||
|
||||
|
||||
class TestStreamingDeltas:
|
||||
"""Incremental feeding produces correct deltas."""
|
||||
|
||||
def test_object_string_value_streamed_char_by_char(self) -> None:
|
||||
chunks = list('{"code": "abc"}')
|
||||
deltas = _all_deltas(chunks)
|
||||
str_parts = []
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "code" in d:
|
||||
val = d["code"]
|
||||
if isinstance(val, str):
|
||||
str_parts.append(val)
|
||||
assert "".join(str_parts) == "abc"
|
||||
|
||||
def test_object_streamed_in_two_halves(self) -> None:
|
||||
deltas = _all_deltas(['{"name": "Al', 'ice"}'])
|
||||
str_parts = []
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "name" in d:
|
||||
val = d["name"]
|
||||
if isinstance(val, str):
|
||||
str_parts.append(val)
|
||||
assert "".join(str_parts) == "Alice"
|
||||
|
||||
def test_multiple_keys_streamed(self) -> None:
|
||||
deltas = _all_deltas(['{"a": "x', '", "b": "y"}'])
|
||||
a_parts: list[str] = []
|
||||
b_parts: list[str] = []
|
||||
for d in deltas:
|
||||
if isinstance(d, dict):
|
||||
if "a" in d and isinstance(d["a"], str):
|
||||
a_parts.append(d["a"])
|
||||
if "b" in d and isinstance(d["b"], str):
|
||||
b_parts.append(d["b"])
|
||||
assert "".join(a_parts) == "x"
|
||||
assert "".join(b_parts) == "y"
|
||||
|
||||
def test_deltas_only_contain_new_string_content(self) -> None:
|
||||
parser = Parser()
|
||||
d1 = parser.feed('{"msg": "hel')
|
||||
d2 = parser.feed('lo"}')
|
||||
parser.finish()
|
||||
|
||||
msg_parts = []
|
||||
for d in d1 + d2:
|
||||
if isinstance(d, dict) and "msg" in d:
|
||||
val = d["msg"]
|
||||
if isinstance(val, str):
|
||||
msg_parts.append(val)
|
||||
assert "".join(msg_parts) == "hello"
|
||||
|
||||
# Each delta should only contain new chars, not repeat previous ones
|
||||
if len(msg_parts) == 2:
|
||||
assert msg_parts[0] == "hel"
|
||||
assert msg_parts[1] == "lo"
|
||||
|
||||
|
||||
class TestEscapeSequences:
|
||||
"""JSON escape sequences are decoded correctly, even across chunk boundaries."""
|
||||
|
||||
def test_newline_escape(self) -> None:
|
||||
deltas = _all_deltas(['{"text": "line1\\nline2"}'])
|
||||
text_parts = []
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "text" in d and isinstance(d["text"], str):
|
||||
text_parts.append(d["text"])
|
||||
assert "".join(text_parts) == "line1\nline2"
|
||||
|
||||
def test_tab_escape(self) -> None:
|
||||
deltas = _all_deltas(['{"t": "a\\tb"}'])
|
||||
parts = []
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "t" in d and isinstance(d["t"], str):
|
||||
parts.append(d["t"])
|
||||
assert "".join(parts) == "a\tb"
|
||||
|
||||
def test_escaped_quote(self) -> None:
|
||||
deltas = _all_deltas(['{"q": "say \\"hi\\""}'])
|
||||
parts = []
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "q" in d and isinstance(d["q"], str):
|
||||
parts.append(d["q"])
|
||||
assert "".join(parts) == 'say "hi"'
|
||||
|
||||
def test_unicode_escape(self) -> None:
|
||||
deltas = _all_deltas(['{"u": "\\u0041\\u0042"}'])
|
||||
parts = []
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "u" in d and isinstance(d["u"], str):
|
||||
parts.append(d["u"])
|
||||
assert "".join(parts) == "AB"
|
||||
|
||||
def test_escape_split_across_chunks(self) -> None:
|
||||
deltas = _all_deltas(['{"x": "a\\', 'nb"}'])
|
||||
parts = []
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "x" in d and isinstance(d["x"], str):
|
||||
parts.append(d["x"])
|
||||
assert "".join(parts) == "a\nb"
|
||||
|
||||
def test_unicode_escape_split_across_chunks(self) -> None:
|
||||
deltas = _all_deltas(['{"u": "\\u00', '41"}'])
|
||||
parts = []
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "u" in d and isinstance(d["u"], str):
|
||||
parts.append(d["u"])
|
||||
assert "".join(parts) == "A"
|
||||
|
||||
def test_backslash_escape(self) -> None:
|
||||
deltas = _all_deltas(['{"p": "c:\\\\dir"}'])
|
||||
parts = []
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "p" in d and isinstance(d["p"], str):
|
||||
parts.append(d["p"])
|
||||
assert "".join(parts) == "c:\\dir"
|
||||
|
||||
|
||||
class TestNestedStructures:
|
||||
"""Nested objects and arrays produce correct deltas."""
|
||||
|
||||
def test_nested_object(self) -> None:
|
||||
deltas = _all_deltas(['{"outer": {"inner": "val"}}'])
|
||||
found = False
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "outer" in d:
|
||||
outer = d["outer"]
|
||||
if isinstance(outer, dict) and "inner" in outer:
|
||||
found = True
|
||||
assert found
|
||||
|
||||
def test_array_of_strings(self) -> None:
|
||||
deltas = _all_deltas(['["a', '", "b"]'])
|
||||
all_items: list[str] = []
|
||||
for d in deltas:
|
||||
if isinstance(d, list):
|
||||
for item in d:
|
||||
if isinstance(item, str):
|
||||
all_items.append(item)
|
||||
elif isinstance(d, str):
|
||||
all_items.append(d)
|
||||
joined = "".join(all_items)
|
||||
assert "a" in joined
|
||||
assert "b" in joined
|
||||
|
||||
def test_object_with_number_and_bool(self) -> None:
|
||||
deltas = _all_deltas(['{"count": 42, "active": true}'])
|
||||
has_count = False
|
||||
has_active = False
|
||||
for d in deltas:
|
||||
if isinstance(d, dict):
|
||||
if "count" in d and d["count"] == 42.0:
|
||||
has_count = True
|
||||
if "active" in d and d["active"] is True:
|
||||
has_active = True
|
||||
assert has_count
|
||||
assert has_active
|
||||
|
||||
def test_object_with_null_value(self) -> None:
|
||||
deltas = _all_deltas(['{"key": null}'])
|
||||
found = False
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "key" in d and d["key"] is None:
|
||||
found = True
|
||||
assert found
|
||||
|
||||
|
||||
class TestComputeDelta:
|
||||
"""Direct tests for the _compute_delta static method."""
|
||||
|
||||
def test_none_prev_returns_current(self) -> None:
|
||||
assert Parser._compute_delta(None, {"a": "b"}) == {"a": "b"}
|
||||
|
||||
def test_string_delta(self) -> None:
|
||||
assert Parser._compute_delta("hel", "hello") == "lo"
|
||||
|
||||
def test_string_no_change(self) -> None:
|
||||
assert Parser._compute_delta("same", "same") is None
|
||||
|
||||
def test_dict_new_key(self) -> None:
|
||||
assert Parser._compute_delta({"a": "x"}, {"a": "x", "b": "y"}) == {"b": "y"}
|
||||
|
||||
def test_dict_string_append(self) -> None:
|
||||
assert Parser._compute_delta({"code": "def"}, {"code": "def hello()"}) == {
|
||||
"code": " hello()"
|
||||
}
|
||||
|
||||
def test_dict_no_change(self) -> None:
|
||||
assert Parser._compute_delta({"a": 1}, {"a": 1}) is None
|
||||
|
||||
def test_list_new_items(self) -> None:
|
||||
assert Parser._compute_delta([1, 2], [1, 2, 3]) == [3]
|
||||
|
||||
def test_list_last_item_updated(self) -> None:
|
||||
assert Parser._compute_delta(["a"], ["ab"]) == ["ab"]
|
||||
|
||||
def test_list_no_change(self) -> None:
|
||||
assert Parser._compute_delta([1, 2], [1, 2]) is None
|
||||
|
||||
def test_primitive_change(self) -> None:
|
||||
assert Parser._compute_delta(1, 2) == 2
|
||||
|
||||
def test_primitive_no_change(self) -> None:
|
||||
assert Parser._compute_delta(42, 42) is None
|
||||
|
||||
|
||||
class TestParserLifecycle:
|
||||
"""Edge cases around parser state and lifecycle."""
|
||||
|
||||
def test_feed_after_finish_returns_empty(self) -> None:
|
||||
parser = Parser()
|
||||
parser.feed('{"a": 1}')
|
||||
parser.finish()
|
||||
assert parser.feed("more") == []
|
||||
|
||||
def test_empty_feed_returns_empty(self) -> None:
|
||||
parser = Parser()
|
||||
assert parser.feed("") == []
|
||||
|
||||
def test_whitespace_only_returns_empty(self) -> None:
|
||||
parser = Parser()
|
||||
assert parser.feed(" ") == []
|
||||
|
||||
def test_finish_with_trailing_whitespace(self) -> None:
|
||||
parser = Parser()
|
||||
# Trailing whitespace terminates the number, so feed() emits it
|
||||
deltas = parser.feed("42 ")
|
||||
assert 42.0 in deltas
|
||||
parser.finish() # Should not raise
|
||||
|
||||
def test_finish_with_trailing_content_raises(self) -> None:
|
||||
parser = Parser()
|
||||
# Feed a complete JSON value followed by non-whitespace in one chunk
|
||||
parser.feed('{"a": 1} extra')
|
||||
with pytest.raises(ValueError, match="Unexpected trailing"):
|
||||
parser.finish()
|
||||
|
||||
def test_finish_flushes_pending_number(self) -> None:
|
||||
parser = Parser()
|
||||
deltas = parser.feed("42")
|
||||
# Number has no terminator, so feed() can't emit it yet
|
||||
assert deltas == []
|
||||
final = parser.finish()
|
||||
assert 42.0 in final
|
||||
|
||||
|
||||
class TestToolCallSimulation:
|
||||
"""Simulate the LLM tool-call streaming use case."""
|
||||
|
||||
def test_python_tool_call_streaming(self) -> None:
|
||||
full_json = json.dumps({"code": "print('hello world')"})
|
||||
chunk_size = 5
|
||||
chunks = [
|
||||
full_json[i : i + chunk_size] for i in range(0, len(full_json), chunk_size)
|
||||
]
|
||||
|
||||
parser = Parser()
|
||||
code_parts: list[str] = []
|
||||
for chunk in chunks:
|
||||
for delta in parser.feed(chunk):
|
||||
if isinstance(delta, dict) and "code" in delta:
|
||||
val = delta["code"]
|
||||
if isinstance(val, str):
|
||||
code_parts.append(val)
|
||||
for delta in parser.finish():
|
||||
if isinstance(delta, dict) and "code" in delta:
|
||||
val = delta["code"]
|
||||
if isinstance(val, str):
|
||||
code_parts.append(val)
|
||||
assert "".join(code_parts) == "print('hello world')"
|
||||
|
||||
def test_multi_arg_tool_call(self) -> None:
|
||||
full = '{"query": "search term", "num_results": 5}'
|
||||
chunks = [full[:15], full[15:30], full[30:]]
|
||||
|
||||
parser = Parser()
|
||||
query_parts: list[str] = []
|
||||
has_num_results = False
|
||||
for chunk in chunks:
|
||||
for delta in parser.feed(chunk):
|
||||
if isinstance(delta, dict):
|
||||
if "query" in delta and isinstance(delta["query"], str):
|
||||
query_parts.append(delta["query"])
|
||||
if "num_results" in delta:
|
||||
has_num_results = True
|
||||
for delta in parser.finish():
|
||||
if isinstance(delta, dict):
|
||||
if "query" in delta and isinstance(delta["query"], str):
|
||||
query_parts.append(delta["query"])
|
||||
if "num_results" in delta:
|
||||
has_num_results = True
|
||||
assert "".join(query_parts) == "search term"
|
||||
assert has_num_results
|
||||
|
||||
def test_code_with_newlines_and_escapes(self) -> None:
|
||||
code = 'def greet(name):\n print(f"Hello, {name}!")\n return True'
|
||||
full = json.dumps({"code": code})
|
||||
chunk_size = 8
|
||||
chunks = [full[i : i + chunk_size] for i in range(0, len(full), chunk_size)]
|
||||
|
||||
parser = Parser()
|
||||
code_parts: list[str] = []
|
||||
for chunk in chunks:
|
||||
for delta in parser.feed(chunk):
|
||||
if isinstance(delta, dict) and "code" in delta:
|
||||
val = delta["code"]
|
||||
if isinstance(val, str):
|
||||
code_parts.append(val)
|
||||
for delta in parser.finish():
|
||||
if isinstance(delta, dict) and "code" in delta:
|
||||
val = delta["code"]
|
||||
if isinstance(val, str):
|
||||
code_parts.append(val)
|
||||
assert "".join(code_parts) == code
|
||||
|
||||
def test_single_char_streaming(self) -> None:
|
||||
full = '{"key": "value"}'
|
||||
parser = Parser()
|
||||
key_parts: list[str] = []
|
||||
for ch in full:
|
||||
for delta in parser.feed(ch):
|
||||
if isinstance(delta, dict) and "key" in delta:
|
||||
val = delta["key"]
|
||||
if isinstance(val, str):
|
||||
key_parts.append(val)
|
||||
for delta in parser.finish():
|
||||
if isinstance(delta, dict) and "key" in delta:
|
||||
val = delta["key"]
|
||||
if isinstance(val, str):
|
||||
key_parts.append(val)
|
||||
assert "".join(key_parts) == "value"
|
||||
@@ -9,8 +9,79 @@ from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.enums import HierarchyNodeType
|
||||
from onyx.indexing import indexing_pipeline
|
||||
from onyx.indexing.postgres_sanitization import sanitize_document_for_postgres
|
||||
from onyx.indexing.postgres_sanitization import sanitize_hierarchy_node_for_postgres
|
||||
from onyx.utils.postgres_sanitization import sanitize_document_for_postgres
|
||||
from onyx.utils.postgres_sanitization import sanitize_hierarchy_node_for_postgres
|
||||
from onyx.utils.postgres_sanitization import sanitize_json_like
|
||||
from onyx.utils.postgres_sanitization import sanitize_string
|
||||
|
||||
|
||||
# ---- sanitize_string tests ----
|
||||
|
||||
|
||||
def test_sanitize_string_strips_nul_bytes() -> None:
|
||||
assert sanitize_string("hello\x00world") == "helloworld"
|
||||
assert sanitize_string("\x00\x00\x00") == ""
|
||||
assert sanitize_string("clean") == "clean"
|
||||
|
||||
|
||||
def test_sanitize_string_strips_high_surrogates() -> None:
|
||||
assert sanitize_string("before\ud800after") == "beforeafter"
|
||||
assert sanitize_string("a\udbffb") == "ab"
|
||||
|
||||
|
||||
def test_sanitize_string_strips_low_surrogates() -> None:
|
||||
assert sanitize_string("before\udc00after") == "beforeafter"
|
||||
assert sanitize_string("a\udfffb") == "ab"
|
||||
|
||||
|
||||
def test_sanitize_string_strips_nul_and_surrogates_together() -> None:
|
||||
assert sanitize_string("he\x00llo\ud800 wo\udfffrld\x00") == "hello world"
|
||||
|
||||
|
||||
def test_sanitize_string_preserves_valid_unicode() -> None:
|
||||
assert sanitize_string("café ☕ 日本語 😀") == "café ☕ 日本語 😀"
|
||||
|
||||
|
||||
def test_sanitize_string_empty_input() -> None:
|
||||
assert sanitize_string("") == ""
|
||||
|
||||
|
||||
# ---- sanitize_json_like tests ----
|
||||
|
||||
|
||||
def test_sanitize_json_like_handles_plain_string() -> None:
|
||||
assert sanitize_json_like("he\x00llo\ud800") == "hello"
|
||||
|
||||
|
||||
def test_sanitize_json_like_handles_nested_dict() -> None:
|
||||
dirty = {
|
||||
"ke\x00y": "va\ud800lue",
|
||||
"nested": {"inne\x00r": "de\udfffep"},
|
||||
}
|
||||
assert sanitize_json_like(dirty) == {
|
||||
"key": "value",
|
||||
"nested": {"inner": "deep"},
|
||||
}
|
||||
|
||||
|
||||
def test_sanitize_json_like_handles_list_with_surrogates() -> None:
|
||||
dirty = ["a\x00", "b\ud800", {"c\udc00": "d\udfff"}]
|
||||
assert sanitize_json_like(dirty) == ["a", "b", {"c": "d"}]
|
||||
|
||||
|
||||
def test_sanitize_json_like_handles_tuple() -> None:
|
||||
dirty = ("a\x00", "b\ud800")
|
||||
assert sanitize_json_like(dirty) == ("a", "b")
|
||||
|
||||
|
||||
def test_sanitize_json_like_passes_through_non_strings() -> None:
|
||||
assert sanitize_json_like(42) == 42
|
||||
assert sanitize_json_like(3.14) == 3.14
|
||||
assert sanitize_json_like(True) is True
|
||||
assert sanitize_json_like(None) is None
|
||||
|
||||
|
||||
# ---- sanitize_document_for_postgres tests ----
|
||||
|
||||
|
||||
def test_sanitize_document_for_postgres_removes_nul_bytes() -> None:
|
||||
@@ -20,8 +20,6 @@ from onyx.document_index.vespa_constants import TENANT_ID
|
||||
from onyx.document_index.vespa_constants import USER_PROJECT
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
# Import the function under test
|
||||
|
||||
|
||||
class TestBuildVespaFilters:
|
||||
def test_empty_filters(self) -> None:
|
||||
@@ -179,11 +177,27 @@ class TestBuildVespaFilters:
|
||||
assert f"!({HIDDEN}=true) and " == result
|
||||
|
||||
def test_user_project_filter(self) -> None:
|
||||
"""Test user project filtering (replacement for user folder IDs)."""
|
||||
# Single project id
|
||||
"""Test user project filtering.
|
||||
|
||||
project_id alone does NOT trigger a knowledge scope restriction
|
||||
(an agent with no explicit knowledge should search everything).
|
||||
It only participates when explicit knowledge filters are present.
|
||||
"""
|
||||
# project_id alone → no restriction
|
||||
filters = IndexFilters(access_control_list=[], project_id=789)
|
||||
result = build_vespa_filters(filters)
|
||||
assert f'!({HIDDEN}=true) and ({USER_PROJECT} contains "789") and ' == result
|
||||
assert f"!({HIDDEN}=true) and " == result
|
||||
|
||||
# project_id with user_file_ids → both OR'd
|
||||
id1 = UUID("00000000-0000-0000-0000-000000000123")
|
||||
filters = IndexFilters(
|
||||
access_control_list=[], project_id=789, user_file_ids=[id1]
|
||||
)
|
||||
result = build_vespa_filters(filters)
|
||||
assert (
|
||||
f'!({HIDDEN}=true) and (({DOCUMENT_ID} contains "{str(id1)}") or ({USER_PROJECT} contains "789")) and '
|
||||
== result
|
||||
)
|
||||
|
||||
# No project id
|
||||
filters = IndexFilters(access_control_list=[], project_id=None)
|
||||
@@ -217,7 +231,11 @@ class TestBuildVespaFilters:
|
||||
)
|
||||
|
||||
def test_combined_filters(self) -> None:
|
||||
"""Test combining multiple filter types."""
|
||||
"""Test combining multiple filter types.
|
||||
|
||||
Knowledge-scope filters (document_set, user_file_ids, project_id,
|
||||
persona_id) are OR'd together, while all other filters are AND'd.
|
||||
"""
|
||||
id1 = UUID("00000000-0000-0000-0000-000000000123")
|
||||
filters = IndexFilters(
|
||||
access_control_list=["user1", "group1"],
|
||||
@@ -231,7 +249,6 @@ class TestBuildVespaFilters:
|
||||
|
||||
result = build_vespa_filters(filters)
|
||||
|
||||
# Build expected result piece by piece for readability
|
||||
expected = f"!({HIDDEN}=true) and "
|
||||
expected += (
|
||||
'(access_control_list contains "user1" or '
|
||||
@@ -239,9 +256,13 @@ class TestBuildVespaFilters:
|
||||
)
|
||||
expected += f'({SOURCE_TYPE} contains "web") and '
|
||||
expected += f'({METADATA_LIST} contains "color{INDEX_SEPARATOR}red") and '
|
||||
expected += f'({DOCUMENT_SETS} contains "set1") and '
|
||||
expected += f'({DOCUMENT_ID} contains "{str(id1)}") and '
|
||||
expected += f'({USER_PROJECT} contains "789") and '
|
||||
# Knowledge scope filters are OR'd together
|
||||
expected += (
|
||||
f'(({DOCUMENT_SETS} contains "set1")'
|
||||
f' or ({DOCUMENT_ID} contains "{str(id1)}")'
|
||||
f' or ({USER_PROJECT} contains "789")'
|
||||
f") and "
|
||||
)
|
||||
cutoff_secs = int(datetime(2023, 1, 1, tzinfo=timezone.utc).timestamp())
|
||||
expected += f"!({DOC_UPDATED_AT} < {cutoff_secs}) and "
|
||||
|
||||
@@ -251,6 +272,32 @@ class TestBuildVespaFilters:
|
||||
result_no_trailing = build_vespa_filters(filters, remove_trailing_and=True)
|
||||
assert expected[:-5] == result_no_trailing # Remove trailing " and "
|
||||
|
||||
def test_knowledge_scope_single_filter_not_wrapped(self) -> None:
|
||||
"""When only one knowledge-scope filter is present it should not
|
||||
be wrapped in an extra OR group."""
|
||||
filters = IndexFilters(access_control_list=[], document_set=["set1"])
|
||||
result = build_vespa_filters(filters)
|
||||
assert f'!({HIDDEN}=true) and ({DOCUMENT_SETS} contains "set1") and ' == result
|
||||
|
||||
def test_knowledge_scope_document_set_and_user_files_ored(self) -> None:
|
||||
"""Document set filter and user file IDs must be OR'd so that
|
||||
connector documents (in the set) and user files (with specific
|
||||
IDs) can both be found."""
|
||||
id1 = UUID("00000000-0000-0000-0000-000000000123")
|
||||
filters = IndexFilters(
|
||||
access_control_list=[],
|
||||
document_set=["engineering"],
|
||||
user_file_ids=[id1],
|
||||
)
|
||||
result = build_vespa_filters(filters)
|
||||
expected = (
|
||||
f"!({HIDDEN}=true) and "
|
||||
f'(({DOCUMENT_SETS} contains "engineering")'
|
||||
f' or ({DOCUMENT_ID} contains "{str(id1)}")'
|
||||
f") and "
|
||||
)
|
||||
assert expected == result
|
||||
|
||||
def test_empty_or_none_values(self) -> None:
|
||||
"""Test with empty or None values in filter lists."""
|
||||
# Empty strings in document set
|
||||
|
||||
3
cli/.gitignore
vendored
Normal file
3
cli/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
onyx-cli
|
||||
cli
|
||||
onyx.cli
|
||||
22
cli/Dockerfile
Normal file
22
cli/Dockerfile
Normal file
@@ -0,0 +1,22 @@
|
||||
FROM golang:1.26-alpine@sha256:2389ebfa5b7f43eeafbd6be0c3700cc46690ef842ad962f6c5bd6be49ed82039 AS builder
|
||||
|
||||
WORKDIR /app
|
||||
COPY ./ .
|
||||
|
||||
ARG TARGETARCH
|
||||
RUN CGO_ENABLED=0 GOOS=linux GOARCH=${TARGETARCH} go build -ldflags="-s -w" -o onyx-cli .
|
||||
RUN mkdir -p /home/onyx/.config
|
||||
|
||||
FROM scratch
|
||||
|
||||
COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/
|
||||
COPY --from=builder --chown=65534:65534 /home/onyx /home/onyx
|
||||
|
||||
COPY --from=builder /app/onyx-cli /onyx-cli
|
||||
|
||||
ENV HOME=/home/onyx
|
||||
ENV XDG_CONFIG_HOME=/home/onyx/.config
|
||||
|
||||
USER 65534:65534
|
||||
|
||||
ENTRYPOINT ["/onyx-cli"]
|
||||
161
cli/README.md
Normal file
161
cli/README.md
Normal file
@@ -0,0 +1,161 @@
|
||||
# Onyx CLI
|
||||
|
||||
[](https://github.com/onyx-dot-app/onyx/actions/workflows/release-cli.yml)
|
||||
[](https://pypi.org/project/onyx-cli/)
|
||||
|
||||
A terminal interface for chatting with your [Onyx](https://github.com/onyx-dot-app/onyx) agent. Built with Go using [Bubble Tea](https://github.com/charmbracelet/bubbletea) for the TUI framework.
|
||||
|
||||
## Installation
|
||||
|
||||
```shell
|
||||
pip install onyx-cli
|
||||
```
|
||||
|
||||
Or with uv:
|
||||
|
||||
```shell
|
||||
uv pip install onyx-cli
|
||||
```
|
||||
|
||||
## Setup
|
||||
|
||||
Run the interactive setup:
|
||||
|
||||
```shell
|
||||
onyx-cli configure
|
||||
```
|
||||
|
||||
This prompts for your Onyx server URL and API key, tests the connection, and saves config to `~/.config/onyx-cli/config.json`.
|
||||
|
||||
Environment variables override config file values:
|
||||
|
||||
| Variable | Required | Description |
|
||||
|----------|----------|-------------|
|
||||
| `ONYX_SERVER_URL` | No | Server base URL (default: `https://cloud.onyx.app`) |
|
||||
| `ONYX_API_KEY` | Yes | API key for authentication |
|
||||
| `ONYX_PERSONA_ID` | No | Default agent/persona ID |
|
||||
|
||||
## Usage
|
||||
|
||||
### Interactive chat (default)
|
||||
|
||||
```shell
|
||||
onyx-cli
|
||||
```
|
||||
|
||||
### One-shot question
|
||||
|
||||
```shell
|
||||
onyx-cli ask "What is our company's PTO policy?"
|
||||
onyx-cli ask --agent-id 5 "Summarize this topic"
|
||||
onyx-cli ask --json "Hello"
|
||||
```
|
||||
|
||||
| Flag | Description |
|
||||
|------|-------------|
|
||||
| `--agent-id <int>` | Agent ID to use (overrides default) |
|
||||
| `--json` | Output raw NDJSON events instead of plain text |
|
||||
|
||||
### List agents
|
||||
|
||||
```shell
|
||||
onyx-cli agents
|
||||
onyx-cli agents --json
|
||||
```
|
||||
|
||||
## Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `chat` | Launch the interactive chat TUI (default) |
|
||||
| `ask` | Ask a one-shot question (non-interactive) |
|
||||
| `agents` | List available agents |
|
||||
| `configure` | Configure server URL and API key |
|
||||
| `validate-config` | Validate configuration and test connection |
|
||||
|
||||
## Slash Commands (in TUI)
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/help` | Show help message |
|
||||
| `/clear` | Clear chat and start a new session |
|
||||
| `/agent` | List and switch agents |
|
||||
| `/attach <path>` | Attach a file to next message |
|
||||
| `/sessions` | List recent chat sessions |
|
||||
| `/configure` | Re-run connection setup |
|
||||
| `/connectors` | Open connectors in browser |
|
||||
| `/settings` | Open settings in browser |
|
||||
| `/quit` | Exit Onyx CLI |
|
||||
|
||||
## Keyboard Shortcuts
|
||||
|
||||
| Key | Action |
|
||||
|-----|--------|
|
||||
| `Enter` | Send message |
|
||||
| `Escape` | Cancel current generation |
|
||||
| `Ctrl+O` | Toggle source citations |
|
||||
| `Ctrl+D` | Quit (press twice) |
|
||||
| `Scroll` / `Shift+Up/Down` | Scroll chat history |
|
||||
| `Page Up` / `Page Down` | Scroll half page |
|
||||
|
||||
## Building from Source
|
||||
|
||||
Requires [Go 1.24+](https://go.dev/dl/).
|
||||
|
||||
```shell
|
||||
cd cli
|
||||
go build -o onyx-cli .
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
```shell
|
||||
# Run tests
|
||||
go test ./...
|
||||
|
||||
# Build
|
||||
go build -o onyx-cli .
|
||||
|
||||
# Lint
|
||||
staticcheck ./...
|
||||
```
|
||||
|
||||
## Publishing to PyPI
|
||||
|
||||
The CLI is distributed as a Python package via [PyPI](https://pypi.org/project/onyx-cli/). The build system uses [hatchling](https://hatch.pypa.io/) with [manygo](https://github.com/nicholasgasior/manygo) to cross-compile Go binaries into platform-specific wheels.
|
||||
|
||||
### CI release (recommended)
|
||||
|
||||
Tag a release and push — the `release-cli.yml` workflow builds wheels for all platforms and publishes to PyPI automatically:
|
||||
|
||||
```shell
|
||||
tag --prefix cli
|
||||
```
|
||||
|
||||
To do this manually:
|
||||
|
||||
```shell
|
||||
git tag cli/v0.1.0
|
||||
git push origin cli/v0.1.0
|
||||
```
|
||||
|
||||
The workflow builds wheels for: linux/amd64, linux/arm64, darwin/amd64, darwin/arm64, windows/amd64, windows/arm64.
|
||||
|
||||
### Manual release
|
||||
|
||||
Build a wheel locally with `uv`. Set `GOOS` and `GOARCH` to cross-compile for other platforms (Go handles this natively — no cross-compiler needed):
|
||||
|
||||
```shell
|
||||
# Build for current platform
|
||||
uv build --wheel
|
||||
|
||||
# Cross-compile for a different platform
|
||||
GOOS=linux GOARCH=amd64 uv build --wheel
|
||||
|
||||
# Upload to PyPI
|
||||
uv publish
|
||||
```
|
||||
|
||||
### Versioning
|
||||
|
||||
Versions are derived from git tags with the `cli/` prefix (e.g. `cli/v0.1.0`). The tag is parsed by `internal/_version.py` and injected into the Go binary via `-ldflags` at build time.
|
||||
63
cli/cmd/agents.go
Normal file
63
cli/cmd/agents.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"text/tabwriter"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newAgentsCmd() *cobra.Command {
|
||||
var agentsJSON bool
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "agents",
|
||||
Short: "List available agents",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cfg := config.Load()
|
||||
if !cfg.IsConfigured() {
|
||||
return fmt.Errorf("onyx CLI is not configured — run 'onyx-cli configure' first")
|
||||
}
|
||||
|
||||
client := api.NewClient(cfg)
|
||||
agents, err := client.ListAgents(cmd.Context())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list agents: %w", err)
|
||||
}
|
||||
|
||||
if agentsJSON {
|
||||
data, err := json.MarshalIndent(agents, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal agents: %w", err)
|
||||
}
|
||||
fmt.Println(string(data))
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(agents) == 0 {
|
||||
fmt.Println("No agents available.")
|
||||
return nil
|
||||
}
|
||||
|
||||
w := tabwriter.NewWriter(cmd.OutOrStdout(), 0, 4, 2, ' ', 0)
|
||||
_, _ = fmt.Fprintln(w, "ID\tNAME\tDESCRIPTION")
|
||||
for _, a := range agents {
|
||||
desc := a.Description
|
||||
if len(desc) > 60 {
|
||||
desc = desc[:57] + "..."
|
||||
}
|
||||
_, _ = fmt.Fprintf(w, "%d\t%s\t%s\n", a.ID, a.Name, desc)
|
||||
}
|
||||
_ = w.Flush()
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().BoolVar(&agentsJSON, "json", false, "Output agents as JSON")
|
||||
|
||||
return cmd
|
||||
}
|
||||
124
cli/cmd/ask.go
Normal file
124
cli/cmd/ask.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/models"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newAskCmd() *cobra.Command {
|
||||
var (
|
||||
askAgentID int
|
||||
askJSON bool
|
||||
)
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "ask [question]",
|
||||
Short: "Ask a one-shot question (non-interactive)",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cfg := config.Load()
|
||||
if !cfg.IsConfigured() {
|
||||
return fmt.Errorf("onyx CLI is not configured — run 'onyx-cli configure' first")
|
||||
}
|
||||
|
||||
question := args[0]
|
||||
agentID := cfg.DefaultAgentID
|
||||
if cmd.Flags().Changed("agent-id") {
|
||||
agentID = askAgentID
|
||||
}
|
||||
|
||||
ctx, stop := signal.NotifyContext(cmd.Context(), os.Interrupt, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
client := api.NewClient(cfg)
|
||||
parentID := -1
|
||||
ch := client.SendMessageStream(
|
||||
ctx,
|
||||
question,
|
||||
nil,
|
||||
agentID,
|
||||
&parentID,
|
||||
nil,
|
||||
)
|
||||
|
||||
var sessionID string
|
||||
var lastErr error
|
||||
gotStop := false
|
||||
for event := range ch {
|
||||
if e, ok := event.(models.SessionCreatedEvent); ok {
|
||||
sessionID = e.ChatSessionID
|
||||
}
|
||||
|
||||
if askJSON {
|
||||
wrapped := struct {
|
||||
Type string `json:"type"`
|
||||
Event models.StreamEvent `json:"event"`
|
||||
}{
|
||||
Type: event.EventType(),
|
||||
Event: event,
|
||||
}
|
||||
data, err := json.Marshal(wrapped)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshaling event: %w", err)
|
||||
}
|
||||
fmt.Println(string(data))
|
||||
if _, ok := event.(models.ErrorEvent); ok {
|
||||
lastErr = fmt.Errorf("%s", event.(models.ErrorEvent).Error)
|
||||
}
|
||||
if _, ok := event.(models.StopEvent); ok {
|
||||
gotStop = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
switch e := event.(type) {
|
||||
case models.MessageDeltaEvent:
|
||||
fmt.Print(e.Content)
|
||||
case models.ErrorEvent:
|
||||
return fmt.Errorf("%s", e.Error)
|
||||
case models.StopEvent:
|
||||
fmt.Println()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if ctx.Err() != nil {
|
||||
if sessionID != "" {
|
||||
client.StopChatSession(context.Background(), sessionID)
|
||||
}
|
||||
if !askJSON {
|
||||
fmt.Println()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
return lastErr
|
||||
}
|
||||
if !gotStop {
|
||||
if !askJSON {
|
||||
fmt.Println()
|
||||
}
|
||||
return fmt.Errorf("stream ended unexpectedly")
|
||||
}
|
||||
if !askJSON {
|
||||
fmt.Println()
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().IntVar(&askAgentID, "agent-id", 0, "Agent ID to use")
|
||||
cmd.Flags().BoolVar(&askJSON, "json", false, "Output raw JSON events")
|
||||
// Suppress cobra's default error/usage on RunE errors
|
||||
return cmd
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user