mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-12 11:12:40 +00:00
Compare commits
112 Commits
nikg/std-e
...
xlsx-parse
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
83558ae04c | ||
|
|
005009602c | ||
|
|
b93875353b | ||
|
|
2290141b53 | ||
|
|
19b33e4d93 | ||
|
|
e56fa57c21 | ||
|
|
5cdeb84164 | ||
|
|
5b5100a07a | ||
|
|
77f58fbad5 | ||
|
|
cf74afc65e | ||
|
|
a887bc616c | ||
|
|
fef1fd093e | ||
|
|
8d085a4ccf | ||
|
|
28310b9138 | ||
|
|
f71fab580c | ||
|
|
89593b353f | ||
|
|
91e24ae63a | ||
|
|
d2b37724d1 | ||
|
|
87f0849330 | ||
|
|
2ec7526772 | ||
|
|
bbd68e2795 | ||
|
|
e74c36001a | ||
|
|
fe593a15da | ||
|
|
27df690a8d | ||
|
|
edbe569edd | ||
|
|
5118193d16 | ||
|
|
63d3efd380 | ||
|
|
ec978d9a3f | ||
|
|
d4d98a6cd0 | ||
|
|
dc40e86dac | ||
|
|
e495f7a13e | ||
|
|
4761e4b132 | ||
|
|
6b5ab54b85 | ||
|
|
959cf444f8 | ||
|
|
2ebccea6d6 | ||
|
|
5fe7a474db | ||
|
|
9d7dc3da21 | ||
|
|
2899be4c5e | ||
|
|
64ee7fc23f | ||
|
|
e07764285d | ||
|
|
cc2e6ffa8a | ||
|
|
d3ee5c9b59 | ||
|
|
dfa0efc093 | ||
|
|
9aad4077f1 | ||
|
|
29d9ebf7b3 | ||
|
|
f1df36e306 | ||
|
|
1611604269 | ||
|
|
c2a71091dc | ||
|
|
cc008699e5 | ||
|
|
48802618db | ||
|
|
6917953b86 | ||
|
|
e7cf027f8a | ||
|
|
41fb1480bb | ||
|
|
bdc2bfdcee | ||
|
|
8816d52b27 | ||
|
|
6590f1d7ba | ||
|
|
c527f75557 | ||
|
|
472d1788a7 | ||
|
|
99e95f8205 | ||
|
|
e618bf8385 | ||
|
|
f4dcd130ba | ||
|
|
910718deaa | ||
|
|
1a7ca93b93 | ||
|
|
a615a920cb | ||
|
|
29d8b310b5 | ||
|
|
d1409ccafa | ||
|
|
e41bad9103 | ||
|
|
661dc831dc | ||
|
|
19016dd35a | ||
|
|
127b2dcc80 | ||
|
|
b015a37cea | ||
|
|
b45277a8b0 | ||
|
|
893e8da79a | ||
|
|
a51f0d7cb2 | ||
|
|
c826d0469e | ||
|
|
0f6ae6f69c | ||
|
|
d0836e2603 | ||
|
|
bda03bafca | ||
|
|
376adff94a | ||
|
|
d2d4b89286 | ||
|
|
dde7a18bb7 | ||
|
|
3f004cf02f | ||
|
|
ae893079c3 | ||
|
|
189c07a913 | ||
|
|
2b82743bf5 | ||
|
|
ba2a5a60e1 | ||
|
|
5888f9d69f | ||
|
|
23b3a0a6ae | ||
|
|
eced88fa7a | ||
|
|
f59aaa902d | ||
|
|
57349bdbd1 | ||
|
|
192639a801 | ||
|
|
c10ffbb464 | ||
|
|
091f41fd1f | ||
|
|
45d77be4eb | ||
|
|
413fa85134 | ||
|
|
108cde4f55 | ||
|
|
f88ce32bd4 | ||
|
|
911f3439ea | ||
|
|
b02590d2b2 | ||
|
|
2d75b4b1f8 | ||
|
|
7e3f7d01c2 | ||
|
|
9d6ce26ea3 | ||
|
|
41713d42a2 | ||
|
|
8afc283410 | ||
|
|
b5c873077e | ||
|
|
20a4dd32eb | ||
|
|
fde0d44bc1 | ||
|
|
8fd91b6e83 | ||
|
|
8247fdd45b | ||
|
|
8c5859ba4d | ||
|
|
62ef6f59bb |
186
.cursor/skills/onyx-cli/SKILL.md
Normal file
186
.cursor/skills/onyx-cli/SKILL.md
Normal file
@@ -0,0 +1,186 @@
|
||||
---
|
||||
name: onyx-cli
|
||||
description: Query the Onyx knowledge base using the onyx-cli command. Use when the user wants to search company documents, ask questions about internal knowledge, query connected data sources, or look up information stored in Onyx.
|
||||
---
|
||||
|
||||
# Onyx CLI — Agent Tool
|
||||
|
||||
Onyx is an enterprise search and Gen-AI platform that connects to company documents, apps, and people. The `onyx-cli` CLI provides non-interactive commands to query the Onyx knowledge base and list available agents.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### 1. Check if installed
|
||||
|
||||
```bash
|
||||
which onyx-cli
|
||||
```
|
||||
|
||||
### 2. Install (if needed)
|
||||
|
||||
**Primary — pip:**
|
||||
|
||||
```bash
|
||||
pip install onyx-cli
|
||||
```
|
||||
|
||||
**From source (Go):**
|
||||
|
||||
```bash
|
||||
cd cli && go build -o onyx-cli . && sudo mv onyx-cli /usr/local/bin/
|
||||
```
|
||||
|
||||
### 3. Check if configured
|
||||
|
||||
```bash
|
||||
onyx-cli validate-config
|
||||
```
|
||||
|
||||
This checks the config file exists, API key is present, and tests the server connection via `/api/me`. Exit code 0 on success, non-zero with a descriptive error on failure.
|
||||
|
||||
If unconfigured, you have two options:
|
||||
|
||||
**Option A — Interactive setup (requires user input):**
|
||||
|
||||
```bash
|
||||
onyx-cli configure
|
||||
```
|
||||
|
||||
This prompts for the Onyx server URL and API key, tests the connection, and saves config.
|
||||
|
||||
**Option B — Environment variables (non-interactive, preferred for agents):**
|
||||
|
||||
```bash
|
||||
export ONYX_SERVER_URL="https://your-onyx-server.com" # default: https://cloud.onyx.app
|
||||
export ONYX_API_KEY="your-api-key"
|
||||
```
|
||||
|
||||
Environment variables override the config file. If these are set, no config file is needed.
|
||||
|
||||
| Variable | Required | Description |
|
||||
|----------|----------|-------------|
|
||||
| `ONYX_SERVER_URL` | No | Onyx server base URL (default: `https://cloud.onyx.app`) |
|
||||
| `ONYX_API_KEY` | Yes | API key for authentication |
|
||||
| `ONYX_PERSONA_ID` | No | Default agent/persona ID |
|
||||
|
||||
If neither the config file nor environment variables are set, tell the user that `onyx-cli` needs to be configured and ask them to either:
|
||||
- Run `onyx-cli configure` interactively, or
|
||||
- Set `ONYX_SERVER_URL` and `ONYX_API_KEY` environment variables
|
||||
|
||||
## Commands
|
||||
|
||||
### Validate configuration
|
||||
|
||||
```bash
|
||||
onyx-cli validate-config
|
||||
```
|
||||
|
||||
Checks config file exists, API key is present, and tests the server connection. Use this before `ask` or `agents` to confirm the CLI is properly set up.
|
||||
|
||||
### List available agents
|
||||
|
||||
```bash
|
||||
onyx-cli agents
|
||||
```
|
||||
|
||||
Prints a table of agent IDs, names, and descriptions. Use `--json` for structured output:
|
||||
|
||||
```bash
|
||||
onyx-cli agents --json
|
||||
```
|
||||
|
||||
Use agent IDs with `ask --agent-id` to query a specific agent.
|
||||
|
||||
### Basic query (plain text output)
|
||||
|
||||
```bash
|
||||
onyx-cli ask "What is our company's PTO policy?"
|
||||
```
|
||||
|
||||
Streams the answer as plain text to stdout. Exit code 0 on success, non-zero on error.
|
||||
|
||||
### JSON output (structured events)
|
||||
|
||||
```bash
|
||||
onyx-cli ask --json "What authentication methods do we support?"
|
||||
```
|
||||
|
||||
Outputs JSON-encoded parsed stream events (one object per line). Key event objects include message deltas, stop, errors, search-start, and citation payloads.
|
||||
|
||||
Each line is a JSON object with this envelope:
|
||||
|
||||
```json
|
||||
{"type": "<event_type>", "event": { ... }}
|
||||
```
|
||||
|
||||
| Event Type | Description |
|
||||
|------------|-------------|
|
||||
| `message_delta` | Content token — concatenate all `content` fields for the full answer |
|
||||
| `stop` | Stream complete |
|
||||
| `error` | Error with `error` message field |
|
||||
| `search_tool_start` | Onyx started searching documents |
|
||||
| `citation_info` | Source citation — see shape below |
|
||||
|
||||
`citation_info` event shape:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "citation_info",
|
||||
"event": {
|
||||
"citation_number": 1,
|
||||
"document_id": "abc123def456",
|
||||
"placement": {"turn_index": 0, "tab_index": 0, "sub_turn_index": null}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`placement` is metadata about where in the conversation the citation appeared and can be ignored for most use cases.
|
||||
|
||||
### Specify an agent
|
||||
|
||||
```bash
|
||||
onyx-cli ask --agent-id 5 "Summarize our Q4 roadmap"
|
||||
```
|
||||
|
||||
Uses a specific Onyx agent/persona instead of the default.
|
||||
|
||||
### All flags
|
||||
|
||||
| Flag | Type | Description |
|
||||
|------|------|-------------|
|
||||
| `--agent-id` | int | Agent ID to use (overrides default) |
|
||||
| `--json` | bool | Output raw NDJSON events instead of plain text |
|
||||
|
||||
## Statelessness
|
||||
|
||||
Each `onyx-cli ask` call creates an independent chat session. There is no built-in way to chain context across multiple `ask` invocations — every call starts fresh. If you need multi-turn conversation with memory, use the interactive TUI (`onyx-cli` or `onyx-cli chat`) instead.
|
||||
|
||||
## When to Use
|
||||
|
||||
Use `onyx-cli ask` when:
|
||||
|
||||
- The user asks about company-specific information (policies, docs, processes)
|
||||
- You need to search internal knowledge bases or connected data sources
|
||||
- The user references Onyx, asks you to "search Onyx", or wants to query their documents
|
||||
- You need context from company wikis, Confluence, Google Drive, Slack, or other connected sources
|
||||
|
||||
Do NOT use when:
|
||||
|
||||
- The question is about general programming knowledge (use your own knowledge)
|
||||
- The user is asking about code in the current repository (use grep/read tools)
|
||||
- The user hasn't mentioned Onyx and the question doesn't require internal company data
|
||||
|
||||
## Examples
|
||||
|
||||
```bash
|
||||
# Simple question
|
||||
onyx-cli ask "What are the steps to deploy to production?"
|
||||
|
||||
# Get structured output for parsing
|
||||
onyx-cli ask --json "List all active API integrations"
|
||||
|
||||
# Use a specialized agent
|
||||
onyx-cli ask --agent-id 3 "What were the action items from last week's standup?"
|
||||
|
||||
# Pipe the answer into another command
|
||||
onyx-cli ask "What is the database schema for users?" | head -20
|
||||
```
|
||||
98
.github/workflows/deployment.yml
vendored
98
.github/workflows/deployment.yml
vendored
@@ -151,7 +151,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
version: "0.9.9"
|
||||
# NOTE: This isn't caching much and zizmor suggests this could be poisoned, so disable.
|
||||
@@ -182,9 +182,53 @@ jobs:
|
||||
title: "🚨 Version Tag Check Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
|
||||
build-desktop:
|
||||
# Create GitHub release first, before desktop builds start.
|
||||
# This ensures all desktop matrix jobs upload to the same release instead of
|
||||
# racing to create duplicate releases.
|
||||
create-release:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-desktop == 'true'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 10
|
||||
permissions:
|
||||
contents: write
|
||||
outputs:
|
||||
release-id: ${{ steps.create-release.outputs.id }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Determine release tag
|
||||
id: release-tag
|
||||
env:
|
||||
IS_TEST_RUN: ${{ needs.determine-builds.outputs.is-test-run }}
|
||||
SHORT_SHA: ${{ needs.determine-builds.outputs.short-sha }}
|
||||
run: |
|
||||
if [ "${IS_TEST_RUN}" == "true" ]; then
|
||||
echo "tag=v0.0.0-dev+${SHORT_SHA}" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "tag=${GITHUB_REF_NAME}" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: Create GitHub Release
|
||||
id: create-release
|
||||
uses: softprops/action-gh-release@da05d552573ad5aba039eaac05058a918a7bf631 # ratchet:softprops/action-gh-release@v2
|
||||
with:
|
||||
tag_name: ${{ steps.release-tag.outputs.tag }}
|
||||
name: ${{ steps.release-tag.outputs.tag }}
|
||||
body: "See the assets to download this version and install."
|
||||
draft: true
|
||||
prerelease: false
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
build-desktop:
|
||||
needs:
|
||||
- determine-builds
|
||||
- create-release
|
||||
if: needs.determine-builds.outputs.build-desktop == 'true'
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: write
|
||||
@@ -208,12 +252,12 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6.0.2
|
||||
with:
|
||||
# NOTE: persist-credentials is needed for tauri-action to create GitHub releases.
|
||||
# NOTE: persist-credentials is needed for tauri-action to upload assets to GitHub releases.
|
||||
persist-credentials: true # zizmor: ignore[artipacked]
|
||||
|
||||
- name: Configure AWS credentials
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -353,11 +397,9 @@ jobs:
|
||||
APPLE_SIGNING_IDENTITY: ${{ env.CERT_ID }}
|
||||
APPLE_TEAM_ID: ${{ env.APPLE_TEAM_ID }}
|
||||
with:
|
||||
tagName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseBody: "See the assets to download this version and install."
|
||||
releaseDraft: true
|
||||
prerelease: false
|
||||
# Use the release created by the create-release job to avoid race conditions
|
||||
# when multiple matrix jobs try to create/update the same release simultaneously
|
||||
releaseId: ${{ needs.create-release.outputs.release-id }}
|
||||
assetNamePattern: "[name]_[arch][ext]"
|
||||
args: ${{ matrix.args }}
|
||||
|
||||
@@ -384,7 +426,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -458,7 +500,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -527,7 +569,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -597,7 +639,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -679,7 +721,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -756,7 +798,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -823,7 +865,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -896,7 +938,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -964,7 +1006,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1034,7 +1076,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1107,7 +1149,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1176,7 +1218,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1246,7 +1288,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1326,7 +1368,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1400,7 +1442,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1465,7 +1507,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1520,7 +1562,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1580,7 +1622,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1637,7 +1679,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
@@ -15,7 +15,8 @@ permissions:
|
||||
jobs:
|
||||
provider-chat-test:
|
||||
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
|
||||
secrets: inherit
|
||||
secrets:
|
||||
AWS_OIDC_ROLE_ARN: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
|
||||
@@ -6,11 +6,13 @@ on:
|
||||
- main
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
cherry-pick-to-latest-release:
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
outputs:
|
||||
should_cherrypick: ${{ steps.gate.outputs.should_cherrypick }}
|
||||
pr_number: ${{ steps.gate.outputs.pr_number }}
|
||||
@@ -68,7 +70,7 @@ jobs:
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
2
.github/workflows/pr-desktop-build.yml
vendored
2
.github/workflows/pr-desktop-build.yml
vendored
@@ -57,7 +57,7 @@ jobs:
|
||||
cache-dependency-path: ./desktop/package-lock.json
|
||||
|
||||
- name: Setup Rust
|
||||
uses: dtolnay/rust-toolchain@4be9e76fd7c4901c61fb841f559994984270fce7
|
||||
uses: dtolnay/rust-toolchain@efa25f7f19611383d5b0ccf2d1c8914531636bf9
|
||||
with:
|
||||
toolchain: stable
|
||||
targets: ${{ matrix.target }}
|
||||
|
||||
56
.github/workflows/pr-golang-tests.yml
vendored
Normal file
56
.github/workflows/pr-golang-tests.yml
vendored
Normal file
@@ -0,0 +1,56 @@
|
||||
name: Golang Tests
|
||||
concurrency:
|
||||
group: Golang-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- "release/**"
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
permissions: {}
|
||||
|
||||
env:
|
||||
GO_VERSION: "1.26"
|
||||
|
||||
jobs:
|
||||
detect-modules:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
outputs:
|
||||
modules: ${{ steps.set-modules.outputs.modules }}
|
||||
steps:
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8
|
||||
with:
|
||||
persist-credentials: false
|
||||
- id: set-modules
|
||||
run: echo "modules=$(find . -name 'go.mod' -exec dirname {} \; | jq -Rc '[.,inputs]')" >> "$GITHUB_OUTPUT"
|
||||
|
||||
golang:
|
||||
needs: detect-modules
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
strategy:
|
||||
matrix:
|
||||
modules: ${{ fromJSON(needs.detect-modules.outputs.modules) }}
|
||||
steps:
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # zizmor: ignore[cache-poisoning]
|
||||
with:
|
||||
go-version: ${{ env.GO_VERSION }}
|
||||
cache-dependency-path: "**/go.sum"
|
||||
|
||||
- run: go mod tidy
|
||||
working-directory: ${{ matrix.modules }}
|
||||
- run: git diff --exit-code go.mod go.sum
|
||||
working-directory: ${{ matrix.modules }}
|
||||
|
||||
- run: go test ./...
|
||||
working-directory: ${{ matrix.modules }}
|
||||
2
.github/workflows/pr-helm-chart-testing.yml
vendored
2
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -71,7 +71,7 @@ jobs:
|
||||
|
||||
- name: Create kind cluster
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@92086f6be054225fa813e0a4b13787fc9088faab # ratchet:helm/kind-action@v1.13.0
|
||||
uses: helm/kind-action@ef37e7f390d99f746eb8b610417061a60e82a6cc # ratchet:helm/kind-action@v1.14.0
|
||||
|
||||
- name: Pre-install cluster status check
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
|
||||
4
.github/workflows/pr-integration-tests.yml
vendored
4
.github/workflows/pr-integration-tests.yml
vendored
@@ -316,6 +316,7 @@ jobs:
|
||||
# Base config shared by both editions
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
OPENSEARCH_FOR_ONYX_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
@@ -418,6 +419,7 @@ jobs:
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e ENABLE_OPENSEARCH_INDEXING_FOR_ONYX=false \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
@@ -637,6 +639,7 @@ jobs:
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
|
||||
DEV_MODE=true \
|
||||
OPENSEARCH_FOR_ONYX_ENABLED=false \
|
||||
docker compose -f docker-compose.multitenant-dev.yml up \
|
||||
relational_db \
|
||||
index \
|
||||
@@ -691,6 +694,7 @@ jobs:
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e ENABLE_OPENSEARCH_INDEXING_FOR_ONYX=false \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
|
||||
2
.github/workflows/pr-jest-tests.yml
vendored
2
.github/workflows/pr-jest-tests.yml
vendored
@@ -31,7 +31,7 @@ jobs:
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
cache: "npm" # zizmor: ignore[cache-poisoning] test-only workflow; no deploy artifacts
|
||||
cache-dependency-path: ./web/package-lock.json
|
||||
|
||||
- name: Install node dependencies
|
||||
|
||||
9
.github/workflows/pr-playwright-tests.yml
vendored
9
.github/workflows/pr-playwright-tests.yml
vendored
@@ -12,6 +12,9 @@ on:
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
# TODO: Remove this if we enable merge-queues for release branches.
|
||||
branches:
|
||||
- "release/**"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -461,14 +464,14 @@ jobs:
|
||||
# --- Visual Regression Diff ---
|
||||
- name: Configure AWS credentials
|
||||
if: always()
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: always()
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
@@ -707,7 +710,7 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Download visual diff summaries
|
||||
uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3
|
||||
with:
|
||||
pattern: screenshot-diff-summary-*
|
||||
path: summaries/
|
||||
|
||||
6
.github/workflows/pr-quality-checks.yml
vendored
6
.github/workflows/pr-quality-checks.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Setup Terraform
|
||||
uses: hashicorp/setup-terraform@b9cd54a3c349d3f38e8881555d616ced269862dd # ratchet:hashicorp/setup-terraform@v3
|
||||
uses: hashicorp/setup-terraform@5e8dbf3c6d9deaf4193ca7a8fb23f2ac83bb6c85 # ratchet:hashicorp/setup-terraform@v4.0.0
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v6
|
||||
with: # zizmor: ignore[cache-poisoning]
|
||||
@@ -38,9 +38,9 @@ jobs:
|
||||
- name: Install node dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
- uses: j178/prek-action@9d6a3097e0c1865ecce00cfb89fe80f2ee91b547 # ratchet:j178/prek-action@v1
|
||||
- uses: j178/prek-action@0bb87d7f00b0c99306c8bcb8b8beba1eb581c037 # ratchet:j178/prek-action@v1
|
||||
with:
|
||||
prek-version: '0.2.21'
|
||||
prek-version: '0.3.4'
|
||||
extra-args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || github.event_name == 'merge_group' && format('--from-ref {0} --to-ref {1}', github.event.merge_group.base_sha, github.event.merge_group.head_sha) || github.ref_name == 'main' && '--all-files' || '' }}
|
||||
- name: Check Actions
|
||||
uses: giner/check-actions@28d366c7cbbe235f9624a88aa31a628167eee28c # ratchet:giner/check-actions@v1.0.1
|
||||
|
||||
214
.github/workflows/release-cli.yml
vendored
Normal file
214
.github/workflows/release-cli.yml
vendored
Normal file
@@ -0,0 +1,214 @@
|
||||
name: Release CLI
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "cli/v*.*.*"
|
||||
|
||||
jobs:
|
||||
pypi:
|
||||
runs-on: ubuntu-latest
|
||||
environment:
|
||||
name: release-cli
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 10
|
||||
strategy:
|
||||
matrix:
|
||||
os-arch:
|
||||
- { goos: "linux", goarch: "amd64" }
|
||||
- { goos: "linux", goarch: "arm64" }
|
||||
- { goos: "windows", goarch: "amd64" }
|
||||
- { goos: "windows", goarch: "arm64" }
|
||||
- { goos: "darwin", goarch: "amd64" }
|
||||
- { goos: "darwin", goarch: "arm64" }
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
- run: |
|
||||
GOOS="${{ matrix.os-arch.goos }}" \
|
||||
GOARCH="${{ matrix.os-arch.goarch }}" \
|
||||
uv build --wheel
|
||||
working-directory: cli
|
||||
- run: uv publish
|
||||
working-directory: cli
|
||||
|
||||
docker-amd64:
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-cli-amd64
|
||||
- extras=ecr-cache
|
||||
environment: deploy
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 30
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-cli
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./cli
|
||||
file: ./cli/Dockerfile
|
||||
platforms: linux/amd64
|
||||
cache-from: type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: type=inline
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
docker-arm64:
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-cli-arm64
|
||||
- extras=ecr-cache
|
||||
environment: deploy
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 30
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-cli
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./cli
|
||||
file: ./cli/Dockerfile
|
||||
platforms: linux/arm64
|
||||
cache-from: type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: type=inline
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
merge-docker:
|
||||
needs:
|
||||
- docker-amd64
|
||||
- docker-arm64
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-cli-merge
|
||||
environment: deploy
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 10
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-cli
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
AMD64_DIGEST: ${{ needs.docker-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.docker-arm64.outputs.digest }}
|
||||
TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
SANITIZED_TAG="${TAG#cli/}"
|
||||
IMAGES=(
|
||||
"${REGISTRY_IMAGE}@${AMD64_DIGEST}"
|
||||
"${REGISTRY_IMAGE}@${ARM64_DIGEST}"
|
||||
)
|
||||
|
||||
if [[ "$TAG" =~ ^cli/v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||
docker buildx imagetools create \
|
||||
-t "${REGISTRY_IMAGE}:${SANITIZED_TAG}" \
|
||||
-t "${REGISTRY_IMAGE}:latest" \
|
||||
"${IMAGES[@]}"
|
||||
else
|
||||
docker buildx imagetools create \
|
||||
-t "${REGISTRY_IMAGE}:${SANITIZED_TAG}" \
|
||||
"${IMAGES[@]}"
|
||||
fi
|
||||
4
.github/workflows/release-devtools.yml
vendored
4
.github/workflows/release-devtools.yml
vendored
@@ -22,13 +22,11 @@ jobs:
|
||||
- { goos: "windows", goarch: "arm64" }
|
||||
- { goos: "darwin", goarch: "amd64" }
|
||||
- { goos: "darwin", goarch: "arm64" }
|
||||
- { goos: "", goarch: "" }
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
- uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
@@ -48,6 +48,10 @@ on:
|
||||
required: false
|
||||
default: true
|
||||
type: boolean
|
||||
secrets:
|
||||
AWS_OIDC_ROLE_ARN:
|
||||
description: "AWS role ARN for OIDC auth"
|
||||
required: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -73,7 +77,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -116,7 +120,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -158,7 +162,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -264,7 +268,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
6
.github/workflows/sandbox-deployment.yml
vendored
6
.github/workflows/sandbox-deployment.yml
vendored
@@ -110,7 +110,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -180,7 +180,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -244,7 +244,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
69
.github/workflows/storybook-deploy.yml
vendored
Normal file
69
.github/workflows/storybook-deploy.yml
vendored
Normal file
@@ -0,0 +1,69 @@
|
||||
name: Storybook Deploy
|
||||
env:
|
||||
VERCEL_ORG_ID: ${{ secrets.VERCEL_ORG_ID }}
|
||||
VERCEL_PROJECT_ID: prj_sG49mVsA25UsxIPhN2pmBJlikJZM
|
||||
VERCEL_CLI: vercel@50.14.1
|
||||
VERCEL_TOKEN: ${{ secrets.VERCEL_TOKEN }}
|
||||
|
||||
concurrency:
|
||||
group: storybook-deploy-production
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "web/lib/opal/**"
|
||||
- "web/src/refresh-components/**"
|
||||
- "web/.storybook/**"
|
||||
- "web/package.json"
|
||||
- "web/package-lock.json"
|
||||
permissions:
|
||||
contents: read
|
||||
jobs:
|
||||
Deploy-Storybook:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
cache-dependency-path: ./web/package-lock.json
|
||||
|
||||
- name: Install dependencies
|
||||
working-directory: web
|
||||
run: npm ci
|
||||
|
||||
- name: Build Storybook
|
||||
working-directory: web
|
||||
run: npm run storybook:build
|
||||
|
||||
- name: Deploy to Vercel (Production)
|
||||
working-directory: web
|
||||
run: npx --yes "$VERCEL_CLI" deploy storybook-static/ --prod --yes
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs: Deploy-Storybook
|
||||
if: always() && needs.Deploy-Storybook.result == 'failure'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
sparse-checkout: .github/actions/slack-notify
|
||||
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
failed-jobs: "• Deploy-Storybook"
|
||||
title: "🚨 Storybook Deploy Failed"
|
||||
2
.github/workflows/zizmor.yml
vendored
2
.github/workflows/zizmor.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install the latest version of uv
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
@@ -119,10 +119,11 @@ repos:
|
||||
]
|
||||
|
||||
- repo: https://github.com/golangci/golangci-lint
|
||||
rev: 9f61b0f53f80672872fced07b6874397c3ed197b # frozen: v2.7.2
|
||||
rev: 5d1e709b7be35cb2025444e19de266b056b7b7ee # frozen: v2.10.1
|
||||
hooks:
|
||||
- id: golangci-lint
|
||||
entry: bash -c "find tools/ -name go.mod -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
|
||||
language_version: "1.26.0"
|
||||
entry: bash -c "find . -name go.mod -not -path './.venv/*' -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
|
||||
3
.vscode/env_template.txt
vendored
3
.vscode/env_template.txt
vendored
@@ -7,6 +7,9 @@
|
||||
|
||||
|
||||
AUTH_TYPE=basic
|
||||
# Recommended for basic auth - used for signing password reset and verification tokens
|
||||
# Generate a secure value with: openssl rand -hex 32
|
||||
USER_AUTH_SECRET=""
|
||||
DEV_MODE=true
|
||||
|
||||
|
||||
|
||||
@@ -104,6 +104,10 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
|
||||
|
||||
- Always use `@shared_task` rather than `@celery_app`
|
||||
- Put tasks under `background/celery/tasks/` or `ee/background/celery/tasks`
|
||||
- Never enqueue a task without an expiration. Always supply `expires=` when
|
||||
sending tasks, either from the beat schedule or directly from another task. It
|
||||
should never be acceptable to submit code which enqueues tasks without an
|
||||
expiration, as doing so can lead to unbounded task queue growth.
|
||||
|
||||
**Defining APIs**:
|
||||
When creating new FastAPI APIs, do NOT use the `response_model` field. Instead, just type the
|
||||
@@ -540,6 +544,8 @@ To run them:
|
||||
npx playwright test <TEST_NAME>
|
||||
```
|
||||
|
||||
For shared fixtures, best practices, and detailed guidance, see `backend/tests/README.md`.
|
||||
|
||||
## Logs
|
||||
|
||||
When (1) writing integration tests or (2) doing live tests (e.g. curl / playwright) you can get access
|
||||
@@ -592,7 +598,7 @@ Before writing your plan, make sure to do research. Explore the relevant section
|
||||
Never hardcode status codes or use `starlette.status` / `fastapi.status` constants directly.**
|
||||
|
||||
A global FastAPI exception handler converts `OnyxError` into a JSON response with the standard
|
||||
`{"error_code": "...", "message": "..."}` shape. This eliminates boilerplate and keeps error
|
||||
`{"error_code": "...", "detail": "..."}` shape. This eliminates boilerplate and keeps error
|
||||
handling consistent across the entire backend.
|
||||
|
||||
```python
|
||||
|
||||
@@ -46,7 +46,9 @@ RUN apt-get update && \
|
||||
pkg-config \
|
||||
gcc \
|
||||
nano \
|
||||
vim && \
|
||||
vim \
|
||||
libjemalloc2 \
|
||||
&& \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
apt-get clean
|
||||
|
||||
@@ -141,6 +143,7 @@ COPY --chown=onyx:onyx ./scripts/debugging /app/scripts/debugging
|
||||
COPY --chown=onyx:onyx ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
|
||||
COPY --chown=onyx:onyx ./scripts/supervisord_entrypoint.sh /app/scripts/supervisord_entrypoint.sh
|
||||
COPY --chown=onyx:onyx ./scripts/setup_craft_templates.sh /app/scripts/setup_craft_templates.sh
|
||||
COPY --chown=onyx:onyx ./scripts/reencrypt_secrets.py /app/scripts/reencrypt_secrets.py
|
||||
RUN chmod +x /app/scripts/supervisord_entrypoint.sh /app/scripts/setup_craft_templates.sh
|
||||
|
||||
# Run Craft template setup at build time when ENABLE_CRAFT=true
|
||||
@@ -164,6 +167,13 @@ ENV PYTHONPATH=/app
|
||||
ARG ONYX_VERSION=0.0.0-dev
|
||||
ENV ONYX_VERSION=${ONYX_VERSION}
|
||||
|
||||
# Use jemalloc instead of glibc malloc to reduce memory fragmentation
|
||||
# in long-running Python processes (API server, Celery workers).
|
||||
# The soname is architecture-independent; the dynamic linker resolves
|
||||
# the correct path from standard library directories.
|
||||
# Placed after all RUN steps so build-time processes are unaffected.
|
||||
ENV LD_PRELOAD=libjemalloc.so.2
|
||||
|
||||
# Default command which does nothing
|
||||
# This container is used by api server and background which specify their own CMD
|
||||
CMD ["tail", "-f", "/dev/null"]
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
"""add hierarchy_node_by_connector_credential_pair table
|
||||
|
||||
Revision ID: b5c4d7e8f9a1
|
||||
Revises: a3b8d9e2f1c4
|
||||
Create Date: 2026-03-04
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = "b5c4d7e8f9a1"
|
||||
down_revision = "a3b8d9e2f1c4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"hierarchy_node_by_connector_credential_pair",
|
||||
sa.Column("hierarchy_node_id", sa.Integer(), nullable=False),
|
||||
sa.Column("connector_id", sa.Integer(), nullable=False),
|
||||
sa.Column("credential_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["hierarchy_node_id"],
|
||||
["hierarchy_node.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["connector_id", "credential_id"],
|
||||
[
|
||||
"connector_credential_pair.connector_id",
|
||||
"connector_credential_pair.credential_id",
|
||||
],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("hierarchy_node_id", "connector_id", "credential_id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_hierarchy_node_cc_pair_connector_credential",
|
||||
"hierarchy_node_by_connector_credential_pair",
|
||||
["connector_id", "credential_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(
|
||||
"ix_hierarchy_node_cc_pair_connector_credential",
|
||||
table_name="hierarchy_node_by_connector_credential_pair",
|
||||
)
|
||||
op.drop_table("hierarchy_node_by_connector_credential_pair")
|
||||
@@ -11,7 +11,6 @@ from sqlalchemy import text
|
||||
from alembic import op
|
||||
from onyx.configs.app_configs import DB_READONLY_PASSWORD
|
||||
from onyx.configs.app_configs import DB_READONLY_USER
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
@@ -22,59 +21,52 @@ depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
if MULTI_TENANT:
|
||||
# Enable pg_trgm extension if not already enabled
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
|
||||
|
||||
# Enable pg_trgm extension if not already enabled
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
|
||||
# Create the read-only db user if it does not already exist.
|
||||
if not (DB_READONLY_USER and DB_READONLY_PASSWORD):
|
||||
raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set")
|
||||
|
||||
# Create read-only db user here only in multi-tenant mode. For single-tenant mode,
|
||||
# the user is created in the standard migration.
|
||||
if not (DB_READONLY_USER and DB_READONLY_PASSWORD):
|
||||
raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set")
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
-- Check if the read-only user already exists
|
||||
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- Create the read-only user with the specified password
|
||||
EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}');
|
||||
-- First revoke all privileges to ensure a clean slate
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Grant only the CONNECT privilege to allow the user to connect to the database
|
||||
-- but not perform any operations without additional specific grants
|
||||
EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
if MULTI_TENANT:
|
||||
# Drop read-only db user here only in single tenant mode. For multi-tenant mode,
|
||||
# the user is dropped in the alembic_tenants migration.
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- First revoke all privileges from the database
|
||||
-- Check if the read-only user already exists
|
||||
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- Create the read-only user with the specified password
|
||||
EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}');
|
||||
-- First revoke all privileges to ensure a clean slate
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Then revoke all privileges from the public schema
|
||||
EXECUTE format('REVOKE ALL ON SCHEMA public FROM %I', '{DB_READONLY_USER}');
|
||||
-- Then drop the user
|
||||
EXECUTE format('DROP USER %I', '{DB_READONLY_USER}');
|
||||
-- Grant only the CONNECT privilege to allow the user to connect to the database
|
||||
-- but not perform any operations without additional specific grants
|
||||
EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
"""
|
||||
)
|
||||
op.execute(text("DROP EXTENSION IF EXISTS pg_trgm"))
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- First revoke all privileges from the database
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Then revoke all privileges from the public schema
|
||||
EXECUTE format('REVOKE ALL ON SCHEMA public FROM %I', '{DB_READONLY_USER}');
|
||||
-- Then drop the user
|
||||
EXECUTE format('DROP USER %I', '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
op.execute(text("DROP EXTENSION IF EXISTS pg_trgm"))
|
||||
|
||||
@@ -9,12 +9,15 @@ from onyx.access.access import (
|
||||
_get_access_for_documents as get_access_for_documents_without_groups,
|
||||
)
|
||||
from onyx.access.access import _get_acl_for_user as get_acl_for_user_without_groups
|
||||
from onyx.access.access import collect_user_file_access
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.access.utils import prefix_external_group
|
||||
from onyx.access.utils import prefix_user_group
|
||||
from onyx.db.document import get_document_sources
|
||||
from onyx.db.document import get_documents_by_ids
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.user_file import fetch_user_files_with_access_relationships
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -116,6 +119,68 @@ def _get_access_for_documents(
|
||||
return access_map
|
||||
|
||||
|
||||
def _collect_user_file_group_names(user_file: UserFile) -> set[str]:
|
||||
"""Extract user-group names from the already-loaded Persona.groups
|
||||
relationships on a UserFile (skipping deleted personas)."""
|
||||
groups: set[str] = set()
|
||||
for persona in user_file.assistants:
|
||||
if persona.deleted:
|
||||
continue
|
||||
for group in persona.groups:
|
||||
groups.add(group.name)
|
||||
return groups
|
||||
|
||||
|
||||
def get_access_for_user_files_impl(
|
||||
user_file_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
"""EE version: extends the MIT user file ACL with user group names
|
||||
from personas shared via user groups.
|
||||
|
||||
Uses a single DB query (via fetch_user_files_with_access_relationships)
|
||||
that eagerly loads both the MIT-needed and EE-needed relationships.
|
||||
|
||||
NOTE: is imported in onyx.access.access by `fetch_versioned_implementation`
|
||||
DO NOT REMOVE."""
|
||||
user_files = fetch_user_files_with_access_relationships(
|
||||
user_file_ids, db_session, eager_load_groups=True
|
||||
)
|
||||
return build_access_for_user_files_impl(user_files)
|
||||
|
||||
|
||||
def build_access_for_user_files_impl(
|
||||
user_files: list[UserFile],
|
||||
) -> dict[str, DocumentAccess]:
|
||||
"""EE version: works on pre-loaded UserFile objects.
|
||||
Expects Persona.groups to be eagerly loaded.
|
||||
|
||||
NOTE: is imported in onyx.access.access by `fetch_versioned_implementation`
|
||||
DO NOT REMOVE."""
|
||||
result: dict[str, DocumentAccess] = {}
|
||||
for user_file in user_files:
|
||||
if user_file.user is None:
|
||||
result[str(user_file.id)] = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
is_public=True,
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
)
|
||||
continue
|
||||
|
||||
emails, is_public = collect_user_file_access(user_file)
|
||||
group_names = _collect_user_file_group_names(user_file)
|
||||
result[str(user_file.id)] = DocumentAccess.build(
|
||||
user_emails=list(emails),
|
||||
user_groups=list(group_names),
|
||||
is_public=is_public,
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _get_acl_for_user(user: User, db_session: Session) -> set[str]:
|
||||
"""Returns a list of ACL entries that the user has access to. This is meant to be
|
||||
used downstream to filter out documents that the user does not have access to. The
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import jwt
|
||||
@@ -20,7 +21,13 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def verify_auth_setting() -> None:
|
||||
# All the Auth flows are valid for EE version
|
||||
# All the Auth flows are valid for EE version, but warn about deprecated 'disabled'
|
||||
raw_auth_type = (os.environ.get("AUTH_TYPE") or "").lower()
|
||||
if raw_auth_type == "disabled":
|
||||
logger.warning(
|
||||
"AUTH_TYPE='disabled' is no longer supported. "
|
||||
"Using 'basic' instead. Please update your configuration."
|
||||
)
|
||||
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
|
||||
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from onyx.db.models import HierarchyNode
|
||||
|
||||
|
||||
def _build_hierarchy_access_filter(
|
||||
user_email: str | None,
|
||||
user_email: str,
|
||||
external_group_ids: list[str],
|
||||
) -> ColumnElement[bool]:
|
||||
"""Build SQLAlchemy filter for hierarchy node access.
|
||||
@@ -43,7 +43,7 @@ def _build_hierarchy_access_filter(
|
||||
def _get_accessible_hierarchy_nodes_for_source(
|
||||
db_session: Session,
|
||||
source: DocumentSource,
|
||||
user_email: str | None,
|
||||
user_email: str,
|
||||
external_group_ids: list[str],
|
||||
) -> list[HierarchyNode]:
|
||||
"""
|
||||
|
||||
@@ -7,6 +7,7 @@ from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import Persona__UserGroup
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.db.persona import mark_persona_user_files_for_sync
|
||||
from onyx.server.features.persona.models import PersonaSharedNotificationData
|
||||
|
||||
|
||||
@@ -26,7 +27,9 @@ def update_persona_access(
|
||||
|
||||
NOTE: Callers are responsible for committing."""
|
||||
|
||||
needs_sync = False
|
||||
if is_public is not None:
|
||||
needs_sync = True
|
||||
persona = db_session.query(Persona).filter(Persona.id == persona_id).first()
|
||||
if persona:
|
||||
persona.is_public = is_public
|
||||
@@ -35,6 +38,7 @@ def update_persona_access(
|
||||
# and a non-empty list means "replace with these shares".
|
||||
|
||||
if user_ids is not None:
|
||||
needs_sync = True
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
@@ -54,6 +58,7 @@ def update_persona_access(
|
||||
)
|
||||
|
||||
if group_ids is not None:
|
||||
needs_sync = True
|
||||
db_session.query(Persona__UserGroup).filter(
|
||||
Persona__UserGroup.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
@@ -63,3 +68,7 @@ def update_persona_access(
|
||||
db_session.add(
|
||||
Persona__UserGroup(persona_id=persona_id, user_group_id=group_id)
|
||||
)
|
||||
|
||||
# When sharing changes, user file ACLs need to be updated in the vector DB
|
||||
if needs_sync:
|
||||
mark_persona_user_files_for_sync(persona_id, db_session)
|
||||
|
||||
@@ -68,6 +68,7 @@ def get_external_access_for_raw_gdrive_file(
|
||||
company_domain: str,
|
||||
retriever_drive_service: GoogleDriveService | None,
|
||||
admin_drive_service: GoogleDriveService,
|
||||
fallback_user_email: str,
|
||||
add_prefix: bool = False,
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
@@ -79,6 +80,11 @@ def get_external_access_for_raw_gdrive_file(
|
||||
set add_prefix to True so group IDs are prefixed with the source type.
|
||||
When invoked from doc_sync (permission sync), use the default (False)
|
||||
since upsert_document_external_perms handles prefixing.
|
||||
fallback_user_email: When we cannot retrieve any permission info for a file
|
||||
(e.g. externally-owned files where the API returns no permissions
|
||||
and permissions.list returns 403), fall back to granting access
|
||||
to this user. This is typically the impersonated org user whose
|
||||
drive contained the file.
|
||||
"""
|
||||
doc_id = file.get("id")
|
||||
if not doc_id:
|
||||
@@ -117,6 +123,26 @@ def get_external_access_for_raw_gdrive_file(
|
||||
[permissions_list, backup_permissions_list]
|
||||
)
|
||||
|
||||
# For externally-owned files, the Drive API may return no permissions
|
||||
# and permissions.list may return 403. In this case, fall back to
|
||||
# granting access to the user who found the file in their drive.
|
||||
# Note, even if other users also have access to this file,
|
||||
# they will not be granted access in Onyx.
|
||||
# We check permissions_list (the final result after all fetch attempts)
|
||||
# rather than the raw fields, because permission_ids may be present
|
||||
# but the actual fetch can still return empty due to a 403.
|
||||
if not permissions_list:
|
||||
logger.info(
|
||||
f"No permission info available for file {doc_id} "
|
||||
f"(likely owned by a user outside of your organization). "
|
||||
f"Falling back to granting access to retriever user: {fallback_user_email}"
|
||||
)
|
||||
return ExternalAccess(
|
||||
external_user_emails={fallback_user_email},
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
folder_ids_to_inherit_permissions_from: set[str] = set()
|
||||
user_emails: set[str] = set()
|
||||
group_emails: set[str] = set()
|
||||
|
||||
@@ -246,7 +246,11 @@ async def get_billing_information(
|
||||
)
|
||||
except OnyxError as e:
|
||||
# Open circuit breaker on connection failures (self-hosted only)
|
||||
if e.status_code in (502, 503, 504):
|
||||
if e.status_code in (
|
||||
OnyxErrorCode.BAD_GATEWAY.status_code,
|
||||
OnyxErrorCode.SERVICE_UNAVAILABLE.status_code,
|
||||
OnyxErrorCode.GATEWAY_TIMEOUT.status_code,
|
||||
):
|
||||
_open_billing_circuit()
|
||||
raise
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from onyx.db.models import Tool
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.server.settings.store import store_settings as store_base_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -125,10 +126,16 @@ def _seed_llms(
|
||||
existing = fetch_existing_llm_provider(name=request.name, db_session=db_session)
|
||||
if existing:
|
||||
request.id = existing.id
|
||||
seeded_providers = [
|
||||
upsert_llm_provider(llm_upsert_request, db_session)
|
||||
for llm_upsert_request in llm_upsert_requests
|
||||
]
|
||||
seeded_providers: list[LLMProviderView] = []
|
||||
for llm_upsert_request in llm_upsert_requests:
|
||||
try:
|
||||
seeded_providers.append(upsert_llm_provider(llm_upsert_request, db_session))
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
"Failed to upsert LLM provider '%s' during seeding: %s",
|
||||
llm_upsert_request.name,
|
||||
e,
|
||||
)
|
||||
|
||||
default_provider = next(
|
||||
(p for p in seeded_providers if p.model_configurations), None
|
||||
|
||||
@@ -14,67 +14,91 @@ from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
@lru_cache(maxsize=2)
|
||||
def _get_trimmed_key(key: str) -> bytes:
|
||||
encoded_key = key.encode()
|
||||
key_length = len(encoded_key)
|
||||
if key_length < 16:
|
||||
raise RuntimeError("Invalid ENCRYPTION_KEY_SECRET - too short")
|
||||
elif key_length > 32:
|
||||
key = key[:32]
|
||||
elif key_length not in (16, 24, 32):
|
||||
valid_lengths = [16, 24, 32]
|
||||
key = key[: min(valid_lengths, key=lambda x: abs(x - key_length))]
|
||||
|
||||
return encoded_key
|
||||
# Trim to the largest valid AES key size that fits
|
||||
valid_lengths = [32, 24, 16]
|
||||
for size in valid_lengths:
|
||||
if key_length >= size:
|
||||
return encoded_key[:size]
|
||||
|
||||
raise AssertionError("unreachable")
|
||||
|
||||
|
||||
def _encrypt_string(input_str: str) -> bytes:
|
||||
if not ENCRYPTION_KEY_SECRET:
|
||||
def _encrypt_string(input_str: str, key: str | None = None) -> bytes:
|
||||
effective_key = key if key is not None else ENCRYPTION_KEY_SECRET
|
||||
if not effective_key:
|
||||
return input_str.encode()
|
||||
|
||||
key = _get_trimmed_key(ENCRYPTION_KEY_SECRET)
|
||||
trimmed = _get_trimmed_key(effective_key)
|
||||
iv = urandom(16)
|
||||
padder = padding.PKCS7(algorithms.AES.block_size).padder()
|
||||
padded_data = padder.update(input_str.encode()) + padder.finalize()
|
||||
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||
cipher = Cipher(algorithms.AES(trimmed), modes.CBC(iv), backend=default_backend())
|
||||
encryptor = cipher.encryptor()
|
||||
encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
|
||||
|
||||
return iv + encrypted_data
|
||||
|
||||
|
||||
def _decrypt_bytes(input_bytes: bytes) -> str:
|
||||
if not ENCRYPTION_KEY_SECRET:
|
||||
def _decrypt_bytes(input_bytes: bytes, key: str | None = None) -> str:
|
||||
effective_key = key if key is not None else ENCRYPTION_KEY_SECRET
|
||||
if not effective_key:
|
||||
return input_bytes.decode()
|
||||
|
||||
key = _get_trimmed_key(ENCRYPTION_KEY_SECRET)
|
||||
iv = input_bytes[:16]
|
||||
encrypted_data = input_bytes[16:]
|
||||
trimmed = _get_trimmed_key(effective_key)
|
||||
try:
|
||||
iv = input_bytes[:16]
|
||||
encrypted_data = input_bytes[16:]
|
||||
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||
decryptor = cipher.decryptor()
|
||||
decrypted_padded_data = decryptor.update(encrypted_data) + decryptor.finalize()
|
||||
cipher = Cipher(
|
||||
algorithms.AES(trimmed), modes.CBC(iv), backend=default_backend()
|
||||
)
|
||||
decryptor = cipher.decryptor()
|
||||
decrypted_padded_data = decryptor.update(encrypted_data) + decryptor.finalize()
|
||||
|
||||
unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
|
||||
decrypted_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
|
||||
unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
|
||||
decrypted_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
|
||||
|
||||
return decrypted_data.decode()
|
||||
return decrypted_data.decode()
|
||||
except (ValueError, UnicodeDecodeError):
|
||||
if key is not None:
|
||||
# Explicit key was provided — don't fall back silently
|
||||
raise
|
||||
# Read path: attempt raw UTF-8 decode as a fallback for legacy data.
|
||||
# Does NOT handle data encrypted with a different key — that
|
||||
# ciphertext is not valid UTF-8 and will raise below.
|
||||
logger.warning(
|
||||
"AES decryption failed — falling back to raw decode. "
|
||||
"Run the re-encrypt secrets script to rotate to the current key."
|
||||
)
|
||||
try:
|
||||
return input_bytes.decode()
|
||||
except UnicodeDecodeError:
|
||||
raise ValueError(
|
||||
"Data is not valid UTF-8 — likely encrypted with a different key. "
|
||||
"Run the re-encrypt secrets script to rotate to the current key."
|
||||
) from None
|
||||
|
||||
|
||||
def encrypt_string_to_bytes(input_str: str) -> bytes:
|
||||
def encrypt_string_to_bytes(input_str: str, key: str | None = None) -> bytes:
|
||||
versioned_encryption_fn = fetch_versioned_implementation(
|
||||
"onyx.utils.encryption", "_encrypt_string"
|
||||
)
|
||||
return versioned_encryption_fn(input_str)
|
||||
return versioned_encryption_fn(input_str, key=key)
|
||||
|
||||
|
||||
def decrypt_bytes_to_string(input_bytes: bytes) -> str:
|
||||
def decrypt_bytes_to_string(input_bytes: bytes, key: str | None = None) -> str:
|
||||
versioned_decryption_fn = fetch_versioned_implementation(
|
||||
"onyx.utils.encryption", "_decrypt_bytes"
|
||||
)
|
||||
return versioned_decryption_fn(input_bytes)
|
||||
return versioned_decryption_fn(input_bytes, key=key)
|
||||
|
||||
|
||||
def test_encryption() -> None:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
@@ -12,6 +11,7 @@ from onyx.db.document import get_access_info_for_document
|
||||
from onyx.db.document import get_access_info_for_documents
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.user_file import fetch_user_files_with_access_relationships
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
@@ -132,19 +132,61 @@ def get_access_for_user_files(
|
||||
user_file_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
user_files = (
|
||||
db_session.query(UserFile)
|
||||
.options(joinedload(UserFile.user)) # Eager load the user relationship
|
||||
.filter(UserFile.id.in_(user_file_ids))
|
||||
.all()
|
||||
versioned_fn = fetch_versioned_implementation(
|
||||
"onyx.access.access", "get_access_for_user_files_impl"
|
||||
)
|
||||
return {
|
||||
str(user_file.id): DocumentAccess.build(
|
||||
user_emails=[user_file.user.email] if user_file.user else [],
|
||||
return versioned_fn(user_file_ids, db_session)
|
||||
|
||||
|
||||
def get_access_for_user_files_impl(
|
||||
user_file_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
user_files = fetch_user_files_with_access_relationships(user_file_ids, db_session)
|
||||
return build_access_for_user_files_impl(user_files)
|
||||
|
||||
|
||||
def build_access_for_user_files(
|
||||
user_files: list[UserFile],
|
||||
) -> dict[str, DocumentAccess]:
|
||||
"""Compute access from pre-loaded UserFile objects (with relationships).
|
||||
Callers must ensure UserFile.user, Persona.users, and Persona.user are
|
||||
eagerly loaded (and Persona.groups for the EE path)."""
|
||||
versioned_fn = fetch_versioned_implementation(
|
||||
"onyx.access.access", "build_access_for_user_files_impl"
|
||||
)
|
||||
return versioned_fn(user_files)
|
||||
|
||||
|
||||
def build_access_for_user_files_impl(
|
||||
user_files: list[UserFile],
|
||||
) -> dict[str, DocumentAccess]:
|
||||
result: dict[str, DocumentAccess] = {}
|
||||
for user_file in user_files:
|
||||
emails, is_public = collect_user_file_access(user_file)
|
||||
result[str(user_file.id)] = DocumentAccess.build(
|
||||
user_emails=list(emails),
|
||||
user_groups=[],
|
||||
is_public=True if user_file.user is None else False,
|
||||
is_public=is_public,
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
)
|
||||
for user_file in user_files
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def collect_user_file_access(user_file: UserFile) -> tuple[set[str], bool]:
|
||||
"""Collect all user emails that should have access to this user file.
|
||||
Includes the owner plus any users who have access via shared personas.
|
||||
Returns (emails, is_public)."""
|
||||
emails: set[str] = {user_file.user.email}
|
||||
is_public = False
|
||||
for persona in user_file.assistants:
|
||||
if persona.deleted:
|
||||
continue
|
||||
if persona.is_public:
|
||||
is_public = True
|
||||
if persona.user_id is not None and persona.user:
|
||||
emails.add(persona.user.email)
|
||||
for shared_user in persona.users:
|
||||
emails.add(shared_user.email)
|
||||
return emails, is_public
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import secrets
|
||||
import string
|
||||
@@ -145,10 +146,22 @@ def is_user_admin(user: User) -> bool:
|
||||
|
||||
|
||||
def verify_auth_setting() -> None:
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
"""Log warnings for AUTH_TYPE issues.
|
||||
|
||||
This only runs on app startup not during migrations/scripts.
|
||||
"""
|
||||
raw_auth_type = (os.environ.get("AUTH_TYPE") or "").lower()
|
||||
|
||||
if raw_auth_type == "cloud":
|
||||
raise ValueError(
|
||||
f"{AUTH_TYPE.value} is not a valid auth type for self-hosted deployments."
|
||||
"'cloud' is not a valid auth type for self-hosted deployments."
|
||||
)
|
||||
if raw_auth_type == "disabled":
|
||||
logger.warning(
|
||||
"AUTH_TYPE='disabled' is no longer supported. "
|
||||
"Using 'basic' instead. Please update your configuration."
|
||||
)
|
||||
|
||||
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
|
||||
|
||||
|
||||
|
||||
@@ -115,8 +115,6 @@ def _extract_from_batch(
|
||||
for item in doc_list:
|
||||
if isinstance(item, HierarchyNode):
|
||||
hierarchy_nodes.append(item)
|
||||
if item.raw_node_id not in ids:
|
||||
ids[item.raw_node_id] = None
|
||||
elif isinstance(item, ConnectorFailure):
|
||||
failed_id = _get_failure_id(item)
|
||||
if failed_id:
|
||||
@@ -125,8 +123,7 @@ def _extract_from_batch(
|
||||
f"Failed to retrieve document {failed_id}: " f"{item.failure_message}"
|
||||
)
|
||||
else:
|
||||
parent_raw = getattr(item, "parent_hierarchy_raw_node_id", None)
|
||||
ids[item.id] = parent_raw
|
||||
ids[item.id] = item.parent_hierarchy_raw_node_id
|
||||
return BatchResult(raw_id_to_parent=ids, hierarchy_nodes=hierarchy_nodes)
|
||||
|
||||
|
||||
@@ -192,9 +189,7 @@ def extract_ids_from_runnable_connector(
|
||||
batch_ids = batch_result.raw_id_to_parent
|
||||
batch_nodes = batch_result.hierarchy_nodes
|
||||
doc_batch_processing_func(batch_ids)
|
||||
for k, v in batch_ids.items():
|
||||
if v is not None or k not in all_raw_id_to_parent:
|
||||
all_raw_id_to_parent[k] = v
|
||||
all_raw_id_to_parent.update(batch_ids)
|
||||
all_hierarchy_nodes.extend(batch_nodes)
|
||||
|
||||
if callback:
|
||||
|
||||
@@ -40,6 +40,7 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.hierarchy import upsert_hierarchy_node_cc_pair_entries
|
||||
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
|
||||
@@ -289,6 +290,14 @@ def _run_hierarchy_extraction(
|
||||
is_connector_public=is_connector_public,
|
||||
)
|
||||
|
||||
upsert_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
hierarchy_node_ids=[n.id for n in upserted_nodes],
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# Cache in Redis for fast ancestor resolution
|
||||
cache_entries = [
|
||||
HierarchyNodeCacheEntry.from_db_model(node) for node in upserted_nodes
|
||||
|
||||
@@ -48,10 +48,15 @@ from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.hierarchy import delete_orphaned_hierarchy_nodes
|
||||
from onyx.db.hierarchy import link_hierarchy_nodes_to_documents
|
||||
from onyx.db.hierarchy import remove_stale_hierarchy_node_cc_pair_entries
|
||||
from onyx.db.hierarchy import reparent_orphaned_hierarchy_nodes
|
||||
from onyx.db.hierarchy import update_document_parent_hierarchy_nodes
|
||||
from onyx.db.hierarchy import upsert_hierarchy_node_cc_pair_entries
|
||||
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import HierarchyNode as DBHierarchyNode
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.db.tag import delete_orphan_tags__no_commit
|
||||
@@ -60,6 +65,7 @@ from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload
|
||||
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
|
||||
from onyx.redis.redis_hierarchy import ensure_source_node_exists
|
||||
from onyx.redis.redis_hierarchy import evict_hierarchy_nodes_from_cache
|
||||
from onyx.redis.redis_hierarchy import get_node_id_from_raw_id
|
||||
from onyx.redis.redis_hierarchy import get_source_node_id_from_cache
|
||||
from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry
|
||||
@@ -579,11 +585,12 @@ def connector_pruning_generator_task(
|
||||
source = cc_pair.connector.source
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
ensure_source_node_exists(redis_client, db_session, source)
|
||||
|
||||
upserted_nodes: list[DBHierarchyNode] = []
|
||||
if extraction_result.hierarchy_nodes:
|
||||
is_connector_public = cc_pair.access_type == AccessType.PUBLIC
|
||||
|
||||
ensure_source_node_exists(redis_client, db_session, source)
|
||||
|
||||
upserted_nodes = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=extraction_result.hierarchy_nodes,
|
||||
@@ -592,6 +599,14 @@ def connector_pruning_generator_task(
|
||||
is_connector_public=is_connector_public,
|
||||
)
|
||||
|
||||
upsert_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
hierarchy_node_ids=[n.id for n in upserted_nodes],
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
cache_entries = [
|
||||
HierarchyNodeCacheEntry.from_db_model(node)
|
||||
for node in upserted_nodes
|
||||
@@ -607,7 +622,6 @@ def connector_pruning_generator_task(
|
||||
f"hierarchy nodes for cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
ensure_source_node_exists(redis_client, db_session, source)
|
||||
# Resolve parent_hierarchy_raw_node_id → parent_hierarchy_node_id
|
||||
# and bulk-update documents, mirroring the docfetching resolution
|
||||
_resolve_and_update_document_parents(
|
||||
@@ -664,6 +678,43 @@ def connector_pruning_generator_task(
|
||||
)
|
||||
|
||||
redis_connector.prune.generator_complete = tasks_generated
|
||||
|
||||
# --- Hierarchy node pruning ---
|
||||
live_node_ids = {n.id for n in upserted_nodes}
|
||||
stale_removed = remove_stale_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
live_hierarchy_node_ids=live_node_ids,
|
||||
commit=True,
|
||||
)
|
||||
deleted_raw_ids = delete_orphaned_hierarchy_nodes(
|
||||
db_session=db_session,
|
||||
source=source,
|
||||
commit=True,
|
||||
)
|
||||
reparented_nodes = reparent_orphaned_hierarchy_nodes(
|
||||
db_session=db_session,
|
||||
source=source,
|
||||
commit=True,
|
||||
)
|
||||
if deleted_raw_ids:
|
||||
evict_hierarchy_nodes_from_cache(redis_client, source, deleted_raw_ids)
|
||||
if reparented_nodes:
|
||||
reparented_cache_entries = [
|
||||
HierarchyNodeCacheEntry.from_db_model(node)
|
||||
for node in reparented_nodes
|
||||
]
|
||||
cache_hierarchy_nodes_batch(
|
||||
redis_client, source, reparented_cache_entries
|
||||
)
|
||||
if stale_removed or deleted_raw_ids or reparented_nodes:
|
||||
task_logger.info(
|
||||
f"Hierarchy node pruning: cc_pair={cc_pair_id} "
|
||||
f"stale_entries_removed={stale_removed} "
|
||||
f"nodes_deleted={len(deleted_raw_ids)} "
|
||||
f"nodes_reparented={len(reparented_nodes)}"
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Pruning exceptioned: cc_pair={cc_pair_id} "
|
||||
|
||||
@@ -12,9 +12,9 @@ from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from retry import retry
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.access import build_access_for_user_files
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
@@ -43,7 +43,9 @@ from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.search_settings import get_active_search_settings_list
|
||||
from onyx.db.user_file import fetch_user_files_with_access_relationships
|
||||
from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
@@ -54,6 +56,7 @@ from onyx.indexing.adapters.user_file_indexing_adapter import UserFileIndexingAd
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
def _as_uuid(value: str | UUID) -> UUID:
|
||||
@@ -791,11 +794,12 @@ def project_sync_user_file_impl(
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file = db_session.execute(
|
||||
select(UserFile)
|
||||
.where(UserFile.id == _as_uuid(user_file_id))
|
||||
.options(selectinload(UserFile.assistants))
|
||||
).scalar_one_or_none()
|
||||
user_files = fetch_user_files_with_access_relationships(
|
||||
[user_file_id],
|
||||
db_session,
|
||||
eager_load_groups=global_version.is_ee_version(),
|
||||
)
|
||||
user_file = user_files[0] if user_files else None
|
||||
if not user_file:
|
||||
task_logger.info(
|
||||
f"project_sync_user_file_impl - User file not found id={user_file_id}"
|
||||
@@ -823,12 +827,21 @@ def project_sync_user_file_impl(
|
||||
|
||||
project_ids = [project.id for project in user_file.projects]
|
||||
persona_ids = [p.id for p in user_file.assistants if not p.deleted]
|
||||
|
||||
file_id_str = str(user_file.id)
|
||||
access_map = build_access_for_user_files([user_file])
|
||||
access = access_map.get(file_id_str)
|
||||
|
||||
for retry_document_index in retry_document_indices:
|
||||
retry_document_index.update_single(
|
||||
doc_id=str(user_file.id),
|
||||
doc_id=file_id_str,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=user_file.chunk_count,
|
||||
fields=None,
|
||||
fields=(
|
||||
VespaDocumentFields(access=access)
|
||||
if access is not None
|
||||
else None
|
||||
),
|
||||
user_fields=VespaDocumentUserFields(
|
||||
user_projects=project_ids,
|
||||
personas=persona_ids,
|
||||
|
||||
@@ -45,6 +45,7 @@ from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.enums import ProcessingMode
|
||||
from onyx.db.hierarchy import upsert_hierarchy_node_cc_pair_entries
|
||||
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
|
||||
from onyx.db.index_attempt import create_index_attempt_error
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
@@ -58,8 +59,6 @@ from onyx.file_store.document_batch_storage import DocumentBatchStorage
|
||||
from onyx.file_store.document_batch_storage import get_document_batch_storage
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
|
||||
from onyx.indexing.postgres_sanitization import sanitize_document_for_postgres
|
||||
from onyx.indexing.postgres_sanitization import sanitize_hierarchy_nodes_for_postgres
|
||||
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
|
||||
from onyx.redis.redis_hierarchy import ensure_source_node_exists
|
||||
from onyx.redis.redis_hierarchy import get_node_id_from_raw_id
|
||||
@@ -71,6 +70,8 @@ from onyx.server.features.build.indexing.persistent_document_writer import (
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.middleware import make_randomized_onyx_request_id
|
||||
from onyx.utils.postgres_sanitization import sanitize_document_for_postgres
|
||||
from onyx.utils.postgres_sanitization import sanitize_hierarchy_nodes_for_postgres
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import INDEX_ATTEMPT_INFO_CONTEXTVAR
|
||||
@@ -587,6 +588,14 @@ def connector_document_extraction(
|
||||
is_connector_public=is_connector_public,
|
||||
)
|
||||
|
||||
upsert_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
hierarchy_node_ids=[n.id for n in upserted_nodes],
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# Cache in Redis for fast ancestor resolution during doc processing
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
cache_entries = [
|
||||
|
||||
@@ -36,7 +36,6 @@ from onyx.db.memory import add_memory
|
||||
from onyx.db.memory import update_memory_at_index
|
||||
from onyx.db.memory import UserMemoryContext
|
||||
from onyx.db.models import Persona
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
@@ -51,6 +50,7 @@ from onyx.tools.built_in_tools import CITEABLE_TOOLS_NAMES
|
||||
from onyx.tools.built_in_tools import STOPPING_TOOLS_NAMES
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ChatFile
|
||||
from onyx.tools.models import CustomToolCallSummary
|
||||
from onyx.tools.models import MemoryToolResponseSnapshot
|
||||
from onyx.tools.models import PythonToolRichResponse
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
@@ -84,28 +84,6 @@ def _looks_like_xml_tool_call_payload(text: str | None) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def _should_keep_bedrock_tool_definitions(
|
||||
llm: object, simple_chat_history: list[ChatMessageSimple]
|
||||
) -> bool:
|
||||
"""Bedrock requires tool config when history includes toolUse/toolResult blocks."""
|
||||
model_provider = getattr(getattr(llm, "config", None), "model_provider", None)
|
||||
if model_provider not in {
|
||||
LlmProviderNames.BEDROCK,
|
||||
LlmProviderNames.BEDROCK_CONVERSE,
|
||||
}:
|
||||
return False
|
||||
|
||||
return any(
|
||||
(
|
||||
msg.message_type == MessageType.ASSISTANT
|
||||
and msg.tool_calls
|
||||
and len(msg.tool_calls) > 0
|
||||
)
|
||||
or msg.message_type == MessageType.TOOL_CALL_RESPONSE
|
||||
for msg in simple_chat_history
|
||||
)
|
||||
|
||||
|
||||
def _try_fallback_tool_extraction(
|
||||
llm_step_result: LlmStepResult,
|
||||
tool_choice: ToolChoiceOptions,
|
||||
@@ -686,12 +664,7 @@ def run_llm_loop(
|
||||
elif out_of_cycles or ran_image_gen:
|
||||
# Last cycle, no tools allowed, just answer!
|
||||
tool_choice = ToolChoiceOptions.NONE
|
||||
# Bedrock requires tool config in requests that include toolUse/toolResult history.
|
||||
final_tools = (
|
||||
tools
|
||||
if _should_keep_bedrock_tool_definitions(llm, simple_chat_history)
|
||||
else []
|
||||
)
|
||||
final_tools = []
|
||||
else:
|
||||
tool_choice = ToolChoiceOptions.AUTO
|
||||
final_tools = tools
|
||||
@@ -1008,6 +981,10 @@ def run_llm_loop(
|
||||
|
||||
if memory_snapshot:
|
||||
saved_response = json.dumps(memory_snapshot.model_dump())
|
||||
elif isinstance(tool_response.rich_response, CustomToolCallSummary):
|
||||
saved_response = json.dumps(
|
||||
tool_response.rich_response.model_dump()
|
||||
)
|
||||
elif isinstance(tool_response.rich_response, str):
|
||||
saved_response = tool_response.rich_response
|
||||
else:
|
||||
|
||||
@@ -15,6 +15,7 @@ from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.tool_call_args_streaming import maybe_emit_argument_delta
|
||||
from onyx.configs.app_configs import LOG_ONYX_MODEL_INTERACTIONS
|
||||
from onyx.configs.app_configs import PROMPT_CACHE_CHAT_HISTORY
|
||||
from onyx.configs.constants import MessageType
|
||||
@@ -54,7 +55,9 @@ from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tracing.framework.create import generation_span
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
from onyx.utils.jsonriver import Parser
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_string
|
||||
from onyx.utils.text_processing import find_all_json_objects
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -166,15 +169,6 @@ def _find_function_calls_open_marker(text_lower: str) -> int:
|
||||
search_from = idx + 1
|
||||
|
||||
|
||||
def _sanitize_llm_output(value: str) -> str:
|
||||
"""Remove characters that PostgreSQL's text/JSONB types cannot store.
|
||||
|
||||
- NULL bytes (\x00): Not allowed in PostgreSQL text types
|
||||
- UTF-16 surrogates (\ud800-\udfff): Invalid in UTF-8 encoding
|
||||
"""
|
||||
return "".join(c for c in value if c != "\x00" and not ("\ud800" <= c <= "\udfff"))
|
||||
|
||||
|
||||
def _try_parse_json_string(value: Any) -> Any:
|
||||
"""Attempt to parse a JSON string value into its Python equivalent.
|
||||
|
||||
@@ -222,9 +216,7 @@ def _parse_tool_args_to_dict(raw_args: Any) -> dict[str, Any]:
|
||||
if isinstance(raw_args, dict):
|
||||
# Parse any string values that look like JSON arrays/objects
|
||||
return {
|
||||
k: _try_parse_json_string(
|
||||
_sanitize_llm_output(v) if isinstance(v, str) else v
|
||||
)
|
||||
k: _try_parse_json_string(sanitize_string(v) if isinstance(v, str) else v)
|
||||
for k, v in raw_args.items()
|
||||
}
|
||||
|
||||
@@ -232,7 +224,7 @@ def _parse_tool_args_to_dict(raw_args: Any) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
# Sanitize before parsing to remove NULL bytes and surrogates
|
||||
raw_args = _sanitize_llm_output(raw_args)
|
||||
raw_args = sanitize_string(raw_args)
|
||||
|
||||
try:
|
||||
parsed1: Any = json.loads(raw_args)
|
||||
@@ -545,12 +537,12 @@ def _extract_xml_attribute(attrs: str, attr_name: str) -> str | None:
|
||||
)
|
||||
if not attr_match:
|
||||
return None
|
||||
return _sanitize_llm_output(unescape(attr_match.group(2).strip()))
|
||||
return sanitize_string(unescape(attr_match.group(2).strip()))
|
||||
|
||||
|
||||
def _parse_xml_parameter_value(raw_value: str, string_attr: str | None) -> Any:
|
||||
"""Parse a parameter value from XML-style tool call payloads."""
|
||||
value = _sanitize_llm_output(unescape(raw_value).strip())
|
||||
value = sanitize_string(unescape(raw_value).strip())
|
||||
|
||||
if string_attr and string_attr.lower() == "true":
|
||||
return value
|
||||
@@ -569,6 +561,7 @@ def _resolve_tool_arguments(obj: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""
|
||||
arguments = obj.get("arguments", obj.get("parameters", {}))
|
||||
if isinstance(arguments, str):
|
||||
arguments = sanitize_string(arguments)
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
@@ -1018,6 +1011,7 @@ def run_llm_step_pkt_generator(
|
||||
)
|
||||
|
||||
id_to_tool_call_map: dict[int, dict[str, Any]] = {}
|
||||
arg_parsers: dict[int, Parser] = {}
|
||||
reasoning_start = False
|
||||
answer_start = False
|
||||
accumulated_reasoning = ""
|
||||
@@ -1224,7 +1218,14 @@ def run_llm_step_pkt_generator(
|
||||
yield from _close_reasoning_if_active()
|
||||
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
# maybe_emit depends and update being called first and attaching the delta
|
||||
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
|
||||
yield from maybe_emit_argument_delta(
|
||||
tool_calls_in_progress=id_to_tool_call_map,
|
||||
tool_call_delta=tool_call_delta,
|
||||
placement=_current_placement(),
|
||||
parsers=arg_parsers,
|
||||
)
|
||||
|
||||
# Flush any tail text buffered while checking for split "<function_calls" markers.
|
||||
filtered_content_tail = xml_tool_call_content_filter.flush()
|
||||
|
||||
@@ -19,6 +19,7 @@ from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_string
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -201,8 +202,13 @@ def save_chat_turn(
|
||||
pre_answer_processing_time: Duration of processing before answer starts (in seconds)
|
||||
"""
|
||||
# 1. Update ChatMessage with message content, reasoning tokens, and token count
|
||||
assistant_message.message = message_text
|
||||
assistant_message.reasoning_tokens = reasoning_tokens
|
||||
sanitized_message_text = (
|
||||
sanitize_string(message_text) if message_text else message_text
|
||||
)
|
||||
assistant_message.message = sanitized_message_text
|
||||
assistant_message.reasoning_tokens = (
|
||||
sanitize_string(reasoning_tokens) if reasoning_tokens else reasoning_tokens
|
||||
)
|
||||
assistant_message.is_clarification = is_clarification
|
||||
|
||||
# Use pre-answer processing time (captured when MESSAGE_START was emitted)
|
||||
@@ -212,8 +218,10 @@ def save_chat_turn(
|
||||
# Calculate token count using default tokenizer, when storing, this should not use the LLM
|
||||
# specific one so we use a system default tokenizer here.
|
||||
default_tokenizer = get_tokenizer(None, None)
|
||||
if message_text:
|
||||
assistant_message.token_count = len(default_tokenizer.encode(message_text))
|
||||
if sanitized_message_text:
|
||||
assistant_message.token_count = len(
|
||||
default_tokenizer.encode(sanitized_message_text)
|
||||
)
|
||||
else:
|
||||
assistant_message.token_count = 0
|
||||
|
||||
@@ -328,8 +336,10 @@ def save_chat_turn(
|
||||
# 8. Attach code interpreter generated files that the assistant actually
|
||||
# referenced in its response, so they are available via load_all_chat_files
|
||||
# on subsequent turns. Files not mentioned are intermediate artifacts.
|
||||
if message_text:
|
||||
referenced = _extract_referenced_file_descriptors(tool_calls, message_text)
|
||||
if sanitized_message_text:
|
||||
referenced = _extract_referenced_file_descriptors(
|
||||
tool_calls, sanitized_message_text
|
||||
)
|
||||
if referenced:
|
||||
existing_files = assistant_message.files or []
|
||||
assistant_message.files = existing_files + referenced
|
||||
|
||||
77
backend/onyx/chat/tool_call_args_streaming.py
Normal file
77
backend/onyx/chat/tool_call_args_streaming.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import Type
|
||||
|
||||
from onyx.llm.model_response import ChatCompletionDeltaToolCall
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ToolCallArgumentDelta
|
||||
from onyx.tools.built_in_tools import TOOL_NAME_TO_CLASS
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.utils.jsonriver import Parser
|
||||
|
||||
|
||||
def _get_tool_class(
|
||||
tool_calls_in_progress: Mapping[int, Mapping[str, Any]],
|
||||
tool_call_delta: ChatCompletionDeltaToolCall,
|
||||
) -> Type[Tool] | None:
|
||||
"""Look up the Tool subclass for a streaming tool call delta."""
|
||||
tool_name = tool_calls_in_progress.get(tool_call_delta.index, {}).get("name")
|
||||
if not tool_name:
|
||||
return None
|
||||
return TOOL_NAME_TO_CLASS.get(tool_name)
|
||||
|
||||
|
||||
def maybe_emit_argument_delta(
|
||||
tool_calls_in_progress: Mapping[int, Mapping[str, Any]],
|
||||
tool_call_delta: ChatCompletionDeltaToolCall,
|
||||
placement: Placement,
|
||||
parsers: dict[int, Parser],
|
||||
) -> Generator[Packet, None, None]:
|
||||
"""Emit decoded tool-call argument deltas to the frontend.
|
||||
|
||||
Uses a ``jsonriver.Parser`` per tool-call index to incrementally parse
|
||||
the JSON argument string and extract only the newly-appended content
|
||||
for each string-valued argument.
|
||||
|
||||
NOTE: Non-string arguments (numbers, booleans, null, arrays, objects)
|
||||
are skipped — they are available in the final tool-call kickoff packet.
|
||||
|
||||
``parsers`` is a mutable dict keyed by tool-call index. A new
|
||||
``Parser`` is created automatically for each new index.
|
||||
"""
|
||||
tool_cls = _get_tool_class(tool_calls_in_progress, tool_call_delta)
|
||||
if not tool_cls or not tool_cls.should_emit_argument_deltas():
|
||||
return
|
||||
|
||||
fn = tool_call_delta.function
|
||||
delta_fragment = fn.arguments if fn else None
|
||||
if not delta_fragment:
|
||||
return
|
||||
|
||||
idx = tool_call_delta.index
|
||||
if idx not in parsers:
|
||||
parsers[idx] = Parser()
|
||||
parser = parsers[idx]
|
||||
|
||||
deltas = parser.feed(delta_fragment)
|
||||
|
||||
argument_deltas: dict[str, str] = {}
|
||||
for delta in deltas:
|
||||
if isinstance(delta, dict):
|
||||
for key, value in delta.items():
|
||||
if isinstance(value, str):
|
||||
argument_deltas[key] = argument_deltas.get(key, "") + value
|
||||
|
||||
if not argument_deltas:
|
||||
return
|
||||
|
||||
tc_data = tool_calls_in_progress[tool_call_delta.index]
|
||||
yield Packet(
|
||||
placement=placement,
|
||||
obj=ToolCallArgumentDelta(
|
||||
tool_type=tc_data.get("name", ""),
|
||||
argument_deltas=argument_deltas,
|
||||
),
|
||||
)
|
||||
@@ -68,6 +68,10 @@ FILE_TOKEN_COUNT_THRESHOLD = int(
|
||||
os.environ.get("FILE_TOKEN_COUNT_THRESHOLD", str(_DEFAULT_FILE_TOKEN_LIMIT))
|
||||
)
|
||||
|
||||
# Maximum upload size for a single user file (chat/projects) in MB.
|
||||
USER_FILE_MAX_UPLOAD_SIZE_MB = int(os.environ.get("USER_FILE_MAX_UPLOAD_SIZE_MB") or 50)
|
||||
USER_FILE_MAX_UPLOAD_SIZE_BYTES = USER_FILE_MAX_UPLOAD_SIZE_MB * 1024 * 1024
|
||||
|
||||
# If set to true, will show extra/uncommon connectors in the "Other" category
|
||||
SHOW_EXTRA_CONNECTORS = os.environ.get("SHOW_EXTRA_CONNECTORS", "").lower() == "true"
|
||||
|
||||
@@ -92,19 +96,12 @@ WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
|
||||
#####
|
||||
# Auth Configs
|
||||
#####
|
||||
# Upgrades users from disabled auth to basic auth and shows warning.
|
||||
_auth_type_str = (os.environ.get("AUTH_TYPE") or "basic").lower()
|
||||
if _auth_type_str == "disabled":
|
||||
logger.warning(
|
||||
"AUTH_TYPE='disabled' is no longer supported. "
|
||||
"Defaulting to 'basic'. Please update your configuration. "
|
||||
"Your existing data will be migrated automatically."
|
||||
)
|
||||
_auth_type_str = AuthType.BASIC.value
|
||||
try:
|
||||
# Silently default to basic - warnings/errors logged in verify_auth_setting()
|
||||
# which only runs on app startup, not during migrations/scripts
|
||||
_auth_type_str = (os.environ.get("AUTH_TYPE") or "").lower()
|
||||
if _auth_type_str in [auth_type.value for auth_type in AuthType]:
|
||||
AUTH_TYPE = AuthType(_auth_type_str)
|
||||
except ValueError:
|
||||
logger.error(f"Invalid AUTH_TYPE: {_auth_type_str}. Defaulting to 'basic'.")
|
||||
else:
|
||||
AUTH_TYPE = AuthType.BASIC
|
||||
|
||||
PASSWORD_MIN_LENGTH = int(os.getenv("PASSWORD_MIN_LENGTH", 8))
|
||||
@@ -207,6 +204,12 @@ JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)
|
||||
|
||||
USER_AUTH_SECRET = os.environ.get("USER_AUTH_SECRET", "")
|
||||
|
||||
if AUTH_TYPE == AuthType.BASIC and not USER_AUTH_SECRET:
|
||||
logger.warning(
|
||||
"USER_AUTH_SECRET is not set. This is required for secure password reset "
|
||||
"and email verification tokens. Please set USER_AUTH_SECRET in production."
|
||||
)
|
||||
|
||||
# Duration (in seconds) for which the FastAPI Users JWT token remains valid in the user's browser.
|
||||
# By default, this is set to match the Redis expiry time for consistency.
|
||||
AUTH_COOKIE_EXPIRE_TIME_SECONDS = int(
|
||||
@@ -288,8 +291,9 @@ OPENSEARCH_TEXT_ANALYZER = os.environ.get("OPENSEARCH_TEXT_ANALYZER") or "englis
|
||||
# environments we always want to be dual indexing into both OpenSearch and Vespa
|
||||
# to stress test the new codepaths. Only enable this if there is some instance
|
||||
# of OpenSearch running for the relevant Onyx instance.
|
||||
# NOTE: Now enabled on by default, unless the env indicates otherwise.
|
||||
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX = (
|
||||
os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "").lower() == "true"
|
||||
os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "true").lower() == "true"
|
||||
)
|
||||
# NOTE: This effectively does nothing anymore, admins can now toggle whether
|
||||
# retrieval is through OpenSearch. This value is only used as a final fallback
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncIterable
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime
|
||||
@@ -204,7 +205,7 @@ def _manage_async_retrieval(
|
||||
|
||||
end_time: datetime | None = end
|
||||
|
||||
async def _async_fetch() -> AsyncIterable[Document]:
|
||||
async def _async_fetch() -> AsyncGenerator[Document, None]:
|
||||
intents = Intents.default()
|
||||
intents.message_content = True
|
||||
async with Client(intents=intents) as discord_client:
|
||||
@@ -227,22 +228,23 @@ def _manage_async_retrieval(
|
||||
|
||||
def run_and_yield() -> Iterable[Document]:
|
||||
loop = asyncio.new_event_loop()
|
||||
async_gen = _async_fetch()
|
||||
try:
|
||||
# Get the async generator
|
||||
async_gen = _async_fetch()
|
||||
# Convert to AsyncIterator
|
||||
async_iter = async_gen.__aiter__()
|
||||
while True:
|
||||
try:
|
||||
# Create a coroutine by calling anext with the async iterator
|
||||
next_coro = anext(async_iter)
|
||||
# Run the coroutine to get the next document
|
||||
doc = loop.run_until_complete(next_coro)
|
||||
doc = loop.run_until_complete(anext(async_gen))
|
||||
yield doc
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
finally:
|
||||
loop.close()
|
||||
# Must close the async generator before the loop so the Discord
|
||||
# client's `async with` block can await its shutdown coroutine.
|
||||
# The nested try/finally ensures the loop always closes even if
|
||||
# aclose() raises (same pattern as cursor.close() before conn.close()).
|
||||
try:
|
||||
loop.run_until_complete(async_gen.aclose())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
return run_and_yield()
|
||||
|
||||
|
||||
@@ -1722,6 +1722,7 @@ class GoogleDriveConnector(
|
||||
primary_admin_email=self.primary_admin_email,
|
||||
google_domain=self.google_domain,
|
||||
),
|
||||
retriever_email=file.user_email,
|
||||
):
|
||||
slim_batch.append(doc)
|
||||
|
||||
|
||||
@@ -476,6 +476,7 @@ def _get_external_access_for_raw_gdrive_file(
|
||||
company_domain: str,
|
||||
retriever_drive_service: GoogleDriveService | None,
|
||||
admin_drive_service: GoogleDriveService,
|
||||
fallback_user_email: str,
|
||||
add_prefix: bool = False,
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
@@ -484,6 +485,8 @@ def _get_external_access_for_raw_gdrive_file(
|
||||
add_prefix: When True, prefix group IDs with source type (for indexing path).
|
||||
When False (default), leave unprefixed (for permission sync path
|
||||
where upsert_document_external_perms handles prefixing).
|
||||
fallback_user_email: When permission info can't be retrieved (e.g. externally-owned
|
||||
files), fall back to granting access to this user.
|
||||
"""
|
||||
external_access_fn = cast(
|
||||
Callable[
|
||||
@@ -492,6 +495,7 @@ def _get_external_access_for_raw_gdrive_file(
|
||||
str,
|
||||
GoogleDriveService | None,
|
||||
GoogleDriveService,
|
||||
str,
|
||||
bool,
|
||||
],
|
||||
ExternalAccess,
|
||||
@@ -507,6 +511,7 @@ def _get_external_access_for_raw_gdrive_file(
|
||||
company_domain,
|
||||
retriever_drive_service,
|
||||
admin_drive_service,
|
||||
fallback_user_email,
|
||||
add_prefix,
|
||||
)
|
||||
|
||||
@@ -672,6 +677,7 @@ def _convert_drive_item_to_document(
|
||||
creds, user_email=permission_sync_context.primary_admin_email
|
||||
),
|
||||
add_prefix=True, # Indexing path - prefix here
|
||||
fallback_user_email=retriever_email,
|
||||
)
|
||||
if permission_sync_context
|
||||
else None
|
||||
@@ -753,6 +759,7 @@ def build_slim_document(
|
||||
# if not specified, we will not sync permissions
|
||||
# will also be a no-op if EE is not enabled
|
||||
permission_sync_context: PermissionSyncContext | None,
|
||||
retriever_email: str,
|
||||
) -> SlimDocument | None:
|
||||
if file.get("mimeType") in [DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE]:
|
||||
return None
|
||||
@@ -774,6 +781,7 @@ def build_slim_document(
|
||||
creds,
|
||||
user_email=permission_sync_context.primary_admin_email,
|
||||
),
|
||||
fallback_user_email=retriever_email,
|
||||
)
|
||||
if permission_sync_context
|
||||
else None
|
||||
|
||||
@@ -44,6 +44,7 @@ from onyx.connectors.google_utils.shared_constants import (
|
||||
from onyx.db.credentials import update_credential_json
|
||||
from onyx.db.models import User
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import unwrap_str
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from onyx.server.documents.models import GoogleAppCredentials
|
||||
from onyx.server.documents.models import GoogleServiceAccountKey
|
||||
@@ -89,7 +90,7 @@ def _get_current_oauth_user(creds: OAuthCredentials, source: DocumentSource) ->
|
||||
|
||||
|
||||
def verify_csrf(credential_id: int, state: str) -> None:
|
||||
csrf = get_kv_store().load(KV_CRED_KEY.format(str(credential_id)))
|
||||
csrf = unwrap_str(get_kv_store().load(KV_CRED_KEY.format(str(credential_id))))
|
||||
if csrf != state:
|
||||
raise PermissionError(
|
||||
"State from Google Drive Connector callback does not match expected"
|
||||
@@ -178,7 +179,9 @@ def get_auth_url(credential_id: int, source: DocumentSource) -> str:
|
||||
params = parse_qs(parsed_url.query)
|
||||
|
||||
get_kv_store().store(
|
||||
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
|
||||
KV_CRED_KEY.format(credential_id),
|
||||
{"value": params.get("state", [None])[0]},
|
||||
encrypt=True,
|
||||
)
|
||||
return str(auth_url)
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.override_models import PromptOverride
|
||||
from onyx.server.query_and_chat.models import ChatMessageDetail
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_string
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -675,58 +676,43 @@ def set_as_latest_chat_message(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _sanitize_for_postgres(value: str) -> str:
|
||||
"""Remove NUL (0x00) characters from strings as PostgreSQL doesn't allow them."""
|
||||
sanitized = value.replace("\x00", "")
|
||||
if value and not sanitized:
|
||||
logger.warning("Sanitization removed all characters from string")
|
||||
return sanitized
|
||||
|
||||
|
||||
def _sanitize_list_for_postgres(values: list[str]) -> list[str]:
|
||||
"""Remove NUL (0x00) characters from all strings in a list."""
|
||||
return [_sanitize_for_postgres(v) for v in values]
|
||||
|
||||
|
||||
def create_db_search_doc(
|
||||
server_search_doc: ServerSearchDoc,
|
||||
db_session: Session,
|
||||
commit: bool = True,
|
||||
) -> DBSearchDoc:
|
||||
# Sanitize string fields to remove NUL characters (PostgreSQL doesn't allow them)
|
||||
db_search_doc = DBSearchDoc(
|
||||
document_id=_sanitize_for_postgres(server_search_doc.document_id),
|
||||
document_id=sanitize_string(server_search_doc.document_id),
|
||||
chunk_ind=server_search_doc.chunk_ind,
|
||||
semantic_id=_sanitize_for_postgres(server_search_doc.semantic_identifier),
|
||||
semantic_id=sanitize_string(server_search_doc.semantic_identifier),
|
||||
link=(
|
||||
_sanitize_for_postgres(server_search_doc.link)
|
||||
sanitize_string(server_search_doc.link)
|
||||
if server_search_doc.link is not None
|
||||
else None
|
||||
),
|
||||
blurb=_sanitize_for_postgres(server_search_doc.blurb),
|
||||
blurb=sanitize_string(server_search_doc.blurb),
|
||||
source_type=server_search_doc.source_type,
|
||||
boost=server_search_doc.boost,
|
||||
hidden=server_search_doc.hidden,
|
||||
doc_metadata=server_search_doc.metadata,
|
||||
is_relevant=server_search_doc.is_relevant,
|
||||
relevance_explanation=(
|
||||
_sanitize_for_postgres(server_search_doc.relevance_explanation)
|
||||
sanitize_string(server_search_doc.relevance_explanation)
|
||||
if server_search_doc.relevance_explanation is not None
|
||||
else None
|
||||
),
|
||||
# For docs further down that aren't reranked, we can't use the retrieval score
|
||||
score=server_search_doc.score or 0.0,
|
||||
match_highlights=_sanitize_list_for_postgres(
|
||||
server_search_doc.match_highlights
|
||||
),
|
||||
match_highlights=[
|
||||
sanitize_string(h) for h in server_search_doc.match_highlights
|
||||
],
|
||||
updated_at=server_search_doc.updated_at,
|
||||
primary_owners=(
|
||||
_sanitize_list_for_postgres(server_search_doc.primary_owners)
|
||||
[sanitize_string(o) for o in server_search_doc.primary_owners]
|
||||
if server_search_doc.primary_owners is not None
|
||||
else None
|
||||
),
|
||||
secondary_owners=(
|
||||
_sanitize_list_for_postgres(server_search_doc.secondary_owners)
|
||||
[sanitize_string(o) for o in server_search_doc.secondary_owners]
|
||||
if server_search_doc.secondary_owners is not None
|
||||
else None
|
||||
),
|
||||
|
||||
@@ -2,7 +2,10 @@
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -10,6 +13,7 @@ from onyx.connectors.models import HierarchyNode as PydanticHierarchyNode
|
||||
from onyx.db.enums import HierarchyNodeType
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import HierarchyNode
|
||||
from onyx.db.models import HierarchyNodeByConnectorCredentialPair
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
@@ -458,7 +462,7 @@ def get_all_hierarchy_nodes_for_source(
|
||||
def _get_accessible_hierarchy_nodes_for_source(
|
||||
db_session: Session,
|
||||
source: DocumentSource,
|
||||
user_email: str | None, # noqa: ARG001
|
||||
user_email: str, # noqa: ARG001
|
||||
external_group_ids: list[str], # noqa: ARG001
|
||||
) -> list[HierarchyNode]:
|
||||
"""
|
||||
@@ -485,7 +489,7 @@ def _get_accessible_hierarchy_nodes_for_source(
|
||||
def get_accessible_hierarchy_nodes_for_source(
|
||||
db_session: Session,
|
||||
source: DocumentSource,
|
||||
user_email: str | None,
|
||||
user_email: str,
|
||||
external_group_ids: list[str],
|
||||
) -> list[HierarchyNode]:
|
||||
"""
|
||||
@@ -620,3 +624,154 @@ def update_hierarchy_node_permissions(
|
||||
db_session.flush()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def upsert_hierarchy_node_cc_pair_entries(
|
||||
db_session: Session,
|
||||
hierarchy_node_ids: list[int],
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
commit: bool = True,
|
||||
) -> None:
|
||||
"""Insert rows into HierarchyNodeByConnectorCredentialPair, ignoring conflicts.
|
||||
|
||||
This records that the given cc_pair "owns" these hierarchy nodes. Used by
|
||||
indexing, pruning, and hierarchy-fetching paths.
|
||||
"""
|
||||
if not hierarchy_node_ids:
|
||||
return
|
||||
|
||||
_M = HierarchyNodeByConnectorCredentialPair
|
||||
stmt = pg_insert(_M).values(
|
||||
[
|
||||
{
|
||||
_M.hierarchy_node_id: node_id,
|
||||
_M.connector_id: connector_id,
|
||||
_M.credential_id: credential_id,
|
||||
}
|
||||
for node_id in hierarchy_node_ids
|
||||
]
|
||||
)
|
||||
stmt = stmt.on_conflict_do_nothing()
|
||||
db_session.execute(stmt)
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
else:
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def remove_stale_hierarchy_node_cc_pair_entries(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
live_hierarchy_node_ids: set[int],
|
||||
commit: bool = True,
|
||||
) -> int:
|
||||
"""Delete join-table rows for this cc_pair that are NOT in the live set.
|
||||
|
||||
If ``live_hierarchy_node_ids`` is empty ALL rows for the cc_pair are deleted
|
||||
(i.e. the connector no longer has any hierarchy nodes). Callers that want a
|
||||
no-op when there are no live nodes must guard before calling.
|
||||
|
||||
Returns the number of deleted rows.
|
||||
"""
|
||||
stmt = delete(HierarchyNodeByConnectorCredentialPair).where(
|
||||
HierarchyNodeByConnectorCredentialPair.connector_id == connector_id,
|
||||
HierarchyNodeByConnectorCredentialPair.credential_id == credential_id,
|
||||
)
|
||||
if live_hierarchy_node_ids:
|
||||
stmt = stmt.where(
|
||||
HierarchyNodeByConnectorCredentialPair.hierarchy_node_id.notin_(
|
||||
live_hierarchy_node_ids
|
||||
)
|
||||
)
|
||||
|
||||
result: CursorResult = db_session.execute(stmt) # type: ignore[assignment]
|
||||
deleted = result.rowcount
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
elif deleted:
|
||||
db_session.flush()
|
||||
|
||||
return deleted
|
||||
|
||||
|
||||
def delete_orphaned_hierarchy_nodes(
|
||||
db_session: Session,
|
||||
source: DocumentSource,
|
||||
commit: bool = True,
|
||||
) -> list[str]:
|
||||
"""Delete hierarchy nodes for a source that have zero cc_pair associations.
|
||||
|
||||
SOURCE-type nodes are excluded (they are synthetic roots).
|
||||
|
||||
Returns the list of raw_node_ids that were deleted (for cache eviction).
|
||||
"""
|
||||
# Find orphaned nodes: no rows in the join table
|
||||
orphan_stmt = (
|
||||
select(HierarchyNode.id, HierarchyNode.raw_node_id)
|
||||
.outerjoin(
|
||||
HierarchyNodeByConnectorCredentialPair,
|
||||
HierarchyNode.id
|
||||
== HierarchyNodeByConnectorCredentialPair.hierarchy_node_id,
|
||||
)
|
||||
.where(
|
||||
HierarchyNode.source == source,
|
||||
HierarchyNode.node_type != HierarchyNodeType.SOURCE,
|
||||
HierarchyNodeByConnectorCredentialPair.hierarchy_node_id.is_(None),
|
||||
)
|
||||
)
|
||||
orphans = db_session.execute(orphan_stmt).all()
|
||||
if not orphans:
|
||||
return []
|
||||
|
||||
orphan_ids = [row[0] for row in orphans]
|
||||
deleted_raw_ids = [row[1] for row in orphans]
|
||||
|
||||
db_session.execute(delete(HierarchyNode).where(HierarchyNode.id.in_(orphan_ids)))
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
else:
|
||||
db_session.flush()
|
||||
|
||||
return deleted_raw_ids
|
||||
|
||||
|
||||
def reparent_orphaned_hierarchy_nodes(
|
||||
db_session: Session,
|
||||
source: DocumentSource,
|
||||
commit: bool = True,
|
||||
) -> list[HierarchyNode]:
|
||||
"""Re-parent hierarchy nodes whose parent_id is NULL to the SOURCE node.
|
||||
|
||||
After pruning deletes stale nodes, their former children get parent_id=NULL
|
||||
via the SET NULL cascade. This function points them back to the SOURCE root.
|
||||
|
||||
Returns the reparented HierarchyNode objects (with updated parent_id)
|
||||
so callers can refresh downstream caches.
|
||||
"""
|
||||
source_node = get_source_hierarchy_node(db_session, source)
|
||||
if not source_node:
|
||||
return []
|
||||
|
||||
stmt = select(HierarchyNode).where(
|
||||
HierarchyNode.source == source,
|
||||
HierarchyNode.parent_id.is_(None),
|
||||
HierarchyNode.node_type != HierarchyNodeType.SOURCE,
|
||||
)
|
||||
orphans = list(db_session.execute(stmt).scalars().all())
|
||||
if not orphans:
|
||||
return []
|
||||
|
||||
for node in orphans:
|
||||
node.parent_id = source_node.id
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
else:
|
||||
db_session.flush()
|
||||
|
||||
return orphans
|
||||
|
||||
@@ -25,8 +25,11 @@ from onyx.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def update_group_llm_provider_relationships__no_commit(
|
||||
llm_provider_id: int,
|
||||
@@ -267,10 +270,35 @@ def upsert_llm_provider(
|
||||
mc.name for mc in llm_provider_upsert_request.model_configurations
|
||||
}
|
||||
|
||||
# Build a lookup of requested visibility by model name
|
||||
requested_visibility = {
|
||||
mc.name: mc.is_visible
|
||||
for mc in llm_provider_upsert_request.model_configurations
|
||||
}
|
||||
|
||||
# Delete removed models
|
||||
removed_ids = [
|
||||
mc.id for name, mc in existing_by_name.items() if name not in models_to_exist
|
||||
]
|
||||
|
||||
default_model = fetch_default_llm_model(db_session)
|
||||
|
||||
# Prevent removing and hiding the default model
|
||||
if default_model:
|
||||
for name, mc in existing_by_name.items():
|
||||
if mc.id == default_model.id:
|
||||
if default_model.id in removed_ids:
|
||||
raise ValueError(
|
||||
f"Cannot remove the default model '{name}'. "
|
||||
"Please change the default model before removing."
|
||||
)
|
||||
if not requested_visibility.get(name, True):
|
||||
raise ValueError(
|
||||
f"Cannot hide the default model '{name}'. "
|
||||
"Please change the default model before hiding."
|
||||
)
|
||||
break
|
||||
|
||||
if removed_ids:
|
||||
db_session.query(ModelConfiguration).filter(
|
||||
ModelConfiguration.id.in_(removed_ids)
|
||||
@@ -535,7 +563,6 @@ def fetch_default_model(
|
||||
.options(selectinload(ModelConfiguration.llm_provider))
|
||||
.join(LLMModelFlow)
|
||||
.where(
|
||||
ModelConfiguration.is_visible == True, # noqa: E712
|
||||
LLMModelFlow.llm_model_flow_type == flow_type,
|
||||
LLMModelFlow.is_default == True, # noqa: E712
|
||||
)
|
||||
@@ -811,6 +838,29 @@ def sync_auto_mode_models(
|
||||
)
|
||||
changes += 1
|
||||
|
||||
# Update the default if this provider currently holds the global CHAT default.
|
||||
# We flush (but don't commit) so that _update_default_model can see the new
|
||||
# model rows, then commit everything atomically to avoid a window where the
|
||||
# old default is invisible but still pointed-to.
|
||||
db_session.flush()
|
||||
|
||||
recommended_default = llm_recommendations.get_default_model(provider.provider)
|
||||
if recommended_default:
|
||||
current_default = fetch_default_llm_model(db_session)
|
||||
|
||||
if (
|
||||
current_default
|
||||
and current_default.llm_provider_id == provider.id
|
||||
and current_default.name != recommended_default.name
|
||||
):
|
||||
_update_default_model__no_commit(
|
||||
db_session=db_session,
|
||||
provider_id=provider.id,
|
||||
model=recommended_default.name,
|
||||
flow_type=LLMModelFlowType.CHAT,
|
||||
)
|
||||
changes += 1
|
||||
|
||||
db_session.commit()
|
||||
return changes
|
||||
|
||||
@@ -942,7 +992,7 @@ def update_model_configuration__no_commit(
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def _update_default_model(
|
||||
def _update_default_model__no_commit(
|
||||
db_session: Session,
|
||||
provider_id: int,
|
||||
model: str,
|
||||
@@ -980,6 +1030,14 @@ def _update_default_model(
|
||||
new_default.is_default = True
|
||||
model_config.is_visible = True
|
||||
|
||||
|
||||
def _update_default_model(
|
||||
db_session: Session,
|
||||
provider_id: int,
|
||||
model: str,
|
||||
flow_type: LLMModelFlowType,
|
||||
) -> None:
|
||||
_update_default_model__no_commit(db_session, provider_id, model, flow_type)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ from sqlalchemy import desc
|
||||
from sqlalchemy import Enum
|
||||
from sqlalchemy import Float
|
||||
from sqlalchemy import ForeignKey
|
||||
from sqlalchemy import ForeignKeyConstraint
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import Index
|
||||
from sqlalchemy import Integer
|
||||
@@ -36,9 +37,11 @@ from sqlalchemy import Text
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import UniqueConstraint
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.engine.interfaces import Dialect
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy.orm import Mapped
|
||||
from sqlalchemy.orm import Mapper
|
||||
from sqlalchemy.orm import mapped_column
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.types import LargeBinary
|
||||
@@ -117,10 +120,50 @@ class Base(DeclarativeBase):
|
||||
__abstract__ = True
|
||||
|
||||
|
||||
class EncryptedString(TypeDecorator):
|
||||
class _EncryptedBase(TypeDecorator):
|
||||
"""Base for encrypted column types that wrap values in SensitiveValue."""
|
||||
|
||||
impl = LargeBinary
|
||||
# This type's behavior is fully deterministic and doesn't depend on any external factors.
|
||||
cache_ok = True
|
||||
_is_json: bool = False
|
||||
|
||||
def wrap_raw(self, value: Any) -> SensitiveValue:
|
||||
"""Encrypt a raw value and wrap it in SensitiveValue.
|
||||
|
||||
Called by the attribute set event so the Python-side type is always
|
||||
SensitiveValue, regardless of whether the value was loaded from the DB
|
||||
or assigned in application code.
|
||||
"""
|
||||
if self._is_json:
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError(
|
||||
f"EncryptedJson column expected dict, got {type(value).__name__}"
|
||||
)
|
||||
raw_str = json.dumps(value)
|
||||
else:
|
||||
if not isinstance(value, str):
|
||||
raise TypeError(
|
||||
f"EncryptedString column expected str, got {type(value).__name__}"
|
||||
)
|
||||
raw_str = value
|
||||
return SensitiveValue(
|
||||
encrypted_bytes=encrypt_string_to_bytes(raw_str),
|
||||
decrypt_fn=decrypt_bytes_to_string,
|
||||
is_json=self._is_json,
|
||||
)
|
||||
|
||||
def compare_values(self, x: Any, y: Any) -> bool:
|
||||
if x is None or y is None:
|
||||
return x == y
|
||||
if isinstance(x, SensitiveValue):
|
||||
x = x.get_value(apply_mask=False)
|
||||
if isinstance(y, SensitiveValue):
|
||||
y = y.get_value(apply_mask=False)
|
||||
return x == y
|
||||
|
||||
|
||||
class EncryptedString(_EncryptedBase):
|
||||
_is_json: bool = False
|
||||
|
||||
def process_bind_param(
|
||||
self, value: str | SensitiveValue[str] | None, dialect: Dialect # noqa: ARG002
|
||||
@@ -144,20 +187,9 @@ class EncryptedString(TypeDecorator):
|
||||
)
|
||||
return None
|
||||
|
||||
def compare_values(self, x: Any, y: Any) -> bool:
|
||||
if x is None or y is None:
|
||||
return x == y
|
||||
if isinstance(x, SensitiveValue):
|
||||
x = x.get_value(apply_mask=False)
|
||||
if isinstance(y, SensitiveValue):
|
||||
y = y.get_value(apply_mask=False)
|
||||
return x == y
|
||||
|
||||
|
||||
class EncryptedJson(TypeDecorator):
|
||||
impl = LargeBinary
|
||||
# This type's behavior is fully deterministic and doesn't depend on any external factors.
|
||||
cache_ok = True
|
||||
class EncryptedJson(_EncryptedBase):
|
||||
_is_json: bool = True
|
||||
|
||||
def process_bind_param(
|
||||
self,
|
||||
@@ -165,9 +197,7 @@ class EncryptedJson(TypeDecorator):
|
||||
dialect: Dialect, # noqa: ARG002
|
||||
) -> bytes | None:
|
||||
if value is not None:
|
||||
# Handle both raw dicts and SensitiveValue wrappers
|
||||
if isinstance(value, SensitiveValue):
|
||||
# Get raw value for storage
|
||||
value = value.get_value(apply_mask=False)
|
||||
json_str = json.dumps(value)
|
||||
return encrypt_string_to_bytes(json_str)
|
||||
@@ -184,14 +214,40 @@ class EncryptedJson(TypeDecorator):
|
||||
)
|
||||
return None
|
||||
|
||||
def compare_values(self, x: Any, y: Any) -> bool:
|
||||
if x is None or y is None:
|
||||
return x == y
|
||||
if isinstance(x, SensitiveValue):
|
||||
x = x.get_value(apply_mask=False)
|
||||
if isinstance(y, SensitiveValue):
|
||||
y = y.get_value(apply_mask=False)
|
||||
return x == y
|
||||
|
||||
_REGISTERED_ATTRS: set[str] = set()
|
||||
|
||||
|
||||
@event.listens_for(Mapper, "mapper_configured")
|
||||
def _register_sensitive_value_set_events(
|
||||
mapper: Mapper,
|
||||
class_: type,
|
||||
) -> None:
|
||||
"""Auto-wrap raw values in SensitiveValue when assigned to encrypted columns."""
|
||||
for prop in mapper.column_attrs:
|
||||
for col in prop.columns:
|
||||
if isinstance(col.type, _EncryptedBase):
|
||||
col_type = col.type
|
||||
attr = getattr(class_, prop.key)
|
||||
|
||||
# Guard against double-registration (e.g. if mapper is
|
||||
# re-configured in test setups)
|
||||
attr_key = f"{class_.__qualname__}.{prop.key}"
|
||||
if attr_key in _REGISTERED_ATTRS:
|
||||
continue
|
||||
_REGISTERED_ATTRS.add(attr_key)
|
||||
|
||||
@event.listens_for(attr, "set", retval=True)
|
||||
def _wrap_value(
|
||||
target: Any, # noqa: ARG001
|
||||
value: Any,
|
||||
oldvalue: Any, # noqa: ARG001
|
||||
initiator: Any, # noqa: ARG001
|
||||
_col_type: _EncryptedBase = col_type,
|
||||
) -> Any:
|
||||
if value is not None and not isinstance(value, SensitiveValue):
|
||||
return _col_type.wrap_raw(value)
|
||||
return value
|
||||
|
||||
|
||||
class NullFilteredString(TypeDecorator):
|
||||
@@ -2370,6 +2426,38 @@ class SyncRecord(Base):
|
||||
)
|
||||
|
||||
|
||||
class HierarchyNodeByConnectorCredentialPair(Base):
|
||||
"""Tracks which cc_pairs reference each hierarchy node.
|
||||
|
||||
During pruning, stale entries are removed for the current cc_pair.
|
||||
Hierarchy nodes with zero remaining entries are then deleted.
|
||||
"""
|
||||
|
||||
__tablename__ = "hierarchy_node_by_connector_credential_pair"
|
||||
|
||||
hierarchy_node_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("hierarchy_node.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
connector_id: Mapped[int] = mapped_column(primary_key=True)
|
||||
credential_id: Mapped[int] = mapped_column(primary_key=True)
|
||||
|
||||
__table_args__ = (
|
||||
ForeignKeyConstraint(
|
||||
["connector_id", "credential_id"],
|
||||
[
|
||||
"connector_credential_pair.connector_id",
|
||||
"connector_credential_pair.credential_id",
|
||||
],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
Index(
|
||||
"ix_hierarchy_node_cc_pair_connector_credential",
|
||||
"connector_id",
|
||||
"credential_id",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class DocumentByConnectorCredentialPair(Base):
|
||||
"""Represents an indexing of a document by a specific connector / credential pair"""
|
||||
|
||||
|
||||
@@ -205,7 +205,9 @@ def update_persona_access(
|
||||
|
||||
NOTE: Callers are responsible for committing."""
|
||||
|
||||
needs_sync = False
|
||||
if is_public is not None:
|
||||
needs_sync = True
|
||||
persona = db_session.query(Persona).filter(Persona.id == persona_id).first()
|
||||
if persona:
|
||||
persona.is_public = is_public
|
||||
@@ -213,6 +215,7 @@ def update_persona_access(
|
||||
# NOTE: For user-ids and group-ids, `None` means "leave unchanged", `[]` means "clear all shares",
|
||||
# and a non-empty list means "replace with these shares".
|
||||
if user_ids is not None:
|
||||
needs_sync = True
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
@@ -233,6 +236,7 @@ def update_persona_access(
|
||||
# MIT doesn't support group-based sharing, so we allow clearing (no-op since
|
||||
# there shouldn't be any) but raise an error if trying to add actual groups.
|
||||
if group_ids is not None:
|
||||
needs_sync = True
|
||||
db_session.query(Persona__UserGroup).filter(
|
||||
Persona__UserGroup.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
@@ -240,6 +244,10 @@ def update_persona_access(
|
||||
if group_ids:
|
||||
raise NotImplementedError("Onyx MIT does not support group-based sharing")
|
||||
|
||||
# When sharing changes, user file ACLs need to be updated in the vector DB
|
||||
if needs_sync:
|
||||
mark_persona_user_files_for_sync(persona_id, db_session)
|
||||
|
||||
|
||||
def create_update_persona(
|
||||
persona_id: int | None,
|
||||
@@ -851,6 +859,24 @@ def update_personas_display_priority(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def mark_persona_user_files_for_sync(
|
||||
persona_id: int,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""When persona sharing changes, mark all of its user files for sync
|
||||
so that their ACLs get updated in the vector DB."""
|
||||
persona = (
|
||||
db_session.query(Persona)
|
||||
.options(selectinload(Persona.user_files))
|
||||
.filter(Persona.id == persona_id)
|
||||
.first()
|
||||
)
|
||||
if not persona:
|
||||
return
|
||||
file_ids = [uf.id for uf in persona.user_files]
|
||||
_mark_files_need_persona_sync(db_session, file_ids)
|
||||
|
||||
|
||||
def _mark_files_need_persona_sync(
|
||||
db_session: Session,
|
||||
user_file_ids: list[UUID],
|
||||
|
||||
161
backend/onyx/db/rotate_encryption_key.py
Normal file
161
backend/onyx/db/rotate_encryption_key.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Rotate encryption key for all encrypted columns.
|
||||
|
||||
Dynamically discovers all columns using EncryptedString / EncryptedJson,
|
||||
decrypts each value with the old key, and re-encrypts with the current
|
||||
ENCRYPTION_KEY_SECRET.
|
||||
|
||||
The operation is idempotent: rows already encrypted with the current key
|
||||
are skipped. Commits are made in batches so a crash mid-rotation can be
|
||||
safely resumed by re-running.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import LargeBinary
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import ENCRYPTION_KEY_SECRET
|
||||
from onyx.db.models import Base
|
||||
from onyx.db.models import EncryptedJson
|
||||
from onyx.db.models import EncryptedString
|
||||
from onyx.utils.encryption import decrypt_bytes_to_string
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_BATCH_SIZE = 500
|
||||
|
||||
|
||||
def _can_decrypt_with_current_key(data: bytes) -> bool:
|
||||
"""Check if data is already encrypted with the current key.
|
||||
|
||||
Passes the key explicitly so the fallback-to-raw-decode path in
|
||||
_decrypt_bytes is NOT triggered — a clean success/failure signal.
|
||||
"""
|
||||
try:
|
||||
decrypt_bytes_to_string(data, key=ENCRYPTION_KEY_SECRET)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _discover_encrypted_columns() -> list[tuple[type, str, list[str], bool]]:
|
||||
"""Walk all ORM models and find columns using EncryptedString/EncryptedJson.
|
||||
|
||||
Returns list of (ModelClass, column_attr_name, [pk_attr_names], is_json).
|
||||
"""
|
||||
results: list[tuple[type, str, list[str], bool]] = []
|
||||
|
||||
for mapper in Base.registry.mappers:
|
||||
model_cls = mapper.class_
|
||||
pk_names = [col.key for col in mapper.primary_key]
|
||||
|
||||
for prop in mapper.column_attrs:
|
||||
for col in prop.columns:
|
||||
if isinstance(col.type, EncryptedJson):
|
||||
results.append((model_cls, prop.key, pk_names, True))
|
||||
elif isinstance(col.type, EncryptedString):
|
||||
results.append((model_cls, prop.key, pk_names, False))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def rotate_encryption_key(
|
||||
db_session: Session,
|
||||
old_key: str | None,
|
||||
dry_run: bool = False,
|
||||
) -> dict[str, int]:
|
||||
"""Decrypt all encrypted columns with old_key and re-encrypt with the current key.
|
||||
|
||||
Args:
|
||||
db_session: Active database session.
|
||||
old_key: The previous encryption key. Pass None or "" if values were
|
||||
not previously encrypted with a key.
|
||||
dry_run: If True, count rows that need rotation without modifying data.
|
||||
|
||||
Returns:
|
||||
Dict of "table.column" -> number of rows re-encrypted (or would be).
|
||||
|
||||
Commits every _BATCH_SIZE rows so that locks are held briefly and progress
|
||||
is preserved on crash. Already-rotated rows are detected and skipped,
|
||||
making the operation safe to re-run.
|
||||
"""
|
||||
if not global_version.is_ee_version():
|
||||
raise RuntimeError("EE mode is not enabled — rotation requires EE encryption.")
|
||||
|
||||
if not ENCRYPTION_KEY_SECRET:
|
||||
raise RuntimeError(
|
||||
"ENCRYPTION_KEY_SECRET is not set — cannot rotate. "
|
||||
"Set the target encryption key in the environment before running."
|
||||
)
|
||||
|
||||
encrypted_columns = _discover_encrypted_columns()
|
||||
totals: dict[str, int] = {}
|
||||
|
||||
for model_cls, col_name, pk_names, is_json in encrypted_columns:
|
||||
table_name: str = model_cls.__tablename__ # type: ignore[attr-defined]
|
||||
col_attr = getattr(model_cls, col_name)
|
||||
pk_attrs = [getattr(model_cls, pk) for pk in pk_names]
|
||||
|
||||
# Read raw bytes directly, bypassing the TypeDecorator
|
||||
raw_col = col_attr.property.columns[0]
|
||||
|
||||
stmt = select(*pk_attrs, raw_col.cast(LargeBinary)).where(col_attr.is_not(None))
|
||||
rows = db_session.execute(stmt).all()
|
||||
|
||||
reencrypted = 0
|
||||
batch_pending = 0
|
||||
for row in rows:
|
||||
raw_bytes: bytes | None = row[-1]
|
||||
if raw_bytes is None:
|
||||
continue
|
||||
|
||||
if _can_decrypt_with_current_key(raw_bytes):
|
||||
continue
|
||||
|
||||
try:
|
||||
if not old_key:
|
||||
decrypted_str = raw_bytes.decode("utf-8")
|
||||
else:
|
||||
decrypted_str = decrypt_bytes_to_string(raw_bytes, key=old_key)
|
||||
|
||||
# For EncryptedJson, parse back to dict so the TypeDecorator
|
||||
# can json.dumps() it cleanly (avoids double-encoding).
|
||||
value: Any = json.loads(decrypted_str) if is_json else decrypted_str
|
||||
except (ValueError, UnicodeDecodeError) as e:
|
||||
pk_vals = [row[i] for i in range(len(pk_names))]
|
||||
logger.warning(
|
||||
f"Could not decrypt/parse {table_name}.{col_name} "
|
||||
f"row {pk_vals} — skipping: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not dry_run:
|
||||
pk_filters = [pk_attr == row[i] for i, pk_attr in enumerate(pk_attrs)]
|
||||
update_stmt = (
|
||||
update(model_cls).where(*pk_filters).values({col_name: value})
|
||||
)
|
||||
db_session.execute(update_stmt)
|
||||
batch_pending += 1
|
||||
|
||||
if batch_pending >= _BATCH_SIZE:
|
||||
db_session.commit()
|
||||
batch_pending = 0
|
||||
reencrypted += 1
|
||||
|
||||
# Flush remaining rows in this column
|
||||
if batch_pending > 0:
|
||||
db_session.commit()
|
||||
|
||||
if reencrypted > 0:
|
||||
totals[f"{table_name}.{col_name}"] = reencrypted
|
||||
logger.info(
|
||||
f"{'[DRY RUN] Would re-encrypt' if dry_run else 'Re-encrypted'} "
|
||||
f"{reencrypted} value(s) in {table_name}.{col_name}"
|
||||
)
|
||||
|
||||
return totals
|
||||
@@ -13,12 +13,15 @@ from onyx.db.constants import UNSET
|
||||
from onyx.db.constants import UnsetType
|
||||
from onyx.db.enums import MCPServerStatus
|
||||
from onyx.db.models import MCPServer
|
||||
from onyx.db.models import OAuthConfig
|
||||
from onyx.db.models import Tool
|
||||
from onyx.db.models import ToolCall
|
||||
from onyx.server.features.tool.models import Header
|
||||
from onyx.tools.built_in_tools import BUILT_IN_TOOL_TYPES
|
||||
from onyx.utils.headers import HeaderItemDict
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_json_like
|
||||
from onyx.utils.postgres_sanitization import sanitize_string
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
@@ -159,10 +162,26 @@ def update_tool(
|
||||
]
|
||||
if passthrough_auth is not None:
|
||||
tool.passthrough_auth = passthrough_auth
|
||||
old_oauth_config_id = tool.oauth_config_id
|
||||
if not isinstance(oauth_config_id, UnsetType):
|
||||
tool.oauth_config_id = oauth_config_id
|
||||
db_session.commit()
|
||||
db_session.flush()
|
||||
|
||||
# Clean up orphaned OAuthConfig if the oauth_config_id was changed
|
||||
if (
|
||||
old_oauth_config_id is not None
|
||||
and not isinstance(oauth_config_id, UnsetType)
|
||||
and old_oauth_config_id != oauth_config_id
|
||||
):
|
||||
other_tools = db_session.scalars(
|
||||
select(Tool).where(Tool.oauth_config_id == old_oauth_config_id)
|
||||
).all()
|
||||
if not other_tools:
|
||||
oauth_config = db_session.get(OAuthConfig, old_oauth_config_id)
|
||||
if oauth_config:
|
||||
db_session.delete(oauth_config)
|
||||
|
||||
db_session.commit()
|
||||
return tool
|
||||
|
||||
|
||||
@@ -171,8 +190,21 @@ def delete_tool__no_commit(tool_id: int, db_session: Session) -> None:
|
||||
if tool is None:
|
||||
raise ValueError(f"Tool with ID {tool_id} does not exist")
|
||||
|
||||
oauth_config_id = tool.oauth_config_id
|
||||
|
||||
db_session.delete(tool)
|
||||
db_session.flush() # Don't commit yet, let caller decide when to commit
|
||||
db_session.flush()
|
||||
|
||||
# Clean up orphaned OAuthConfig if no other tools reference it
|
||||
if oauth_config_id is not None:
|
||||
other_tools = db_session.scalars(
|
||||
select(Tool).where(Tool.oauth_config_id == oauth_config_id)
|
||||
).all()
|
||||
if not other_tools:
|
||||
oauth_config = db_session.get(OAuthConfig, oauth_config_id)
|
||||
if oauth_config:
|
||||
db_session.delete(oauth_config)
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def get_builtin_tool(
|
||||
@@ -256,11 +288,13 @@ def create_tool_call_no_commit(
|
||||
tab_index=tab_index,
|
||||
tool_id=tool_id,
|
||||
tool_call_id=tool_call_id,
|
||||
reasoning_tokens=reasoning_tokens,
|
||||
tool_call_arguments=tool_call_arguments,
|
||||
tool_call_response=tool_call_response,
|
||||
reasoning_tokens=(
|
||||
sanitize_string(reasoning_tokens) if reasoning_tokens else reasoning_tokens
|
||||
),
|
||||
tool_call_arguments=sanitize_json_like(tool_call_arguments),
|
||||
tool_call_response=sanitize_json_like(tool_call_response),
|
||||
tool_call_tokens=tool_call_tokens,
|
||||
generated_images=generated_images,
|
||||
generated_images=sanitize_json_like(generated_images),
|
||||
)
|
||||
|
||||
db_session.add(tool_call)
|
||||
|
||||
@@ -3,9 +3,11 @@ from uuid import UUID
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Project__UserFile
|
||||
from onyx.db.models import UserFile
|
||||
|
||||
@@ -118,3 +120,31 @@ def get_file_ids_by_user_file_ids(
|
||||
) -> list[str]:
|
||||
user_files = db_session.query(UserFile).filter(UserFile.id.in_(user_file_ids)).all()
|
||||
return [user_file.file_id for user_file in user_files]
|
||||
|
||||
|
||||
def fetch_user_files_with_access_relationships(
|
||||
user_file_ids: list[str],
|
||||
db_session: Session,
|
||||
eager_load_groups: bool = False,
|
||||
) -> list[UserFile]:
|
||||
"""Fetch user files with the owner and assistant relationships
|
||||
eagerly loaded (needed for computing access control).
|
||||
|
||||
When eager_load_groups is True, Persona.groups is also loaded so that
|
||||
callers can extract user-group names without a second DB round-trip."""
|
||||
persona_sub_options = [
|
||||
selectinload(Persona.users),
|
||||
selectinload(Persona.user),
|
||||
]
|
||||
if eager_load_groups:
|
||||
persona_sub_options.append(selectinload(Persona.groups))
|
||||
|
||||
return (
|
||||
db_session.query(UserFile)
|
||||
.options(
|
||||
joinedload(UserFile.user),
|
||||
selectinload(UserFile.assistants).options(*persona_sub_options),
|
||||
)
|
||||
.filter(UserFile.id.in_(user_file_ids))
|
||||
.all()
|
||||
)
|
||||
|
||||
103
backend/onyx/document_index/FILTER_SEMANTICS.md
Normal file
103
backend/onyx/document_index/FILTER_SEMANTICS.md
Normal file
@@ -0,0 +1,103 @@
|
||||
# Vector DB Filter Semantics
|
||||
|
||||
How `IndexFilters` fields combine into the final query filter. Applies to both Vespa and OpenSearch.
|
||||
|
||||
## Filter categories
|
||||
|
||||
| Category | Fields | Join logic |
|
||||
|---|---|---|
|
||||
| **Visibility** | `hidden` | Always applied (unless `include_hidden`) |
|
||||
| **Tenant** | `tenant_id` | AND (multi-tenant only) |
|
||||
| **ACL** | `access_control_list` | OR within, AND with rest |
|
||||
| **Narrowing** | `source_type`, `tags`, `time_cutoff` | Each OR within, AND with rest |
|
||||
| **Knowledge scope** | `document_set`, `user_file_ids`, `attached_document_ids`, `hierarchy_node_ids` | OR within group, AND with rest |
|
||||
| **Additive scope** | `project_id`, `persona_id` | OR'd into knowledge scope **only when** a knowledge scope filter already exists |
|
||||
|
||||
## How filters combine
|
||||
|
||||
All categories are AND'd together. Within the knowledge scope category, individual filters are OR'd.
|
||||
|
||||
```
|
||||
NOT hidden
|
||||
AND tenant = T -- if multi-tenant
|
||||
AND (acl contains A1 OR acl contains A2)
|
||||
AND (source_type = S1 OR ...) -- if set
|
||||
AND (tag = T1 OR ...) -- if set
|
||||
AND <knowledge scope> -- see below
|
||||
AND time >= cutoff -- if set
|
||||
```
|
||||
|
||||
## Knowledge scope rules
|
||||
|
||||
The knowledge scope filter controls **what knowledge an assistant can access**.
|
||||
|
||||
### No explicit knowledge attached
|
||||
|
||||
When `document_set`, `user_file_ids`, `attached_document_ids`, and `hierarchy_node_ids` are all empty/None:
|
||||
|
||||
- **No knowledge scope filter is applied.** The assistant can see everything (subject to ACL).
|
||||
- `project_id` and `persona_id` are ignored — they never restrict on their own.
|
||||
|
||||
### One explicit knowledge type
|
||||
|
||||
```
|
||||
-- Only document sets
|
||||
AND (document_sets contains "Engineering" OR document_sets contains "Legal")
|
||||
|
||||
-- Only user files
|
||||
AND (document_id = "uuid-1" OR document_id = "uuid-2")
|
||||
```
|
||||
|
||||
### Multiple explicit knowledge types (OR'd)
|
||||
|
||||
```
|
||||
-- Document sets + user files
|
||||
AND (
|
||||
document_sets contains "Engineering"
|
||||
OR document_id = "uuid-1"
|
||||
)
|
||||
```
|
||||
|
||||
### Explicit knowledge + overflowing user files
|
||||
|
||||
When an explicit knowledge restriction is in effect **and** `project_id` or `persona_id` is set (user files overflowed the LLM context window), the additive scopes widen the filter:
|
||||
|
||||
```
|
||||
-- Document sets + persona user files overflowed
|
||||
AND (
|
||||
document_sets contains "Engineering"
|
||||
OR personas contains 42
|
||||
)
|
||||
|
||||
-- User files + project files overflowed
|
||||
AND (
|
||||
document_id = "uuid-1"
|
||||
OR user_project contains 7
|
||||
)
|
||||
```
|
||||
|
||||
### Only project_id or persona_id (no explicit knowledge)
|
||||
|
||||
No knowledge scope filter. The assistant searches everything.
|
||||
|
||||
```
|
||||
-- Just ACL, no restriction
|
||||
NOT hidden
|
||||
AND (acl contains ...)
|
||||
```
|
||||
|
||||
## Field reference
|
||||
|
||||
| Filter field | Vespa field | Vespa type | Purpose |
|
||||
|---|---|---|---|
|
||||
| `document_set` | `document_sets` | `weightedset<string>` | Connector doc sets attached to assistant |
|
||||
| `user_file_ids` | `document_id` | `string` | User files uploaded to assistant |
|
||||
| `attached_document_ids` | `document_id` | `string` | Documents explicitly attached (OpenSearch only) |
|
||||
| `hierarchy_node_ids` | `ancestor_hierarchy_node_ids` | `array<int>` | Folder/space nodes (OpenSearch only) |
|
||||
| `project_id` | `user_project` | `array<int>` | Project tag for overflowing user files |
|
||||
| `persona_id` | `personas` | `array<int>` | Persona tag for overflowing user files |
|
||||
| `access_control_list` | `access_control_list` | `weightedset<string>` | ACL entries for the requesting user |
|
||||
| `source_type` | `source_type` | `string` | Connector source type (e.g. `web`, `jira`) |
|
||||
| `tags` | `metadata_list` | `array<string>` | Document metadata tags |
|
||||
| `time_cutoff` | `doc_updated_at` | `long` | Minimum document update timestamp |
|
||||
| `tenant_id` | `tenant_id` | `string` | Tenant isolation (multi-tenant) |
|
||||
@@ -61,6 +61,25 @@ class SearchHit(BaseModel, Generic[SchemaDocumentModel]):
|
||||
explanation: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class IndexInfo(BaseModel):
|
||||
"""
|
||||
Represents information about an OpenSearch index.
|
||||
"""
|
||||
|
||||
model_config = {"frozen": True}
|
||||
|
||||
name: str
|
||||
health: str
|
||||
status: str
|
||||
num_primary_shards: str
|
||||
num_replica_shards: str
|
||||
docs_count: str
|
||||
docs_deleted: str
|
||||
created_at: str
|
||||
total_size: str
|
||||
primary_shards_size: str
|
||||
|
||||
|
||||
def get_new_body_without_vectors(body: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Recursively replaces vectors in the body with their length.
|
||||
|
||||
@@ -159,8 +178,8 @@ class OpenSearchClient(AbstractContextManager):
|
||||
Raises:
|
||||
Exception: There was an error creating the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
|
||||
if not result.get("acknowledged", False):
|
||||
response = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
|
||||
if not response.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to create search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
@@ -173,8 +192,8 @@ class OpenSearchClient(AbstractContextManager):
|
||||
Raises:
|
||||
Exception: There was an error deleting the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.delete(id=pipeline_id)
|
||||
if not result.get("acknowledged", False):
|
||||
response = self._client.search_pipeline.delete(id=pipeline_id)
|
||||
if not response.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to delete search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
@@ -198,6 +217,34 @@ class OpenSearchClient(AbstractContextManager):
|
||||
logger.error(f"Failed to put cluster settings: {response}.")
|
||||
return False
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def list_indices_with_info(self) -> list[IndexInfo]:
|
||||
"""
|
||||
Lists the indices in the OpenSearch cluster with information about each
|
||||
index.
|
||||
|
||||
Returns:
|
||||
A list of IndexInfo objects for each index.
|
||||
"""
|
||||
response = self._client.cat.indices(format="json")
|
||||
indices: list[IndexInfo] = []
|
||||
for raw_index_info in response:
|
||||
indices.append(
|
||||
IndexInfo(
|
||||
name=raw_index_info.get("index", ""),
|
||||
health=raw_index_info.get("health", ""),
|
||||
status=raw_index_info.get("status", ""),
|
||||
num_primary_shards=raw_index_info.get("pri", ""),
|
||||
num_replica_shards=raw_index_info.get("rep", ""),
|
||||
docs_count=raw_index_info.get("docs.count", ""),
|
||||
docs_deleted=raw_index_info.get("docs.deleted", ""),
|
||||
created_at=raw_index_info.get("creation.date.string", ""),
|
||||
total_size=raw_index_info.get("store.size", ""),
|
||||
primary_shards_size=raw_index_info.get("pri.store.size", ""),
|
||||
)
|
||||
)
|
||||
return indices
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def ping(self) -> bool:
|
||||
"""Pings the OpenSearch cluster.
|
||||
|
||||
@@ -739,7 +739,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
The number of chunks successfully deleted.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Deleting document {document_id} from index {self._index_name}."
|
||||
f"[OpenSearchDocumentIndex] Deleting document {document_id} from index "
|
||||
f"{self._index_name}."
|
||||
)
|
||||
query_body = DocumentQuery.delete_from_document_id_query(
|
||||
document_id=document_id,
|
||||
@@ -775,7 +776,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
specified documents.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Updating {len(update_requests)} chunks for index {self._index_name}."
|
||||
f"[OpenSearchDocumentIndex] Updating {len(update_requests)} chunks for index "
|
||||
f"{self._index_name}."
|
||||
)
|
||||
for update_request in update_requests:
|
||||
properties_to_update: dict[str, Any] = dict()
|
||||
@@ -831,9 +833,11 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
# here.
|
||||
# TODO(andrei): Fix the aforementioned race condition.
|
||||
raise ChunkCountNotFoundError(
|
||||
f"Tried to update document {doc_id} but its chunk count is not known. Older versions of the "
|
||||
"application used to permit this but is not a supported state for a document when using OpenSearch. "
|
||||
"The document was likely just added to the indexing pipeline and the chunk count will be updated shortly."
|
||||
f"Tried to update document {doc_id} but its chunk count is not known. "
|
||||
"Older versions of the application used to permit this but is not a "
|
||||
"supported state for a document when using OpenSearch. The document was "
|
||||
"likely just added to the indexing pipeline and the chunk count will be "
|
||||
"updated shortly."
|
||||
)
|
||||
if doc_chunk_count == 0:
|
||||
raise ValueError(
|
||||
@@ -865,7 +869,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
chunk IDs vs querying for matching document chunks.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Retrieving {len(chunk_requests)} chunks for index {self._index_name}."
|
||||
f"[OpenSearchDocumentIndex] Retrieving {len(chunk_requests)} chunks for index "
|
||||
f"{self._index_name}."
|
||||
)
|
||||
results: list[InferenceChunk] = []
|
||||
for chunk_request in chunk_requests:
|
||||
@@ -912,7 +917,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Hybrid retrieving {num_to_retrieve} chunks for index {self._index_name}."
|
||||
f"[OpenSearchDocumentIndex] Hybrid retrieving {num_to_retrieve} chunks for index "
|
||||
f"{self._index_name}."
|
||||
)
|
||||
# TODO(andrei): This could be better, the caller should just make this
|
||||
# decision when passing in the query param. See the above comment in the
|
||||
@@ -932,8 +938,10 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
index_filters=filters,
|
||||
include_hidden=False,
|
||||
)
|
||||
# NOTE: Using z-score normalization here because it's better for hybrid search from a theoretical standpoint.
|
||||
# Empirically on a small dataset of up to 10K docs, it's not very different. Likely more impactful at scale.
|
||||
# NOTE: Using z-score normalization here because it's better for hybrid
|
||||
# search from a theoretical standpoint. Empirically on a small dataset
|
||||
# of up to 10K docs, it's not very different. Likely more impactful at
|
||||
# scale.
|
||||
# https://opensearch.org/blog/introducing-the-z-score-normalization-technique-for-hybrid-search/
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
|
||||
body=query_body,
|
||||
@@ -960,7 +968,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
dirty: bool | None = None, # noqa: ARG002
|
||||
) -> list[InferenceChunk]:
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Randomly retrieving {num_to_retrieve} chunks for index {self._index_name}."
|
||||
f"[OpenSearchDocumentIndex] Randomly retrieving {num_to_retrieve} chunks for index "
|
||||
f"{self._index_name}."
|
||||
)
|
||||
query_body = DocumentQuery.get_random_search_query(
|
||||
tenant_state=self._tenant_state,
|
||||
@@ -990,7 +999,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
complete.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Indexing {len(chunks)} raw chunks for index {self._index_name}."
|
||||
f"[OpenSearchDocumentIndex] Indexing {len(chunks)} raw chunks for index "
|
||||
f"{self._index_name}."
|
||||
)
|
||||
# Do not raise if the document already exists, just update. This is
|
||||
# because the document may already have been indexed during the
|
||||
|
||||
@@ -243,7 +243,8 @@ class DocumentChunk(BaseModel):
|
||||
return value
|
||||
if not isinstance(value, int):
|
||||
raise ValueError(
|
||||
f"Bug: Expected an int for the last_updated property from OpenSearch, got {type(value)} instead."
|
||||
f"Bug: Expected an int for the last_updated property from OpenSearch, got "
|
||||
f"{type(value)} instead."
|
||||
)
|
||||
return datetime.fromtimestamp(value, tz=timezone.utc)
|
||||
|
||||
@@ -284,19 +285,22 @@ class DocumentChunk(BaseModel):
|
||||
elif isinstance(value, TenantState):
|
||||
if MULTI_TENANT != value.multitenant:
|
||||
raise ValueError(
|
||||
f"Bug: An existing TenantState object was supplied to the DocumentChunk model but its multi-tenant mode "
|
||||
f"({value.multitenant}) does not match the program's current global tenancy state."
|
||||
f"Bug: An existing TenantState object was supplied to the DocumentChunk model "
|
||||
f"but its multi-tenant mode ({value.multitenant}) does not match the program's "
|
||||
"current global tenancy state."
|
||||
)
|
||||
return value
|
||||
elif not isinstance(value, str):
|
||||
raise ValueError(
|
||||
f"Bug: Expected a str for the tenant_id property from OpenSearch, got {type(value)} instead."
|
||||
f"Bug: Expected a str for the tenant_id property from OpenSearch, got "
|
||||
f"{type(value)} instead."
|
||||
)
|
||||
else:
|
||||
if not MULTI_TENANT:
|
||||
raise ValueError(
|
||||
"Bug: Got a non-null str for the tenant_id property from OpenSearch but multi-tenant mode is not enabled. "
|
||||
"This is unexpected because in single-tenant mode we don't expect to see a tenant_id."
|
||||
"Bug: Got a non-null str for the tenant_id property from OpenSearch but "
|
||||
"multi-tenant mode is not enabled. This is unexpected because in single-tenant "
|
||||
"mode we don't expect to see a tenant_id."
|
||||
)
|
||||
return TenantState(tenant_id=value, multitenant=MULTI_TENANT)
|
||||
|
||||
@@ -352,8 +356,10 @@ class DocumentSchema:
|
||||
"properties": {
|
||||
TITLE_FIELD_NAME: {
|
||||
"type": "text",
|
||||
# Language analyzer (e.g. english) stems at index and search time for variant matching.
|
||||
# Configure via OPENSEARCH_TEXT_ANALYZER. Existing indices need reindexing after a change.
|
||||
# Language analyzer (e.g. english) stems at index and search
|
||||
# time for variant matching. Configure via
|
||||
# OPENSEARCH_TEXT_ANALYZER. Existing indices need reindexing
|
||||
# after a change.
|
||||
"analyzer": OPENSEARCH_TEXT_ANALYZER,
|
||||
"fields": {
|
||||
# Subfield accessed as title.keyword. Not indexed for
|
||||
|
||||
@@ -698,41 +698,6 @@ class DocumentQuery:
|
||||
"""
|
||||
return {"terms": {ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME: node_ids}}
|
||||
|
||||
def _get_assistant_knowledge_filter(
|
||||
attached_doc_ids: list[str] | None,
|
||||
node_ids: list[int] | None,
|
||||
file_ids: list[UUID] | None,
|
||||
document_sets: list[str] | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Combined filter for assistant knowledge.
|
||||
|
||||
When an assistant has attached knowledge, search should be scoped to:
|
||||
- Documents explicitly attached (by document ID), OR
|
||||
- Documents under attached hierarchy nodes (by ancestor node IDs), OR
|
||||
- User-uploaded files attached to the assistant, OR
|
||||
- Documents in the assistant's document sets (if any)
|
||||
"""
|
||||
knowledge_filter: dict[str, Any] = {
|
||||
"bool": {"should": [], "minimum_should_match": 1}
|
||||
}
|
||||
if attached_doc_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_attached_document_id_filter(attached_doc_ids)
|
||||
)
|
||||
if node_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_hierarchy_node_filter(node_ids)
|
||||
)
|
||||
if file_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_file_id_filter(file_ids)
|
||||
)
|
||||
if document_sets:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_document_set_filter(document_sets)
|
||||
)
|
||||
return knowledge_filter
|
||||
|
||||
filter_clauses: list[dict[str, Any]] = []
|
||||
|
||||
if not include_hidden:
|
||||
@@ -758,41 +723,53 @@ class DocumentQuery:
|
||||
# document's metadata list.
|
||||
filter_clauses.append(_get_tag_filter(tags))
|
||||
|
||||
# Check if this is an assistant knowledge search (has any assistant-scoped knowledge)
|
||||
has_assistant_knowledge = (
|
||||
# Knowledge scope: explicit knowledge attachments restrict what
|
||||
# an assistant can see. When none are set the assistant
|
||||
# searches everything.
|
||||
#
|
||||
# project_id / persona_id are additive: they make overflowing
|
||||
# user files findable but must NOT trigger the restriction on
|
||||
# their own (an agent with no explicit knowledge should search
|
||||
# everything).
|
||||
has_knowledge_scope = (
|
||||
attached_document_ids
|
||||
or hierarchy_node_ids
|
||||
or user_file_ids
|
||||
or document_sets
|
||||
)
|
||||
|
||||
if has_assistant_knowledge:
|
||||
# If assistant has attached knowledge, scope search to that knowledge.
|
||||
# Document sets are included in the OR filter so directly attached
|
||||
# docs are always findable even if not in the document sets.
|
||||
filter_clauses.append(
|
||||
_get_assistant_knowledge_filter(
|
||||
attached_document_ids,
|
||||
hierarchy_node_ids,
|
||||
user_file_ids,
|
||||
document_sets,
|
||||
if has_knowledge_scope:
|
||||
knowledge_filter: dict[str, Any] = {
|
||||
"bool": {"should": [], "minimum_should_match": 1}
|
||||
}
|
||||
if attached_document_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_attached_document_id_filter(attached_document_ids)
|
||||
)
|
||||
)
|
||||
elif user_file_ids:
|
||||
# Fallback for non-assistant user file searches (e.g., project searches)
|
||||
# If at least one user file ID is provided, the caller will only
|
||||
# retrieve documents where the document ID is in this input list of
|
||||
# file IDs.
|
||||
filter_clauses.append(_get_user_file_id_filter(user_file_ids))
|
||||
|
||||
if project_id is not None:
|
||||
# If a project ID is provided, the caller will only retrieve
|
||||
# documents where the project ID provided here is present in the
|
||||
# document's user projects list.
|
||||
filter_clauses.append(_get_user_project_filter(project_id))
|
||||
|
||||
if persona_id is not None:
|
||||
filter_clauses.append(_get_persona_filter(persona_id))
|
||||
if hierarchy_node_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_hierarchy_node_filter(hierarchy_node_ids)
|
||||
)
|
||||
if user_file_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_file_id_filter(user_file_ids)
|
||||
)
|
||||
if document_sets:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_document_set_filter(document_sets)
|
||||
)
|
||||
# Additive: widen scope to also cover overflowing user
|
||||
# files, but only when an explicit restriction is already
|
||||
# in effect.
|
||||
if project_id is not None:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_project_filter(project_id)
|
||||
)
|
||||
if persona_id is not None:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_persona_filter(persona_id)
|
||||
)
|
||||
filter_clauses.append(knowledge_filter)
|
||||
|
||||
if time_cutoff is not None:
|
||||
# If a time cutoff is provided, the caller will only retrieve
|
||||
|
||||
@@ -23,11 +23,8 @@ from shared_configs.configs import MULTI_TENANT
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def build_tenant_id_filter(tenant_id: str, include_trailing_and: bool = False) -> str:
|
||||
filter_str = f'({TENANT_ID} contains "{tenant_id}")'
|
||||
if include_trailing_and:
|
||||
filter_str += " and "
|
||||
return filter_str
|
||||
def build_tenant_id_filter(tenant_id: str) -> str:
|
||||
return f'({TENANT_ID} contains "{tenant_id}")'
|
||||
|
||||
|
||||
def build_vespa_filters(
|
||||
@@ -37,30 +34,22 @@ def build_vespa_filters(
|
||||
remove_trailing_and: bool = False, # Set to True when using as a complete Vespa query
|
||||
) -> str:
|
||||
def _build_or_filters(key: str, vals: list[str] | None) -> str:
|
||||
"""For string-based 'contains' filters, e.g. WSET fields or array<string> fields."""
|
||||
"""For string-based 'contains' filters, e.g. WSET fields or array<string> fields.
|
||||
Returns a bare clause like '(key contains "v1" or key contains "v2")' or ""."""
|
||||
if not key or not vals:
|
||||
return ""
|
||||
eq_elems = [f'{key} contains "{val}"' for val in vals if val]
|
||||
if not eq_elems:
|
||||
return ""
|
||||
or_clause = " or ".join(eq_elems)
|
||||
return f"({or_clause}) and "
|
||||
return f"({' or '.join(eq_elems)})"
|
||||
|
||||
def _build_int_or_filters(key: str, vals: list[int] | None) -> str:
|
||||
"""
|
||||
For an integer field filter.
|
||||
If vals is not None, we want *only* docs whose key matches one of vals.
|
||||
"""
|
||||
# If `vals` is None => skip the filter entirely
|
||||
"""For an integer field filter.
|
||||
Returns a bare clause or ""."""
|
||||
if vals is None or not vals:
|
||||
return ""
|
||||
|
||||
# Otherwise build the OR filter
|
||||
eq_elems = [f"{key} = {val}" for val in vals]
|
||||
or_clause = " or ".join(eq_elems)
|
||||
result = f"({or_clause}) and "
|
||||
|
||||
return result
|
||||
return f"({' or '.join(eq_elems)})"
|
||||
|
||||
def _build_kg_filter(
|
||||
kg_entities: list[str] | None,
|
||||
@@ -73,16 +62,12 @@ def build_vespa_filters(
|
||||
combined_filter_parts = []
|
||||
|
||||
def _build_kge(entity: str) -> str:
|
||||
# TYPE-SUBTYPE::ID -> "TYPE-SUBTYPE::ID"
|
||||
# TYPE-SUBTYPE::* -> ({prefix: true}"TYPE-SUBTYPE")
|
||||
# TYPE::* -> ({prefix: true}"TYPE")
|
||||
GENERAL = "::*"
|
||||
if entity.endswith(GENERAL):
|
||||
return f'({{prefix: true}}"{entity.split(GENERAL, 1)[0]}")'
|
||||
else:
|
||||
return f'"{entity}"'
|
||||
|
||||
# OR the entities (give new design)
|
||||
if kg_entities:
|
||||
filter_parts = []
|
||||
for kg_entity in kg_entities:
|
||||
@@ -104,8 +89,7 @@ def build_vespa_filters(
|
||||
|
||||
# TODO: remove kg terms entirely from prompts and codebase
|
||||
|
||||
# AND the combined filter parts
|
||||
return f"({' and '.join(combined_filter_parts)}) and "
|
||||
return f"({' and '.join(combined_filter_parts)})"
|
||||
|
||||
def _build_kg_source_filters(
|
||||
kg_sources: list[str] | None,
|
||||
@@ -114,16 +98,14 @@ def build_vespa_filters(
|
||||
return ""
|
||||
|
||||
source_phrases = [f'{DOCUMENT_ID} contains "{source}"' for source in kg_sources]
|
||||
|
||||
return f"({' or '.join(source_phrases)}) and "
|
||||
return f"({' or '.join(source_phrases)})"
|
||||
|
||||
def _build_kg_chunk_id_zero_only_filter(
|
||||
kg_chunk_id_zero_only: bool,
|
||||
) -> str:
|
||||
if not kg_chunk_id_zero_only:
|
||||
return ""
|
||||
|
||||
return "(chunk_id = 0 ) and "
|
||||
return "(chunk_id = 0)"
|
||||
|
||||
def _build_time_filter(
|
||||
cutoff: datetime | None,
|
||||
@@ -135,8 +117,8 @@ def build_vespa_filters(
|
||||
cutoff_secs = int(cutoff.timestamp())
|
||||
|
||||
if include_untimed:
|
||||
return f"!({DOC_UPDATED_AT} < {cutoff_secs}) and "
|
||||
return f"({DOC_UPDATED_AT} >= {cutoff_secs}) and "
|
||||
return f"!({DOC_UPDATED_AT} < {cutoff_secs})"
|
||||
return f"({DOC_UPDATED_AT} >= {cutoff_secs})"
|
||||
|
||||
def _build_user_project_filter(
|
||||
project_id: int | None,
|
||||
@@ -147,8 +129,7 @@ def build_vespa_filters(
|
||||
pid = int(project_id)
|
||||
except Exception:
|
||||
return ""
|
||||
# Vespa YQL 'contains' expects a string literal; quote the integer
|
||||
return f'({USER_PROJECT} contains "{pid}") and '
|
||||
return f'({USER_PROJECT} contains "{pid}")'
|
||||
|
||||
def _build_persona_filter(
|
||||
persona_id: int | None,
|
||||
@@ -160,73 +141,94 @@ def build_vespa_filters(
|
||||
except Exception:
|
||||
logger.warning(f"Invalid persona ID: {persona_id}")
|
||||
return ""
|
||||
return f'({PERSONAS} contains "{pid}") and '
|
||||
return f'({PERSONAS} contains "{pid}")'
|
||||
|
||||
# Start building the filter string
|
||||
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
|
||||
def _append(parts: list[str], clause: str) -> None:
|
||||
if clause:
|
||||
parts.append(clause)
|
||||
|
||||
# Collect all top-level filter clauses, then join with " and " at the end.
|
||||
filter_parts: list[str] = []
|
||||
|
||||
if not include_hidden:
|
||||
filter_parts.append(f"!({HIDDEN}=true)")
|
||||
|
||||
# TODO: add error condition if MULTI_TENANT and no tenant_id filter is set
|
||||
# If running in multi-tenant mode
|
||||
if filters.tenant_id and MULTI_TENANT:
|
||||
filter_str += build_tenant_id_filter(
|
||||
filters.tenant_id, include_trailing_and=True
|
||||
)
|
||||
filter_parts.append(build_tenant_id_filter(filters.tenant_id))
|
||||
|
||||
# ACL filters
|
||||
if filters.access_control_list is not None:
|
||||
filter_str += _build_or_filters(
|
||||
ACCESS_CONTROL_LIST, filters.access_control_list
|
||||
_append(
|
||||
filter_parts,
|
||||
_build_or_filters(ACCESS_CONTROL_LIST, filters.access_control_list),
|
||||
)
|
||||
|
||||
# Source type filters
|
||||
source_strs = (
|
||||
[s.value for s in filters.source_type] if filters.source_type else None
|
||||
)
|
||||
filter_str += _build_or_filters(SOURCE_TYPE, source_strs)
|
||||
_append(filter_parts, _build_or_filters(SOURCE_TYPE, source_strs))
|
||||
|
||||
# Tag filters
|
||||
tag_attributes = None
|
||||
if filters.tags:
|
||||
# build e.g. "tag_key|tag_value"
|
||||
tag_attributes = [
|
||||
f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}" for tag in filters.tags
|
||||
]
|
||||
filter_str += _build_or_filters(METADATA_LIST, tag_attributes)
|
||||
_append(filter_parts, _build_or_filters(METADATA_LIST, tag_attributes))
|
||||
|
||||
# Document sets
|
||||
filter_str += _build_or_filters(DOCUMENT_SETS, filters.document_set)
|
||||
# Knowledge scope: explicit knowledge attachments (document_sets,
|
||||
# user_file_ids) restrict what an assistant can see. When none are
|
||||
# set, the assistant can see everything.
|
||||
#
|
||||
# project_id / persona_id are additive: they make overflowing user
|
||||
# files findable in Vespa but must NOT trigger the restriction on
|
||||
# their own (an agent with no explicit knowledge should search
|
||||
# everything).
|
||||
knowledge_scope_parts: list[str] = []
|
||||
|
||||
_append(
|
||||
knowledge_scope_parts, _build_or_filters(DOCUMENT_SETS, filters.document_set)
|
||||
)
|
||||
|
||||
# Convert UUIDs to strings for user_file_ids
|
||||
user_file_ids_str = (
|
||||
[str(uuid) for uuid in filters.user_file_ids] if filters.user_file_ids else None
|
||||
)
|
||||
filter_str += _build_or_filters(DOCUMENT_ID, user_file_ids_str)
|
||||
_append(knowledge_scope_parts, _build_or_filters(DOCUMENT_ID, user_file_ids_str))
|
||||
|
||||
# User project filter (array<int> attribute membership)
|
||||
filter_str += _build_user_project_filter(filters.project_id)
|
||||
# Only include project/persona scopes when an explicit knowledge
|
||||
# restriction is already in effect — they widen the scope to also
|
||||
# cover overflowing user files but never restrict on their own.
|
||||
if knowledge_scope_parts:
|
||||
_append(knowledge_scope_parts, _build_user_project_filter(filters.project_id))
|
||||
_append(knowledge_scope_parts, _build_persona_filter(filters.persona_id))
|
||||
|
||||
# Persona filter (array<int> attribute membership)
|
||||
filter_str += _build_persona_filter(filters.persona_id)
|
||||
if len(knowledge_scope_parts) > 1:
|
||||
filter_parts.append("(" + " or ".join(knowledge_scope_parts) + ")")
|
||||
elif len(knowledge_scope_parts) == 1:
|
||||
filter_parts.append(knowledge_scope_parts[0])
|
||||
|
||||
# Time filter
|
||||
filter_str += _build_time_filter(filters.time_cutoff)
|
||||
_append(filter_parts, _build_time_filter(filters.time_cutoff))
|
||||
|
||||
# # Knowledge Graph Filters
|
||||
# filter_str += _build_kg_filter(
|
||||
# _append(filter_parts, _build_kg_filter(
|
||||
# kg_entities=filters.kg_entities,
|
||||
# kg_relationships=filters.kg_relationships,
|
||||
# kg_terms=filters.kg_terms,
|
||||
# )
|
||||
# ))
|
||||
|
||||
# filter_str += _build_kg_source_filters(filters.kg_sources)
|
||||
# _append(filter_parts, _build_kg_source_filters(filters.kg_sources))
|
||||
|
||||
# filter_str += _build_kg_chunk_id_zero_only_filter(
|
||||
# _append(filter_parts, _build_kg_chunk_id_zero_only_filter(
|
||||
# filters.kg_chunk_id_zero_only or False
|
||||
# )
|
||||
# ))
|
||||
|
||||
# Trim trailing " and "
|
||||
if remove_trailing_and and filter_str.endswith(" and "):
|
||||
filter_str = filter_str[:-5]
|
||||
filter_str = " and ".join(filter_parts)
|
||||
|
||||
if filter_str and not remove_trailing_and:
|
||||
filter_str += " and "
|
||||
|
||||
return filter_str
|
||||
|
||||
|
||||
@@ -91,11 +91,11 @@ class OnyxErrorCode(Enum):
|
||||
"""Build a structured error detail dict.
|
||||
|
||||
Returns a dict like:
|
||||
{"error_code": "UNAUTHENTICATED", "message": "Token expired"}
|
||||
{"error_code": "UNAUTHENTICATED", "detail": "Token expired"}
|
||||
|
||||
If no message is supplied, the error code itself is used as the message.
|
||||
If no message is supplied, the error code itself is used as the detail.
|
||||
"""
|
||||
return {
|
||||
"error_code": self.code,
|
||||
"message": message or self.code,
|
||||
"detail": message or self.code,
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
Raise ``OnyxError`` instead of ``HTTPException`` in business code. A global
|
||||
FastAPI exception handler (registered via ``register_onyx_exception_handlers``)
|
||||
converts it into a JSON response with the standard
|
||||
``{"error_code": "...", "message": "..."}`` shape.
|
||||
``{"error_code": "...", "detail": "..."}`` shape.
|
||||
|
||||
Usage::
|
||||
|
||||
@@ -37,21 +37,22 @@ class OnyxError(Exception):
|
||||
|
||||
Attributes:
|
||||
error_code: The ``OnyxErrorCode`` enum member.
|
||||
message: Human-readable message (defaults to the error code string).
|
||||
detail: Human-readable detail (defaults to the error code string).
|
||||
status_code: HTTP status — either overridden or from the error code.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_code: OnyxErrorCode,
|
||||
message: str | None = None,
|
||||
detail: str | None = None,
|
||||
*,
|
||||
status_code_override: int | None = None,
|
||||
) -> None:
|
||||
resolved_detail = detail or error_code.code
|
||||
super().__init__(resolved_detail)
|
||||
self.error_code = error_code
|
||||
self.message = message or error_code.code
|
||||
self.detail = resolved_detail
|
||||
self._status_code_override = status_code_override
|
||||
super().__init__(self.message)
|
||||
|
||||
@property
|
||||
def status_code(self) -> int:
|
||||
@@ -72,11 +73,11 @@ def register_onyx_exception_handlers(app: FastAPI) -> None:
|
||||
) -> JSONResponse:
|
||||
status_code = exc.status_code
|
||||
if status_code >= 500:
|
||||
logger.error(f"OnyxError {exc.error_code.code}: {exc.message}")
|
||||
logger.error(f"OnyxError {exc.error_code.code}: {exc.detail}")
|
||||
elif status_code >= 400:
|
||||
logger.warning(f"OnyxError {exc.error_code.code}: {exc.message}")
|
||||
logger.warning(f"OnyxError {exc.error_code.code}: {exc.detail}")
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content=exc.error_code.detail(exc.message),
|
||||
content=exc.error_code.detail(exc.detail),
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import csv
|
||||
import gc
|
||||
import io
|
||||
import json
|
||||
@@ -19,6 +20,7 @@ from zipfile import BadZipFile
|
||||
|
||||
import chardet
|
||||
import openpyxl
|
||||
from openpyxl.worksheet.worksheet import Worksheet
|
||||
from PIL import Image
|
||||
|
||||
from onyx.configs.constants import ONYX_METADATA_FILENAME
|
||||
@@ -352,6 +354,65 @@ def pptx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
return presentation.markdown
|
||||
|
||||
|
||||
def _worksheet_to_matrix(
|
||||
worksheet: Worksheet,
|
||||
) -> list[list[str]]:
|
||||
"""
|
||||
Converts a singular worksheet to a matrix of values
|
||||
"""
|
||||
rows: list[list[str]] = []
|
||||
for worksheet_row in worksheet.iter_rows(min_row=1, values_only=True):
|
||||
row = ["" if cell is None else str(cell) for cell in worksheet_row]
|
||||
rows.append(row)
|
||||
|
||||
return rows
|
||||
|
||||
|
||||
def _clean_worksheet_matrix(matrix: list[list[str]]) -> list[list[str]]:
|
||||
"""
|
||||
Cleans a worksheet matrix by removing rows if there are N consecutive empty
|
||||
rows and removing cols if there are M consecutive empty columns
|
||||
"""
|
||||
MAX_EMPTY_ROWS = 2 # Runs longer than this are capped to max_empty; shorter runs are preserved as-is
|
||||
MAX_EMPTY_COLS = 2
|
||||
|
||||
# Row cleanup
|
||||
matrix = _remove_empty_runs(matrix, max_empty=MAX_EMPTY_ROWS)
|
||||
|
||||
# Column cleanup (transpose, clean, transpose back)
|
||||
transposed = list(map(list, zip(*matrix))) if matrix else []
|
||||
transposed = _remove_empty_runs(transposed, max_empty=MAX_EMPTY_COLS)
|
||||
matrix = list(map(list, zip(*transposed))) if transposed else []
|
||||
|
||||
return matrix
|
||||
|
||||
|
||||
def _remove_empty_runs(
|
||||
rows: list[list[str]],
|
||||
max_empty: int,
|
||||
) -> list[list[str]]:
|
||||
"""Removes entire runs of empty rows when the run length exceeds max_empty.
|
||||
|
||||
Leading and trailing empty rows are always dropped regardless of run length,
|
||||
since there is no adjacent non-empty row to bound the run.
|
||||
"""
|
||||
result: list[list[str]] = []
|
||||
empty_buffer: list[list[str]] = []
|
||||
|
||||
for row in rows:
|
||||
# Check if empty
|
||||
if not any(row):
|
||||
empty_buffer.append(row)
|
||||
else:
|
||||
# Add upto max empty rows onto the result - that's what we allow
|
||||
result.extend(empty_buffer[:max_empty])
|
||||
# Add the new non-empty row
|
||||
result.append(row)
|
||||
empty_buffer = []
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
# TODO: switch back to this approach in a few months when markitdown
|
||||
# fixes their handling of excel files
|
||||
@@ -390,30 +451,15 @@ def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
f"Failed to extract text from {file_name or 'xlsx file'}. This happens due to a bug in openpyxl. {e}"
|
||||
)
|
||||
return ""
|
||||
raise e
|
||||
raise
|
||||
|
||||
text_content = []
|
||||
for sheet in workbook.worksheets:
|
||||
rows = []
|
||||
num_empty_consecutive_rows = 0
|
||||
for row in sheet.iter_rows(min_row=1, values_only=True):
|
||||
row_str = ",".join(str(cell or "") for cell in row)
|
||||
|
||||
# Only add the row if there are any values in the cells
|
||||
if len(row_str) >= len(row):
|
||||
rows.append(row_str)
|
||||
num_empty_consecutive_rows = 0
|
||||
else:
|
||||
num_empty_consecutive_rows += 1
|
||||
|
||||
if num_empty_consecutive_rows > 100:
|
||||
# handle massive excel sheets with mostly empty cells
|
||||
logger.warning(
|
||||
f"Found {num_empty_consecutive_rows} empty rows in {file_name}, skipping rest of file"
|
||||
)
|
||||
break
|
||||
sheet_str = "\n".join(rows)
|
||||
text_content.append(sheet_str)
|
||||
sheet_matrix = _clean_worksheet_matrix(_worksheet_to_matrix(sheet))
|
||||
buf = io.StringIO()
|
||||
writer = csv.writer(buf, lineterminator="\n")
|
||||
writer.writerows(sheet_matrix)
|
||||
text_content.append(buf.getvalue().rstrip("\n"))
|
||||
return TEXT_SECTION_SEPARATOR.join(text_content)
|
||||
|
||||
|
||||
|
||||
@@ -19,12 +19,16 @@ class OnyxMimeTypes:
|
||||
PLAIN_TEXT_MIME_TYPE,
|
||||
"text/markdown",
|
||||
"text/x-markdown",
|
||||
"text/x-log",
|
||||
"text/x-config",
|
||||
"text/tab-separated-values",
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"text/xml",
|
||||
"application/x-yaml",
|
||||
"application/yaml",
|
||||
"text/yaml",
|
||||
"text/x-yaml",
|
||||
}
|
||||
DOCUMENT_MIME_TYPES = {
|
||||
PDF_MIME_TYPE,
|
||||
|
||||
@@ -49,7 +49,6 @@ from onyx.indexing.embedder import IndexingEmbedder
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import IndexingBatchAdapter
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
from onyx.indexing.postgres_sanitization import sanitize_documents_for_postgres
|
||||
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
|
||||
from onyx.llm.factory import get_default_llm_with_vision
|
||||
from onyx.llm.factory import get_llm_for_contextual_rag
|
||||
@@ -65,6 +64,7 @@ from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT1
|
||||
from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT2
|
||||
from onyx.prompts.contextual_retrieval import DOCUMENT_SUMMARY_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_documents_for_postgres
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import abc
|
||||
from typing import cast
|
||||
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
|
||||
@@ -7,6 +8,19 @@ class KvKeyNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def unwrap_str(val: JSON_ro) -> str:
|
||||
"""Unwrap a string stored as {"value": str} in the encrypted KV store.
|
||||
Also handles legacy plain-string values cached in Redis."""
|
||||
if isinstance(val, dict):
|
||||
try:
|
||||
return cast(str, val["value"])
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Expected dict with 'value' key, got keys: {list(val.keys())}"
|
||||
)
|
||||
return cast(str, val)
|
||||
|
||||
|
||||
class KeyValueStore:
|
||||
# In the Multi Tenant case, the tenant context is picked up automatically, it does not need to be passed in
|
||||
# It's read from the global thread level variable
|
||||
|
||||
@@ -22,6 +22,7 @@ class LlmProviderNames(str, Enum):
|
||||
OPENROUTER = "openrouter"
|
||||
AZURE = "azure"
|
||||
OLLAMA_CHAT = "ollama_chat"
|
||||
LM_STUDIO = "lm_studio"
|
||||
MISTRAL = "mistral"
|
||||
LITELLM_PROXY = "litellm_proxy"
|
||||
|
||||
@@ -41,6 +42,7 @@ WELL_KNOWN_PROVIDER_NAMES = [
|
||||
LlmProviderNames.OPENROUTER,
|
||||
LlmProviderNames.AZURE,
|
||||
LlmProviderNames.OLLAMA_CHAT,
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
]
|
||||
|
||||
|
||||
@@ -56,6 +58,7 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
|
||||
LlmProviderNames.AZURE: "Azure",
|
||||
"ollama": "Ollama",
|
||||
LlmProviderNames.OLLAMA_CHAT: "Ollama",
|
||||
LlmProviderNames.LM_STUDIO: "LM Studio",
|
||||
"groq": "Groq",
|
||||
"anyscale": "Anyscale",
|
||||
"deepseek": "DeepSeek",
|
||||
@@ -103,6 +106,7 @@ AGGREGATOR_PROVIDERS: set[str] = {
|
||||
LlmProviderNames.BEDROCK_CONVERSE,
|
||||
LlmProviderNames.OPENROUTER,
|
||||
LlmProviderNames.OLLAMA_CHAT,
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
LlmProviderNames.VERTEX_AI,
|
||||
LlmProviderNames.AZURE,
|
||||
}
|
||||
|
||||
@@ -20,7 +20,9 @@ from onyx.llm.multi_llm import LitellmLLM
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.utils import get_max_input_tokens_from_llm_provider
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.llm.well_known_providers.constants import OLLAMA_API_KEY_CONFIG_KEY
|
||||
from onyx.llm.well_known_providers.constants import (
|
||||
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING,
|
||||
)
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.utils.headers import build_llm_extra_headers
|
||||
@@ -32,14 +34,18 @@ logger = setup_logger()
|
||||
def _build_provider_extra_headers(
|
||||
provider: str, custom_config: dict[str, str] | None
|
||||
) -> dict[str, str]:
|
||||
if provider == LlmProviderNames.OLLAMA_CHAT and custom_config:
|
||||
raw_api_key = custom_config.get(OLLAMA_API_KEY_CONFIG_KEY)
|
||||
api_key = raw_api_key.strip() if raw_api_key else None
|
||||
if provider in PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING and custom_config:
|
||||
raw = custom_config.get(PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING[provider])
|
||||
api_key = raw.strip() if raw else None
|
||||
if not api_key:
|
||||
return {}
|
||||
if not api_key.lower().startswith("bearer "):
|
||||
api_key = f"Bearer {api_key}"
|
||||
return {"Authorization": api_key}
|
||||
return {
|
||||
"Authorization": (
|
||||
api_key
|
||||
if api_key.lower().startswith("bearer ")
|
||||
else f"Bearer {api_key}"
|
||||
)
|
||||
}
|
||||
|
||||
# Passing these will put Onyx on the OpenRouter leaderboard
|
||||
elif provider == LlmProviderNames.OPENROUTER:
|
||||
|
||||
@@ -1512,6 +1512,10 @@
|
||||
"display_name": "Claude Opus 4.5",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-opus-4-6": {
|
||||
"display_name": "Claude Opus 4.6",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-opus-4-5-20251101": {
|
||||
"display_name": "Claude Opus 4.5",
|
||||
"model_vendor": "anthropic",
|
||||
@@ -1526,6 +1530,10 @@
|
||||
"display_name": "Claude Sonnet 4.5",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-sonnet-4-6": {
|
||||
"display_name": "Claude Sonnet 4.6",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-sonnet-4-5-20250929": {
|
||||
"display_name": "Claude Sonnet 4.5",
|
||||
"model_vendor": "anthropic",
|
||||
@@ -2516,6 +2524,10 @@
|
||||
"model_vendor": "openai",
|
||||
"model_version": "2025-10-06"
|
||||
},
|
||||
"gpt-5.4": {
|
||||
"display_name": "GPT-5.4",
|
||||
"model_vendor": "openai"
|
||||
},
|
||||
"gpt-5.2-pro-2025-12-11": {
|
||||
"display_name": "GPT-5.2 Pro",
|
||||
"model_vendor": "openai",
|
||||
|
||||
@@ -42,6 +42,7 @@ from onyx.llm.well_known_providers.constants import AWS_SECRET_ACCESS_KEY_KWARG
|
||||
from onyx.llm.well_known_providers.constants import (
|
||||
AWS_SECRET_ACCESS_KEY_KWARG_ENV_VAR_FORMAT,
|
||||
)
|
||||
from onyx.llm.well_known_providers.constants import LM_STUDIO_API_KEY_CONFIG_KEY
|
||||
from onyx.llm.well_known_providers.constants import OLLAMA_API_KEY_CONFIG_KEY
|
||||
from onyx.llm.well_known_providers.constants import VERTEX_CREDENTIALS_FILE_KWARG
|
||||
from onyx.llm.well_known_providers.constants import (
|
||||
@@ -92,6 +93,98 @@ def _prompt_to_dicts(prompt: LanguageModelInput) -> list[dict[str, Any]]:
|
||||
return [prompt.model_dump(exclude_none=True)]
|
||||
|
||||
|
||||
def _normalize_content(raw: Any) -> str:
|
||||
"""Normalize a message content field to a plain string.
|
||||
|
||||
Content can be a string, None, or a list of content-block dicts
|
||||
(e.g. [{"type": "text", "text": "..."}]).
|
||||
"""
|
||||
if raw is None:
|
||||
return ""
|
||||
if isinstance(raw, str):
|
||||
return raw
|
||||
if isinstance(raw, list):
|
||||
return "\n".join(
|
||||
block.get("text", "") if isinstance(block, dict) else str(block)
|
||||
for block in raw
|
||||
)
|
||||
return str(raw)
|
||||
|
||||
|
||||
def _strip_tool_content_from_messages(
|
||||
messages: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert tool-related messages to plain text.
|
||||
|
||||
Bedrock's Converse API requires toolConfig when messages contain
|
||||
toolUse/toolResult content blocks. When no tools are provided for the
|
||||
current request, we must convert any tool-related history into plain text
|
||||
to avoid the "toolConfig field must be defined" error.
|
||||
|
||||
This is the same approach used by _OllamaHistoryMessageFormatter.
|
||||
"""
|
||||
result: list[dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
role = msg.get("role")
|
||||
tool_calls = msg.get("tool_calls")
|
||||
|
||||
if role == "assistant" and tool_calls:
|
||||
# Convert structured tool calls to text representation
|
||||
tool_call_lines = []
|
||||
for tc in tool_calls:
|
||||
func = tc.get("function", {})
|
||||
name = func.get("name", "unknown")
|
||||
args = func.get("arguments", "{}")
|
||||
tc_id = tc.get("id", "")
|
||||
tool_call_lines.append(
|
||||
f"[Tool Call] name={name} id={tc_id} args={args}"
|
||||
)
|
||||
|
||||
existing_content = _normalize_content(msg.get("content"))
|
||||
parts = (
|
||||
[existing_content] + tool_call_lines
|
||||
if existing_content
|
||||
else tool_call_lines
|
||||
)
|
||||
new_msg = {
|
||||
"role": "assistant",
|
||||
"content": "\n".join(parts),
|
||||
}
|
||||
result.append(new_msg)
|
||||
|
||||
elif role == "tool":
|
||||
# Convert tool response to user message with text content
|
||||
tool_call_id = msg.get("tool_call_id", "")
|
||||
content = _normalize_content(msg.get("content"))
|
||||
tool_result_text = f"[Tool Result] id={tool_call_id}\n{content}"
|
||||
# Merge into previous user message if it is also a converted
|
||||
# tool result to avoid consecutive user messages (Bedrock requires
|
||||
# strict user/assistant alternation).
|
||||
if (
|
||||
result
|
||||
and result[-1]["role"] == "user"
|
||||
and "[Tool Result]" in result[-1].get("content", "")
|
||||
):
|
||||
result[-1]["content"] += "\n\n" + tool_result_text
|
||||
else:
|
||||
result.append({"role": "user", "content": tool_result_text})
|
||||
|
||||
else:
|
||||
result.append(msg)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _messages_contain_tool_content(messages: list[dict[str, Any]]) -> bool:
|
||||
"""Check if any messages contain tool-related content blocks."""
|
||||
for msg in messages:
|
||||
if msg.get("role") == "tool":
|
||||
return True
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_vertex_model_rejecting_output_config(model_name: str) -> bool:
|
||||
normalized_model_name = model_name.lower()
|
||||
return any(
|
||||
@@ -157,6 +250,9 @@ class LitellmLLM(LLM):
|
||||
elif model_provider == LlmProviderNames.OLLAMA_CHAT:
|
||||
if k == OLLAMA_API_KEY_CONFIG_KEY:
|
||||
model_kwargs["api_key"] = v
|
||||
elif model_provider == LlmProviderNames.LM_STUDIO:
|
||||
if k == LM_STUDIO_API_KEY_CONFIG_KEY:
|
||||
model_kwargs["api_key"] = v
|
||||
elif model_provider == LlmProviderNames.BEDROCK:
|
||||
if k == AWS_REGION_NAME_KWARG:
|
||||
model_kwargs[k] = v
|
||||
@@ -173,6 +269,19 @@ class LitellmLLM(LLM):
|
||||
elif k == AWS_SECRET_ACCESS_KEY_KWARG_ENV_VAR_FORMAT:
|
||||
model_kwargs[AWS_SECRET_ACCESS_KEY_KWARG] = v
|
||||
|
||||
# LM Studio: LiteLLM defaults to "fake-api-key" when no key is provided,
|
||||
# which LM Studio rejects. Ensure we always pass an explicit key (or empty
|
||||
# string) to prevent LiteLLM from injecting its fake default.
|
||||
if model_provider == LlmProviderNames.LM_STUDIO:
|
||||
model_kwargs.setdefault("api_key", "")
|
||||
|
||||
# Users provide the server root (e.g. http://localhost:1234) but LiteLLM
|
||||
# needs /v1 for OpenAI-compatible calls.
|
||||
if self._api_base is not None:
|
||||
base = self._api_base.rstrip("/")
|
||||
self._api_base = base if base.endswith("/v1") else f"{base}/v1"
|
||||
model_kwargs["api_base"] = self._api_base
|
||||
|
||||
# Default vertex_location to "global" if not provided for Vertex AI
|
||||
# Latest gemini models are only available through the global region
|
||||
if (
|
||||
@@ -404,13 +513,30 @@ class LitellmLLM(LLM):
|
||||
else nullcontext()
|
||||
)
|
||||
with env_ctx:
|
||||
messages = _prompt_to_dicts(prompt)
|
||||
|
||||
# Bedrock's Converse API requires toolConfig when messages
|
||||
# contain toolUse/toolResult content blocks. When no tools are
|
||||
# provided for this request but the history contains tool
|
||||
# content from previous turns, strip it to plain text.
|
||||
is_bedrock = self._model_provider in {
|
||||
LlmProviderNames.BEDROCK,
|
||||
LlmProviderNames.BEDROCK_CONVERSE,
|
||||
}
|
||||
if (
|
||||
is_bedrock
|
||||
and not tools
|
||||
and _messages_contain_tool_content(messages)
|
||||
):
|
||||
messages = _strip_tool_content_from_messages(messages)
|
||||
|
||||
response = litellm.completion(
|
||||
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
|
||||
model=model,
|
||||
base_url=self._api_base or None,
|
||||
api_version=self._api_version or None,
|
||||
custom_llm_provider=self._custom_llm_provider or None,
|
||||
messages=_prompt_to_dicts(prompt),
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=stream,
|
||||
|
||||
@@ -322,7 +322,7 @@ def test_llm(llm: LLM) -> str | None:
|
||||
error_msg = None
|
||||
for _ in range(2):
|
||||
try:
|
||||
llm.invoke(UserMessage(content="Do not respond"))
|
||||
llm.invoke(UserMessage(content="Do not respond"), max_tokens=50)
|
||||
return None
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
|
||||
@@ -1,52 +1,27 @@
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
|
||||
OPENAI_PROVIDER_NAME = "openai"
|
||||
# Curated list of OpenAI models to show by default in the UI
|
||||
OPENAI_VISIBLE_MODEL_NAMES = {
|
||||
"gpt-5",
|
||||
"gpt-5-mini",
|
||||
"o1",
|
||||
"o3-mini",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
}
|
||||
|
||||
BEDROCK_PROVIDER_NAME = "bedrock"
|
||||
BEDROCK_DEFAULT_MODEL = "anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
|
||||
|
||||
def _fallback_bedrock_regions() -> list[str]:
|
||||
# Fall back to a conservative set of well-known Bedrock regions if boto3 data isn't available.
|
||||
return [
|
||||
"us-east-1",
|
||||
"us-east-2",
|
||||
"us-gov-east-1",
|
||||
"us-gov-west-1",
|
||||
"us-west-2",
|
||||
"ap-northeast-1",
|
||||
"ap-south-1",
|
||||
"ap-southeast-1",
|
||||
"ap-southeast-2",
|
||||
"ap-east-1",
|
||||
"ca-central-1",
|
||||
"eu-central-1",
|
||||
"eu-west-2",
|
||||
]
|
||||
|
||||
|
||||
OLLAMA_PROVIDER_NAME = "ollama_chat"
|
||||
OLLAMA_API_KEY_CONFIG_KEY = "OLLAMA_API_KEY"
|
||||
|
||||
LM_STUDIO_PROVIDER_NAME = "lm_studio"
|
||||
LM_STUDIO_API_KEY_CONFIG_KEY = "LM_STUDIO_API_KEY"
|
||||
|
||||
# Providers that use optional Bearer auth from custom_config
|
||||
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING: dict[str, str] = {
|
||||
LlmProviderNames.OLLAMA_CHAT: OLLAMA_API_KEY_CONFIG_KEY,
|
||||
LlmProviderNames.LM_STUDIO: LM_STUDIO_API_KEY_CONFIG_KEY,
|
||||
}
|
||||
|
||||
# OpenRouter
|
||||
OPENROUTER_PROVIDER_NAME = "openrouter"
|
||||
|
||||
ANTHROPIC_PROVIDER_NAME = "anthropic"
|
||||
|
||||
# Curated list of Anthropic models to show by default in the UI
|
||||
ANTHROPIC_VISIBLE_MODEL_NAMES = {
|
||||
"claude-opus-4-5",
|
||||
"claude-sonnet-4-5",
|
||||
"claude-haiku-4-5",
|
||||
}
|
||||
|
||||
AZURE_PROVIDER_NAME = "azure"
|
||||
|
||||
|
||||
@@ -54,13 +29,6 @@ VERTEXAI_PROVIDER_NAME = "vertex_ai"
|
||||
VERTEX_CREDENTIALS_FILE_KWARG = "vertex_credentials"
|
||||
VERTEX_CREDENTIALS_FILE_KWARG_ENV_VAR_FORMAT = "CREDENTIALS_FILE"
|
||||
VERTEX_LOCATION_KWARG = "vertex_location"
|
||||
VERTEXAI_DEFAULT_MODEL = "gemini-2.5-flash"
|
||||
# Curated list of Vertex AI models to show by default in the UI
|
||||
VERTEXAI_VISIBLE_MODEL_NAMES = {
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-flash-lite",
|
||||
"gemini-2.5-pro",
|
||||
}
|
||||
|
||||
AWS_REGION_NAME_KWARG = "aws_region_name"
|
||||
AWS_REGION_NAME_KWARG_ENV_VAR_FORMAT = "AWS_REGION_NAME"
|
||||
|
||||
@@ -15,6 +15,7 @@ from onyx.llm.well_known_providers.auto_update_service import (
|
||||
from onyx.llm.well_known_providers.constants import ANTHROPIC_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import AZURE_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import BEDROCK_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import LM_STUDIO_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OLLAMA_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENAI_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENROUTER_PROVIDER_NAME
|
||||
@@ -44,6 +45,7 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
|
||||
ANTHROPIC_PROVIDER_NAME: get_anthropic_model_names(),
|
||||
VERTEXAI_PROVIDER_NAME: get_vertexai_model_names(),
|
||||
OLLAMA_PROVIDER_NAME: [], # Dynamic - fetched from Ollama API
|
||||
LM_STUDIO_PROVIDER_NAME: [], # Dynamic - fetched from LM Studio API
|
||||
OPENROUTER_PROVIDER_NAME: [], # Dynamic - fetched from OpenRouter API
|
||||
}
|
||||
|
||||
@@ -323,6 +325,7 @@ def get_provider_display_name(provider_name: str) -> str:
|
||||
_ONYX_PROVIDER_DISPLAY_NAMES: dict[str, str] = {
|
||||
OPENAI_PROVIDER_NAME: "ChatGPT (OpenAI)",
|
||||
OLLAMA_PROVIDER_NAME: "Ollama",
|
||||
LM_STUDIO_PROVIDER_NAME: "LM Studio",
|
||||
ANTHROPIC_PROVIDER_NAME: "Claude (Anthropic)",
|
||||
AZURE_PROVIDER_NAME: "Azure OpenAI",
|
||||
BEDROCK_PROVIDER_NAME: "Amazon Bedrock",
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"version": "1.1",
|
||||
"updated_at": "2026-02-05T00:00:00Z",
|
||||
"updated_at": "2026-03-05T00:00:00Z",
|
||||
"providers": {
|
||||
"openai": {
|
||||
"default_model": { "name": "gpt-5.2" },
|
||||
"default_model": { "name": "gpt-5.4" },
|
||||
"additional_visible_models": [
|
||||
{ "name": "gpt-5-mini" },
|
||||
{ "name": "gpt-4.1" }
|
||||
{ "name": "gpt-5.4" },
|
||||
{ "name": "gpt-5.2" }
|
||||
]
|
||||
},
|
||||
"anthropic": {
|
||||
@@ -16,6 +16,10 @@
|
||||
"name": "claude-opus-4-6",
|
||||
"display_name": "Claude Opus 4.6"
|
||||
},
|
||||
{
|
||||
"name": "claude-sonnet-4-6",
|
||||
"display_name": "Claude Sonnet 4.6"
|
||||
},
|
||||
{
|
||||
"name": "claude-opus-4-5",
|
||||
"display_name": "Claude Opus 4.5"
|
||||
|
||||
@@ -10,6 +10,7 @@ from onyx.mcp_server.utils import get_indexed_sources
|
||||
from onyx.mcp_server.utils import require_access_token
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import build_api_server_url_for_http_requests
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -26,6 +27,14 @@ async def search_indexed_documents(
|
||||
Use this tool for information that is not public knowledge and specific to the user,
|
||||
their team, their work, or their organization/company.
|
||||
|
||||
Note: In CE mode, this tool uses the chat endpoint internally which invokes an LLM
|
||||
on every call, consuming tokens and adding latency.
|
||||
Additionally, CE callers receive a truncated snippet (blurb) instead of a full document chunk,
|
||||
but this should still be sufficient for most use cases. CE mode functionality should be swapped
|
||||
when a dedicated CE search endpoint is implemented.
|
||||
|
||||
In EE mode, the dedicated search endpoint is used instead.
|
||||
|
||||
To find a list of available sources, use the `indexed_sources` resource.
|
||||
Returns chunks of text as search results with snippets, scores, and metadata.
|
||||
|
||||
@@ -111,48 +120,73 @@ async def search_indexed_documents(
|
||||
if time_cutoff_dt:
|
||||
filters["time_cutoff"] = time_cutoff_dt.isoformat()
|
||||
|
||||
# Build the search request using the new SendSearchQueryRequest format
|
||||
search_request = {
|
||||
"search_query": query,
|
||||
"filters": filters,
|
||||
"num_docs_fed_to_llm_selection": limit,
|
||||
"run_query_expansion": False,
|
||||
"include_content": True,
|
||||
"stream": False,
|
||||
}
|
||||
is_ee = global_version.is_ee_version()
|
||||
base_url = build_api_server_url_for_http_requests(respect_env_override_if_set=True)
|
||||
auth_headers = {"Authorization": f"Bearer {access_token.token}"}
|
||||
|
||||
search_request: dict[str, Any]
|
||||
if is_ee:
|
||||
# EE: use the dedicated search endpoint (no LLM invocation)
|
||||
search_request = {
|
||||
"search_query": query,
|
||||
"filters": filters,
|
||||
"num_docs_fed_to_llm_selection": limit,
|
||||
"run_query_expansion": False,
|
||||
"include_content": True,
|
||||
"stream": False,
|
||||
}
|
||||
endpoint = f"{base_url}/search/send-search-message"
|
||||
error_key = "error"
|
||||
docs_key = "search_docs"
|
||||
content_field = "content"
|
||||
else:
|
||||
# CE: fall back to the chat endpoint (invokes LLM, consumes tokens)
|
||||
search_request = {
|
||||
"message": query,
|
||||
"stream": False,
|
||||
"chat_session_info": {},
|
||||
}
|
||||
if filters:
|
||||
search_request["internal_search_filters"] = filters
|
||||
endpoint = f"{base_url}/chat/send-chat-message"
|
||||
error_key = "error_msg"
|
||||
docs_key = "top_documents"
|
||||
content_field = "blurb"
|
||||
|
||||
# Call the API server using the new send-search-message route
|
||||
try:
|
||||
response = await get_http_client().post(
|
||||
f"{build_api_server_url_for_http_requests(respect_env_override_if_set=True)}/search/send-search-message",
|
||||
endpoint,
|
||||
json=search_request,
|
||||
headers={"Authorization": f"Bearer {access_token.token}"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
# Check for error in response
|
||||
if result.get("error"):
|
||||
if result.get(error_key):
|
||||
return {
|
||||
"documents": [],
|
||||
"total_results": 0,
|
||||
"query": query,
|
||||
"error": result.get("error"),
|
||||
"error": result.get(error_key),
|
||||
}
|
||||
|
||||
# Return simplified format for MCP clients
|
||||
fields_to_return = [
|
||||
"semantic_identifier",
|
||||
"content",
|
||||
"source_type",
|
||||
"link",
|
||||
"score",
|
||||
]
|
||||
documents = [
|
||||
{key: doc.get(key) for key in fields_to_return}
|
||||
for doc in result.get("search_docs", [])
|
||||
{
|
||||
"semantic_identifier": doc.get("semantic_identifier"),
|
||||
"content": doc.get(content_field),
|
||||
"source_type": doc.get("source_type"),
|
||||
"link": doc.get("link"),
|
||||
"score": doc.get("score"),
|
||||
}
|
||||
for doc in result.get(docs_key, [])
|
||||
]
|
||||
|
||||
# NOTE: search depth is controlled by the backend persona defaults, not `limit`.
|
||||
# `limit` only caps the returned list; fewer results may be returned if the
|
||||
# backend retrieves fewer documents than requested.
|
||||
documents = documents[:limit]
|
||||
|
||||
logger.info(
|
||||
f"Onyx MCP Server: Internal search returned {len(documents)} results"
|
||||
)
|
||||
@@ -160,7 +194,6 @@ async def search_indexed_documents(
|
||||
"documents": documents,
|
||||
"total_results": len(documents),
|
||||
"query": query,
|
||||
"executed_queries": result.get("all_executed_queries", [query]),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Onyx MCP Server: Document search error: {e}", exc_info=True)
|
||||
|
||||
@@ -16,6 +16,7 @@ Cache Strategy:
|
||||
using only the SOURCE-type node as the ancestor
|
||||
"""
|
||||
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -204,6 +205,30 @@ def cache_hierarchy_nodes_batch(
|
||||
redis_client.expire(raw_id_key, HIERARCHY_CACHE_TTL_SECONDS)
|
||||
|
||||
|
||||
def evict_hierarchy_nodes_from_cache(
|
||||
redis_client: Redis,
|
||||
source: DocumentSource,
|
||||
raw_node_ids: list[str],
|
||||
) -> None:
|
||||
"""Remove specific hierarchy nodes from the Redis cache.
|
||||
|
||||
Deletes entries from both the parent-chain hash and the raw_id→node_id hash.
|
||||
"""
|
||||
if not raw_node_ids:
|
||||
return
|
||||
|
||||
cache_key = _cache_key(source)
|
||||
raw_id_key = _raw_id_cache_key(source)
|
||||
|
||||
# Look up node_ids so we can remove them from the parent-chain hash
|
||||
raw_values = cast(list[str | None], redis_client.hmget(raw_id_key, raw_node_ids))
|
||||
node_id_strs = [v for v in raw_values if v is not None]
|
||||
|
||||
if node_id_strs:
|
||||
redis_client.hdel(cache_key, *node_id_strs)
|
||||
redis_client.hdel(raw_id_key, *raw_node_ids)
|
||||
|
||||
|
||||
def get_node_id_from_raw_id(
|
||||
redis_client: Redis,
|
||||
source: DocumentSource,
|
||||
|
||||
@@ -1905,7 +1905,7 @@ def get_connector_by_id(
|
||||
@router.post("/connector-request")
|
||||
def submit_connector_request(
|
||||
request_data: ConnectorRequestSubmission,
|
||||
user: User | None = Depends(current_user),
|
||||
user: User = Depends(current_user),
|
||||
) -> StatusResponse:
|
||||
"""
|
||||
Submit a connector request for Cloud deployments.
|
||||
@@ -1918,7 +1918,7 @@ def submit_connector_request(
|
||||
raise HTTPException(status_code=400, detail="Connector name cannot be empty")
|
||||
|
||||
# Get user identifier for telemetry
|
||||
user_email = user.email if user else None
|
||||
user_email = user.email
|
||||
distinct_id = user_email or tenant_id
|
||||
|
||||
# Track connector request via PostHog telemetry (Cloud only)
|
||||
|
||||
@@ -57,9 +57,6 @@ def list_messages(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> MessageListResponse:
|
||||
"""Get all messages for a build session."""
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
session_manager = SessionManager(db_session)
|
||||
|
||||
messages = session_manager.list_messages(session_id, user.id)
|
||||
|
||||
@@ -961,9 +961,9 @@
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@hono/node-server": {
|
||||
"version": "1.19.9",
|
||||
"resolved": "https://registry.npmjs.org/@hono/node-server/-/node-server-1.19.9.tgz",
|
||||
"integrity": "sha512-vHL6w3ecZsky+8P5MD+eFfaGTyCeOHUIFYMGpQGbrBTSmNNoxv0if69rEZ5giu36weC5saFuznL411gRX7bJDw==",
|
||||
"version": "1.19.10",
|
||||
"resolved": "https://registry.npmjs.org/@hono/node-server/-/node-server-1.19.10.tgz",
|
||||
"integrity": "sha512-hZ7nOssGqRgyV3FVVQdfi+U4q02uB23bpnYpdvNXkYTRRyWx84b7yf1ans+dnJ/7h41sGL3CeQTfO+ZGxuO+Iw==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=18.14.1"
|
||||
@@ -1573,27 +1573,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@isaacs/balanced-match": {
|
||||
"version": "4.0.1",
|
||||
"resolved": "https://registry.npmjs.org/@isaacs/balanced-match/-/balanced-match-4.0.1.tgz",
|
||||
"integrity": "sha512-yzMTt9lEb8Gv7zRioUilSglI0c0smZ9k5D65677DLWLtWJaXIS3CqcGyUFByYKlnUj6TkjLVs54fBl6+TiGQDQ==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": "20 || >=22"
|
||||
}
|
||||
},
|
||||
"node_modules/@isaacs/brace-expansion": {
|
||||
"version": "5.0.1",
|
||||
"resolved": "https://registry.npmjs.org/@isaacs/brace-expansion/-/brace-expansion-5.0.1.tgz",
|
||||
"integrity": "sha512-WMz71T1JS624nWj2n2fnYAuPovhv7EUhk69R6i9dsVyzxt5eM3bjwvgk9L+APE1TRscGysAVMANkB0jh0LQZrQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@isaacs/balanced-match": "^4.0.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": "20 || >=22"
|
||||
}
|
||||
},
|
||||
"node_modules/@jridgewell/gen-mapping": {
|
||||
"version": "0.3.13",
|
||||
"resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.13.tgz",
|
||||
@@ -1680,9 +1659,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@modelcontextprotocol/sdk/node_modules/ajv": {
|
||||
"version": "8.17.1",
|
||||
"resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz",
|
||||
"integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==",
|
||||
"version": "8.18.0",
|
||||
"resolved": "https://registry.npmjs.org/ajv/-/ajv-8.18.0.tgz",
|
||||
"integrity": "sha512-PlXPeEWMXMZ7sPYOHqmDyCJzcfNrUr3fGNKtezX14ykXOEIvyK81d+qydx89KY5O71FKMPaQ2vBfBFI5NHR63A==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"fast-deep-equal": "^3.1.3",
|
||||
@@ -3855,6 +3834,27 @@
|
||||
"path-browserify": "^1.0.1"
|
||||
}
|
||||
},
|
||||
"node_modules/@ts-morph/common/node_modules/balanced-match": {
|
||||
"version": "4.0.4",
|
||||
"resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-4.0.4.tgz",
|
||||
"integrity": "sha512-BLrgEcRTwX2o6gGxGOCNyMvGSp35YofuYzw9h1IMTRmKqttAZZVU67bdb9Pr2vUHA8+j3i2tJfjO6C6+4myGTA==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": "18 || 20 || >=22"
|
||||
}
|
||||
},
|
||||
"node_modules/@ts-morph/common/node_modules/brace-expansion": {
|
||||
"version": "5.0.3",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-5.0.3.tgz",
|
||||
"integrity": "sha512-fy6KJm2RawA5RcHkLa1z/ScpBeA762UF9KmZQxwIbDtRJrgLzM10depAiEQ+CXYcoiqW1/m96OAAoke2nE9EeA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"balanced-match": "^4.0.2"
|
||||
},
|
||||
"engines": {
|
||||
"node": "18 || 20 || >=22"
|
||||
}
|
||||
},
|
||||
"node_modules/@ts-morph/common/node_modules/fast-glob": {
|
||||
"version": "3.3.3",
|
||||
"resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.3.3.tgz",
|
||||
@@ -3884,15 +3884,15 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@ts-morph/common/node_modules/minimatch": {
|
||||
"version": "10.1.1",
|
||||
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-10.1.1.tgz",
|
||||
"integrity": "sha512-enIvLvRAFZYXJzkCYG5RKmPfrFArdLv+R+lbQ53BmIMLIry74bjKzX6iHAm8WYamJkhSSEabrWN5D97XnKObjQ==",
|
||||
"version": "10.2.4",
|
||||
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-10.2.4.tgz",
|
||||
"integrity": "sha512-oRjTw/97aTBN0RHbYCdtF1MQfvusSIBQM0IZEgzl6426+8jSC0nF1a/GmnVLpfB9yyr6g6FTqWqiZVbxrtaCIg==",
|
||||
"license": "BlueOak-1.0.0",
|
||||
"dependencies": {
|
||||
"@isaacs/brace-expansion": "^5.0.0"
|
||||
"brace-expansion": "^5.0.2"
|
||||
},
|
||||
"engines": {
|
||||
"node": "20 || >=22"
|
||||
"node": "18 || 20 || >=22"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/isaacs"
|
||||
@@ -4234,13 +4234,13 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@typescript-eslint/typescript-estree/node_modules/minimatch": {
|
||||
"version": "9.0.5",
|
||||
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz",
|
||||
"integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==",
|
||||
"version": "9.0.9",
|
||||
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.9.tgz",
|
||||
"integrity": "sha512-OBwBN9AL4dqmETlpS2zasx+vTeWclWzkblfZk7KTA5j3jeOONz/tRCnZomUyvNg83wL5Zv9Ss6HMJXAgL8R2Yg==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"dependencies": {
|
||||
"brace-expansion": "^2.0.1"
|
||||
"brace-expansion": "^2.0.2"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=16 || 14 >=14.17"
|
||||
@@ -4619,9 +4619,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/ajv": {
|
||||
"version": "6.12.6",
|
||||
"resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz",
|
||||
"integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==",
|
||||
"version": "6.14.0",
|
||||
"resolved": "https://registry.npmjs.org/ajv/-/ajv-6.14.0.tgz",
|
||||
"integrity": "sha512-IWrosm/yrn43eiKqkfkHis7QioDleaXQHdDVPKg0FSwwd/DuvyX79TZnFOnYpB7dcsFAMmtFztZuXPDvSePkFw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
@@ -4653,9 +4653,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/ajv-formats/node_modules/ajv": {
|
||||
"version": "8.17.1",
|
||||
"resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz",
|
||||
"integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==",
|
||||
"version": "8.18.0",
|
||||
"resolved": "https://registry.npmjs.org/ajv/-/ajv-8.18.0.tgz",
|
||||
"integrity": "sha512-PlXPeEWMXMZ7sPYOHqmDyCJzcfNrUr3fGNKtezX14ykXOEIvyK81d+qydx89KY5O71FKMPaQ2vBfBFI5NHR63A==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"fast-deep-equal": "^3.1.3",
|
||||
@@ -6758,12 +6758,12 @@
|
||||
}
|
||||
},
|
||||
"node_modules/express-rate-limit": {
|
||||
"version": "8.2.1",
|
||||
"resolved": "https://registry.npmjs.org/express-rate-limit/-/express-rate-limit-8.2.1.tgz",
|
||||
"integrity": "sha512-PCZEIEIxqwhzw4KF0n7QF4QqruVTcF73O5kFKUnGOyjbCCgizBBiFaYpd/fnBLUMPw/BWw9OsiN7GgrNYr7j6g==",
|
||||
"version": "8.3.0",
|
||||
"resolved": "https://registry.npmjs.org/express-rate-limit/-/express-rate-limit-8.3.0.tgz",
|
||||
"integrity": "sha512-KJzBawY6fB9FiZGdE/0aftepZ91YlaGIrV8vgblRM3J8X+dHx/aiowJWwkx6LIGyuqGiANsjSwwrbb8mifOJ4Q==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"ip-address": "10.0.1"
|
||||
"ip-address": "10.1.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 16"
|
||||
@@ -7556,9 +7556,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/ip-address": {
|
||||
"version": "10.0.1",
|
||||
"resolved": "https://registry.npmjs.org/ip-address/-/ip-address-10.0.1.tgz",
|
||||
"integrity": "sha512-NWv9YLW4PoW2B7xtzaS3NCot75m6nK7Icdv0o3lfMceJVRfSoQwqD4wEH5rLwoKJwUiZ/rfpiVBhnaF0FK4HoA==",
|
||||
"version": "10.1.0",
|
||||
"resolved": "https://registry.npmjs.org/ip-address/-/ip-address-10.1.0.tgz",
|
||||
"integrity": "sha512-XXADHxXmvT9+CRxhXg56LJovE+bmWnEWB78LB83VZTprKTmaC5QfruXocxzTZ2Kl0DNwKuBdlIhjL8LeY8Sf8Q==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 12"
|
||||
@@ -8831,9 +8831,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/minimatch": {
|
||||
"version": "3.1.2",
|
||||
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz",
|
||||
"integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==",
|
||||
"version": "3.1.5",
|
||||
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.5.tgz",
|
||||
"integrity": "sha512-VgjWUsnnT6n+NUk6eZq77zeFdpW2LWDzP6zFGrCbHXiYNul5Dzqk2HHQ5uFH2DNW5Xbp8+jVzaeNt94ssEEl4w==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"dependencies": {
|
||||
@@ -9699,9 +9699,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/qs": {
|
||||
"version": "6.14.1",
|
||||
"resolved": "https://registry.npmjs.org/qs/-/qs-6.14.1.tgz",
|
||||
"integrity": "sha512-4EK3+xJl8Ts67nLYNwqw/dsFVnCf+qR7RgXSK9jEEm9unao3njwMDdmsdvoKBKHzxd7tCYz5e5M+SnMjdtXGQQ==",
|
||||
"version": "6.14.2",
|
||||
"resolved": "https://registry.npmjs.org/qs/-/qs-6.14.2.tgz",
|
||||
"integrity": "sha512-V/yCWTTF7VJ9hIh18Ugr2zhJMP01MY7c5kh4J870L7imm6/DIzBsNLTXzMwUA3yZ5b/KBqLx8Kp3uRvd7xSe3Q==",
|
||||
"license": "BSD-3-Clause",
|
||||
"dependencies": {
|
||||
"side-channel": "^1.1.0"
|
||||
|
||||
@@ -54,18 +54,14 @@ def _require_opensearch(db_session: Session) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _get_user_access_info(
|
||||
user: User | None, db_session: Session
|
||||
) -> tuple[str | None, list[str]]:
|
||||
if not user:
|
||||
return None, []
|
||||
def _get_user_access_info(user: User, db_session: Session) -> tuple[str, list[str]]:
|
||||
return user.email, get_user_external_group_ids(db_session, user)
|
||||
|
||||
|
||||
@router.get(HIERARCHY_NODES_LIST_PATH)
|
||||
def list_accessible_hierarchy_nodes(
|
||||
source: DocumentSource,
|
||||
user: User | None = Depends(current_user),
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HierarchyNodesResponse:
|
||||
_require_opensearch(db_session)
|
||||
@@ -92,7 +88,7 @@ def list_accessible_hierarchy_nodes(
|
||||
@router.post(HIERARCHY_NODE_DOCUMENTS_PATH)
|
||||
def list_accessible_hierarchy_node_documents(
|
||||
documents_request: HierarchyNodeDocumentsRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HierarchyNodeDocumentsResponse:
|
||||
_require_opensearch(db_session)
|
||||
|
||||
@@ -1013,7 +1013,7 @@ def get_mcp_servers_for_assistant(
|
||||
@router.get("/servers", response_model=MCPServersResponse)
|
||||
def get_mcp_servers_for_user(
|
||||
db: Session = Depends(get_session),
|
||||
user: User | None = Depends(current_user),
|
||||
user: User = Depends(current_user),
|
||||
) -> MCPServersResponse:
|
||||
"""List all MCP servers for use in agent configuration and chat UI.
|
||||
|
||||
|
||||
@@ -10,6 +10,8 @@ from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_BYTES
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
@@ -35,6 +37,38 @@ def get_safe_filename(upload: UploadFile) -> str:
|
||||
return upload.filename
|
||||
|
||||
|
||||
def get_upload_size_bytes(upload: UploadFile) -> int | None:
|
||||
"""Best-effort file size in bytes without consuming the stream."""
|
||||
if upload.size is not None:
|
||||
return upload.size
|
||||
|
||||
try:
|
||||
current_pos = upload.file.tell()
|
||||
upload.file.seek(0, 2)
|
||||
size = upload.file.tell()
|
||||
upload.file.seek(current_pos)
|
||||
return size
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Could not determine upload size via stream seek "
|
||||
f"(filename='{get_safe_filename(upload)}', "
|
||||
f"error_type={type(e).__name__}, error={e})"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def is_upload_too_large(upload: UploadFile, max_bytes: int) -> bool:
|
||||
"""Return True when upload size is known and exceeds max_bytes."""
|
||||
size_bytes = get_upload_size_bytes(upload)
|
||||
if size_bytes is None:
|
||||
logger.warning(
|
||||
"Could not determine upload size; skipping size-limit check for "
|
||||
f"'{get_safe_filename(upload)}'"
|
||||
)
|
||||
return False
|
||||
return size_bytes > max_bytes
|
||||
|
||||
|
||||
# Guard against extremely large images
|
||||
Image.MAX_IMAGE_PIXELS = 12000 * 12000
|
||||
|
||||
@@ -159,6 +193,18 @@ def categorize_uploaded_files(
|
||||
for upload in files:
|
||||
try:
|
||||
filename = get_safe_filename(upload)
|
||||
|
||||
# Size limit is a hard safety cap and is enforced even when token
|
||||
# threshold checks are skipped via SKIP_USERFILE_THRESHOLD settings.
|
||||
if is_upload_too_large(upload, USER_FILE_MAX_UPLOAD_SIZE_BYTES):
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
reason=f"Exceeds {USER_FILE_MAX_UPLOAD_SIZE_MB} MB file size limit",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
extension = get_file_ext(filename)
|
||||
|
||||
# If image, estimate tokens via dedicated method first
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import cast
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
@@ -27,8 +28,6 @@ from onyx.db.feedback import update_document_boost_for_user
|
||||
from onyx.db.feedback import update_document_hidden_for_user
|
||||
from onyx.db.index_attempt import cancel_indexing_attempts_for_ccpair
|
||||
from onyx.db.models import User
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
@@ -125,11 +124,11 @@ def validate_existing_genai_api_key(
|
||||
try:
|
||||
llm = get_default_llm(timeout=10)
|
||||
except ValueError:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "LLM not setup")
|
||||
raise HTTPException(status_code=404, detail="LLM not setup")
|
||||
|
||||
error = test_llm(llm)
|
||||
if error:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, error)
|
||||
raise HTTPException(status_code=400, detail=error)
|
||||
|
||||
# Mark check as successful
|
||||
curr_time = datetime.now(tz=timezone.utc)
|
||||
@@ -160,7 +159,10 @@ def create_deletion_attempt_for_connector_id(
|
||||
f"'{credential_id}' does not exist. Has it already been deleted?"
|
||||
)
|
||||
logger.error(error)
|
||||
raise OnyxError(OnyxErrorCode.CONNECTOR_NOT_FOUND, error)
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=error,
|
||||
)
|
||||
|
||||
# Cancel any scheduled indexing attempts
|
||||
cancel_indexing_attempts_for_ccpair(
|
||||
@@ -176,9 +178,9 @@ def create_deletion_attempt_for_connector_id(
|
||||
# connector_credential_pair=cc_pair, db_session=db_session
|
||||
# )
|
||||
# if deletion_attempt_disallowed_reason:
|
||||
# raise OnyxError(
|
||||
# OnyxErrorCode.VALIDATION_ERROR,
|
||||
# deletion_attempt_disallowed_reason,
|
||||
# raise HTTPException(
|
||||
# status_code=400,
|
||||
# detail=deletion_attempt_disallowed_reason,
|
||||
# )
|
||||
|
||||
# mark as deleting
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
@@ -22,8 +24,6 @@ from onyx.db.discord_bot import update_discord_channel_config
|
||||
from onyx.db.discord_bot import update_guild_config
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.manage.discord_bot.models import DiscordBotConfigCreateRequest
|
||||
from onyx.server.manage.discord_bot.models import DiscordBotConfigResponse
|
||||
from onyx.server.manage.discord_bot.models import DiscordChannelConfigResponse
|
||||
@@ -47,14 +47,14 @@ def _check_bot_config_api_access() -> None:
|
||||
- When DISCORD_BOT_TOKEN env var is set (managed via env)
|
||||
"""
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.UNAUTHORIZED,
|
||||
"Discord bot configuration is managed by Onyx on Cloud.",
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Discord bot configuration is managed by Onyx on Cloud.",
|
||||
)
|
||||
if DISCORD_BOT_TOKEN:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.UNAUTHORIZED,
|
||||
"Discord bot is configured via environment variables. API access disabled.",
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Discord bot is configured via environment variables. API access disabled.",
|
||||
)
|
||||
|
||||
|
||||
@@ -92,9 +92,9 @@ def create_bot_request(
|
||||
bot_token=request.bot_token,
|
||||
)
|
||||
except ValueError:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CONFLICT,
|
||||
"Discord bot config already exists. Delete it first to create a new one.",
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Discord bot config already exists. Delete it first to create a new one.",
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
@@ -117,7 +117,7 @@ def delete_bot_config_endpoint(
|
||||
"""
|
||||
deleted = delete_discord_bot_config(db_session)
|
||||
if not deleted:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Bot config not found")
|
||||
raise HTTPException(status_code=404, detail="Bot config not found")
|
||||
|
||||
# Also delete the service API key used by the Discord bot
|
||||
delete_discord_service_api_key(db_session)
|
||||
@@ -144,7 +144,7 @@ def delete_service_api_key_endpoint(
|
||||
"""
|
||||
deleted = delete_discord_service_api_key(db_session)
|
||||
if not deleted:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Service API key not found")
|
||||
raise HTTPException(status_code=404, detail="Service API key not found")
|
||||
db_session.commit()
|
||||
return {"deleted": True}
|
||||
|
||||
@@ -189,7 +189,7 @@ def get_guild_config(
|
||||
"""Get specific guild config."""
|
||||
config = get_guild_config_by_internal_id(db_session, internal_id=config_id)
|
||||
if not config:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Guild config not found")
|
||||
raise HTTPException(status_code=404, detail="Guild config not found")
|
||||
return DiscordGuildConfigResponse.model_validate(config)
|
||||
|
||||
|
||||
@@ -203,7 +203,7 @@ def update_guild_request(
|
||||
"""Update guild config."""
|
||||
config = get_guild_config_by_internal_id(db_session, internal_id=config_id)
|
||||
if not config:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Guild config not found")
|
||||
raise HTTPException(status_code=404, detail="Guild config not found")
|
||||
|
||||
config = update_guild_config(
|
||||
db_session,
|
||||
@@ -228,7 +228,7 @@ def delete_guild_request(
|
||||
"""
|
||||
deleted = delete_guild_config(db_session, config_id)
|
||||
if not deleted:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Guild config not found")
|
||||
raise HTTPException(status_code=404, detail="Guild config not found")
|
||||
|
||||
# On Cloud, delete service API key when all guilds are removed
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
@@ -254,9 +254,9 @@ def list_channel_configs(
|
||||
"""List whitelisted channels for a guild."""
|
||||
guild_config = get_guild_config_by_internal_id(db_session, internal_id=config_id)
|
||||
if not guild_config:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Guild config not found")
|
||||
raise HTTPException(status_code=404, detail="Guild config not found")
|
||||
if not guild_config.guild_id:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Guild not yet registered")
|
||||
raise HTTPException(status_code=400, detail="Guild not yet registered")
|
||||
|
||||
configs = get_channel_configs(db_session, config_id)
|
||||
return [DiscordChannelConfigResponse.model_validate(c) for c in configs]
|
||||
@@ -278,7 +278,7 @@ def update_channel_request(
|
||||
db_session, guild_config_id, channel_config_id
|
||||
)
|
||||
if not config:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Channel config not found")
|
||||
raise HTTPException(status_code=404, detail="Channel config not found")
|
||||
|
||||
config = update_discord_channel_config(
|
||||
db_session,
|
||||
|
||||
@@ -3,6 +3,7 @@ import re
|
||||
|
||||
import requests
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
|
||||
from onyx import __version__
|
||||
from onyx.auth.users import anonymous_user_enabled
|
||||
@@ -15,8 +16,6 @@ from onyx.configs.constants import DEV_VERSION_PATTERN
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.configs.constants import STABLE_VERSION_PATTERN
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.manage.models import AllVersions
|
||||
from onyx.server.manage.models import AuthTypeResponse
|
||||
from onyx.server.manage.models import ContainerVersions
|
||||
@@ -105,14 +104,14 @@ def get_versions() -> AllVersions:
|
||||
|
||||
# Ensure we have at least one tag of each type
|
||||
if not dev_tags:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"No valid dev versions found matching pattern v(number).(number).(number)-beta.(number)",
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="No valid dev versions found matching pattern v(number).(number).(number)-beta.(number)",
|
||||
)
|
||||
if not stable_tags:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"No valid stable versions found matching pattern v(number).(number).(number)",
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="No valid stable versions found matching pattern v(number).(number).(number)",
|
||||
)
|
||||
|
||||
# Sort common tags and get the latest one
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
@@ -14,8 +15,6 @@ from onyx.db.llm import remove_llm_provider__no_commit
|
||||
from onyx.db.models import LLMProvider as LLMProviderModel
|
||||
from onyx.db.models import ModelConfiguration
|
||||
from onyx.db.models import User
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.image_gen.exceptions import ImageProviderCredentialsError
|
||||
from onyx.image_gen.factory import get_image_generation_provider
|
||||
from onyx.image_gen.factory import validate_credentials
|
||||
@@ -75,9 +74,9 @@ def _build_llm_provider_request(
|
||||
# Clone mode: Only use API key from source provider
|
||||
source_provider = db_session.get(LLMProviderModel, source_llm_provider_id)
|
||||
if not source_provider:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"Source LLM provider with id {source_llm_provider_id} not found",
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Source LLM provider with id {source_llm_provider_id} not found",
|
||||
)
|
||||
|
||||
_validate_llm_provider_change(
|
||||
@@ -111,9 +110,9 @@ def _build_llm_provider_request(
|
||||
)
|
||||
|
||||
if not provider:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No provider or source llm provided",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No provider or source llm provided",
|
||||
)
|
||||
|
||||
credentials = ImageGenerationProviderCredentials(
|
||||
@@ -125,9 +124,9 @@ def _build_llm_provider_request(
|
||||
)
|
||||
|
||||
if not validate_credentials(provider, credentials):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Incorrect credentials for {provider}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Incorrect credentials for {provider}",
|
||||
)
|
||||
|
||||
return LLMProviderUpsertRequest(
|
||||
@@ -216,9 +215,9 @@ def test_image_generation(
|
||||
LLMProviderModel, test_request.source_llm_provider_id
|
||||
)
|
||||
if not source_provider:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"Source LLM provider with id {test_request.source_llm_provider_id} not found",
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Source LLM provider with id {test_request.source_llm_provider_id} not found",
|
||||
)
|
||||
|
||||
_validate_llm_provider_change(
|
||||
@@ -237,9 +236,9 @@ def test_image_generation(
|
||||
provider = source_provider.provider
|
||||
|
||||
if provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No provider or source llm provided",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No provider or source llm provided",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -258,14 +257,14 @@ def test_image_generation(
|
||||
),
|
||||
)
|
||||
except ValueError:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"Invalid image generation provider: {provider}",
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Invalid image generation provider: {provider}",
|
||||
)
|
||||
except ImageProviderCredentialsError:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.UNAUTHENTICATED,
|
||||
"Invalid image generation credentials",
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid image generation credentials",
|
||||
)
|
||||
|
||||
quality = _get_test_quality_for_model(test_request.model_name)
|
||||
@@ -277,15 +276,15 @@ def test_image_generation(
|
||||
n=1,
|
||||
quality=quality,
|
||||
)
|
||||
except OnyxError:
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
# Log only exception type to avoid exposing sensitive data
|
||||
# (LiteLLM errors may contain URLs with API keys or auth tokens)
|
||||
logger.warning(f"Image generation test failed: {type(e).__name__}")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Image generation test failed: {type(e).__name__}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Image generation test failed: {type(e).__name__}",
|
||||
)
|
||||
|
||||
|
||||
@@ -310,9 +309,9 @@ def create_config(
|
||||
db_session, config_create.image_provider_id
|
||||
)
|
||||
if existing_config:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"ImageGenerationConfig with image_provider_id '{config_create.image_provider_id}' already exists",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"ImageGenerationConfig with image_provider_id '{config_create.image_provider_id}' already exists",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -346,10 +345,10 @@ def create_config(
|
||||
db_session.commit()
|
||||
db_session.refresh(config)
|
||||
return ImageGenerationConfigView.from_model(config)
|
||||
except OnyxError:
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.get("/config")
|
||||
@@ -374,9 +373,9 @@ def get_config_credentials(
|
||||
"""
|
||||
config = get_image_generation_config(db_session, image_provider_id)
|
||||
if not config:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
|
||||
)
|
||||
|
||||
return ImageGenerationCredentials.from_model(config)
|
||||
@@ -402,9 +401,9 @@ def update_config(
|
||||
# 1. Get existing config
|
||||
existing_config = get_image_generation_config(db_session, image_provider_id)
|
||||
if not existing_config:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
|
||||
)
|
||||
|
||||
old_llm_provider_id = existing_config.model_configuration.llm_provider_id
|
||||
@@ -473,10 +472,10 @@ def update_config(
|
||||
db_session.refresh(existing_config)
|
||||
return ImageGenerationConfigView.from_model(existing_config)
|
||||
|
||||
except OnyxError:
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.delete("/config/{image_provider_id}")
|
||||
@@ -490,9 +489,9 @@ def delete_config(
|
||||
# Get the config first to find the associated LLM provider
|
||||
existing_config = get_image_generation_config(db_session, image_provider_id)
|
||||
if not existing_config:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
|
||||
)
|
||||
|
||||
llm_provider_id = existing_config.model_configuration.llm_provider_id
|
||||
@@ -504,10 +503,10 @@ def delete_config(
|
||||
remove_llm_provider__no_commit(db_session, llm_provider_id)
|
||||
|
||||
db_session.commit()
|
||||
except OnyxError:
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.post("/config/{image_provider_id}/default")
|
||||
@@ -520,7 +519,7 @@ def set_config_as_default(
|
||||
try:
|
||||
set_default_image_generation_config(db_session, image_provider_id)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.delete("/config/{image_provider_id}/default")
|
||||
@@ -533,4 +532,4 @@ def unset_config_as_default(
|
||||
try:
|
||||
unset_default_image_generation_config(db_session, image_provider_id)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
@@ -48,6 +48,7 @@ from onyx.llm.utils import test_llm
|
||||
from onyx.llm.well_known_providers.auto_update_service import (
|
||||
fetch_llm_recommendations_from_github,
|
||||
)
|
||||
from onyx.llm.well_known_providers.constants import LM_STUDIO_API_KEY_CONFIG_KEY
|
||||
from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
fetch_available_well_known_llms,
|
||||
)
|
||||
@@ -62,6 +63,9 @@ from onyx.server.manage.llm.models import LLMProviderDescriptor
|
||||
from onyx.server.manage.llm.models import LLMProviderResponse
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import LMStudioFinalModelResponse
|
||||
from onyx.server.manage.llm.models import LMStudioModelsRequest
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.manage.llm.models import OllamaFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OllamaModelDetails
|
||||
from onyx.server.manage.llm.models import OllamaModelsRequest
|
||||
@@ -73,6 +77,7 @@ from onyx.server.manage.llm.models import VisionProviderResponse
|
||||
from onyx.server.manage.llm.utils import generate_bedrock_display_name
|
||||
from onyx.server.manage.llm.utils import generate_ollama_display_name
|
||||
from onyx.server.manage.llm.utils import infer_vision_support
|
||||
from onyx.server.manage.llm.utils import is_reasoning_model
|
||||
from onyx.server.manage.llm.utils import is_valid_bedrock_model
|
||||
from onyx.server.manage.llm.utils import ModelMetadata
|
||||
from onyx.server.manage.llm.utils import strip_openrouter_vendor_prefix
|
||||
@@ -441,6 +446,18 @@ def put_llm_provider(
|
||||
not existing_provider or not existing_provider.is_auto_mode
|
||||
)
|
||||
|
||||
# When transitioning to auto mode, preserve existing model configurations
|
||||
# so the upsert doesn't try to delete them (which would trip the default
|
||||
# model protection guard). sync_auto_mode_models will handle the model
|
||||
# lifecycle afterward — adding new models, hiding removed ones, and
|
||||
# updating the default. This is safe even if sync fails: the provider
|
||||
# keeps its old models and default rather than losing them.
|
||||
if transitioning_to_auto_mode and existing_provider:
|
||||
llm_provider_upsert_request.model_configurations = [
|
||||
ModelConfigurationUpsertRequest.from_model(mc)
|
||||
for mc in existing_provider.model_configurations
|
||||
]
|
||||
|
||||
try:
|
||||
result = upsert_llm_provider(
|
||||
llm_provider_upsert_request=llm_provider_upsert_request,
|
||||
@@ -453,7 +470,6 @@ def put_llm_provider(
|
||||
|
||||
config = fetch_llm_recommendations_from_github()
|
||||
if config and llm_provider_upsert_request.provider in config.providers:
|
||||
# Refetch the provider to get the updated model
|
||||
updated_provider = fetch_existing_llm_provider_by_id(
|
||||
id=result.id, db_session=db_session
|
||||
)
|
||||
@@ -1217,3 +1233,117 @@ def get_openrouter_available_models(
|
||||
logger.warning(f"Failed to sync OpenRouter models to DB: {e}")
|
||||
|
||||
return sorted_results
|
||||
|
||||
|
||||
@admin_router.post("/lm-studio/available-models")
|
||||
def get_lm_studio_available_models(
|
||||
request: LMStudioModelsRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LMStudioFinalModelResponse]:
|
||||
"""Fetch available models from an LM Studio server.
|
||||
|
||||
Uses the LM Studio-native /api/v1/models endpoint which exposes
|
||||
rich metadata including capabilities (vision, reasoning),
|
||||
display names, and context lengths.
|
||||
"""
|
||||
cleaned_api_base = request.api_base.strip().rstrip("/")
|
||||
# Strip /v1 suffix that users may copy from OpenAI-compatible tool configs;
|
||||
# the native metadata endpoint lives at /api/v1/models, not /v1/api/v1/models.
|
||||
cleaned_api_base = cleaned_api_base.removesuffix("/v1")
|
||||
if not cleaned_api_base:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"API base URL is required to fetch LM Studio models.",
|
||||
)
|
||||
|
||||
# If provider_name is given and the api_key hasn't been changed by the user,
|
||||
# fall back to the stored API key from the database (the form value is masked).
|
||||
api_key = request.api_key
|
||||
if request.provider_name and not request.api_key_changed:
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=request.provider_name, db_session=db_session
|
||||
)
|
||||
if existing_provider and existing_provider.custom_config:
|
||||
api_key = existing_provider.custom_config.get(LM_STUDIO_API_KEY_CONFIG_KEY)
|
||||
|
||||
url = f"{cleaned_api_base}/api/v1/models"
|
||||
headers: dict[str, str] = {}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
try:
|
||||
response = httpx.get(url, headers=headers, timeout=10.0)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
except Exception as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch LM Studio models: {e}",
|
||||
)
|
||||
|
||||
models = response_json.get("models", [])
|
||||
if not isinstance(models, list) or len(models) == 0:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No models found from your LM Studio server.",
|
||||
)
|
||||
|
||||
results: list[LMStudioFinalModelResponse] = []
|
||||
for item in models:
|
||||
# Filter to LLM-type models only (skip embeddings, etc.)
|
||||
if item.get("type") != "llm":
|
||||
continue
|
||||
|
||||
model_key = item.get("key")
|
||||
if not model_key:
|
||||
continue
|
||||
|
||||
display_name = item.get("display_name") or model_key
|
||||
max_context_length = item.get("max_context_length")
|
||||
capabilities = item.get("capabilities") or {}
|
||||
|
||||
results.append(
|
||||
LMStudioFinalModelResponse(
|
||||
name=model_key,
|
||||
display_name=display_name,
|
||||
max_input_tokens=max_context_length,
|
||||
supports_image_input=capabilities.get("vision", False),
|
||||
supports_reasoning=capabilities.get("reasoning", False)
|
||||
or is_reasoning_model(model_key, display_name),
|
||||
)
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No compatible models found from LM Studio server.",
|
||||
)
|
||||
|
||||
sorted_results = sorted(results, key=lambda m: m.name.lower())
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
try:
|
||||
models_to_sync = [
|
||||
{
|
||||
"name": r.name,
|
||||
"display_name": r.display_name,
|
||||
"max_input_tokens": r.max_input_tokens,
|
||||
"supports_image_input": r.supports_image_input,
|
||||
}
|
||||
for r in sorted_results
|
||||
]
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=models_to_sync,
|
||||
)
|
||||
if new_count > 0:
|
||||
logger.info(
|
||||
f"Added {new_count} new LM Studio models to provider '{request.provider_name}'"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync LM Studio models to DB: {e}")
|
||||
|
||||
return sorted_results
|
||||
|
||||
@@ -371,6 +371,22 @@ class OpenRouterFinalModelResponse(BaseModel):
|
||||
supports_image_input: bool
|
||||
|
||||
|
||||
# LM Studio dynamic models fetch
|
||||
class LMStudioModelsRequest(BaseModel):
|
||||
api_base: str
|
||||
api_key: str | None = None
|
||||
api_key_changed: bool = False
|
||||
provider_name: str | None = None # Optional: to save models to existing provider
|
||||
|
||||
|
||||
class LMStudioFinalModelResponse(BaseModel):
|
||||
name: str # Model ID from LM Studio (e.g., "lmstudio-community/Meta-Llama-3-8B")
|
||||
display_name: str # Human-readable name
|
||||
max_input_tokens: int | None # From LM Studio API or None if unavailable
|
||||
supports_image_input: bool
|
||||
supports_reasoning: bool
|
||||
|
||||
|
||||
class DefaultModel(BaseModel):
|
||||
provider_id: int
|
||||
model_name: str
|
||||
|
||||
@@ -12,6 +12,7 @@ from typing import TypedDict
|
||||
|
||||
from onyx.llm.constants import BEDROCK_MODEL_NAME_MAPPINGS
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.constants import MODEL_PREFIX_TO_VENDOR
|
||||
from onyx.llm.constants import OLLAMA_MODEL_NAME_MAPPINGS
|
||||
from onyx.llm.constants import OLLAMA_MODEL_TO_VENDOR
|
||||
from onyx.llm.constants import PROVIDER_DISPLAY_NAMES
|
||||
@@ -23,6 +24,7 @@ DYNAMIC_LLM_PROVIDERS = frozenset(
|
||||
LlmProviderNames.OPENROUTER,
|
||||
LlmProviderNames.BEDROCK,
|
||||
LlmProviderNames.OLLAMA_CHAT,
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -348,4 +350,19 @@ def extract_vendor_from_model_name(model_name: str, provider: str) -> str | None
|
||||
# Fallback: capitalize the base name as vendor
|
||||
return base_name.split("-")[0].title()
|
||||
|
||||
elif provider == LlmProviderNames.LM_STUDIO:
|
||||
# LM Studio model IDs can be paths like "publisher/model-name"
|
||||
# or simple names. Use MODEL_PREFIX_TO_VENDOR for matching.
|
||||
|
||||
model_lower = model_name.lower()
|
||||
# Check for slash-separated vendor prefix first
|
||||
if "/" in model_lower:
|
||||
vendor_key = model_lower.split("/")[0]
|
||||
return PROVIDER_DISPLAY_NAMES.get(vendor_key, vendor_key.title())
|
||||
# Fallback to model prefix matching
|
||||
for prefix, vendor in MODEL_PREFIX_TO_VENDOR.items():
|
||||
if model_lower.startswith(prefix):
|
||||
return PROVIDER_DISPLAY_NAMES.get(vendor, vendor.title())
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
@@ -25,8 +27,6 @@ from onyx.db.search_settings import update_current_search_settings
|
||||
from onyx.db.search_settings import update_search_settings_status
|
||||
from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_processing.unstructured import delete_unstructured_api_key
|
||||
from onyx.file_processing.unstructured import get_unstructured_api_key
|
||||
from onyx.file_processing.unstructured import update_unstructured_api_key
|
||||
@@ -58,9 +58,9 @@ def set_new_search_settings(
|
||||
|
||||
# Disallow contextual RAG for cloud deployments.
|
||||
if MULTI_TENANT and search_settings_new.enable_contextual_rag:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Contextual RAG disabled in Onyx Cloud",
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Contextual RAG disabled in Onyx Cloud",
|
||||
)
|
||||
|
||||
# Validate cloud provider exists or create new LiteLLM provider.
|
||||
@@ -70,9 +70,9 @@ def set_new_search_settings(
|
||||
)
|
||||
|
||||
if cloud_provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"No embedding provider exists for cloud embedding type {search_settings_new.provider_type}",
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"No embedding provider exists for cloud embedding type {search_settings_new.provider_type}",
|
||||
)
|
||||
|
||||
validate_contextual_rag_model(
|
||||
@@ -188,7 +188,7 @@ def delete_search_settings_endpoint(
|
||||
search_settings_id=deletion_request.search_settings_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/get-current-search-settings")
|
||||
@@ -238,9 +238,9 @@ def update_saved_search_settings(
|
||||
) -> None:
|
||||
# Disallow contextual RAG for cloud deployments
|
||||
if MULTI_TENANT and search_settings.enable_contextual_rag:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Contextual RAG disabled in Onyx Cloud",
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Contextual RAG disabled in Onyx Cloud",
|
||||
)
|
||||
|
||||
validate_contextual_rag_model(
|
||||
@@ -294,7 +294,7 @@ def validate_contextual_rag_model(
|
||||
model_name=model_name,
|
||||
db_session=db_session,
|
||||
):
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, error_msg)
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=error_msg)
|
||||
|
||||
|
||||
def _validate_contextual_rag_model(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
@@ -20,8 +21,6 @@ from onyx.db.slack_channel_config import fetch_slack_channel_configs
|
||||
from onyx.db.slack_channel_config import insert_slack_channel_config
|
||||
from onyx.db.slack_channel_config import remove_slack_channel_config
|
||||
from onyx.db.slack_channel_config import update_slack_channel_config
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.onyxbot.slack.config import validate_channel_name
|
||||
from onyx.server.manage.models import SlackBot
|
||||
from onyx.server.manage.models import SlackBotCreationRequest
|
||||
@@ -64,7 +63,10 @@ def _form_channel_config(
|
||||
current_slack_bot_id=slack_channel_config_creation_request.slack_bot_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
if respond_tag_only and respond_member_group_list:
|
||||
raise ValueError(
|
||||
@@ -121,7 +123,10 @@ def create_slack_channel_config(
|
||||
)
|
||||
|
||||
if channel_config["channel_name"] is None:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Channel name is required")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Channel name is required",
|
||||
)
|
||||
|
||||
persona_id = None
|
||||
if slack_channel_config_creation_request.persona_id is not None:
|
||||
@@ -166,7 +171,10 @@ def patch_slack_channel_config(
|
||||
db_session=db_session, slack_channel_config_id=slack_channel_config_id
|
||||
)
|
||||
if existing_slack_channel_config is None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Slack channel config not found")
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Slack channel config not found",
|
||||
)
|
||||
|
||||
existing_persona_id = existing_slack_channel_config.persona_id
|
||||
if existing_persona_id is not None:
|
||||
|
||||
@@ -13,6 +13,7 @@ from email_validator import validate_email
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Body
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi import Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
@@ -72,8 +73,6 @@ from onyx.db.users import get_page_of_filtered_users
|
||||
from onyx.db.users import get_total_filtered_users_count
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.db.users import validate_user_role_update
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.redis.redis_pool import get_raw_redis_client
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
@@ -125,7 +124,7 @@ def set_user_role(
|
||||
email=user_role_update_request.user_email, db_session=db_session
|
||||
)
|
||||
if not user_to_update:
|
||||
raise OnyxError(OnyxErrorCode.USER_NOT_FOUND, "User not found")
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
current_role = user_to_update.role
|
||||
requested_role = user_role_update_request.new_role
|
||||
@@ -140,9 +139,9 @@ def set_user_role(
|
||||
)
|
||||
|
||||
if user_to_update.id == current_user.id:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"An admin cannot demote themselves from admin role!",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="An admin cannot demote themselves from admin role!",
|
||||
)
|
||||
|
||||
if requested_role == UserRole.CURATOR:
|
||||
@@ -387,9 +386,9 @@ def bulk_invite_users(
|
||||
new_invited_emails.append(email_info.normalized)
|
||||
|
||||
except (EmailUndeliverableError, EmailNotValidError) as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Invalid email address: {email} - {str(e)}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid email address: {email} - {str(e)}",
|
||||
)
|
||||
|
||||
# Count only new users (not already invited or existing) that need seats
|
||||
@@ -406,9 +405,9 @@ def bulk_invite_users(
|
||||
if MULTI_TENANT and is_tenant_on_trial_fn(tenant_id):
|
||||
current_invited = len(already_invited)
|
||||
if current_invited + len(emails_needing_seats) > NUM_FREE_TRIAL_USER_INVITES:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.UNAUTHORIZED,
|
||||
"You have hit your invite limit. "
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="You have hit your invite limit. "
|
||||
"Please upgrade for unlimited invites.",
|
||||
)
|
||||
|
||||
@@ -503,16 +502,14 @@ def deactivate_user_api(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
if current_user.email == user_email.user_email:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR, "You cannot deactivate yourself"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="You cannot deactivate yourself")
|
||||
|
||||
user_to_deactivate = get_user_by_email(
|
||||
email=user_email.user_email, db_session=db_session
|
||||
)
|
||||
|
||||
if not user_to_deactivate:
|
||||
raise OnyxError(OnyxErrorCode.USER_NOT_FOUND, "User not found")
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
if user_to_deactivate.is_active is False:
|
||||
logger.warning("{} is already deactivated".format(user_to_deactivate.email))
|
||||
@@ -537,15 +534,14 @@ async def delete_user(
|
||||
email=user_email.user_email, db_session=db_session
|
||||
)
|
||||
if not user_to_delete:
|
||||
raise OnyxError(OnyxErrorCode.USER_NOT_FOUND, "User not found")
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
if user_to_delete.is_active is True:
|
||||
logger.warning(
|
||||
"{} must be deactivated before deleting".format(user_to_delete.email)
|
||||
)
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"User must be deactivated before deleting",
|
||||
raise HTTPException(
|
||||
status_code=400, detail="User must be deactivated before deleting"
|
||||
)
|
||||
|
||||
# Detach the user from the current session
|
||||
@@ -569,7 +565,7 @@ async def delete_user(
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
logger.error(f"Error deleting user {user_to_delete.email}: {str(e)}")
|
||||
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, "Error deleting user")
|
||||
raise HTTPException(status_code=500, detail="Error deleting user")
|
||||
|
||||
|
||||
@router.patch("/manage/admin/activate-user", tags=PUBLIC_API_TAGS)
|
||||
@@ -582,7 +578,7 @@ def activate_user_api(
|
||||
email=user_email.user_email, db_session=db_session
|
||||
)
|
||||
if not user_to_activate:
|
||||
raise OnyxError(OnyxErrorCode.USER_NOT_FOUND, "User not found")
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
if user_to_activate.is_active is True:
|
||||
logger.warning("{} is already activated".format(user_to_activate.email))
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
from onyx.configs.constants import SLACK_USER_TOKEN_PREFIX
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
SLACK_API_URL = "https://slack.com/api/auth.test"
|
||||
SLACK_CONNECTIONS_OPEN_URL = "https://slack.com/api/apps.connections.open"
|
||||
@@ -13,15 +12,15 @@ def validate_bot_token(bot_token: str) -> bool:
|
||||
response = requests.post(SLACK_API_URL, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR, "Error communicating with Slack API."
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Error communicating with Slack API."
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
if not data.get("ok", False):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Invalid bot token: {data.get('error', 'Unknown error')}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid bot token: {data.get('error', 'Unknown error')}",
|
||||
)
|
||||
|
||||
return True
|
||||
@@ -32,15 +31,15 @@ def validate_app_token(app_token: str) -> bool:
|
||||
response = requests.post(SLACK_CONNECTIONS_OPEN_URL, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR, "Error communicating with Slack API."
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Error communicating with Slack API."
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
if not data.get("ok", False):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Invalid app token: {data.get('error', 'Unknown error')}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid app token: {data.get('error', 'Unknown error')}",
|
||||
)
|
||||
|
||||
return True
|
||||
@@ -55,16 +54,16 @@ def validate_user_token(user_token: str | None) -> None:
|
||||
Returns:
|
||||
None is valid and will return successfully.
|
||||
Raises:
|
||||
OnyxError: If the token is invalid or missing required fields
|
||||
HTTPException: If the token is invalid or missing required fields
|
||||
"""
|
||||
if not user_token:
|
||||
# user_token is optional, so None or empty string is valid
|
||||
return
|
||||
|
||||
if not user_token.startswith(SLACK_USER_TOKEN_PREFIX):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Invalid user token format. User OAuth tokens must start with '{SLACK_USER_TOKEN_PREFIX}'",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid user token format. User OAuth tokens must start with '{SLACK_USER_TOKEN_PREFIX}'",
|
||||
)
|
||||
|
||||
# Test the token with Slack API to ensure it's valid
|
||||
@@ -72,13 +71,13 @@ def validate_user_token(user_token: str | None) -> None:
|
||||
response = requests.post(SLACK_API_URL, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR, "Error communicating with Slack API."
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Error communicating with Slack API."
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
if not data.get("ok", False):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Invalid user token: {data.get('error', 'Unknown error')}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid user token: {data.get('error', 'Unknown error')}",
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Response
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -25,8 +26,6 @@ from onyx.db.web_search import set_active_web_content_provider
|
||||
from onyx.db.web_search import set_active_web_search_provider
|
||||
from onyx.db.web_search import upsert_web_content_provider
|
||||
from onyx.db.web_search import upsert_web_search_provider
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.manage.web_search.models import WebContentProviderTestRequest
|
||||
from onyx.server.manage.web_search.models import WebContentProviderUpsertRequest
|
||||
from onyx.server.manage.web_search.models import WebContentProviderView
|
||||
@@ -87,9 +86,9 @@ def upsert_search_provider_endpoint(
|
||||
and request.id is not None
|
||||
and existing_by_name.id != request.id
|
||||
):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"A search provider named '{request.name}' already exists.",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"A search provider named '{request.name}' already exists.",
|
||||
)
|
||||
|
||||
provider = upsert_web_search_provider(
|
||||
@@ -194,16 +193,16 @@ def test_search_provider(
|
||||
request.provider_type, db_session
|
||||
)
|
||||
if existing_provider is None or not existing_provider.api_key:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No stored API key found for this provider type.",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No stored API key found for this provider type.",
|
||||
)
|
||||
api_key = existing_provider.api_key.get_value(apply_mask=False)
|
||||
|
||||
if requires_key and not api_key:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"API key is required. Either provide api_key or set use_stored_key to true.",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="API key is required. Either provide api_key or set use_stored_key to true.",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -213,21 +212,20 @@ def test_search_provider(
|
||||
config=request.config or {},
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(exc)) from exc
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
if provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Unable to build provider configuration.",
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Unable to build provider configuration."
|
||||
)
|
||||
|
||||
# Run the API client's test_connection method to ensure the connection is valid.
|
||||
try:
|
||||
return provider.test_connection()
|
||||
except OnyxError:
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e)) from e
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
|
||||
|
||||
@admin_router.get("/content-providers", response_model=list[WebContentProviderView])
|
||||
@@ -261,9 +259,9 @@ def upsert_content_provider_endpoint(
|
||||
and request.id is not None
|
||||
and existing_by_name.id != request.id
|
||||
):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"A content provider named '{request.name}' already exists.",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"A content provider named '{request.name}' already exists.",
|
||||
)
|
||||
|
||||
provider = upsert_web_content_provider(
|
||||
@@ -381,9 +379,9 @@ def test_content_provider(
|
||||
request.provider_type, db_session
|
||||
)
|
||||
if existing_provider is None or not existing_provider.api_key:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No stored API key found for this provider type.",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No stored API key found for this provider type.",
|
||||
)
|
||||
if MULTI_TENANT:
|
||||
stored_base_url = (
|
||||
@@ -391,17 +389,17 @@ def test_content_provider(
|
||||
)
|
||||
request_base_url = request.config.base_url
|
||||
if request_base_url != stored_base_url:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Base URL cannot differ from stored provider when using stored API key",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Base URL cannot differ from stored provider when using stored API key",
|
||||
)
|
||||
|
||||
api_key = existing_provider.api_key.get_value(apply_mask=False)
|
||||
|
||||
if not api_key:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"API key is required. Either provide api_key or set use_stored_key to true.",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="API key is required. Either provide api_key or set use_stored_key to true.",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -411,12 +409,11 @@ def test_content_provider(
|
||||
config=request.config,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(exc)) from exc
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
if provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Unable to build provider configuration.",
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Unable to build provider configuration."
|
||||
)
|
||||
|
||||
# Actually test the API key by making a real content fetch call
|
||||
@@ -428,11 +425,11 @@ def test_content_provider(
|
||||
if not test_results or not any(
|
||||
result.scrape_successful for result in test_results
|
||||
):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"API key validation failed: content fetch returned no results.",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="API key validation failed: content fetch returned no results.",
|
||||
)
|
||||
except OnyxError:
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
@@ -441,13 +438,13 @@ def test_content_provider(
|
||||
or "key" in error_msg.lower()
|
||||
or "auth" in error_msg.lower()
|
||||
):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Invalid API key: {error_msg}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid API key: {error_msg}",
|
||||
) from e
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"API key validation failed: {error_msg}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"API key validation failed: {error_msg}",
|
||||
) from e
|
||||
|
||||
logger.info(
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.citation_utils import extract_citation_order_from_text
|
||||
@@ -20,7 +22,9 @@ from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolArgs
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolErrorInfo
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import FileReaderResult
|
||||
from onyx.server.query_and_chat.streaming_models import FileReaderStart
|
||||
@@ -180,24 +184,37 @@ def create_custom_tool_packets(
|
||||
tab_index: int = 0,
|
||||
data: dict | list | str | int | float | bool | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
error: CustomToolErrorInfo | None = None,
|
||||
tool_args: dict[str, Any] | None = None,
|
||||
tool_id: int | None = None,
|
||||
) -> list[Packet]:
|
||||
packets: list[Packet] = []
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index, tab_index=tab_index),
|
||||
obj=CustomToolStart(tool_name=tool_name),
|
||||
obj=CustomToolStart(tool_name=tool_name, tool_id=tool_id),
|
||||
)
|
||||
)
|
||||
|
||||
if tool_args:
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index, tab_index=tab_index),
|
||||
obj=CustomToolArgs(tool_name=tool_name, tool_args=tool_args),
|
||||
)
|
||||
)
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index, tab_index=tab_index),
|
||||
obj=CustomToolDelta(
|
||||
tool_name=tool_name,
|
||||
tool_id=tool_id,
|
||||
response_type=response_type,
|
||||
data=data,
|
||||
file_ids=file_ids,
|
||||
error=error,
|
||||
),
|
||||
),
|
||||
)
|
||||
@@ -657,13 +674,55 @@ def translate_assistant_message_to_packets(
|
||||
|
||||
else:
|
||||
# Custom tool or unknown tool
|
||||
# Try to parse as structured CustomToolCallSummary JSON
|
||||
custom_data: dict | list | str | int | float | bool | None = (
|
||||
tool_call.tool_call_response
|
||||
)
|
||||
custom_error: CustomToolErrorInfo | None = None
|
||||
custom_response_type = "text"
|
||||
|
||||
try:
|
||||
parsed = json.loads(tool_call.tool_call_response)
|
||||
if isinstance(parsed, dict) and "tool_name" in parsed:
|
||||
custom_data = parsed.get("tool_result")
|
||||
custom_response_type = parsed.get(
|
||||
"response_type", "text"
|
||||
)
|
||||
if parsed.get("error"):
|
||||
custom_error = CustomToolErrorInfo(
|
||||
**parsed["error"]
|
||||
)
|
||||
except (
|
||||
json.JSONDecodeError,
|
||||
KeyError,
|
||||
TypeError,
|
||||
ValidationError,
|
||||
):
|
||||
pass
|
||||
|
||||
custom_file_ids: list[str] | None = None
|
||||
if custom_response_type in ("image", "csv") and isinstance(
|
||||
custom_data, dict
|
||||
):
|
||||
custom_file_ids = custom_data.get("file_ids")
|
||||
custom_data = None
|
||||
|
||||
custom_args = {
|
||||
k: v
|
||||
for k, v in (tool_call.tool_call_arguments or {}).items()
|
||||
if k != "requestBody"
|
||||
}
|
||||
turn_tool_packets.extend(
|
||||
create_custom_tool_packets(
|
||||
tool_name=tool.display_name or tool.name,
|
||||
response_type="text",
|
||||
response_type=custom_response_type,
|
||||
turn_index=turn_num,
|
||||
tab_index=tool_call.tab_index,
|
||||
data=tool_call.tool_call_response,
|
||||
data=custom_data,
|
||||
file_ids=custom_file_ids,
|
||||
error=custom_error,
|
||||
tool_args=custom_args if custom_args else None,
|
||||
tool_id=tool_call.tool_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ class StreamingType(Enum):
|
||||
PYTHON_TOOL_START = "python_tool_start"
|
||||
PYTHON_TOOL_DELTA = "python_tool_delta"
|
||||
CUSTOM_TOOL_START = "custom_tool_start"
|
||||
CUSTOM_TOOL_ARGS = "custom_tool_args"
|
||||
CUSTOM_TOOL_DELTA = "custom_tool_delta"
|
||||
FILE_READER_START = "file_reader_start"
|
||||
FILE_READER_RESULT = "file_reader_result"
|
||||
@@ -41,6 +42,7 @@ class StreamingType(Enum):
|
||||
REASONING_DONE = "reasoning_done"
|
||||
CITATION_INFO = "citation_info"
|
||||
TOOL_CALL_DEBUG = "tool_call_debug"
|
||||
TOOL_CALL_ARGUMENT_DELTA = "tool_call_argument_delta"
|
||||
|
||||
MEMORY_TOOL_START = "memory_tool_start"
|
||||
MEMORY_TOOL_DELTA = "memory_tool_delta"
|
||||
@@ -245,6 +247,20 @@ class CustomToolStart(BaseObj):
|
||||
type: Literal["custom_tool_start"] = StreamingType.CUSTOM_TOOL_START.value
|
||||
|
||||
tool_name: str
|
||||
tool_id: int | None = None
|
||||
|
||||
|
||||
class CustomToolArgs(BaseObj):
|
||||
type: Literal["custom_tool_args"] = StreamingType.CUSTOM_TOOL_ARGS.value
|
||||
|
||||
tool_name: str
|
||||
tool_args: dict[str, Any]
|
||||
|
||||
|
||||
class CustomToolErrorInfo(BaseModel):
|
||||
is_auth_error: bool = False
|
||||
status_code: int
|
||||
message: str
|
||||
|
||||
|
||||
# The allowed streamed packets for a custom tool
|
||||
@@ -252,11 +268,22 @@ class CustomToolDelta(BaseObj):
|
||||
type: Literal["custom_tool_delta"] = StreamingType.CUSTOM_TOOL_DELTA.value
|
||||
|
||||
tool_name: str
|
||||
tool_id: int | None = None
|
||||
response_type: str
|
||||
# For non-file responses
|
||||
data: dict | list | str | int | float | bool | None = None
|
||||
# For file-based responses like image/csv
|
||||
file_ids: list[str] | None = None
|
||||
error: CustomToolErrorInfo | None = None
|
||||
|
||||
|
||||
class ToolCallArgumentDelta(BaseObj):
|
||||
type: Literal["tool_call_argument_delta"] = (
|
||||
StreamingType.TOOL_CALL_ARGUMENT_DELTA.value
|
||||
)
|
||||
|
||||
tool_type: str
|
||||
argument_deltas: dict[str, Any]
|
||||
|
||||
|
||||
################################################
|
||||
@@ -366,6 +393,7 @@ PacketObj = Union[
|
||||
PythonToolStart,
|
||||
PythonToolDelta,
|
||||
CustomToolStart,
|
||||
CustomToolArgs,
|
||||
CustomToolDelta,
|
||||
FileReaderStart,
|
||||
FileReaderResult,
|
||||
@@ -379,6 +407,7 @@ PacketObj = Union[
|
||||
# Citation Packets
|
||||
CitationInfo,
|
||||
ToolCallDebug,
|
||||
ToolCallArgumentDelta,
|
||||
# Deep Research Packets
|
||||
DeepResearchPlanStart,
|
||||
DeepResearchPlanDelta,
|
||||
|
||||
@@ -8,8 +8,6 @@ from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import GeneratedImage
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationFinal
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
|
||||
@@ -165,39 +163,6 @@ def create_image_generation_packets(
|
||||
return packets
|
||||
|
||||
|
||||
def create_custom_tool_packets(
|
||||
tool_name: str,
|
||||
response_type: str,
|
||||
turn_index: int,
|
||||
data: dict | list | str | int | float | bool | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
) -> list[Packet]:
|
||||
packets: list[Packet] = []
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
obj=CustomToolStart(tool_name=tool_name),
|
||||
)
|
||||
)
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
obj=CustomToolDelta(
|
||||
tool_name=tool_name,
|
||||
response_type=response_type,
|
||||
data=data,
|
||||
file_ids=file_ids,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
|
||||
|
||||
return packets
|
||||
|
||||
|
||||
def create_fetch_packets(
|
||||
fetch_docs: list[SavedSearchDoc],
|
||||
urls: list[str],
|
||||
|
||||
@@ -78,6 +78,7 @@ class Settings(BaseModel):
|
||||
|
||||
# User Knowledge settings
|
||||
user_knowledge_enabled: bool | None = True
|
||||
user_file_max_upload_size_mb: int | None = None
|
||||
|
||||
# Connector settings
|
||||
show_extra_connectors: bool | None = True
|
||||
|
||||
@@ -3,6 +3,7 @@ from onyx.configs.app_configs import DISABLE_USER_KNOWLEDGE
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.app_configs import SHOW_EXTRA_CONNECTORS
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.configs.constants import KV_SETTINGS_KEY
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
@@ -50,6 +51,7 @@ def load_settings() -> Settings:
|
||||
if DISABLE_USER_KNOWLEDGE:
|
||||
settings.user_knowledge_enabled = False
|
||||
|
||||
settings.user_file_max_upload_size_mb = USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
settings.show_extra_connectors = SHOW_EXTRA_CONNECTORS
|
||||
settings.opensearch_indexing_enabled = ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
return settings
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user