mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-10 18:22:40 +00:00
Compare commits
166 Commits
refactor/o
...
fix/search
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6ac62b753f | ||
|
|
dae0868fed | ||
|
|
d1966e58e3 | ||
|
|
e05e286467 | ||
|
|
1d0687a616 | ||
|
|
5cdeb84164 | ||
|
|
5b5100a07a | ||
|
|
77f58fbad5 | ||
|
|
cf74afc65e | ||
|
|
a887bc616c | ||
|
|
fef1fd093e | ||
|
|
8d085a4ccf | ||
|
|
28310b9138 | ||
|
|
f71fab580c | ||
|
|
89593b353f | ||
|
|
91e24ae63a | ||
|
|
d2b37724d1 | ||
|
|
87f0849330 | ||
|
|
2ec7526772 | ||
|
|
bbd68e2795 | ||
|
|
e74c36001a | ||
|
|
fe593a15da | ||
|
|
27df690a8d | ||
|
|
edbe569edd | ||
|
|
5118193d16 | ||
|
|
63d3efd380 | ||
|
|
ec978d9a3f | ||
|
|
d4d98a6cd0 | ||
|
|
dc40e86dac | ||
|
|
e495f7a13e | ||
|
|
4761e4b132 | ||
|
|
6b5ab54b85 | ||
|
|
959cf444f8 | ||
|
|
2ebccea6d6 | ||
|
|
5fe7a474db | ||
|
|
9d7dc3da21 | ||
|
|
2899be4c5e | ||
|
|
64ee7fc23f | ||
|
|
e07764285d | ||
|
|
cc2e6ffa8a | ||
|
|
d3ee5c9b59 | ||
|
|
dfa0efc093 | ||
|
|
9aad4077f1 | ||
|
|
29d9ebf7b3 | ||
|
|
f1df36e306 | ||
|
|
1611604269 | ||
|
|
c2a71091dc | ||
|
|
cc008699e5 | ||
|
|
48802618db | ||
|
|
6917953b86 | ||
|
|
e7cf027f8a | ||
|
|
41fb1480bb | ||
|
|
bdc2bfdcee | ||
|
|
8816d52b27 | ||
|
|
6590f1d7ba | ||
|
|
c527f75557 | ||
|
|
472d1788a7 | ||
|
|
99e95f8205 | ||
|
|
e618bf8385 | ||
|
|
f4dcd130ba | ||
|
|
910718deaa | ||
|
|
1a7ca93b93 | ||
|
|
a615a920cb | ||
|
|
29d8b310b5 | ||
|
|
d1409ccafa | ||
|
|
e41bad9103 | ||
|
|
661dc831dc | ||
|
|
19016dd35a | ||
|
|
127b2dcc80 | ||
|
|
b015a37cea | ||
|
|
b45277a8b0 | ||
|
|
893e8da79a | ||
|
|
a51f0d7cb2 | ||
|
|
c826d0469e | ||
|
|
0f6ae6f69c | ||
|
|
d0836e2603 | ||
|
|
bda03bafca | ||
|
|
376adff94a | ||
|
|
d2d4b89286 | ||
|
|
dde7a18bb7 | ||
|
|
3f004cf02f | ||
|
|
ae893079c3 | ||
|
|
189c07a913 | ||
|
|
2b82743bf5 | ||
|
|
ba2a5a60e1 | ||
|
|
5888f9d69f | ||
|
|
23b3a0a6ae | ||
|
|
eced88fa7a | ||
|
|
f59aaa902d | ||
|
|
57349bdbd1 | ||
|
|
192639a801 | ||
|
|
c10ffbb464 | ||
|
|
091f41fd1f | ||
|
|
45d77be4eb | ||
|
|
413fa85134 | ||
|
|
108cde4f55 | ||
|
|
f88ce32bd4 | ||
|
|
911f3439ea | ||
|
|
b02590d2b2 | ||
|
|
2d75b4b1f8 | ||
|
|
7e3f7d01c2 | ||
|
|
9d6ce26ea3 | ||
|
|
41713d42a2 | ||
|
|
8afc283410 | ||
|
|
b5c873077e | ||
|
|
20a4dd32eb | ||
|
|
fde0d44bc1 | ||
|
|
8fd91b6e83 | ||
|
|
8247fdd45b | ||
|
|
8c5859ba4d | ||
|
|
62ef6f59bb | ||
|
|
7eabfa125c | ||
|
|
ee18114739 | ||
|
|
f7630f5648 | ||
|
|
e0d91b9ea7 | ||
|
|
2c0a4a60a5 | ||
|
|
3a7d4dad56 | ||
|
|
c5c236d098 | ||
|
|
b18baff4d0 | ||
|
|
eb3e15c195 | ||
|
|
47d9a9e1ac | ||
|
|
aca466b35d | ||
|
|
5176fd7386 | ||
|
|
92538084e9 | ||
|
|
2d996e05a4 | ||
|
|
b2956f795b | ||
|
|
b272085543 | ||
|
|
8193aa4fd0 | ||
|
|
52db41a00b | ||
|
|
f1cf3c4589 | ||
|
|
5322aeed90 | ||
|
|
5da8870fd2 | ||
|
|
57d3ab3b40 | ||
|
|
649c7fe8b9 | ||
|
|
e5e2bc6149 | ||
|
|
b148065e1d | ||
|
|
367808951c | ||
|
|
0f74da3302 | ||
|
|
96f7cbd25a | ||
|
|
c627cea17d | ||
|
|
a8cdc3965d | ||
|
|
60891b2f44 | ||
|
|
d2f35e1fae | ||
|
|
7a7350f387 | ||
|
|
8ef504acd5 | ||
|
|
0dbabfe445 | ||
|
|
50575d0f6b | ||
|
|
9862fbd4a6 | ||
|
|
003d94546a | ||
|
|
01d3473974 | ||
|
|
19c7809a43 | ||
|
|
98e6346152 | ||
|
|
c63fdf1c13 | ||
|
|
49b509a0a7 | ||
|
|
2b1f1fe311 | ||
|
|
3e67ea9df7 | ||
|
|
98e3602dd6 | ||
|
|
4fded5b0a1 | ||
|
|
328c305d26 | ||
|
|
f902727215 | ||
|
|
69c8aa08b3 | ||
|
|
c98aa486e4 | ||
|
|
03553114c5 | ||
|
|
6532c94230 | ||
|
|
1b32a7d94e | ||
|
|
5fd0fe192b |
186
.cursor/skills/onyx-cli/SKILL.md
Normal file
186
.cursor/skills/onyx-cli/SKILL.md
Normal file
@@ -0,0 +1,186 @@
|
||||
---
|
||||
name: onyx-cli
|
||||
description: Query the Onyx knowledge base using the onyx-cli command. Use when the user wants to search company documents, ask questions about internal knowledge, query connected data sources, or look up information stored in Onyx.
|
||||
---
|
||||
|
||||
# Onyx CLI — Agent Tool
|
||||
|
||||
Onyx is an enterprise search and Gen-AI platform that connects to company documents, apps, and people. The `onyx-cli` CLI provides non-interactive commands to query the Onyx knowledge base and list available agents.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### 1. Check if installed
|
||||
|
||||
```bash
|
||||
which onyx-cli
|
||||
```
|
||||
|
||||
### 2. Install (if needed)
|
||||
|
||||
**Primary — pip:**
|
||||
|
||||
```bash
|
||||
pip install onyx-cli
|
||||
```
|
||||
|
||||
**From source (Go):**
|
||||
|
||||
```bash
|
||||
cd cli && go build -o onyx-cli . && sudo mv onyx-cli /usr/local/bin/
|
||||
```
|
||||
|
||||
### 3. Check if configured
|
||||
|
||||
```bash
|
||||
onyx-cli validate-config
|
||||
```
|
||||
|
||||
This checks the config file exists, API key is present, and tests the server connection via `/api/me`. Exit code 0 on success, non-zero with a descriptive error on failure.
|
||||
|
||||
If unconfigured, you have two options:
|
||||
|
||||
**Option A — Interactive setup (requires user input):**
|
||||
|
||||
```bash
|
||||
onyx-cli configure
|
||||
```
|
||||
|
||||
This prompts for the Onyx server URL and API key, tests the connection, and saves config.
|
||||
|
||||
**Option B — Environment variables (non-interactive, preferred for agents):**
|
||||
|
||||
```bash
|
||||
export ONYX_SERVER_URL="https://your-onyx-server.com" # default: https://cloud.onyx.app
|
||||
export ONYX_API_KEY="your-api-key"
|
||||
```
|
||||
|
||||
Environment variables override the config file. If these are set, no config file is needed.
|
||||
|
||||
| Variable | Required | Description |
|
||||
|----------|----------|-------------|
|
||||
| `ONYX_SERVER_URL` | No | Onyx server base URL (default: `https://cloud.onyx.app`) |
|
||||
| `ONYX_API_KEY` | Yes | API key for authentication |
|
||||
| `ONYX_PERSONA_ID` | No | Default agent/persona ID |
|
||||
|
||||
If neither the config file nor environment variables are set, tell the user that `onyx-cli` needs to be configured and ask them to either:
|
||||
- Run `onyx-cli configure` interactively, or
|
||||
- Set `ONYX_SERVER_URL` and `ONYX_API_KEY` environment variables
|
||||
|
||||
## Commands
|
||||
|
||||
### Validate configuration
|
||||
|
||||
```bash
|
||||
onyx-cli validate-config
|
||||
```
|
||||
|
||||
Checks config file exists, API key is present, and tests the server connection. Use this before `ask` or `agents` to confirm the CLI is properly set up.
|
||||
|
||||
### List available agents
|
||||
|
||||
```bash
|
||||
onyx-cli agents
|
||||
```
|
||||
|
||||
Prints a table of agent IDs, names, and descriptions. Use `--json` for structured output:
|
||||
|
||||
```bash
|
||||
onyx-cli agents --json
|
||||
```
|
||||
|
||||
Use agent IDs with `ask --agent-id` to query a specific agent.
|
||||
|
||||
### Basic query (plain text output)
|
||||
|
||||
```bash
|
||||
onyx-cli ask "What is our company's PTO policy?"
|
||||
```
|
||||
|
||||
Streams the answer as plain text to stdout. Exit code 0 on success, non-zero on error.
|
||||
|
||||
### JSON output (structured events)
|
||||
|
||||
```bash
|
||||
onyx-cli ask --json "What authentication methods do we support?"
|
||||
```
|
||||
|
||||
Outputs JSON-encoded parsed stream events (one object per line). Key event objects include message deltas, stop, errors, search-start, and citation payloads.
|
||||
|
||||
Each line is a JSON object with this envelope:
|
||||
|
||||
```json
|
||||
{"type": "<event_type>", "event": { ... }}
|
||||
```
|
||||
|
||||
| Event Type | Description |
|
||||
|------------|-------------|
|
||||
| `message_delta` | Content token — concatenate all `content` fields for the full answer |
|
||||
| `stop` | Stream complete |
|
||||
| `error` | Error with `error` message field |
|
||||
| `search_tool_start` | Onyx started searching documents |
|
||||
| `citation_info` | Source citation — see shape below |
|
||||
|
||||
`citation_info` event shape:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "citation_info",
|
||||
"event": {
|
||||
"citation_number": 1,
|
||||
"document_id": "abc123def456",
|
||||
"placement": {"turn_index": 0, "tab_index": 0, "sub_turn_index": null}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`placement` is metadata about where in the conversation the citation appeared and can be ignored for most use cases.
|
||||
|
||||
### Specify an agent
|
||||
|
||||
```bash
|
||||
onyx-cli ask --agent-id 5 "Summarize our Q4 roadmap"
|
||||
```
|
||||
|
||||
Uses a specific Onyx agent/persona instead of the default.
|
||||
|
||||
### All flags
|
||||
|
||||
| Flag | Type | Description |
|
||||
|------|------|-------------|
|
||||
| `--agent-id` | int | Agent ID to use (overrides default) |
|
||||
| `--json` | bool | Output raw NDJSON events instead of plain text |
|
||||
|
||||
## Statelessness
|
||||
|
||||
Each `onyx-cli ask` call creates an independent chat session. There is no built-in way to chain context across multiple `ask` invocations — every call starts fresh. If you need multi-turn conversation with memory, use the interactive TUI (`onyx-cli` or `onyx-cli chat`) instead.
|
||||
|
||||
## When to Use
|
||||
|
||||
Use `onyx-cli ask` when:
|
||||
|
||||
- The user asks about company-specific information (policies, docs, processes)
|
||||
- You need to search internal knowledge bases or connected data sources
|
||||
- The user references Onyx, asks you to "search Onyx", or wants to query their documents
|
||||
- You need context from company wikis, Confluence, Google Drive, Slack, or other connected sources
|
||||
|
||||
Do NOT use when:
|
||||
|
||||
- The question is about general programming knowledge (use your own knowledge)
|
||||
- The user is asking about code in the current repository (use grep/read tools)
|
||||
- The user hasn't mentioned Onyx and the question doesn't require internal company data
|
||||
|
||||
## Examples
|
||||
|
||||
```bash
|
||||
# Simple question
|
||||
onyx-cli ask "What are the steps to deploy to production?"
|
||||
|
||||
# Get structured output for parsing
|
||||
onyx-cli ask --json "List all active API integrations"
|
||||
|
||||
# Use a specialized agent
|
||||
onyx-cli ask --agent-id 3 "What were the action items from last week's standup?"
|
||||
|
||||
# Pipe the answer into another command
|
||||
onyx-cli ask "What is the database schema for users?" | head -20
|
||||
```
|
||||
98
.github/workflows/deployment.yml
vendored
98
.github/workflows/deployment.yml
vendored
@@ -151,7 +151,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
version: "0.9.9"
|
||||
# NOTE: This isn't caching much and zizmor suggests this could be poisoned, so disable.
|
||||
@@ -182,9 +182,53 @@ jobs:
|
||||
title: "🚨 Version Tag Check Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
|
||||
build-desktop:
|
||||
# Create GitHub release first, before desktop builds start.
|
||||
# This ensures all desktop matrix jobs upload to the same release instead of
|
||||
# racing to create duplicate releases.
|
||||
create-release:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-desktop == 'true'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 10
|
||||
permissions:
|
||||
contents: write
|
||||
outputs:
|
||||
release-id: ${{ steps.create-release.outputs.id }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Determine release tag
|
||||
id: release-tag
|
||||
env:
|
||||
IS_TEST_RUN: ${{ needs.determine-builds.outputs.is-test-run }}
|
||||
SHORT_SHA: ${{ needs.determine-builds.outputs.short-sha }}
|
||||
run: |
|
||||
if [ "${IS_TEST_RUN}" == "true" ]; then
|
||||
echo "tag=v0.0.0-dev+${SHORT_SHA}" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "tag=${GITHUB_REF_NAME}" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: Create GitHub Release
|
||||
id: create-release
|
||||
uses: softprops/action-gh-release@da05d552573ad5aba039eaac05058a918a7bf631 # ratchet:softprops/action-gh-release@v2
|
||||
with:
|
||||
tag_name: ${{ steps.release-tag.outputs.tag }}
|
||||
name: ${{ steps.release-tag.outputs.tag }}
|
||||
body: "See the assets to download this version and install."
|
||||
draft: true
|
||||
prerelease: false
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
build-desktop:
|
||||
needs:
|
||||
- determine-builds
|
||||
- create-release
|
||||
if: needs.determine-builds.outputs.build-desktop == 'true'
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: write
|
||||
@@ -208,12 +252,12 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6.0.2
|
||||
with:
|
||||
# NOTE: persist-credentials is needed for tauri-action to create GitHub releases.
|
||||
# NOTE: persist-credentials is needed for tauri-action to upload assets to GitHub releases.
|
||||
persist-credentials: true # zizmor: ignore[artipacked]
|
||||
|
||||
- name: Configure AWS credentials
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -353,11 +397,9 @@ jobs:
|
||||
APPLE_SIGNING_IDENTITY: ${{ env.CERT_ID }}
|
||||
APPLE_TEAM_ID: ${{ env.APPLE_TEAM_ID }}
|
||||
with:
|
||||
tagName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseBody: "See the assets to download this version and install."
|
||||
releaseDraft: true
|
||||
prerelease: false
|
||||
# Use the release created by the create-release job to avoid race conditions
|
||||
# when multiple matrix jobs try to create/update the same release simultaneously
|
||||
releaseId: ${{ needs.create-release.outputs.release-id }}
|
||||
assetNamePattern: "[name]_[arch][ext]"
|
||||
args: ${{ matrix.args }}
|
||||
|
||||
@@ -384,7 +426,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -458,7 +500,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -527,7 +569,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -597,7 +639,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -679,7 +721,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -756,7 +798,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -823,7 +865,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -896,7 +938,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -964,7 +1006,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1034,7 +1076,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1107,7 +1149,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1176,7 +1218,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1246,7 +1288,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1326,7 +1368,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1400,7 +1442,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1465,7 +1507,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1520,7 +1562,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1580,7 +1622,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1637,7 +1679,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
@@ -15,6 +15,8 @@ permissions:
|
||||
jobs:
|
||||
provider-chat-test:
|
||||
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
|
||||
secrets:
|
||||
AWS_OIDC_ROLE_ARN: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
|
||||
@@ -6,11 +6,13 @@ on:
|
||||
- main
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
cherry-pick-to-latest-release:
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
outputs:
|
||||
should_cherrypick: ${{ steps.gate.outputs.should_cherrypick }}
|
||||
pr_number: ${{ steps.gate.outputs.pr_number }}
|
||||
@@ -68,7 +70,7 @@ jobs:
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
2
.github/workflows/pr-desktop-build.yml
vendored
2
.github/workflows/pr-desktop-build.yml
vendored
@@ -57,7 +57,7 @@ jobs:
|
||||
cache-dependency-path: ./desktop/package-lock.json
|
||||
|
||||
- name: Setup Rust
|
||||
uses: dtolnay/rust-toolchain@4be9e76fd7c4901c61fb841f559994984270fce7
|
||||
uses: dtolnay/rust-toolchain@efa25f7f19611383d5b0ccf2d1c8914531636bf9
|
||||
with:
|
||||
toolchain: stable
|
||||
targets: ${{ matrix.target }}
|
||||
|
||||
56
.github/workflows/pr-golang-tests.yml
vendored
Normal file
56
.github/workflows/pr-golang-tests.yml
vendored
Normal file
@@ -0,0 +1,56 @@
|
||||
name: Golang Tests
|
||||
concurrency:
|
||||
group: Golang-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- "release/**"
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
permissions: {}
|
||||
|
||||
env:
|
||||
GO_VERSION: "1.26"
|
||||
|
||||
jobs:
|
||||
detect-modules:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
outputs:
|
||||
modules: ${{ steps.set-modules.outputs.modules }}
|
||||
steps:
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8
|
||||
with:
|
||||
persist-credentials: false
|
||||
- id: set-modules
|
||||
run: echo "modules=$(find . -name 'go.mod' -exec dirname {} \; | jq -Rc '[.,inputs]')" >> "$GITHUB_OUTPUT"
|
||||
|
||||
golang:
|
||||
needs: detect-modules
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
strategy:
|
||||
matrix:
|
||||
modules: ${{ fromJSON(needs.detect-modules.outputs.modules) }}
|
||||
steps:
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # zizmor: ignore[cache-poisoning]
|
||||
with:
|
||||
go-version: ${{ env.GO_VERSION }}
|
||||
cache-dependency-path: "**/go.sum"
|
||||
|
||||
- run: go mod tidy
|
||||
working-directory: ${{ matrix.modules }}
|
||||
- run: git diff --exit-code go.mod go.sum
|
||||
working-directory: ${{ matrix.modules }}
|
||||
|
||||
- run: go test ./...
|
||||
working-directory: ${{ matrix.modules }}
|
||||
2
.github/workflows/pr-helm-chart-testing.yml
vendored
2
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -71,7 +71,7 @@ jobs:
|
||||
|
||||
- name: Create kind cluster
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@92086f6be054225fa813e0a4b13787fc9088faab # ratchet:helm/kind-action@v1.13.0
|
||||
uses: helm/kind-action@ef37e7f390d99f746eb8b610417061a60e82a6cc # ratchet:helm/kind-action@v1.14.0
|
||||
|
||||
- name: Pre-install cluster status check
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
|
||||
56
.github/workflows/pr-integration-tests.yml
vendored
56
.github/workflows/pr-integration-tests.yml
vendored
@@ -316,6 +316,7 @@ jobs:
|
||||
# Base config shared by both editions
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
OPENSEARCH_FOR_ONYX_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
@@ -335,7 +336,6 @@ jobs:
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
|
||||
USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
|
||||
EOF
|
||||
fi
|
||||
|
||||
@@ -419,6 +419,7 @@ jobs:
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e ENABLE_OPENSEARCH_INDEXING_FOR_ONYX=false \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
@@ -471,13 +472,13 @@ jobs:
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
|
||||
no-vectordb-tests:
|
||||
onyx-lite-tests:
|
||||
needs: [build-backend-image, build-integration-image]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=4cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-no-vectordb-tests",
|
||||
"run-id=${{ github.run_id }}-onyx-lite-tests",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
@@ -495,13 +496,12 @@ jobs:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Create .env file for no-vectordb Docker Compose
|
||||
- name: Create .env file for Onyx Lite Docker Compose
|
||||
env:
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
@@ -509,28 +509,23 @@ jobs:
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
DISABLE_VECTOR_DB=true
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
USE_LIGHTWEIGHT_BACKGROUND_WORKER=true
|
||||
EOF
|
||||
|
||||
# Start only the services needed for no-vectordb mode (no Vespa, no model servers)
|
||||
- name: Start Docker containers (no-vectordb)
|
||||
# Start only the services needed for Onyx Lite (Postgres + API server)
|
||||
- name: Start Docker containers (onyx-lite)
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml up \
|
||||
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml up \
|
||||
relational_db \
|
||||
cache \
|
||||
minio \
|
||||
api_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker_no_vectordb
|
||||
id: start_docker_onyx_lite
|
||||
|
||||
- name: Wait for services to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script (no-vectordb)..."
|
||||
echo "Starting wait-for-service script (onyx-lite)..."
|
||||
start_time=$(date +%s)
|
||||
timeout=300
|
||||
while true; do
|
||||
@@ -552,14 +547,14 @@ jobs:
|
||||
sleep 5
|
||||
done
|
||||
|
||||
- name: Run No-VectorDB Integration Tests
|
||||
- name: Run Onyx Lite Integration Tests
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
timeout_minutes: 20
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
echo "Running no-vectordb integration tests..."
|
||||
echo "Running onyx-lite integration tests..."
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
@@ -570,39 +565,38 @@ jobs:
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
|
||||
/app/tests/integration/tests/no_vectordb
|
||||
|
||||
- name: Dump API server logs (no-vectordb)
|
||||
- name: Dump API server logs (onyx-lite)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml \
|
||||
logs --no-color api_server > $GITHUB_WORKSPACE/api_server_no_vectordb.log || true
|
||||
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml \
|
||||
logs --no-color api_server > $GITHUB_WORKSPACE/api_server_onyx_lite.log || true
|
||||
|
||||
- name: Dump all-container logs (no-vectordb)
|
||||
- name: Dump all-container logs (onyx-lite)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml \
|
||||
logs --no-color > $GITHUB_WORKSPACE/docker-compose-no-vectordb.log || true
|
||||
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml \
|
||||
logs --no-color > $GITHUB_WORKSPACE/docker-compose-onyx-lite.log || true
|
||||
|
||||
- name: Upload logs (no-vectordb)
|
||||
- name: Upload logs (onyx-lite)
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-no-vectordb
|
||||
path: ${{ github.workspace }}/docker-compose-no-vectordb.log
|
||||
name: docker-all-logs-onyx-lite
|
||||
path: ${{ github.workspace }}/docker-compose-onyx-lite.log
|
||||
|
||||
- name: Stop Docker containers (no-vectordb)
|
||||
- name: Stop Docker containers (onyx-lite)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml down -v
|
||||
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml down -v
|
||||
|
||||
multitenant-tests:
|
||||
needs:
|
||||
@@ -645,6 +639,7 @@ jobs:
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
|
||||
DEV_MODE=true \
|
||||
OPENSEARCH_FOR_ONYX_ENABLED=false \
|
||||
docker compose -f docker-compose.multitenant-dev.yml up \
|
||||
relational_db \
|
||||
index \
|
||||
@@ -699,6 +694,7 @@ jobs:
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e ENABLE_OPENSEARCH_INDEXING_FOR_ONYX=false \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
@@ -744,7 +740,7 @@ jobs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
needs: [integration-tests, no-vectordb-tests, multitenant-tests]
|
||||
needs: [integration-tests, onyx-lite-tests, multitenant-tests]
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- name: Check job status
|
||||
|
||||
2
.github/workflows/pr-jest-tests.yml
vendored
2
.github/workflows/pr-jest-tests.yml
vendored
@@ -31,7 +31,7 @@ jobs:
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
cache: "npm" # zizmor: ignore[cache-poisoning] test-only workflow; no deploy artifacts
|
||||
cache-dependency-path: ./web/package-lock.json
|
||||
|
||||
- name: Install node dependencies
|
||||
|
||||
114
.github/workflows/pr-playwright-tests.yml
vendored
114
.github/workflows/pr-playwright-tests.yml
vendored
@@ -268,10 +268,11 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
cache: "npm" # zizmor: ignore[cache-poisoning]
|
||||
cache-dependency-path: ./web/package-lock.json
|
||||
|
||||
- name: Install node dependencies
|
||||
@@ -279,6 +280,7 @@ jobs:
|
||||
run: npm ci
|
||||
|
||||
- name: Cache playwright cache
|
||||
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: ~/.cache/ms-playwright
|
||||
@@ -459,14 +461,14 @@ jobs:
|
||||
# --- Visual Regression Diff ---
|
||||
- name: Configure AWS credentials
|
||||
if: always()
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: always()
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
@@ -590,6 +592,108 @@ jobs:
|
||||
name: docker-logs-${{ matrix.project }}-${{ github.run_id }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
playwright-tests-lite:
|
||||
needs: [build-web-image, build-backend-image]
|
||||
name: Playwright Tests (lite)
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- "run-id=${{ github.run_id }}-playwright-tests-lite"
|
||||
- "extras=ecr-cache"
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm" # zizmor: ignore[cache-poisoning]
|
||||
cache-dependency-path: ./web/package-lock.json
|
||||
|
||||
- name: Install node dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
- name: Cache playwright cache
|
||||
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: ~/.cache/ms-playwright
|
||||
key: ${{ runner.os }}-playwright-npm-${{ hashFiles('web/package-lock.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-playwright-npm-
|
||||
|
||||
- name: Install playwright browsers
|
||||
working-directory: ./web
|
||||
run: npx playwright install --with-deps
|
||||
|
||||
- name: Create .env file for Docker Compose
|
||||
env:
|
||||
OPENAI_API_KEY_VALUE: ${{ env.OPENAI_API_KEY }}
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
GEN_AI_API_KEY=${OPENAI_API_KEY_VALUE}
|
||||
MOCK_LLM_RESPONSE=true
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:playwright-test-backend-${RUN_ID}
|
||||
ONYX_WEB_SERVER_IMAGE=${ECR_CACHE}:playwright-test-web-${RUN_ID}
|
||||
EOF
|
||||
|
||||
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Start Docker containers (lite)
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml up -d
|
||||
id: start_docker
|
||||
|
||||
- name: Run Playwright tests (lite)
|
||||
working-directory: ./web
|
||||
run: npx playwright test --project lite
|
||||
|
||||
- uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
if: always()
|
||||
with:
|
||||
name: playwright-test-results-lite-${{ github.run_id }}
|
||||
path: ./web/output/playwright/
|
||||
retention-days: 30
|
||||
|
||||
- name: Save Docker logs
|
||||
if: success() || failure()
|
||||
env:
|
||||
WORKSPACE: ${{ github.workspace }}
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose logs > docker-compose.log
|
||||
mv docker-compose.log ${WORKSPACE}/docker-compose.log
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-logs-lite-${{ github.run_id }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
# Post a single combined visual regression comment after all matrix jobs finish
|
||||
visual-regression-comment:
|
||||
needs: [playwright-tests]
|
||||
@@ -603,7 +707,7 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Download visual diff summaries
|
||||
uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3
|
||||
with:
|
||||
pattern: screenshot-diff-summary-*
|
||||
path: summaries/
|
||||
@@ -686,7 +790,7 @@ jobs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
needs: [playwright-tests]
|
||||
needs: [playwright-tests, playwright-tests-lite]
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- name: Check job status
|
||||
|
||||
6
.github/workflows/pr-quality-checks.yml
vendored
6
.github/workflows/pr-quality-checks.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Setup Terraform
|
||||
uses: hashicorp/setup-terraform@b9cd54a3c349d3f38e8881555d616ced269862dd # ratchet:hashicorp/setup-terraform@v3
|
||||
uses: hashicorp/setup-terraform@5e8dbf3c6d9deaf4193ca7a8fb23f2ac83bb6c85 # ratchet:hashicorp/setup-terraform@v4.0.0
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v6
|
||||
with: # zizmor: ignore[cache-poisoning]
|
||||
@@ -38,9 +38,9 @@ jobs:
|
||||
- name: Install node dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
- uses: j178/prek-action@9d6a3097e0c1865ecce00cfb89fe80f2ee91b547 # ratchet:j178/prek-action@v1
|
||||
- uses: j178/prek-action@0bb87d7f00b0c99306c8bcb8b8beba1eb581c037 # ratchet:j178/prek-action@v1
|
||||
with:
|
||||
prek-version: '0.2.21'
|
||||
prek-version: '0.3.4'
|
||||
extra-args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || github.event_name == 'merge_group' && format('--from-ref {0} --to-ref {1}', github.event.merge_group.base_sha, github.event.merge_group.head_sha) || github.ref_name == 'main' && '--all-files' || '' }}
|
||||
- name: Check Actions
|
||||
uses: giner/check-actions@28d366c7cbbe235f9624a88aa31a628167eee28c # ratchet:giner/check-actions@v1.0.1
|
||||
|
||||
214
.github/workflows/release-cli.yml
vendored
Normal file
214
.github/workflows/release-cli.yml
vendored
Normal file
@@ -0,0 +1,214 @@
|
||||
name: Release CLI
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "cli/v*.*.*"
|
||||
|
||||
jobs:
|
||||
pypi:
|
||||
runs-on: ubuntu-latest
|
||||
environment:
|
||||
name: release-cli
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 10
|
||||
strategy:
|
||||
matrix:
|
||||
os-arch:
|
||||
- { goos: "linux", goarch: "amd64" }
|
||||
- { goos: "linux", goarch: "arm64" }
|
||||
- { goos: "windows", goarch: "amd64" }
|
||||
- { goos: "windows", goarch: "arm64" }
|
||||
- { goos: "darwin", goarch: "amd64" }
|
||||
- { goos: "darwin", goarch: "arm64" }
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
- run: |
|
||||
GOOS="${{ matrix.os-arch.goos }}" \
|
||||
GOARCH="${{ matrix.os-arch.goarch }}" \
|
||||
uv build --wheel
|
||||
working-directory: cli
|
||||
- run: uv publish
|
||||
working-directory: cli
|
||||
|
||||
docker-amd64:
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-cli-amd64
|
||||
- extras=ecr-cache
|
||||
environment: deploy
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 30
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-cli
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./cli
|
||||
file: ./cli/Dockerfile
|
||||
platforms: linux/amd64
|
||||
cache-from: type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: type=inline
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
docker-arm64:
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-cli-arm64
|
||||
- extras=ecr-cache
|
||||
environment: deploy
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 30
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-cli
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./cli
|
||||
file: ./cli/Dockerfile
|
||||
platforms: linux/arm64
|
||||
cache-from: type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: type=inline
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
merge-docker:
|
||||
needs:
|
||||
- docker-amd64
|
||||
- docker-arm64
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-cli-merge
|
||||
environment: deploy
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 10
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-cli
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
AMD64_DIGEST: ${{ needs.docker-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.docker-arm64.outputs.digest }}
|
||||
TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
SANITIZED_TAG="${TAG#cli/}"
|
||||
IMAGES=(
|
||||
"${REGISTRY_IMAGE}@${AMD64_DIGEST}"
|
||||
"${REGISTRY_IMAGE}@${ARM64_DIGEST}"
|
||||
)
|
||||
|
||||
if [[ "$TAG" =~ ^cli/v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||
docker buildx imagetools create \
|
||||
-t "${REGISTRY_IMAGE}:${SANITIZED_TAG}" \
|
||||
-t "${REGISTRY_IMAGE}:latest" \
|
||||
"${IMAGES[@]}"
|
||||
else
|
||||
docker buildx imagetools create \
|
||||
-t "${REGISTRY_IMAGE}:${SANITIZED_TAG}" \
|
||||
"${IMAGES[@]}"
|
||||
fi
|
||||
4
.github/workflows/release-devtools.yml
vendored
4
.github/workflows/release-devtools.yml
vendored
@@ -22,13 +22,11 @@ jobs:
|
||||
- { goos: "windows", goarch: "arm64" }
|
||||
- { goos: "darwin", goarch: "amd64" }
|
||||
- { goos: "darwin", goarch: "arm64" }
|
||||
- { goos: "", goarch: "" }
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
- uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
@@ -48,6 +48,10 @@ on:
|
||||
required: false
|
||||
default: true
|
||||
type: boolean
|
||||
secrets:
|
||||
AWS_OIDC_ROLE_ARN:
|
||||
description: "AWS role ARN for OIDC auth"
|
||||
required: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -73,7 +77,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -116,7 +120,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -158,7 +162,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -264,7 +268,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
6
.github/workflows/sandbox-deployment.yml
vendored
6
.github/workflows/sandbox-deployment.yml
vendored
@@ -110,7 +110,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -180,7 +180,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -244,7 +244,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
69
.github/workflows/storybook-deploy.yml
vendored
Normal file
69
.github/workflows/storybook-deploy.yml
vendored
Normal file
@@ -0,0 +1,69 @@
|
||||
name: Storybook Deploy
|
||||
env:
|
||||
VERCEL_ORG_ID: ${{ secrets.VERCEL_ORG_ID }}
|
||||
VERCEL_PROJECT_ID: prj_sG49mVsA25UsxIPhN2pmBJlikJZM
|
||||
VERCEL_CLI: vercel@50.14.1
|
||||
VERCEL_TOKEN: ${{ secrets.VERCEL_TOKEN }}
|
||||
|
||||
concurrency:
|
||||
group: storybook-deploy-production
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "web/lib/opal/**"
|
||||
- "web/src/refresh-components/**"
|
||||
- "web/.storybook/**"
|
||||
- "web/package.json"
|
||||
- "web/package-lock.json"
|
||||
permissions:
|
||||
contents: read
|
||||
jobs:
|
||||
Deploy-Storybook:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
cache-dependency-path: ./web/package-lock.json
|
||||
|
||||
- name: Install dependencies
|
||||
working-directory: web
|
||||
run: npm ci
|
||||
|
||||
- name: Build Storybook
|
||||
working-directory: web
|
||||
run: npm run storybook:build
|
||||
|
||||
- name: Deploy to Vercel (Production)
|
||||
working-directory: web
|
||||
run: npx --yes "$VERCEL_CLI" deploy storybook-static/ --prod --yes
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs: Deploy-Storybook
|
||||
if: always() && needs.Deploy-Storybook.result == 'failure'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
sparse-checkout: .github/actions/slack-notify
|
||||
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
failed-jobs: "• Deploy-Storybook"
|
||||
title: "🚨 Storybook Deploy Failed"
|
||||
2
.github/workflows/zizmor.yml
vendored
2
.github/workflows/zizmor.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install the latest version of uv
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
@@ -119,10 +119,11 @@ repos:
|
||||
]
|
||||
|
||||
- repo: https://github.com/golangci/golangci-lint
|
||||
rev: 9f61b0f53f80672872fced07b6874397c3ed197b # frozen: v2.7.2
|
||||
rev: 5d1e709b7be35cb2025444e19de266b056b7b7ee # frozen: v2.10.1
|
||||
hooks:
|
||||
- id: golangci-lint
|
||||
entry: bash -c "find tools/ -name go.mod -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
|
||||
language_version: "1.26.0"
|
||||
entry: bash -c "find . -name go.mod -not -path './.venv/*' -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
|
||||
58
.vscode/launch.json
vendored
58
.vscode/launch.json
vendored
@@ -40,19 +40,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Celery (lightweight mode)",
|
||||
"configurations": [
|
||||
"Celery primary",
|
||||
"Celery background",
|
||||
"Celery beat"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
},
|
||||
"stopAll": true
|
||||
},
|
||||
{
|
||||
"name": "Celery (standard mode)",
|
||||
"name": "Celery",
|
||||
"configurations": [
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
@@ -253,35 +241,6 @@
|
||||
},
|
||||
"consoleTitle": "Celery light Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery background",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.background",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=20",
|
||||
"--prefetch-multiplier=4",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=background@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery background Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery heavy",
|
||||
"type": "debugpy",
|
||||
@@ -526,21 +485,6 @@
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clear and Restart OpenSearch Container",
|
||||
// Generic debugger type, required arg but has no bearing on bash.
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": [
|
||||
"${workspaceFolder}/backend/scripts/restart_opensearch_container.sh"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Eval CLI",
|
||||
"type": "debugpy",
|
||||
|
||||
76
AGENTS.md
76
AGENTS.md
@@ -86,37 +86,6 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
|
||||
- Monitoring tasks (every 5 minutes)
|
||||
- Cleanup tasks (hourly)
|
||||
|
||||
#### Worker Deployment Modes
|
||||
|
||||
Onyx supports two deployment modes for background workers, controlled by the `USE_LIGHTWEIGHT_BACKGROUND_WORKER` environment variable:
|
||||
|
||||
**Lightweight Mode** (default, `USE_LIGHTWEIGHT_BACKGROUND_WORKER=true`):
|
||||
|
||||
- Runs a single consolidated `background` worker that handles all background tasks:
|
||||
- Light worker tasks (Vespa operations, permissions sync, deletion)
|
||||
- Document processing (indexing pipeline)
|
||||
- Document fetching (connector data retrieval)
|
||||
- Pruning operations (from `heavy` worker)
|
||||
- Knowledge graph processing (from `kg_processing` worker)
|
||||
- Monitoring tasks (from `monitoring` worker)
|
||||
- User file processing (from `user_file_processing` worker)
|
||||
- Lower resource footprint (fewer worker processes)
|
||||
- Suitable for smaller deployments or development environments
|
||||
- Default concurrency: 20 threads (increased to handle combined workload)
|
||||
|
||||
**Standard Mode** (`USE_LIGHTWEIGHT_BACKGROUND_WORKER=false`):
|
||||
|
||||
- Runs separate specialized workers as documented above (light, docprocessing, docfetching, heavy, kg_processing, monitoring, user_file_processing)
|
||||
- Better isolation and scalability
|
||||
- Can scale individual workers independently based on workload
|
||||
- Suitable for production deployments with higher load
|
||||
|
||||
The deployment mode affects:
|
||||
|
||||
- **Backend**: Worker processes spawned by supervisord or dev scripts
|
||||
- **Helm**: Which Kubernetes deployments are created
|
||||
- **Dev Environment**: Which workers `dev_run_background_jobs.py` spawns
|
||||
|
||||
#### Key Features
|
||||
|
||||
- **Thread-based Workers**: All workers use thread pools (not processes) for stability
|
||||
@@ -135,6 +104,10 @@ The deployment mode affects:
|
||||
|
||||
- Always use `@shared_task` rather than `@celery_app`
|
||||
- Put tasks under `background/celery/tasks/` or `ee/background/celery/tasks`
|
||||
- Never enqueue a task without an expiration. Always supply `expires=` when
|
||||
sending tasks, either from the beat schedule or directly from another task. It
|
||||
should never be acceptable to submit code which enqueues tasks without an
|
||||
expiration, as doing so can lead to unbounded task queue growth.
|
||||
|
||||
**Defining APIs**:
|
||||
When creating new FastAPI APIs, do NOT use the `response_model` field. Instead, just type the
|
||||
@@ -571,6 +544,8 @@ To run them:
|
||||
npx playwright test <TEST_NAME>
|
||||
```
|
||||
|
||||
For shared fixtures, best practices, and detailed guidance, see `backend/tests/README.md`.
|
||||
|
||||
## Logs
|
||||
|
||||
When (1) writing integration tests or (2) doing live tests (e.g. curl / playwright) you can get access
|
||||
@@ -617,6 +592,45 @@ Keep it high level. You can reference certain files or functions though.
|
||||
|
||||
Before writing your plan, make sure to do research. Explore the relevant sections in the codebase.
|
||||
|
||||
## Error Handling
|
||||
|
||||
**Always raise `OnyxError` from `onyx.error_handling.exceptions` instead of `HTTPException`.
|
||||
Never hardcode status codes or use `starlette.status` / `fastapi.status` constants directly.**
|
||||
|
||||
A global FastAPI exception handler converts `OnyxError` into a JSON response with the standard
|
||||
`{"error_code": "...", "detail": "..."}` shape. This eliminates boilerplate and keeps error
|
||||
handling consistent across the entire backend.
|
||||
|
||||
```python
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
# ✅ Good
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Session not found")
|
||||
|
||||
# ✅ Good — no extra message needed
|
||||
raise OnyxError(OnyxErrorCode.UNAUTHENTICATED)
|
||||
|
||||
# ✅ Good — upstream service with dynamic status code
|
||||
raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)
|
||||
|
||||
# ❌ Bad — using HTTPException directly
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
# ❌ Bad — starlette constant
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
|
||||
```
|
||||
|
||||
Available error codes are defined in `backend/onyx/error_handling/error_codes.py`. If a new error
|
||||
category is needed, add it there first — do not invent ad-hoc codes.
|
||||
|
||||
**Upstream service errors:** When forwarding errors from an upstream service where the HTTP
|
||||
status code is dynamic (comes from the upstream response), use `status_code_override`:
|
||||
|
||||
```python
|
||||
raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=e.response.status_code)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
In addition to the other content in this file, best practices for contributing
|
||||
|
||||
@@ -46,7 +46,9 @@ RUN apt-get update && \
|
||||
pkg-config \
|
||||
gcc \
|
||||
nano \
|
||||
vim && \
|
||||
vim \
|
||||
libjemalloc2 \
|
||||
&& \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
apt-get clean
|
||||
|
||||
@@ -141,6 +143,7 @@ COPY --chown=onyx:onyx ./scripts/debugging /app/scripts/debugging
|
||||
COPY --chown=onyx:onyx ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
|
||||
COPY --chown=onyx:onyx ./scripts/supervisord_entrypoint.sh /app/scripts/supervisord_entrypoint.sh
|
||||
COPY --chown=onyx:onyx ./scripts/setup_craft_templates.sh /app/scripts/setup_craft_templates.sh
|
||||
COPY --chown=onyx:onyx ./scripts/reencrypt_secrets.py /app/scripts/reencrypt_secrets.py
|
||||
RUN chmod +x /app/scripts/supervisord_entrypoint.sh /app/scripts/setup_craft_templates.sh
|
||||
|
||||
# Run Craft template setup at build time when ENABLE_CRAFT=true
|
||||
@@ -164,6 +167,13 @@ ENV PYTHONPATH=/app
|
||||
ARG ONYX_VERSION=0.0.0-dev
|
||||
ENV ONYX_VERSION=${ONYX_VERSION}
|
||||
|
||||
# Use jemalloc instead of glibc malloc to reduce memory fragmentation
|
||||
# in long-running Python processes (API server, Celery workers).
|
||||
# The soname is architecture-independent; the dynamic linker resolves
|
||||
# the correct path from standard library directories.
|
||||
# Placed after all RUN steps so build-time processes are unaffected.
|
||||
ENV LD_PRELOAD=libjemalloc.so.2
|
||||
|
||||
# Default command which does nothing
|
||||
# This container is used by api server and background which specify their own CMD
|
||||
CMD ["tail", "-f", "/dev/null"]
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
"""add hierarchy_node_by_connector_credential_pair table
|
||||
|
||||
Revision ID: b5c4d7e8f9a1
|
||||
Revises: a3b8d9e2f1c4
|
||||
Create Date: 2026-03-04
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = "b5c4d7e8f9a1"
|
||||
down_revision = "a3b8d9e2f1c4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"hierarchy_node_by_connector_credential_pair",
|
||||
sa.Column("hierarchy_node_id", sa.Integer(), nullable=False),
|
||||
sa.Column("connector_id", sa.Integer(), nullable=False),
|
||||
sa.Column("credential_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["hierarchy_node_id"],
|
||||
["hierarchy_node.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["connector_id", "credential_id"],
|
||||
[
|
||||
"connector_credential_pair.connector_id",
|
||||
"connector_credential_pair.credential_id",
|
||||
],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("hierarchy_node_id", "connector_id", "credential_id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_hierarchy_node_cc_pair_connector_credential",
|
||||
"hierarchy_node_by_connector_credential_pair",
|
||||
["connector_id", "credential_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(
|
||||
"ix_hierarchy_node_cc_pair_connector_credential",
|
||||
table_name="hierarchy_node_by_connector_credential_pair",
|
||||
)
|
||||
op.drop_table("hierarchy_node_by_connector_credential_pair")
|
||||
@@ -11,7 +11,6 @@ from sqlalchemy import text
|
||||
from alembic import op
|
||||
from onyx.configs.app_configs import DB_READONLY_PASSWORD
|
||||
from onyx.configs.app_configs import DB_READONLY_USER
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
@@ -22,59 +21,52 @@ depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
if MULTI_TENANT:
|
||||
# Enable pg_trgm extension if not already enabled
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
|
||||
|
||||
# Enable pg_trgm extension if not already enabled
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
|
||||
# Create the read-only db user if it does not already exist.
|
||||
if not (DB_READONLY_USER and DB_READONLY_PASSWORD):
|
||||
raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set")
|
||||
|
||||
# Create read-only db user here only in multi-tenant mode. For single-tenant mode,
|
||||
# the user is created in the standard migration.
|
||||
if not (DB_READONLY_USER and DB_READONLY_PASSWORD):
|
||||
raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set")
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
-- Check if the read-only user already exists
|
||||
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- Create the read-only user with the specified password
|
||||
EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}');
|
||||
-- First revoke all privileges to ensure a clean slate
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Grant only the CONNECT privilege to allow the user to connect to the database
|
||||
-- but not perform any operations without additional specific grants
|
||||
EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
if MULTI_TENANT:
|
||||
# Drop read-only db user here only in single tenant mode. For multi-tenant mode,
|
||||
# the user is dropped in the alembic_tenants migration.
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- First revoke all privileges from the database
|
||||
-- Check if the read-only user already exists
|
||||
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- Create the read-only user with the specified password
|
||||
EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}');
|
||||
-- First revoke all privileges to ensure a clean slate
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Then revoke all privileges from the public schema
|
||||
EXECUTE format('REVOKE ALL ON SCHEMA public FROM %I', '{DB_READONLY_USER}');
|
||||
-- Then drop the user
|
||||
EXECUTE format('DROP USER %I', '{DB_READONLY_USER}');
|
||||
-- Grant only the CONNECT privilege to allow the user to connect to the database
|
||||
-- but not perform any operations without additional specific grants
|
||||
EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
"""
|
||||
)
|
||||
op.execute(text("DROP EXTENSION IF EXISTS pg_trgm"))
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- First revoke all privileges from the database
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Then revoke all privileges from the public schema
|
||||
EXECUTE format('REVOKE ALL ON SCHEMA public FROM %I', '{DB_READONLY_USER}');
|
||||
-- Then drop the user
|
||||
EXECUTE format('DROP USER %I', '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
op.execute(text("DROP EXTENSION IF EXISTS pg_trgm"))
|
||||
|
||||
@@ -9,12 +9,15 @@ from onyx.access.access import (
|
||||
_get_access_for_documents as get_access_for_documents_without_groups,
|
||||
)
|
||||
from onyx.access.access import _get_acl_for_user as get_acl_for_user_without_groups
|
||||
from onyx.access.access import collect_user_file_access
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.access.utils import prefix_external_group
|
||||
from onyx.access.utils import prefix_user_group
|
||||
from onyx.db.document import get_document_sources
|
||||
from onyx.db.document import get_documents_by_ids
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.user_file import fetch_user_files_with_access_relationships
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -116,6 +119,68 @@ def _get_access_for_documents(
|
||||
return access_map
|
||||
|
||||
|
||||
def _collect_user_file_group_names(user_file: UserFile) -> set[str]:
|
||||
"""Extract user-group names from the already-loaded Persona.groups
|
||||
relationships on a UserFile (skipping deleted personas)."""
|
||||
groups: set[str] = set()
|
||||
for persona in user_file.assistants:
|
||||
if persona.deleted:
|
||||
continue
|
||||
for group in persona.groups:
|
||||
groups.add(group.name)
|
||||
return groups
|
||||
|
||||
|
||||
def get_access_for_user_files_impl(
|
||||
user_file_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
"""EE version: extends the MIT user file ACL with user group names
|
||||
from personas shared via user groups.
|
||||
|
||||
Uses a single DB query (via fetch_user_files_with_access_relationships)
|
||||
that eagerly loads both the MIT-needed and EE-needed relationships.
|
||||
|
||||
NOTE: is imported in onyx.access.access by `fetch_versioned_implementation`
|
||||
DO NOT REMOVE."""
|
||||
user_files = fetch_user_files_with_access_relationships(
|
||||
user_file_ids, db_session, eager_load_groups=True
|
||||
)
|
||||
return build_access_for_user_files_impl(user_files)
|
||||
|
||||
|
||||
def build_access_for_user_files_impl(
|
||||
user_files: list[UserFile],
|
||||
) -> dict[str, DocumentAccess]:
|
||||
"""EE version: works on pre-loaded UserFile objects.
|
||||
Expects Persona.groups to be eagerly loaded.
|
||||
|
||||
NOTE: is imported in onyx.access.access by `fetch_versioned_implementation`
|
||||
DO NOT REMOVE."""
|
||||
result: dict[str, DocumentAccess] = {}
|
||||
for user_file in user_files:
|
||||
if user_file.user is None:
|
||||
result[str(user_file.id)] = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
is_public=True,
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
)
|
||||
continue
|
||||
|
||||
emails, is_public = collect_user_file_access(user_file)
|
||||
group_names = _collect_user_file_group_names(user_file)
|
||||
result[str(user_file.id)] = DocumentAccess.build(
|
||||
user_emails=list(emails),
|
||||
user_groups=list(group_names),
|
||||
is_public=is_public,
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _get_acl_for_user(user: User, db_session: Session) -> set[str]:
|
||||
"""Returns a list of ACL entries that the user has access to. This is meant to be
|
||||
used downstream to filter out documents that the user does not have access to. The
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import jwt
|
||||
@@ -20,7 +21,13 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def verify_auth_setting() -> None:
|
||||
# All the Auth flows are valid for EE version
|
||||
# All the Auth flows are valid for EE version, but warn about deprecated 'disabled'
|
||||
raw_auth_type = (os.environ.get("AUTH_TYPE") or "").lower()
|
||||
if raw_auth_type == "disabled":
|
||||
logger.warning(
|
||||
"AUTH_TYPE='disabled' is no longer supported. "
|
||||
"Using 'basic' instead. Please update your configuration."
|
||||
)
|
||||
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
|
||||
|
||||
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
from onyx.background.celery.apps import app_base
|
||||
from onyx.background.celery.apps.background import celery_app
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cleanup",
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning",
|
||||
"ee.onyx.background.celery.tasks.query_history",
|
||||
]
|
||||
)
|
||||
)
|
||||
@@ -18,7 +18,7 @@ from onyx.db.models import HierarchyNode
|
||||
|
||||
|
||||
def _build_hierarchy_access_filter(
|
||||
user_email: str | None,
|
||||
user_email: str,
|
||||
external_group_ids: list[str],
|
||||
) -> ColumnElement[bool]:
|
||||
"""Build SQLAlchemy filter for hierarchy node access.
|
||||
@@ -43,7 +43,7 @@ def _build_hierarchy_access_filter(
|
||||
def _get_accessible_hierarchy_nodes_for_source(
|
||||
db_session: Session,
|
||||
source: DocumentSource,
|
||||
user_email: str | None,
|
||||
user_email: str,
|
||||
external_group_ids: list[str],
|
||||
) -> list[HierarchyNode]:
|
||||
"""
|
||||
|
||||
@@ -11,11 +11,10 @@ from ee.onyx.server.license.models import LicenseMetadata
|
||||
from ee.onyx.server.license.models import LicensePayload
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
|
||||
from onyx.db.models import License
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
@@ -142,7 +141,7 @@ def get_used_seats(tenant_id: str | None = None) -> int:
|
||||
|
||||
def get_cached_license_metadata(tenant_id: str | None = None) -> LicenseMetadata | None:
|
||||
"""
|
||||
Get license metadata from Redis cache.
|
||||
Get license metadata from cache.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
@@ -150,38 +149,34 @@ def get_cached_license_metadata(tenant_id: str | None = None) -> LicenseMetadata
|
||||
Returns:
|
||||
LicenseMetadata if cached, None otherwise
|
||||
"""
|
||||
tenant = tenant_id or get_current_tenant_id()
|
||||
redis_client = get_redis_replica_client(tenant_id=tenant)
|
||||
cache = get_cache_backend(tenant_id=tenant_id)
|
||||
cached = cache.get(LICENSE_METADATA_KEY)
|
||||
if not cached:
|
||||
return None
|
||||
|
||||
cached = redis_client.get(LICENSE_METADATA_KEY)
|
||||
if cached:
|
||||
try:
|
||||
cached_str: str
|
||||
if isinstance(cached, bytes):
|
||||
cached_str = cached.decode("utf-8")
|
||||
else:
|
||||
cached_str = str(cached)
|
||||
return LicenseMetadata.model_validate_json(cached_str)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse cached license metadata: {e}")
|
||||
return None
|
||||
return None
|
||||
try:
|
||||
cached_str = (
|
||||
cached.decode("utf-8") if isinstance(cached, bytes) else str(cached)
|
||||
)
|
||||
return LicenseMetadata.model_validate_json(cached_str)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse cached license metadata: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def invalidate_license_cache(tenant_id: str | None = None) -> None:
|
||||
"""
|
||||
Invalidate the license metadata cache (not the license itself).
|
||||
|
||||
This deletes the cached LicenseMetadata from Redis. The actual license
|
||||
in the database is not affected. Redis delete is idempotent - if the
|
||||
key doesn't exist, this is a no-op.
|
||||
Deletes the cached LicenseMetadata. The actual license in the database
|
||||
is not affected. Delete is idempotent — if the key doesn't exist, this
|
||||
is a no-op.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
"""
|
||||
tenant = tenant_id or get_current_tenant_id()
|
||||
redis_client = get_redis_client(tenant_id=tenant)
|
||||
redis_client.delete(LICENSE_METADATA_KEY)
|
||||
cache = get_cache_backend(tenant_id=tenant_id)
|
||||
cache.delete(LICENSE_METADATA_KEY)
|
||||
logger.info("License cache invalidated")
|
||||
|
||||
|
||||
@@ -192,7 +187,7 @@ def update_license_cache(
|
||||
tenant_id: str | None = None,
|
||||
) -> LicenseMetadata:
|
||||
"""
|
||||
Update the Redis cache with license metadata.
|
||||
Update the cache with license metadata.
|
||||
|
||||
We cache all license statuses (ACTIVE, GRACE_PERIOD, GATED_ACCESS) because:
|
||||
1. Frontend needs status to show appropriate UI/banners
|
||||
@@ -211,7 +206,7 @@ def update_license_cache(
|
||||
from ee.onyx.utils.license import get_license_status
|
||||
|
||||
tenant = tenant_id or get_current_tenant_id()
|
||||
redis_client = get_redis_client(tenant_id=tenant)
|
||||
cache = get_cache_backend(tenant_id=tenant_id)
|
||||
|
||||
used_seats = get_used_seats(tenant)
|
||||
status = get_license_status(payload, grace_period_end)
|
||||
@@ -230,7 +225,7 @@ def update_license_cache(
|
||||
stripe_subscription_id=payload.stripe_subscription_id,
|
||||
)
|
||||
|
||||
redis_client.set(
|
||||
cache.set(
|
||||
LICENSE_METADATA_KEY,
|
||||
metadata.model_dump_json(),
|
||||
ex=LICENSE_CACHE_TTL_SECONDS,
|
||||
|
||||
@@ -7,6 +7,7 @@ from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import Persona__UserGroup
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.db.persona import mark_persona_user_files_for_sync
|
||||
from onyx.server.features.persona.models import PersonaSharedNotificationData
|
||||
|
||||
|
||||
@@ -26,7 +27,9 @@ def update_persona_access(
|
||||
|
||||
NOTE: Callers are responsible for committing."""
|
||||
|
||||
needs_sync = False
|
||||
if is_public is not None:
|
||||
needs_sync = True
|
||||
persona = db_session.query(Persona).filter(Persona.id == persona_id).first()
|
||||
if persona:
|
||||
persona.is_public = is_public
|
||||
@@ -35,6 +38,7 @@ def update_persona_access(
|
||||
# and a non-empty list means "replace with these shares".
|
||||
|
||||
if user_ids is not None:
|
||||
needs_sync = True
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
@@ -54,6 +58,7 @@ def update_persona_access(
|
||||
)
|
||||
|
||||
if group_ids is not None:
|
||||
needs_sync = True
|
||||
db_session.query(Persona__UserGroup).filter(
|
||||
Persona__UserGroup.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
@@ -63,3 +68,7 @@ def update_persona_access(
|
||||
db_session.add(
|
||||
Persona__UserGroup(persona_id=persona_id, user_group_id=group_id)
|
||||
)
|
||||
|
||||
# When sharing changes, user file ACLs need to be updated in the vector DB
|
||||
if needs_sync:
|
||||
mark_persona_user_files_for_sync(persona_id, db_session)
|
||||
|
||||
@@ -15,6 +15,7 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.server.user_group.models import SetCuratorRequest
|
||||
from ee.onyx.server.user_group.models import UserGroupCreate
|
||||
from ee.onyx.server.user_group.models import UserGroupUpdate
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
@@ -471,7 +472,9 @@ def _add_user_group__cc_pair_relationships__no_commit(
|
||||
|
||||
def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserGroup:
|
||||
db_user_group = UserGroup(
|
||||
name=user_group.name, time_last_modified_by_user=func.now()
|
||||
name=user_group.name,
|
||||
time_last_modified_by_user=func.now(),
|
||||
is_up_to_date=DISABLE_VECTOR_DB,
|
||||
)
|
||||
db_session.add(db_user_group)
|
||||
db_session.flush() # give the group an ID
|
||||
@@ -774,8 +777,7 @@ def update_user_group(
|
||||
cc_pair_ids=user_group_update.cc_pair_ids,
|
||||
)
|
||||
|
||||
# only needs to sync with Vespa if the cc_pairs have been updated
|
||||
if cc_pairs_updated:
|
||||
if cc_pairs_updated and not DISABLE_VECTOR_DB:
|
||||
db_user_group.is_up_to_date = False
|
||||
|
||||
removed_users = db_session.scalars(
|
||||
|
||||
@@ -68,6 +68,7 @@ def get_external_access_for_raw_gdrive_file(
|
||||
company_domain: str,
|
||||
retriever_drive_service: GoogleDriveService | None,
|
||||
admin_drive_service: GoogleDriveService,
|
||||
fallback_user_email: str,
|
||||
add_prefix: bool = False,
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
@@ -79,6 +80,11 @@ def get_external_access_for_raw_gdrive_file(
|
||||
set add_prefix to True so group IDs are prefixed with the source type.
|
||||
When invoked from doc_sync (permission sync), use the default (False)
|
||||
since upsert_document_external_perms handles prefixing.
|
||||
fallback_user_email: When we cannot retrieve any permission info for a file
|
||||
(e.g. externally-owned files where the API returns no permissions
|
||||
and permissions.list returns 403), fall back to granting access
|
||||
to this user. This is typically the impersonated org user whose
|
||||
drive contained the file.
|
||||
"""
|
||||
doc_id = file.get("id")
|
||||
if not doc_id:
|
||||
@@ -117,6 +123,26 @@ def get_external_access_for_raw_gdrive_file(
|
||||
[permissions_list, backup_permissions_list]
|
||||
)
|
||||
|
||||
# For externally-owned files, the Drive API may return no permissions
|
||||
# and permissions.list may return 403. In this case, fall back to
|
||||
# granting access to the user who found the file in their drive.
|
||||
# Note, even if other users also have access to this file,
|
||||
# they will not be granted access in Onyx.
|
||||
# We check permissions_list (the final result after all fetch attempts)
|
||||
# rather than the raw fields, because permission_ids may be present
|
||||
# but the actual fetch can still return empty due to a 403.
|
||||
if not permissions_list:
|
||||
logger.info(
|
||||
f"No permission info available for file {doc_id} "
|
||||
f"(likely owned by a user outside of your organization). "
|
||||
f"Falling back to granting access to retriever user: {fallback_user_email}"
|
||||
)
|
||||
return ExternalAccess(
|
||||
external_user_emails={fallback_user_email},
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
folder_ids_to_inherit_permissions_from: set[str] = set()
|
||||
user_emails: set[str] = set()
|
||||
group_emails: set[str] = set()
|
||||
|
||||
@@ -4,7 +4,6 @@ from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.server.analytics.api import router as analytics_router
|
||||
from ee.onyx.server.auth_check import check_ee_router_auth
|
||||
from ee.onyx.server.billing.api import router as billing_router
|
||||
@@ -153,12 +152,9 @@ def get_application() -> FastAPI:
|
||||
# License management
|
||||
include_router_with_global_prefix_prepended(application, license_router)
|
||||
|
||||
# Unified billing API - available when license system is enabled
|
||||
# Works for both self-hosted and cloud deployments
|
||||
# TODO(ENG-3533): Once frontend migrates to /admin/billing/*, this becomes the
|
||||
# primary billing API and /tenants/* billing endpoints can be removed
|
||||
if LICENSE_ENFORCEMENT_ENABLED:
|
||||
include_router_with_global_prefix_prepended(application, billing_router)
|
||||
# Unified billing API - always registered in EE.
|
||||
# Each endpoint is protected by the `current_admin_user` dependency (admin auth).
|
||||
include_router_with_global_prefix_prepended(application, billing_router)
|
||||
|
||||
if MULTI_TENANT:
|
||||
# Tenant management
|
||||
|
||||
@@ -26,7 +26,6 @@ import asyncio
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -42,7 +41,6 @@ from ee.onyx.server.billing.models import SeatUpdateRequest
|
||||
from ee.onyx.server.billing.models import SeatUpdateResponse
|
||||
from ee.onyx.server.billing.models import StripePublishableKeyResponse
|
||||
from ee.onyx.server.billing.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.billing.service import BillingServiceError
|
||||
from ee.onyx.server.billing.service import (
|
||||
create_checkout_session as create_checkout_service,
|
||||
)
|
||||
@@ -58,6 +56,8 @@ from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.redis.redis_pool import get_shared_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -169,26 +169,23 @@ async def create_checkout_session(
|
||||
if seats is not None:
|
||||
used_seats = get_used_seats(tenant_id)
|
||||
if seats < used_seats:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot subscribe with fewer seats than current usage. "
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Cannot subscribe with fewer seats than current usage. "
|
||||
f"You have {used_seats} active users/integrations but requested {seats} seats.",
|
||||
)
|
||||
|
||||
# Build redirect URL for after checkout completion
|
||||
redirect_url = f"{WEB_DOMAIN}/admin/billing?checkout=success"
|
||||
|
||||
try:
|
||||
return await create_checkout_service(
|
||||
billing_period=billing_period,
|
||||
seats=seats,
|
||||
email=email,
|
||||
license_data=license_data,
|
||||
redirect_url=redirect_url,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
return await create_checkout_service(
|
||||
billing_period=billing_period,
|
||||
seats=seats,
|
||||
email=email,
|
||||
license_data=license_data,
|
||||
redirect_url=redirect_url,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
@@ -206,18 +203,15 @@ async def create_customer_portal_session(
|
||||
|
||||
# Self-hosted requires license
|
||||
if not MULTI_TENANT and not license_data:
|
||||
raise HTTPException(status_code=400, detail="No license found")
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "No license found")
|
||||
|
||||
return_url = request.return_url if request else f"{WEB_DOMAIN}/admin/billing"
|
||||
|
||||
try:
|
||||
return await create_portal_service(
|
||||
license_data=license_data,
|
||||
return_url=return_url,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
return await create_portal_service(
|
||||
license_data=license_data,
|
||||
return_url=return_url,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/billing-information")
|
||||
@@ -240,9 +234,9 @@ async def get_billing_information(
|
||||
|
||||
# Check circuit breaker (self-hosted only)
|
||||
if _is_billing_circuit_open():
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Stripe connection temporarily disabled. Click 'Connect to Stripe' to retry.",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.SERVICE_UNAVAILABLE,
|
||||
"Stripe connection temporarily disabled. Click 'Connect to Stripe' to retry.",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -250,11 +244,15 @@ async def get_billing_information(
|
||||
license_data=license_data,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except BillingServiceError as e:
|
||||
except OnyxError as e:
|
||||
# Open circuit breaker on connection failures (self-hosted only)
|
||||
if e.status_code in (502, 503, 504):
|
||||
if e.status_code in (
|
||||
OnyxErrorCode.BAD_GATEWAY.status_code,
|
||||
OnyxErrorCode.SERVICE_UNAVAILABLE.status_code,
|
||||
OnyxErrorCode.GATEWAY_TIMEOUT.status_code,
|
||||
):
|
||||
_open_billing_circuit()
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/seats/update")
|
||||
@@ -274,31 +272,25 @@ async def update_seats(
|
||||
|
||||
# Self-hosted requires license
|
||||
if not MULTI_TENANT and not license_data:
|
||||
raise HTTPException(status_code=400, detail="No license found")
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "No license found")
|
||||
|
||||
# Validate that new seat count is not less than current used seats
|
||||
used_seats = get_used_seats(tenant_id)
|
||||
if request.new_seat_count < used_seats:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot reduce seats below current usage. "
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Cannot reduce seats below current usage. "
|
||||
f"You have {used_seats} active users/integrations but requested {request.new_seat_count} seats.",
|
||||
)
|
||||
|
||||
try:
|
||||
result = await update_seat_service(
|
||||
new_seat_count=request.new_seat_count,
|
||||
license_data=license_data,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# Note: Don't store license here - the control plane may still be processing
|
||||
# the subscription update. The frontend should call /license/claim after a
|
||||
# short delay to get the freshly generated license.
|
||||
|
||||
return result
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
# Note: Don't store license here - the control plane may still be processing
|
||||
# the subscription update. The frontend should call /license/claim after a
|
||||
# short delay to get the freshly generated license.
|
||||
return await update_seat_service(
|
||||
new_seat_count=request.new_seat_count,
|
||||
license_data=license_data,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/stripe-publishable-key")
|
||||
@@ -329,18 +321,18 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
if STRIPE_PUBLISHABLE_KEY_OVERRIDE:
|
||||
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Invalid Stripe publishable key format",
|
||||
)
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
|
||||
# Fall back to S3 bucket
|
||||
if not STRIPE_PUBLISHABLE_KEY_URL:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Stripe publishable key is not configured",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Stripe publishable key is not configured",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -351,17 +343,17 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
|
||||
# Validate key format
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Invalid Stripe publishable key format",
|
||||
)
|
||||
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
except httpx.HTTPError:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to fetch Stripe publishable key",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Failed to fetch Stripe publishable key",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -22,6 +22,8 @@ from ee.onyx.server.billing.models import SeatUpdateResponse
|
||||
from ee.onyx.server.billing.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -31,15 +33,6 @@ logger = setup_logger()
|
||||
_REQUEST_TIMEOUT = 30.0
|
||||
|
||||
|
||||
class BillingServiceError(Exception):
|
||||
"""Exception raised for billing service errors."""
|
||||
|
||||
def __init__(self, message: str, status_code: int = 500):
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
def _get_proxy_headers(license_data: str | None) -> dict[str, str]:
|
||||
"""Build headers for proxy requests (self-hosted).
|
||||
|
||||
@@ -101,7 +94,7 @@ async def _make_billing_request(
|
||||
Response JSON as dict
|
||||
|
||||
Raises:
|
||||
BillingServiceError: If request fails
|
||||
OnyxError: If request fails
|
||||
"""
|
||||
|
||||
base_url = _get_base_url()
|
||||
@@ -128,11 +121,17 @@ async def _make_billing_request(
|
||||
except Exception:
|
||||
pass
|
||||
logger.error(f"{error_message}: {e.response.status_code} - {detail}")
|
||||
raise BillingServiceError(detail, e.response.status_code)
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
detail,
|
||||
status_code_override=e.response.status_code,
|
||||
)
|
||||
|
||||
except httpx.RequestError:
|
||||
logger.exception("Failed to connect to billing service")
|
||||
raise BillingServiceError("Failed to connect to billing service", 502)
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY, "Failed to connect to billing service"
|
||||
)
|
||||
|
||||
|
||||
async def create_checkout_session(
|
||||
|
||||
@@ -223,6 +223,15 @@ def get_active_scim_token(
|
||||
token = dal.get_active_token()
|
||||
if not token:
|
||||
raise HTTPException(status_code=404, detail="No active SCIM token")
|
||||
|
||||
# Derive the IdP domain from the first synced user as a heuristic.
|
||||
idp_domain: str | None = None
|
||||
mappings, _total = dal.list_user_mappings(start_index=1, count=1)
|
||||
if mappings:
|
||||
user = dal.get_user(mappings[0].user_id)
|
||||
if user and "@" in user.email:
|
||||
idp_domain = user.email.rsplit("@", 1)[1]
|
||||
|
||||
return ScimTokenResponse(
|
||||
id=token.id,
|
||||
name=token.name,
|
||||
@@ -230,6 +239,7 @@ def get_active_scim_token(
|
||||
is_active=token.is_active,
|
||||
created_at=token.created_at,
|
||||
last_used_at=token.last_used_at,
|
||||
idp_domain=idp_domain,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ import requests
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import HTTPException
|
||||
from fastapi import UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -35,6 +34,8 @@ from ee.onyx.server.license.models import SeatUsageResponse
|
||||
from ee.onyx.utils.license import verify_license_signature
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -127,9 +128,9 @@ async def claim_license(
|
||||
2. Without session_id: Re-claim using existing license for auth
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License claiming is only available for self-hosted deployments",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"License claiming is only available for self-hosted deployments",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -146,15 +147,16 @@ async def claim_license(
|
||||
# Re-claim using existing license for auth
|
||||
metadata = get_license_metadata(db_session)
|
||||
if not metadata or not metadata.tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No license found. Provide session_id after checkout.",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No license found. Provide session_id after checkout.",
|
||||
)
|
||||
|
||||
license_row = get_license(db_session)
|
||||
if not license_row or not license_row.license_data:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No license found in database"
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No license found in database",
|
||||
)
|
||||
|
||||
url = f"{CLOUD_DATA_PLANE_URL}/proxy/license/{metadata.tenant_id}"
|
||||
@@ -173,7 +175,7 @@ async def claim_license(
|
||||
license_data = data.get("license")
|
||||
|
||||
if not license_data:
|
||||
raise HTTPException(status_code=404, detail="No license in response")
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "No license in response")
|
||||
|
||||
# Verify signature before persisting
|
||||
payload = verify_license_signature(license_data)
|
||||
@@ -199,12 +201,14 @@ async def claim_license(
|
||||
detail = error_data.get("detail", detail)
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(status_code=status_code, detail=detail)
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=status_code
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
except requests.RequestException:
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Failed to connect to license server"
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY, "Failed to connect to license server"
|
||||
)
|
||||
|
||||
|
||||
@@ -221,9 +225,9 @@ async def upload_license(
|
||||
The license file must be cryptographically signed by Onyx.
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License upload is only available for self-hosted deployments",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"License upload is only available for self-hosted deployments",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -234,14 +238,14 @@ async def upload_license(
|
||||
# Remove any stray whitespace/newlines from user input
|
||||
license_data = license_data.strip()
|
||||
except UnicodeDecodeError:
|
||||
raise HTTPException(status_code=400, detail="Invalid license file format")
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, "Invalid license file format")
|
||||
|
||||
# Verify cryptographic signature - this is the only validation needed
|
||||
# The license's tenant_id identifies the customer in control plane, not locally
|
||||
try:
|
||||
payload = verify_license_signature(license_data)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
|
||||
# Persist to DB and update cache
|
||||
upsert_license(db_session, license_data)
|
||||
@@ -297,9 +301,9 @@ async def delete_license(
|
||||
Admin only - removes license from database and invalidates cache.
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License deletion is only available for self-hosted deployments",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"License deletion is only available for self-hosted deployments",
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -46,7 +46,6 @@ from fastapi import FastAPI
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from redis.exceptions import RedisError
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
@@ -56,6 +55,7 @@ from ee.onyx.configs.license_enforcement_config import (
|
||||
)
|
||||
from ee.onyx.db.license import get_cached_license_metadata
|
||||
from ee.onyx.db.license import refresh_license_cache
|
||||
from onyx.cache.interface import CACHE_TRANSIENT_ERRORS
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
@@ -164,9 +164,9 @@ def add_license_enforcement_middleware(
|
||||
"[license_enforcement] No license, allowing community features"
|
||||
)
|
||||
is_gated = False
|
||||
except RedisError as e:
|
||||
except CACHE_TRANSIENT_ERRORS as e:
|
||||
logger.warning(f"Failed to check license metadata: {e}")
|
||||
# Fail open - don't block users due to Redis connectivity issues
|
||||
# Fail open - don't block users due to cache connectivity issues
|
||||
is_gated = False
|
||||
|
||||
if is_gated:
|
||||
|
||||
@@ -365,6 +365,7 @@ class ScimTokenResponse(BaseModel):
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
last_used_at: datetime | None = None
|
||||
idp_domain: str | None = None
|
||||
|
||||
|
||||
class ScimTokenCreatedResponse(ScimTokenResponse):
|
||||
|
||||
@@ -26,6 +26,7 @@ from onyx.db.models import Tool
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.server.settings.store import store_settings as store_base_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -125,10 +126,16 @@ def _seed_llms(
|
||||
existing = fetch_existing_llm_provider(name=request.name, db_session=db_session)
|
||||
if existing:
|
||||
request.id = existing.id
|
||||
seeded_providers = [
|
||||
upsert_llm_provider(llm_upsert_request, db_session)
|
||||
for llm_upsert_request in llm_upsert_requests
|
||||
]
|
||||
seeded_providers: list[LLMProviderView] = []
|
||||
for llm_upsert_request in llm_upsert_requests:
|
||||
try:
|
||||
seeded_providers.append(upsert_llm_provider(llm_upsert_request, db_session))
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
"Failed to upsert LLM provider '%s' during seeding: %s",
|
||||
llm_upsert_request.name,
|
||||
e,
|
||||
)
|
||||
|
||||
default_provider = next(
|
||||
(p for p in seeded_providers if p.model_configurations), None
|
||||
|
||||
@@ -6,6 +6,7 @@ from sqlalchemy.exc import SQLAlchemyError
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.db.license import get_cached_license_metadata
|
||||
from ee.onyx.db.license import refresh_license_cache
|
||||
from onyx.cache.interface import CACHE_TRANSIENT_ERRORS
|
||||
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
@@ -125,7 +126,7 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
# syncing) means indexed data may need protection.
|
||||
settings.application_status = _BLOCKING_STATUS
|
||||
settings.ee_features_enabled = False
|
||||
except RedisError as e:
|
||||
except CACHE_TRANSIENT_ERRORS as e:
|
||||
logger.warning(f"Failed to check license metadata for settings: {e}")
|
||||
# Fail closed - disable EE features if we can't verify license
|
||||
settings.ee_features_enabled = False
|
||||
|
||||
@@ -21,7 +21,6 @@ import asyncio
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.server.tenants.access import control_plane_dep
|
||||
@@ -43,6 +42,8 @@ from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
@@ -116,9 +117,14 @@ async def create_customer_portal_session(
|
||||
try:
|
||||
portal_url = fetch_customer_portal_session(tenant_id, return_url)
|
||||
return {"stripe_customer_portal_url": portal_url}
|
||||
except Exception as e:
|
||||
except OnyxError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to create customer portal session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Failed to create customer portal session",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/create-checkout-session")
|
||||
@@ -134,9 +140,14 @@ async def create_checkout_session(
|
||||
try:
|
||||
checkout_url = fetch_stripe_checkout_session(tenant_id, billing_period, seats)
|
||||
return {"stripe_checkout_url": checkout_url}
|
||||
except Exception as e:
|
||||
except OnyxError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to create checkout session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Failed to create checkout session",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/create-subscription-session")
|
||||
@@ -147,15 +158,20 @@ async def create_subscription_session(
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=400, detail="Tenant ID not found")
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Tenant ID not found")
|
||||
|
||||
billing_period = request.billing_period if request else "monthly"
|
||||
session_id = fetch_stripe_checkout_session(tenant_id, billing_period)
|
||||
return SubscriptionSessionResponse(sessionId=session_id)
|
||||
|
||||
except Exception as e:
|
||||
except OnyxError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to create subscription session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Failed to create subscription session",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/stripe-publishable-key")
|
||||
@@ -186,18 +202,18 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
if STRIPE_PUBLISHABLE_KEY_OVERRIDE:
|
||||
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Invalid Stripe publishable key format",
|
||||
)
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
|
||||
# Fall back to S3 bucket
|
||||
if not STRIPE_PUBLISHABLE_KEY_URL:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Stripe publishable key is not configured",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Stripe publishable key is not configured",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -208,15 +224,15 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
|
||||
# Validate key format
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Invalid Stripe publishable key format",
|
||||
)
|
||||
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
except httpx.HTTPError:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to fetch Stripe publishable key",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Failed to fetch Stripe publishable key",
|
||||
)
|
||||
|
||||
@@ -5,6 +5,8 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.user_group import add_users_to_user_group
|
||||
from ee.onyx.db.user_group import delete_user_group as db_delete_user_group
|
||||
from ee.onyx.db.user_group import fetch_user_group
|
||||
from ee.onyx.db.user_group import fetch_user_groups
|
||||
from ee.onyx.db.user_group import fetch_user_groups_for_user
|
||||
from ee.onyx.db.user_group import insert_user_group
|
||||
@@ -20,6 +22,7 @@ from ee.onyx.server.user_group.models import UserGroupUpdate
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
@@ -153,3 +156,8 @@ def delete_user_group(
|
||||
prepare_user_group_for_deletion(db_session, user_group_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
user_group = fetch_user_group(db_session, user_group_id)
|
||||
if user_group:
|
||||
db_delete_user_group(db_session, user_group)
|
||||
|
||||
@@ -14,67 +14,91 @@ from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
@lru_cache(maxsize=2)
|
||||
def _get_trimmed_key(key: str) -> bytes:
|
||||
encoded_key = key.encode()
|
||||
key_length = len(encoded_key)
|
||||
if key_length < 16:
|
||||
raise RuntimeError("Invalid ENCRYPTION_KEY_SECRET - too short")
|
||||
elif key_length > 32:
|
||||
key = key[:32]
|
||||
elif key_length not in (16, 24, 32):
|
||||
valid_lengths = [16, 24, 32]
|
||||
key = key[: min(valid_lengths, key=lambda x: abs(x - key_length))]
|
||||
|
||||
return encoded_key
|
||||
# Trim to the largest valid AES key size that fits
|
||||
valid_lengths = [32, 24, 16]
|
||||
for size in valid_lengths:
|
||||
if key_length >= size:
|
||||
return encoded_key[:size]
|
||||
|
||||
raise AssertionError("unreachable")
|
||||
|
||||
|
||||
def _encrypt_string(input_str: str) -> bytes:
|
||||
if not ENCRYPTION_KEY_SECRET:
|
||||
def _encrypt_string(input_str: str, key: str | None = None) -> bytes:
|
||||
effective_key = key if key is not None else ENCRYPTION_KEY_SECRET
|
||||
if not effective_key:
|
||||
return input_str.encode()
|
||||
|
||||
key = _get_trimmed_key(ENCRYPTION_KEY_SECRET)
|
||||
trimmed = _get_trimmed_key(effective_key)
|
||||
iv = urandom(16)
|
||||
padder = padding.PKCS7(algorithms.AES.block_size).padder()
|
||||
padded_data = padder.update(input_str.encode()) + padder.finalize()
|
||||
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||
cipher = Cipher(algorithms.AES(trimmed), modes.CBC(iv), backend=default_backend())
|
||||
encryptor = cipher.encryptor()
|
||||
encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
|
||||
|
||||
return iv + encrypted_data
|
||||
|
||||
|
||||
def _decrypt_bytes(input_bytes: bytes) -> str:
|
||||
if not ENCRYPTION_KEY_SECRET:
|
||||
def _decrypt_bytes(input_bytes: bytes, key: str | None = None) -> str:
|
||||
effective_key = key if key is not None else ENCRYPTION_KEY_SECRET
|
||||
if not effective_key:
|
||||
return input_bytes.decode()
|
||||
|
||||
key = _get_trimmed_key(ENCRYPTION_KEY_SECRET)
|
||||
iv = input_bytes[:16]
|
||||
encrypted_data = input_bytes[16:]
|
||||
trimmed = _get_trimmed_key(effective_key)
|
||||
try:
|
||||
iv = input_bytes[:16]
|
||||
encrypted_data = input_bytes[16:]
|
||||
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||
decryptor = cipher.decryptor()
|
||||
decrypted_padded_data = decryptor.update(encrypted_data) + decryptor.finalize()
|
||||
cipher = Cipher(
|
||||
algorithms.AES(trimmed), modes.CBC(iv), backend=default_backend()
|
||||
)
|
||||
decryptor = cipher.decryptor()
|
||||
decrypted_padded_data = decryptor.update(encrypted_data) + decryptor.finalize()
|
||||
|
||||
unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
|
||||
decrypted_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
|
||||
unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
|
||||
decrypted_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
|
||||
|
||||
return decrypted_data.decode()
|
||||
return decrypted_data.decode()
|
||||
except (ValueError, UnicodeDecodeError):
|
||||
if key is not None:
|
||||
# Explicit key was provided — don't fall back silently
|
||||
raise
|
||||
# Read path: attempt raw UTF-8 decode as a fallback for legacy data.
|
||||
# Does NOT handle data encrypted with a different key — that
|
||||
# ciphertext is not valid UTF-8 and will raise below.
|
||||
logger.warning(
|
||||
"AES decryption failed — falling back to raw decode. "
|
||||
"Run the re-encrypt secrets script to rotate to the current key."
|
||||
)
|
||||
try:
|
||||
return input_bytes.decode()
|
||||
except UnicodeDecodeError:
|
||||
raise ValueError(
|
||||
"Data is not valid UTF-8 — likely encrypted with a different key. "
|
||||
"Run the re-encrypt secrets script to rotate to the current key."
|
||||
) from None
|
||||
|
||||
|
||||
def encrypt_string_to_bytes(input_str: str) -> bytes:
|
||||
def encrypt_string_to_bytes(input_str: str, key: str | None = None) -> bytes:
|
||||
versioned_encryption_fn = fetch_versioned_implementation(
|
||||
"onyx.utils.encryption", "_encrypt_string"
|
||||
)
|
||||
return versioned_encryption_fn(input_str)
|
||||
return versioned_encryption_fn(input_str, key=key)
|
||||
|
||||
|
||||
def decrypt_bytes_to_string(input_bytes: bytes) -> str:
|
||||
def decrypt_bytes_to_string(input_bytes: bytes, key: str | None = None) -> str:
|
||||
versioned_decryption_fn = fetch_versioned_implementation(
|
||||
"onyx.utils.encryption", "_decrypt_bytes"
|
||||
)
|
||||
return versioned_decryption_fn(input_bytes)
|
||||
return versioned_decryption_fn(input_bytes, key=key)
|
||||
|
||||
|
||||
def test_encryption() -> None:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
@@ -12,6 +11,7 @@ from onyx.db.document import get_access_info_for_document
|
||||
from onyx.db.document import get_access_info_for_documents
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.user_file import fetch_user_files_with_access_relationships
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
@@ -132,19 +132,61 @@ def get_access_for_user_files(
|
||||
user_file_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
user_files = (
|
||||
db_session.query(UserFile)
|
||||
.options(joinedload(UserFile.user)) # Eager load the user relationship
|
||||
.filter(UserFile.id.in_(user_file_ids))
|
||||
.all()
|
||||
versioned_fn = fetch_versioned_implementation(
|
||||
"onyx.access.access", "get_access_for_user_files_impl"
|
||||
)
|
||||
return {
|
||||
str(user_file.id): DocumentAccess.build(
|
||||
user_emails=[user_file.user.email] if user_file.user else [],
|
||||
return versioned_fn(user_file_ids, db_session)
|
||||
|
||||
|
||||
def get_access_for_user_files_impl(
|
||||
user_file_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
user_files = fetch_user_files_with_access_relationships(user_file_ids, db_session)
|
||||
return build_access_for_user_files_impl(user_files)
|
||||
|
||||
|
||||
def build_access_for_user_files(
|
||||
user_files: list[UserFile],
|
||||
) -> dict[str, DocumentAccess]:
|
||||
"""Compute access from pre-loaded UserFile objects (with relationships).
|
||||
Callers must ensure UserFile.user, Persona.users, and Persona.user are
|
||||
eagerly loaded (and Persona.groups for the EE path)."""
|
||||
versioned_fn = fetch_versioned_implementation(
|
||||
"onyx.access.access", "build_access_for_user_files_impl"
|
||||
)
|
||||
return versioned_fn(user_files)
|
||||
|
||||
|
||||
def build_access_for_user_files_impl(
|
||||
user_files: list[UserFile],
|
||||
) -> dict[str, DocumentAccess]:
|
||||
result: dict[str, DocumentAccess] = {}
|
||||
for user_file in user_files:
|
||||
emails, is_public = collect_user_file_access(user_file)
|
||||
result[str(user_file.id)] = DocumentAccess.build(
|
||||
user_emails=list(emails),
|
||||
user_groups=[],
|
||||
is_public=True if user_file.user is None else False,
|
||||
is_public=is_public,
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
)
|
||||
for user_file in user_files
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def collect_user_file_access(user_file: UserFile) -> tuple[set[str], bool]:
|
||||
"""Collect all user emails that should have access to this user file.
|
||||
Includes the owner plus any users who have access via shared personas.
|
||||
Returns (emails, is_public)."""
|
||||
emails: set[str] = {user_file.user.email}
|
||||
is_public = False
|
||||
for persona in user_file.assistants:
|
||||
if persona.deleted:
|
||||
continue
|
||||
if persona.is_public:
|
||||
is_public = True
|
||||
if persona.user_id is not None and persona.user:
|
||||
emails.add(persona.user.email)
|
||||
for shared_user in persona.users:
|
||||
emails.add(shared_user.email)
|
||||
return emails, is_public
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import secrets
|
||||
import string
|
||||
@@ -120,7 +121,6 @@ from onyx.db.models import User
|
||||
from onyx.db.pat import fetch_user_for_pat
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -146,10 +146,22 @@ def is_user_admin(user: User) -> bool:
|
||||
|
||||
|
||||
def verify_auth_setting() -> None:
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
"""Log warnings for AUTH_TYPE issues.
|
||||
|
||||
This only runs on app startup not during migrations/scripts.
|
||||
"""
|
||||
raw_auth_type = (os.environ.get("AUTH_TYPE") or "").lower()
|
||||
|
||||
if raw_auth_type == "cloud":
|
||||
raise ValueError(
|
||||
f"{AUTH_TYPE.value} is not a valid auth type for self-hosted deployments."
|
||||
"'cloud' is not a valid auth type for self-hosted deployments."
|
||||
)
|
||||
if raw_auth_type == "disabled":
|
||||
logger.warning(
|
||||
"AUTH_TYPE='disabled' is no longer supported. "
|
||||
"Using 'basic' instead. Please update your configuration."
|
||||
)
|
||||
|
||||
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
|
||||
|
||||
|
||||
@@ -201,13 +213,14 @@ def user_needs_to_be_verified() -> bool:
|
||||
|
||||
|
||||
def anonymous_user_enabled(*, tenant_id: str | None = None) -> bool:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
value = redis_client.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
|
||||
cache = get_cache_backend(tenant_id=tenant_id)
|
||||
value = cache.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
|
||||
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
assert isinstance(value, bytes)
|
||||
return int(value.decode("utf-8")) == 1
|
||||
|
||||
|
||||
|
||||
@@ -1,142 +0,0 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.apps.worker import Worker
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_process_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.background")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
EXTRA_CONCURRENCY = 8 # small extra fudge factor for connection limits
|
||||
|
||||
logger.info("worker_init signal received for consolidated background worker.")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME)
|
||||
pool_size = cast(int, sender.concurrency) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=pool_size, max_overflow=EXTRA_CONCURRENCY)
|
||||
|
||||
# Initialize Vespa httpx pool (needed for light worker tasks)
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
sender.concurrency + EXTRA_CONCURRENCY, # type: ignore
|
||||
ssl_cert=VESPA_CLOUD_CERT_PATH,
|
||||
ssl_key=VESPA_CLOUD_KEY_PATH,
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(sender.concurrency + EXTRA_CONCURRENCY) # type: ignore
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_process_init.connect
|
||||
def init_worker(**kwargs: Any) -> None: # noqa: ARG001
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
base_bootsteps = app_base.get_bootsteps()
|
||||
for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
# Original background worker tasks
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
# Light worker tasks
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"onyx.background.celery.tasks.opensearch_migration",
|
||||
# Docprocessing worker tasks
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
# Docfetching worker tasks
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
# Sandbox cleanup tasks (isolated in build feature)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
]
|
||||
)
|
||||
)
|
||||
@@ -39,9 +39,13 @@ CT = TypeVar("CT", bound=ConnectorCheckpoint)
|
||||
|
||||
|
||||
class SlimConnectorExtractionResult(BaseModel):
|
||||
"""Result of extracting document IDs and hierarchy nodes from a connector."""
|
||||
"""Result of extracting document IDs and hierarchy nodes from a connector.
|
||||
|
||||
doc_ids: set[str]
|
||||
raw_id_to_parent maps document ID → parent_hierarchy_raw_node_id (or None).
|
||||
Use raw_id_to_parent.keys() wherever the old set of IDs was needed.
|
||||
"""
|
||||
|
||||
raw_id_to_parent: dict[str, str | None]
|
||||
hierarchy_nodes: list[HierarchyNode]
|
||||
|
||||
|
||||
@@ -93,30 +97,34 @@ def _get_failure_id(failure: ConnectorFailure) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
class BatchResult(BaseModel):
|
||||
raw_id_to_parent: dict[str, str | None]
|
||||
hierarchy_nodes: list[HierarchyNode]
|
||||
|
||||
|
||||
def _extract_from_batch(
|
||||
doc_list: Sequence[Document | SlimDocument | HierarchyNode | ConnectorFailure],
|
||||
) -> tuple[set[str], list[HierarchyNode]]:
|
||||
"""Separate a batch into document IDs and hierarchy nodes.
|
||||
) -> BatchResult:
|
||||
"""Separate a batch into document IDs (with parent mapping) and hierarchy nodes.
|
||||
|
||||
ConnectorFailure items have their failed document/entity IDs added to the
|
||||
ID set so that failed-to-retrieve documents are not accidentally pruned.
|
||||
ID dict so that failed-to-retrieve documents are not accidentally pruned.
|
||||
"""
|
||||
ids: set[str] = set()
|
||||
ids: dict[str, str | None] = {}
|
||||
hierarchy_nodes: list[HierarchyNode] = []
|
||||
for item in doc_list:
|
||||
if isinstance(item, HierarchyNode):
|
||||
hierarchy_nodes.append(item)
|
||||
ids.add(item.raw_node_id)
|
||||
elif isinstance(item, ConnectorFailure):
|
||||
failed_id = _get_failure_id(item)
|
||||
if failed_id:
|
||||
ids.add(failed_id)
|
||||
ids[failed_id] = None
|
||||
logger.warning(
|
||||
f"Failed to retrieve document {failed_id}: " f"{item.failure_message}"
|
||||
)
|
||||
else:
|
||||
ids.add(item.id)
|
||||
return ids, hierarchy_nodes
|
||||
ids[item.id] = item.parent_hierarchy_raw_node_id
|
||||
return BatchResult(raw_id_to_parent=ids, hierarchy_nodes=hierarchy_nodes)
|
||||
|
||||
|
||||
def extract_ids_from_runnable_connector(
|
||||
@@ -132,7 +140,7 @@ def extract_ids_from_runnable_connector(
|
||||
|
||||
Optionally, a callback can be passed to handle the length of each document batch.
|
||||
"""
|
||||
all_connector_doc_ids: set[str] = set()
|
||||
all_raw_id_to_parent: dict[str, str | None] = {}
|
||||
all_hierarchy_nodes: list[HierarchyNode] = []
|
||||
|
||||
# Sequence (covariant) lets all the specific list[...] iterator types unify here
|
||||
@@ -177,15 +185,18 @@ def extract_ids_from_runnable_connector(
|
||||
"extract_ids_from_runnable_connector: Stop signal detected"
|
||||
)
|
||||
|
||||
batch_ids, batch_nodes = _extract_from_batch(doc_list)
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(batch_ids))
|
||||
batch_result = _extract_from_batch(doc_list)
|
||||
batch_ids = batch_result.raw_id_to_parent
|
||||
batch_nodes = batch_result.hierarchy_nodes
|
||||
doc_batch_processing_func(batch_ids)
|
||||
all_raw_id_to_parent.update(batch_ids)
|
||||
all_hierarchy_nodes.extend(batch_nodes)
|
||||
|
||||
if callback:
|
||||
callback.progress("extract_ids_from_runnable_connector", len(batch_ids))
|
||||
|
||||
return SlimConnectorExtractionResult(
|
||||
doc_ids=all_connector_doc_ids,
|
||||
raw_id_to_parent=all_raw_id_to_parent,
|
||||
hierarchy_nodes=all_hierarchy_nodes,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
import onyx.background.celery.configs.base as shared_config
|
||||
from onyx.configs.app_configs import CELERY_WORKER_BACKGROUND_CONCURRENCY
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
worker_concurrency = CELERY_WORKER_BACKGROUND_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
# Increased from 1 to 4 to handle fast light worker tasks more efficiently
|
||||
# This allows the worker to prefetch multiple tasks per thread
|
||||
worker_prefetch_multiplier = 4
|
||||
@@ -40,6 +40,7 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.hierarchy import upsert_hierarchy_node_cc_pair_entries
|
||||
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
|
||||
@@ -289,6 +290,14 @@ def _run_hierarchy_extraction(
|
||||
is_connector_public=is_connector_public,
|
||||
)
|
||||
|
||||
upsert_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
hierarchy_node_ids=[n.id for n in upserted_nodes],
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# Cache in Redis for fast ancestor resolution
|
||||
cache_entries = [
|
||||
HierarchyNodeCacheEntry.from_db_model(node) for node in upserted_nodes
|
||||
|
||||
@@ -30,6 +30,7 @@ from onyx.background.celery.tasks.opensearch_migration.transformer import (
|
||||
transform_vespa_chunks_to_opensearch_chunks,
|
||||
)
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import VESPA_MIGRATION_REQUEST_TIMEOUT_S
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
@@ -47,6 +48,7 @@ from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchDocumentIndex,
|
||||
)
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@@ -146,7 +148,12 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
task_logger.error(err_str)
|
||||
return False
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with (
|
||||
get_session_with_current_tenant() as db_session,
|
||||
get_vespa_http_client(
|
||||
timeout=VESPA_MIGRATION_REQUEST_TIMEOUT_S
|
||||
) as vespa_client,
|
||||
):
|
||||
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT)
|
||||
@@ -161,6 +168,7 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
index_name=search_settings.index_name,
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=False,
|
||||
httpx_client=vespa_client,
|
||||
)
|
||||
|
||||
sanitized_doc_start_time = time.monotonic()
|
||||
|
||||
@@ -29,6 +29,7 @@ from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -47,8 +48,15 @@ from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.hierarchy import delete_orphaned_hierarchy_nodes
|
||||
from onyx.db.hierarchy import link_hierarchy_nodes_to_documents
|
||||
from onyx.db.hierarchy import remove_stale_hierarchy_node_cc_pair_entries
|
||||
from onyx.db.hierarchy import reparent_orphaned_hierarchy_nodes
|
||||
from onyx.db.hierarchy import update_document_parent_hierarchy_nodes
|
||||
from onyx.db.hierarchy import upsert_hierarchy_node_cc_pair_entries
|
||||
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import HierarchyNode as DBHierarchyNode
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.db.tag import delete_orphan_tags__no_commit
|
||||
@@ -57,6 +65,9 @@ from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload
|
||||
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
|
||||
from onyx.redis.redis_hierarchy import ensure_source_node_exists
|
||||
from onyx.redis.redis_hierarchy import evict_hierarchy_nodes_from_cache
|
||||
from onyx.redis.redis_hierarchy import get_node_id_from_raw_id
|
||||
from onyx.redis.redis_hierarchy import get_source_node_id_from_cache
|
||||
from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
@@ -113,6 +124,38 @@ class PruneCallback(IndexingCallbackBase):
|
||||
super().progress(tag, amount)
|
||||
|
||||
|
||||
def _resolve_and_update_document_parents(
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
source: DocumentSource,
|
||||
raw_id_to_parent: dict[str, str | None],
|
||||
) -> None:
|
||||
"""Resolve parent_hierarchy_raw_node_id → parent_hierarchy_node_id for
|
||||
each document and bulk-update the DB. Mirrors the resolution logic in
|
||||
run_docfetching.py."""
|
||||
source_node_id = get_source_node_id_from_cache(redis_client, db_session, source)
|
||||
|
||||
resolved: dict[str, int | None] = {}
|
||||
for doc_id, raw_parent_id in raw_id_to_parent.items():
|
||||
if raw_parent_id is None:
|
||||
continue
|
||||
node_id, found = get_node_id_from_raw_id(redis_client, source, raw_parent_id)
|
||||
resolved[doc_id] = node_id if found else source_node_id
|
||||
|
||||
if not resolved:
|
||||
return
|
||||
|
||||
update_document_parent_hierarchy_nodes(
|
||||
db_session=db_session,
|
||||
doc_parent_map=resolved,
|
||||
commit=True,
|
||||
)
|
||||
task_logger.info(
|
||||
f"Pruning: resolved and updated parent hierarchy for "
|
||||
f"{len(resolved)} documents (source={source.value})"
|
||||
)
|
||||
|
||||
|
||||
"""Jobs / utils for kicking off pruning tasks."""
|
||||
|
||||
|
||||
@@ -535,33 +578,42 @@ def connector_pruning_generator_task(
|
||||
extraction_result = extract_ids_from_runnable_connector(
|
||||
runnable_connector, callback
|
||||
)
|
||||
all_connector_doc_ids = extraction_result.doc_ids
|
||||
all_connector_doc_ids = extraction_result.raw_id_to_parent
|
||||
|
||||
# Process hierarchy nodes (same as docfetching):
|
||||
# upsert to Postgres and cache in Redis
|
||||
source = cc_pair.connector.source
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
ensure_source_node_exists(redis_client, db_session, source)
|
||||
|
||||
upserted_nodes: list[DBHierarchyNode] = []
|
||||
if extraction_result.hierarchy_nodes:
|
||||
is_connector_public = cc_pair.access_type == AccessType.PUBLIC
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
ensure_source_node_exists(
|
||||
redis_client, db_session, cc_pair.connector.source
|
||||
)
|
||||
|
||||
upserted_nodes = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=extraction_result.hierarchy_nodes,
|
||||
source=cc_pair.connector.source,
|
||||
source=source,
|
||||
commit=True,
|
||||
is_connector_public=is_connector_public,
|
||||
)
|
||||
|
||||
upsert_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
hierarchy_node_ids=[n.id for n in upserted_nodes],
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
cache_entries = [
|
||||
HierarchyNodeCacheEntry.from_db_model(node)
|
||||
for node in upserted_nodes
|
||||
]
|
||||
cache_hierarchy_nodes_batch(
|
||||
redis_client=redis_client,
|
||||
source=cc_pair.connector.source,
|
||||
source=source,
|
||||
entries=cache_entries,
|
||||
)
|
||||
|
||||
@@ -570,6 +622,25 @@ def connector_pruning_generator_task(
|
||||
f"hierarchy nodes for cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
# Resolve parent_hierarchy_raw_node_id → parent_hierarchy_node_id
|
||||
# and bulk-update documents, mirroring the docfetching resolution
|
||||
_resolve_and_update_document_parents(
|
||||
db_session=db_session,
|
||||
redis_client=redis_client,
|
||||
source=source,
|
||||
raw_id_to_parent=all_connector_doc_ids,
|
||||
)
|
||||
|
||||
# Link hierarchy nodes to documents for sources where pages can be
|
||||
# both hierarchy nodes AND documents (e.g. Notion, Confluence)
|
||||
all_doc_id_list = list(all_connector_doc_ids.keys())
|
||||
link_hierarchy_nodes_to_documents(
|
||||
db_session=db_session,
|
||||
document_ids=all_doc_id_list,
|
||||
source=source,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# a list of docs in our local index
|
||||
all_indexed_document_ids = {
|
||||
doc.id
|
||||
@@ -581,7 +652,9 @@ def connector_pruning_generator_task(
|
||||
}
|
||||
|
||||
# generate list of docs to remove (no longer in the source)
|
||||
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
|
||||
doc_ids_to_remove = list(
|
||||
all_indexed_document_ids - all_connector_doc_ids.keys()
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
"Pruning set collected: "
|
||||
@@ -605,6 +678,43 @@ def connector_pruning_generator_task(
|
||||
)
|
||||
|
||||
redis_connector.prune.generator_complete = tasks_generated
|
||||
|
||||
# --- Hierarchy node pruning ---
|
||||
live_node_ids = {n.id for n in upserted_nodes}
|
||||
stale_removed = remove_stale_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
live_hierarchy_node_ids=live_node_ids,
|
||||
commit=True,
|
||||
)
|
||||
deleted_raw_ids = delete_orphaned_hierarchy_nodes(
|
||||
db_session=db_session,
|
||||
source=source,
|
||||
commit=True,
|
||||
)
|
||||
reparented_nodes = reparent_orphaned_hierarchy_nodes(
|
||||
db_session=db_session,
|
||||
source=source,
|
||||
commit=True,
|
||||
)
|
||||
if deleted_raw_ids:
|
||||
evict_hierarchy_nodes_from_cache(redis_client, source, deleted_raw_ids)
|
||||
if reparented_nodes:
|
||||
reparented_cache_entries = [
|
||||
HierarchyNodeCacheEntry.from_db_model(node)
|
||||
for node in reparented_nodes
|
||||
]
|
||||
cache_hierarchy_nodes_batch(
|
||||
redis_client, source, reparented_cache_entries
|
||||
)
|
||||
if stale_removed or deleted_raw_ids or reparented_nodes:
|
||||
task_logger.info(
|
||||
f"Hierarchy node pruning: cc_pair={cc_pair_id} "
|
||||
f"stale_entries_removed={stale_removed} "
|
||||
f"nodes_deleted={len(deleted_raw_ids)} "
|
||||
f"nodes_reparented={len(reparented_nodes)}"
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Pruning exceptioned: cc_pair={cc_pair_id} "
|
||||
|
||||
@@ -12,9 +12,9 @@ from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from retry import retry
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.access import build_access_for_user_files
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
@@ -43,7 +43,9 @@ from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.search_settings import get_active_search_settings_list
|
||||
from onyx.db.user_file import fetch_user_files_with_access_relationships
|
||||
from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
@@ -54,6 +56,7 @@ from onyx.indexing.adapters.user_file_indexing_adapter import UserFileIndexingAd
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
def _as_uuid(value: str | UUID) -> UUID:
|
||||
@@ -791,11 +794,12 @@ def project_sync_user_file_impl(
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file = db_session.execute(
|
||||
select(UserFile)
|
||||
.where(UserFile.id == _as_uuid(user_file_id))
|
||||
.options(selectinload(UserFile.assistants))
|
||||
).scalar_one_or_none()
|
||||
user_files = fetch_user_files_with_access_relationships(
|
||||
[user_file_id],
|
||||
db_session,
|
||||
eager_load_groups=global_version.is_ee_version(),
|
||||
)
|
||||
user_file = user_files[0] if user_files else None
|
||||
if not user_file:
|
||||
task_logger.info(
|
||||
f"project_sync_user_file_impl - User file not found id={user_file_id}"
|
||||
@@ -823,12 +827,21 @@ def project_sync_user_file_impl(
|
||||
|
||||
project_ids = [project.id for project in user_file.projects]
|
||||
persona_ids = [p.id for p in user_file.assistants if not p.deleted]
|
||||
|
||||
file_id_str = str(user_file.id)
|
||||
access_map = build_access_for_user_files([user_file])
|
||||
access = access_map.get(file_id_str)
|
||||
|
||||
for retry_document_index in retry_document_indices:
|
||||
retry_document_index.update_single(
|
||||
doc_id=str(user_file.id),
|
||||
doc_id=file_id_str,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=user_file.chunk_count,
|
||||
fields=None,
|
||||
fields=(
|
||||
VespaDocumentFields(access=access)
|
||||
if access is not None
|
||||
else None
|
||||
),
|
||||
user_fields=VespaDocumentUserFields(
|
||||
user_projects=project_ids,
|
||||
personas=persona_ids,
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
from celery import Celery
|
||||
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
app: Celery = fetch_versioned_implementation(
|
||||
"onyx.background.celery.apps.background",
|
||||
"celery_app",
|
||||
)
|
||||
@@ -45,6 +45,7 @@ from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.enums import ProcessingMode
|
||||
from onyx.db.hierarchy import upsert_hierarchy_node_cc_pair_entries
|
||||
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
|
||||
from onyx.db.index_attempt import create_index_attempt_error
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
@@ -58,8 +59,6 @@ from onyx.file_store.document_batch_storage import DocumentBatchStorage
|
||||
from onyx.file_store.document_batch_storage import get_document_batch_storage
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
|
||||
from onyx.indexing.postgres_sanitization import sanitize_document_for_postgres
|
||||
from onyx.indexing.postgres_sanitization import sanitize_hierarchy_nodes_for_postgres
|
||||
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
|
||||
from onyx.redis.redis_hierarchy import ensure_source_node_exists
|
||||
from onyx.redis.redis_hierarchy import get_node_id_from_raw_id
|
||||
@@ -71,6 +70,8 @@ from onyx.server.features.build.indexing.persistent_document_writer import (
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.middleware import make_randomized_onyx_request_id
|
||||
from onyx.utils.postgres_sanitization import sanitize_document_for_postgres
|
||||
from onyx.utils.postgres_sanitization import sanitize_hierarchy_nodes_for_postgres
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import INDEX_ATTEMPT_INFO_CONTEXTVAR
|
||||
@@ -587,6 +588,14 @@ def connector_document_extraction(
|
||||
is_connector_public=is_connector_public,
|
||||
)
|
||||
|
||||
upsert_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
hierarchy_node_ids=[n.id for n in upserted_nodes],
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# Cache in Redis for fast ancestor resolution during doc processing
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
cache_entries = [
|
||||
|
||||
11
backend/onyx/cache/interface.py
vendored
11
backend/onyx/cache/interface.py
vendored
@@ -1,9 +1,20 @@
|
||||
import abc
|
||||
from enum import Enum
|
||||
|
||||
from redis.exceptions import RedisError
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
TTL_KEY_NOT_FOUND = -2
|
||||
TTL_NO_EXPIRY = -1
|
||||
|
||||
CACHE_TRANSIENT_ERRORS: tuple[type[Exception], ...] = (RedisError, SQLAlchemyError)
|
||||
"""Exception types that represent transient cache connectivity / operational
|
||||
failures. Callers that want to fail-open (or fail-closed) on cache errors
|
||||
should catch this tuple instead of bare ``Exception``.
|
||||
|
||||
When adding a new ``CacheBackend`` implementation, add its transient error
|
||||
base class(es) here so all call-sites pick it up automatically."""
|
||||
|
||||
|
||||
class CacheBackendType(str, Enum):
|
||||
REDIS = "redis"
|
||||
|
||||
@@ -36,7 +36,6 @@ from onyx.db.memory import add_memory
|
||||
from onyx.db.memory import update_memory_at_index
|
||||
from onyx.db.memory import UserMemoryContext
|
||||
from onyx.db.models import Persona
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
@@ -51,7 +50,9 @@ from onyx.tools.built_in_tools import CITEABLE_TOOLS_NAMES
|
||||
from onyx.tools.built_in_tools import STOPPING_TOOLS_NAMES
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ChatFile
|
||||
from onyx.tools.models import CustomToolCallSummary
|
||||
from onyx.tools.models import MemoryToolResponseSnapshot
|
||||
from onyx.tools.models import PythonToolRichResponse
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
@@ -83,28 +84,6 @@ def _looks_like_xml_tool_call_payload(text: str | None) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def _should_keep_bedrock_tool_definitions(
|
||||
llm: object, simple_chat_history: list[ChatMessageSimple]
|
||||
) -> bool:
|
||||
"""Bedrock requires tool config when history includes toolUse/toolResult blocks."""
|
||||
model_provider = getattr(getattr(llm, "config", None), "model_provider", None)
|
||||
if model_provider not in {
|
||||
LlmProviderNames.BEDROCK,
|
||||
LlmProviderNames.BEDROCK_CONVERSE,
|
||||
}:
|
||||
return False
|
||||
|
||||
return any(
|
||||
(
|
||||
msg.message_type == MessageType.ASSISTANT
|
||||
and msg.tool_calls
|
||||
and len(msg.tool_calls) > 0
|
||||
)
|
||||
or msg.message_type == MessageType.TOOL_CALL_RESPONSE
|
||||
for msg in simple_chat_history
|
||||
)
|
||||
|
||||
|
||||
def _try_fallback_tool_extraction(
|
||||
llm_step_result: LlmStepResult,
|
||||
tool_choice: ToolChoiceOptions,
|
||||
@@ -685,12 +664,7 @@ def run_llm_loop(
|
||||
elif out_of_cycles or ran_image_gen:
|
||||
# Last cycle, no tools allowed, just answer!
|
||||
tool_choice = ToolChoiceOptions.NONE
|
||||
# Bedrock requires tool config in requests that include toolUse/toolResult history.
|
||||
final_tools = (
|
||||
tools
|
||||
if _should_keep_bedrock_tool_definitions(llm, simple_chat_history)
|
||||
else []
|
||||
)
|
||||
final_tools = []
|
||||
else:
|
||||
tool_choice = ToolChoiceOptions.AUTO
|
||||
final_tools = tools
|
||||
@@ -966,6 +940,13 @@ def run_llm_loop(
|
||||
):
|
||||
generated_images = tool_response.rich_response.generated_images
|
||||
|
||||
# Extract generated_files if this is a code interpreter response
|
||||
generated_files = None
|
||||
if isinstance(tool_response.rich_response, PythonToolRichResponse):
|
||||
generated_files = (
|
||||
tool_response.rich_response.generated_files or None
|
||||
)
|
||||
|
||||
# Persist memory if this is a memory tool response
|
||||
memory_snapshot: MemoryToolResponseSnapshot | None = None
|
||||
if isinstance(tool_response.rich_response, MemoryToolResponse):
|
||||
@@ -1000,6 +981,10 @@ def run_llm_loop(
|
||||
|
||||
if memory_snapshot:
|
||||
saved_response = json.dumps(memory_snapshot.model_dump())
|
||||
elif isinstance(tool_response.rich_response, CustomToolCallSummary):
|
||||
saved_response = json.dumps(
|
||||
tool_response.rich_response.model_dump()
|
||||
)
|
||||
elif isinstance(tool_response.rich_response, str):
|
||||
saved_response = tool_response.rich_response
|
||||
else:
|
||||
@@ -1017,6 +1002,7 @@ def run_llm_loop(
|
||||
tool_call_response=saved_response,
|
||||
search_docs=displayed_docs or search_docs,
|
||||
generated_images=generated_images,
|
||||
generated_files=generated_files,
|
||||
)
|
||||
# Add to state container for partial save support
|
||||
state_container.add_tool_call(tool_call_info)
|
||||
|
||||
@@ -15,6 +15,7 @@ from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.tool_call_args_streaming import maybe_emit_argument_delta
|
||||
from onyx.configs.app_configs import LOG_ONYX_MODEL_INTERACTIONS
|
||||
from onyx.configs.app_configs import PROMPT_CACHE_CHAT_HISTORY
|
||||
from onyx.configs.constants import MessageType
|
||||
@@ -54,7 +55,9 @@ from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tracing.framework.create import generation_span
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
from onyx.utils.jsonriver import Parser
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_string
|
||||
from onyx.utils.text_processing import find_all_json_objects
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -166,15 +169,6 @@ def _find_function_calls_open_marker(text_lower: str) -> int:
|
||||
search_from = idx + 1
|
||||
|
||||
|
||||
def _sanitize_llm_output(value: str) -> str:
|
||||
"""Remove characters that PostgreSQL's text/JSONB types cannot store.
|
||||
|
||||
- NULL bytes (\x00): Not allowed in PostgreSQL text types
|
||||
- UTF-16 surrogates (\ud800-\udfff): Invalid in UTF-8 encoding
|
||||
"""
|
||||
return "".join(c for c in value if c != "\x00" and not ("\ud800" <= c <= "\udfff"))
|
||||
|
||||
|
||||
def _try_parse_json_string(value: Any) -> Any:
|
||||
"""Attempt to parse a JSON string value into its Python equivalent.
|
||||
|
||||
@@ -222,9 +216,7 @@ def _parse_tool_args_to_dict(raw_args: Any) -> dict[str, Any]:
|
||||
if isinstance(raw_args, dict):
|
||||
# Parse any string values that look like JSON arrays/objects
|
||||
return {
|
||||
k: _try_parse_json_string(
|
||||
_sanitize_llm_output(v) if isinstance(v, str) else v
|
||||
)
|
||||
k: _try_parse_json_string(sanitize_string(v) if isinstance(v, str) else v)
|
||||
for k, v in raw_args.items()
|
||||
}
|
||||
|
||||
@@ -232,7 +224,7 @@ def _parse_tool_args_to_dict(raw_args: Any) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
# Sanitize before parsing to remove NULL bytes and surrogates
|
||||
raw_args = _sanitize_llm_output(raw_args)
|
||||
raw_args = sanitize_string(raw_args)
|
||||
|
||||
try:
|
||||
parsed1: Any = json.loads(raw_args)
|
||||
@@ -545,12 +537,12 @@ def _extract_xml_attribute(attrs: str, attr_name: str) -> str | None:
|
||||
)
|
||||
if not attr_match:
|
||||
return None
|
||||
return _sanitize_llm_output(unescape(attr_match.group(2).strip()))
|
||||
return sanitize_string(unescape(attr_match.group(2).strip()))
|
||||
|
||||
|
||||
def _parse_xml_parameter_value(raw_value: str, string_attr: str | None) -> Any:
|
||||
"""Parse a parameter value from XML-style tool call payloads."""
|
||||
value = _sanitize_llm_output(unescape(raw_value).strip())
|
||||
value = sanitize_string(unescape(raw_value).strip())
|
||||
|
||||
if string_attr and string_attr.lower() == "true":
|
||||
return value
|
||||
@@ -569,6 +561,7 @@ def _resolve_tool_arguments(obj: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""
|
||||
arguments = obj.get("arguments", obj.get("parameters", {}))
|
||||
if isinstance(arguments, str):
|
||||
arguments = sanitize_string(arguments)
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
@@ -1018,6 +1011,7 @@ def run_llm_step_pkt_generator(
|
||||
)
|
||||
|
||||
id_to_tool_call_map: dict[int, dict[str, Any]] = {}
|
||||
arg_parsers: dict[int, Parser] = {}
|
||||
reasoning_start = False
|
||||
answer_start = False
|
||||
accumulated_reasoning = ""
|
||||
@@ -1224,7 +1218,14 @@ def run_llm_step_pkt_generator(
|
||||
yield from _close_reasoning_if_active()
|
||||
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
# maybe_emit depends and update being called first and attaching the delta
|
||||
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
|
||||
yield from maybe_emit_argument_delta(
|
||||
tool_calls_in_progress=id_to_tool_call_map,
|
||||
tool_call_delta=tool_call_delta,
|
||||
placement=_current_placement(),
|
||||
parsers=arg_parsers,
|
||||
)
|
||||
|
||||
# Flush any tail text buffered while checking for split "<function_calls" markers.
|
||||
filtered_content_tail = xml_tool_call_content_filter.flush()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import mimetypes
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -12,14 +13,42 @@ from onyx.db.chat import create_db_search_doc
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ToolCall
|
||||
from onyx.db.tools import create_tool_call_no_commit
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_string
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _extract_referenced_file_descriptors(
|
||||
tool_calls: list[ToolCallInfo],
|
||||
message_text: str,
|
||||
) -> list[FileDescriptor]:
|
||||
"""Extract FileDescriptors for code interpreter files referenced in the message text."""
|
||||
descriptors: list[FileDescriptor] = []
|
||||
for tool_call_info in tool_calls:
|
||||
if not tool_call_info.generated_files:
|
||||
continue
|
||||
for gen_file in tool_call_info.generated_files:
|
||||
file_id = (
|
||||
gen_file.file_link.rsplit("/", 1)[-1] if gen_file.file_link else ""
|
||||
)
|
||||
if file_id and file_id in message_text:
|
||||
mime_type, _ = mimetypes.guess_type(gen_file.filename)
|
||||
descriptors.append(
|
||||
FileDescriptor(
|
||||
id=file_id,
|
||||
type=mime_type_to_chat_file_type(mime_type),
|
||||
name=gen_file.filename,
|
||||
)
|
||||
)
|
||||
return descriptors
|
||||
|
||||
|
||||
def _create_and_link_tool_calls(
|
||||
tool_calls: list[ToolCallInfo],
|
||||
assistant_message: ChatMessage,
|
||||
@@ -173,8 +202,13 @@ def save_chat_turn(
|
||||
pre_answer_processing_time: Duration of processing before answer starts (in seconds)
|
||||
"""
|
||||
# 1. Update ChatMessage with message content, reasoning tokens, and token count
|
||||
assistant_message.message = message_text
|
||||
assistant_message.reasoning_tokens = reasoning_tokens
|
||||
sanitized_message_text = (
|
||||
sanitize_string(message_text) if message_text else message_text
|
||||
)
|
||||
assistant_message.message = sanitized_message_text
|
||||
assistant_message.reasoning_tokens = (
|
||||
sanitize_string(reasoning_tokens) if reasoning_tokens else reasoning_tokens
|
||||
)
|
||||
assistant_message.is_clarification = is_clarification
|
||||
|
||||
# Use pre-answer processing time (captured when MESSAGE_START was emitted)
|
||||
@@ -184,8 +218,10 @@ def save_chat_turn(
|
||||
# Calculate token count using default tokenizer, when storing, this should not use the LLM
|
||||
# specific one so we use a system default tokenizer here.
|
||||
default_tokenizer = get_tokenizer(None, None)
|
||||
if message_text:
|
||||
assistant_message.token_count = len(default_tokenizer.encode(message_text))
|
||||
if sanitized_message_text:
|
||||
assistant_message.token_count = len(
|
||||
default_tokenizer.encode(sanitized_message_text)
|
||||
)
|
||||
else:
|
||||
assistant_message.token_count = 0
|
||||
|
||||
@@ -297,5 +333,16 @@ def save_chat_turn(
|
||||
citation_number_to_search_doc_id if citation_number_to_search_doc_id else None
|
||||
)
|
||||
|
||||
# 8. Attach code interpreter generated files that the assistant actually
|
||||
# referenced in its response, so they are available via load_all_chat_files
|
||||
# on subsequent turns. Files not mentioned are intermediate artifacts.
|
||||
if sanitized_message_text:
|
||||
referenced = _extract_referenced_file_descriptors(
|
||||
tool_calls, sanitized_message_text
|
||||
)
|
||||
if referenced:
|
||||
existing_files = assistant_message.files or []
|
||||
assistant_message.files = existing_files + referenced
|
||||
|
||||
# Finally save the messages, tool calls, and docs
|
||||
db_session.commit()
|
||||
|
||||
77
backend/onyx/chat/tool_call_args_streaming.py
Normal file
77
backend/onyx/chat/tool_call_args_streaming.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import Type
|
||||
|
||||
from onyx.llm.model_response import ChatCompletionDeltaToolCall
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ToolCallArgumentDelta
|
||||
from onyx.tools.built_in_tools import TOOL_NAME_TO_CLASS
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.utils.jsonriver import Parser
|
||||
|
||||
|
||||
def _get_tool_class(
|
||||
tool_calls_in_progress: Mapping[int, Mapping[str, Any]],
|
||||
tool_call_delta: ChatCompletionDeltaToolCall,
|
||||
) -> Type[Tool] | None:
|
||||
"""Look up the Tool subclass for a streaming tool call delta."""
|
||||
tool_name = tool_calls_in_progress.get(tool_call_delta.index, {}).get("name")
|
||||
if not tool_name:
|
||||
return None
|
||||
return TOOL_NAME_TO_CLASS.get(tool_name)
|
||||
|
||||
|
||||
def maybe_emit_argument_delta(
|
||||
tool_calls_in_progress: Mapping[int, Mapping[str, Any]],
|
||||
tool_call_delta: ChatCompletionDeltaToolCall,
|
||||
placement: Placement,
|
||||
parsers: dict[int, Parser],
|
||||
) -> Generator[Packet, None, None]:
|
||||
"""Emit decoded tool-call argument deltas to the frontend.
|
||||
|
||||
Uses a ``jsonriver.Parser`` per tool-call index to incrementally parse
|
||||
the JSON argument string and extract only the newly-appended content
|
||||
for each string-valued argument.
|
||||
|
||||
NOTE: Non-string arguments (numbers, booleans, null, arrays, objects)
|
||||
are skipped — they are available in the final tool-call kickoff packet.
|
||||
|
||||
``parsers`` is a mutable dict keyed by tool-call index. A new
|
||||
``Parser`` is created automatically for each new index.
|
||||
"""
|
||||
tool_cls = _get_tool_class(tool_calls_in_progress, tool_call_delta)
|
||||
if not tool_cls or not tool_cls.should_emit_argument_deltas():
|
||||
return
|
||||
|
||||
fn = tool_call_delta.function
|
||||
delta_fragment = fn.arguments if fn else None
|
||||
if not delta_fragment:
|
||||
return
|
||||
|
||||
idx = tool_call_delta.index
|
||||
if idx not in parsers:
|
||||
parsers[idx] = Parser()
|
||||
parser = parsers[idx]
|
||||
|
||||
deltas = parser.feed(delta_fragment)
|
||||
|
||||
argument_deltas: dict[str, str] = {}
|
||||
for delta in deltas:
|
||||
if isinstance(delta, dict):
|
||||
for key, value in delta.items():
|
||||
if isinstance(value, str):
|
||||
argument_deltas[key] = argument_deltas.get(key, "") + value
|
||||
|
||||
if not argument_deltas:
|
||||
return
|
||||
|
||||
tc_data = tool_calls_in_progress[tool_call_delta.index]
|
||||
yield Packet(
|
||||
placement=placement,
|
||||
obj=ToolCallArgumentDelta(
|
||||
tool_type=tc_data.get("name", ""),
|
||||
argument_deltas=argument_deltas,
|
||||
),
|
||||
)
|
||||
@@ -68,6 +68,10 @@ FILE_TOKEN_COUNT_THRESHOLD = int(
|
||||
os.environ.get("FILE_TOKEN_COUNT_THRESHOLD", str(_DEFAULT_FILE_TOKEN_LIMIT))
|
||||
)
|
||||
|
||||
# Maximum upload size for a single user file (chat/projects) in MB.
|
||||
USER_FILE_MAX_UPLOAD_SIZE_MB = int(os.environ.get("USER_FILE_MAX_UPLOAD_SIZE_MB") or 50)
|
||||
USER_FILE_MAX_UPLOAD_SIZE_BYTES = USER_FILE_MAX_UPLOAD_SIZE_MB * 1024 * 1024
|
||||
|
||||
# If set to true, will show extra/uncommon connectors in the "Other" category
|
||||
SHOW_EXTRA_CONNECTORS = os.environ.get("SHOW_EXTRA_CONNECTORS", "").lower() == "true"
|
||||
|
||||
@@ -92,19 +96,12 @@ WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
|
||||
#####
|
||||
# Auth Configs
|
||||
#####
|
||||
# Upgrades users from disabled auth to basic auth and shows warning.
|
||||
_auth_type_str = (os.environ.get("AUTH_TYPE") or "basic").lower()
|
||||
if _auth_type_str == "disabled":
|
||||
logger.warning(
|
||||
"AUTH_TYPE='disabled' is no longer supported. "
|
||||
"Defaulting to 'basic'. Please update your configuration. "
|
||||
"Your existing data will be migrated automatically."
|
||||
)
|
||||
_auth_type_str = AuthType.BASIC.value
|
||||
try:
|
||||
# Silently default to basic - warnings/errors logged in verify_auth_setting()
|
||||
# which only runs on app startup, not during migrations/scripts
|
||||
_auth_type_str = (os.environ.get("AUTH_TYPE") or "").lower()
|
||||
if _auth_type_str in [auth_type.value for auth_type in AuthType]:
|
||||
AUTH_TYPE = AuthType(_auth_type_str)
|
||||
except ValueError:
|
||||
logger.error(f"Invalid AUTH_TYPE: {_auth_type_str}. Defaulting to 'basic'.")
|
||||
else:
|
||||
AUTH_TYPE = AuthType.BASIC
|
||||
|
||||
PASSWORD_MIN_LENGTH = int(os.getenv("PASSWORD_MIN_LENGTH", 8))
|
||||
@@ -288,8 +285,9 @@ OPENSEARCH_TEXT_ANALYZER = os.environ.get("OPENSEARCH_TEXT_ANALYZER") or "englis
|
||||
# environments we always want to be dual indexing into both OpenSearch and Vespa
|
||||
# to stress test the new codepaths. Only enable this if there is some instance
|
||||
# of OpenSearch running for the relevant Onyx instance.
|
||||
# NOTE: Now enabled on by default, unless the env indicates otherwise.
|
||||
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX = (
|
||||
os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "").lower() == "true"
|
||||
os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "true").lower() == "true"
|
||||
)
|
||||
# NOTE: This effectively does nothing anymore, admins can now toggle whether
|
||||
# retrieval is through OpenSearch. This value is only used as a final fallback
|
||||
@@ -495,14 +493,7 @@ CELERY_WORKER_PRIMARY_POOL_OVERFLOW = int(
|
||||
os.environ.get("CELERY_WORKER_PRIMARY_POOL_OVERFLOW") or 4
|
||||
)
|
||||
|
||||
# Consolidated background worker (light, docprocessing, docfetching, heavy, monitoring, user_file_processing)
|
||||
# separate workers' defaults: light=24, docprocessing=6, docfetching=1, heavy=4, kg=2, monitoring=1, user_file=2
|
||||
# Total would be 40, but we use a more conservative default of 20 for the consolidated worker
|
||||
CELERY_WORKER_BACKGROUND_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_BACKGROUND_CONCURRENCY") or 20
|
||||
)
|
||||
|
||||
# Individual worker concurrency settings (used when USE_LIGHTWEIGHT_BACKGROUND_WORKER is False or on Kuberenetes deployments)
|
||||
# Individual worker concurrency settings
|
||||
CELERY_WORKER_HEAVY_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_HEAVY_CONCURRENCY") or 4
|
||||
)
|
||||
@@ -819,7 +810,9 @@ RERANK_COUNT = int(os.environ.get("RERANK_COUNT") or 1000)
|
||||
# Tool Configs
|
||||
#####
|
||||
# Code Interpreter Service Configuration
|
||||
CODE_INTERPRETER_BASE_URL = os.environ.get("CODE_INTERPRETER_BASE_URL")
|
||||
CODE_INTERPRETER_BASE_URL = os.environ.get(
|
||||
"CODE_INTERPRETER_BASE_URL", "http://localhost:8000"
|
||||
)
|
||||
|
||||
CODE_INTERPRETER_DEFAULT_TIMEOUT_MS = int(
|
||||
os.environ.get("CODE_INTERPRETER_DEFAULT_TIMEOUT_MS") or 60_000
|
||||
@@ -900,6 +893,9 @@ CUSTOM_ANSWER_VALIDITY_CONDITIONS = json.loads(
|
||||
)
|
||||
|
||||
VESPA_REQUEST_TIMEOUT = int(os.environ.get("VESPA_REQUEST_TIMEOUT") or "15")
|
||||
VESPA_MIGRATION_REQUEST_TIMEOUT_S = int(
|
||||
os.environ.get("VESPA_MIGRATION_REQUEST_TIMEOUT_S") or "120"
|
||||
)
|
||||
|
||||
SYSTEM_RECURSION_LIMIT = int(os.environ.get("SYSTEM_RECURSION_LIMIT") or "1000")
|
||||
|
||||
|
||||
@@ -84,7 +84,6 @@ POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
|
||||
POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME = "celery_worker_docprocessing"
|
||||
POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME = "celery_worker_docfetching"
|
||||
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
|
||||
POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME = "celery_worker_background"
|
||||
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
|
||||
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
|
||||
POSTGRES_CELERY_WORKER_USER_FILE_PROCESSING_APP_NAME = (
|
||||
|
||||
@@ -943,6 +943,9 @@ class ConfluenceConnector(
|
||||
if include_permissions
|
||||
else None
|
||||
),
|
||||
parent_hierarchy_raw_node_id=self._get_parent_hierarchy_raw_id(
|
||||
page
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -992,6 +995,7 @@ class ConfluenceConnector(
|
||||
if include_permissions
|
||||
else None
|
||||
),
|
||||
parent_hierarchy_raw_node_id=page_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncIterable
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime
|
||||
@@ -204,7 +205,7 @@ def _manage_async_retrieval(
|
||||
|
||||
end_time: datetime | None = end
|
||||
|
||||
async def _async_fetch() -> AsyncIterable[Document]:
|
||||
async def _async_fetch() -> AsyncGenerator[Document, None]:
|
||||
intents = Intents.default()
|
||||
intents.message_content = True
|
||||
async with Client(intents=intents) as discord_client:
|
||||
@@ -227,22 +228,23 @@ def _manage_async_retrieval(
|
||||
|
||||
def run_and_yield() -> Iterable[Document]:
|
||||
loop = asyncio.new_event_loop()
|
||||
async_gen = _async_fetch()
|
||||
try:
|
||||
# Get the async generator
|
||||
async_gen = _async_fetch()
|
||||
# Convert to AsyncIterator
|
||||
async_iter = async_gen.__aiter__()
|
||||
while True:
|
||||
try:
|
||||
# Create a coroutine by calling anext with the async iterator
|
||||
next_coro = anext(async_iter)
|
||||
# Run the coroutine to get the next document
|
||||
doc = loop.run_until_complete(next_coro)
|
||||
doc = loop.run_until_complete(anext(async_gen))
|
||||
yield doc
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
finally:
|
||||
loop.close()
|
||||
# Must close the async generator before the loop so the Discord
|
||||
# client's `async with` block can await its shutdown coroutine.
|
||||
# The nested try/finally ensures the loop always closes even if
|
||||
# aclose() raises (same pattern as cursor.close() before conn.close()).
|
||||
try:
|
||||
loop.run_until_complete(async_gen.aclose())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
return run_and_yield()
|
||||
|
||||
|
||||
@@ -1722,6 +1722,7 @@ class GoogleDriveConnector(
|
||||
primary_admin_email=self.primary_admin_email,
|
||||
google_domain=self.google_domain,
|
||||
),
|
||||
retriever_email=file.user_email,
|
||||
):
|
||||
slim_batch.append(doc)
|
||||
|
||||
|
||||
@@ -476,6 +476,7 @@ def _get_external_access_for_raw_gdrive_file(
|
||||
company_domain: str,
|
||||
retriever_drive_service: GoogleDriveService | None,
|
||||
admin_drive_service: GoogleDriveService,
|
||||
fallback_user_email: str,
|
||||
add_prefix: bool = False,
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
@@ -484,6 +485,8 @@ def _get_external_access_for_raw_gdrive_file(
|
||||
add_prefix: When True, prefix group IDs with source type (for indexing path).
|
||||
When False (default), leave unprefixed (for permission sync path
|
||||
where upsert_document_external_perms handles prefixing).
|
||||
fallback_user_email: When permission info can't be retrieved (e.g. externally-owned
|
||||
files), fall back to granting access to this user.
|
||||
"""
|
||||
external_access_fn = cast(
|
||||
Callable[
|
||||
@@ -492,6 +495,7 @@ def _get_external_access_for_raw_gdrive_file(
|
||||
str,
|
||||
GoogleDriveService | None,
|
||||
GoogleDriveService,
|
||||
str,
|
||||
bool,
|
||||
],
|
||||
ExternalAccess,
|
||||
@@ -507,6 +511,7 @@ def _get_external_access_for_raw_gdrive_file(
|
||||
company_domain,
|
||||
retriever_drive_service,
|
||||
admin_drive_service,
|
||||
fallback_user_email,
|
||||
add_prefix,
|
||||
)
|
||||
|
||||
@@ -672,6 +677,7 @@ def _convert_drive_item_to_document(
|
||||
creds, user_email=permission_sync_context.primary_admin_email
|
||||
),
|
||||
add_prefix=True, # Indexing path - prefix here
|
||||
fallback_user_email=retriever_email,
|
||||
)
|
||||
if permission_sync_context
|
||||
else None
|
||||
@@ -753,6 +759,7 @@ def build_slim_document(
|
||||
# if not specified, we will not sync permissions
|
||||
# will also be a no-op if EE is not enabled
|
||||
permission_sync_context: PermissionSyncContext | None,
|
||||
retriever_email: str,
|
||||
) -> SlimDocument | None:
|
||||
if file.get("mimeType") in [DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE]:
|
||||
return None
|
||||
@@ -774,6 +781,7 @@ def build_slim_document(
|
||||
creds,
|
||||
user_email=permission_sync_context.primary_admin_email,
|
||||
),
|
||||
fallback_user_email=retriever_email,
|
||||
)
|
||||
if permission_sync_context
|
||||
else None
|
||||
@@ -781,4 +789,5 @@ def build_slim_document(
|
||||
return SlimDocument(
|
||||
id=onyx_document_id_from_drive_file(file),
|
||||
external_access=external_access,
|
||||
parent_hierarchy_raw_node_id=(file.get("parents") or [None])[0],
|
||||
)
|
||||
|
||||
@@ -44,6 +44,7 @@ from onyx.connectors.google_utils.shared_constants import (
|
||||
from onyx.db.credentials import update_credential_json
|
||||
from onyx.db.models import User
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import unwrap_str
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from onyx.server.documents.models import GoogleAppCredentials
|
||||
from onyx.server.documents.models import GoogleServiceAccountKey
|
||||
@@ -89,7 +90,7 @@ def _get_current_oauth_user(creds: OAuthCredentials, source: DocumentSource) ->
|
||||
|
||||
|
||||
def verify_csrf(credential_id: int, state: str) -> None:
|
||||
csrf = get_kv_store().load(KV_CRED_KEY.format(str(credential_id)))
|
||||
csrf = unwrap_str(get_kv_store().load(KV_CRED_KEY.format(str(credential_id))))
|
||||
if csrf != state:
|
||||
raise PermissionError(
|
||||
"State from Google Drive Connector callback does not match expected"
|
||||
@@ -178,7 +179,9 @@ def get_auth_url(credential_id: int, source: DocumentSource) -> str:
|
||||
params = parse_qs(parsed_url.query)
|
||||
|
||||
get_kv_store().store(
|
||||
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
|
||||
KV_CRED_KEY.format(credential_id),
|
||||
{"value": params.get("state", [None])[0]},
|
||||
encrypt=True,
|
||||
)
|
||||
return str(auth_url)
|
||||
|
||||
|
||||
@@ -902,6 +902,11 @@ class JiraConnector(
|
||||
external_access=self._get_project_permissions(
|
||||
project_key, add_prefix=False
|
||||
),
|
||||
parent_hierarchy_raw_node_id=(
|
||||
self._get_parent_hierarchy_raw_node_id(issue, project_key)
|
||||
if project_key
|
||||
else None
|
||||
),
|
||||
)
|
||||
)
|
||||
current_offset += 1
|
||||
|
||||
@@ -385,6 +385,7 @@ class IndexingDocument(Document):
|
||||
class SlimDocument(BaseModel):
|
||||
id: str
|
||||
external_access: ExternalAccess | None = None
|
||||
parent_hierarchy_raw_node_id: str | None = None
|
||||
|
||||
|
||||
class HierarchyNode(BaseModel):
|
||||
|
||||
@@ -772,6 +772,7 @@ def _convert_driveitem_to_slim_document(
|
||||
drive_name: str,
|
||||
ctx: ClientContext,
|
||||
graph_client: GraphClient,
|
||||
parent_hierarchy_raw_node_id: str | None = None,
|
||||
) -> SlimDocument:
|
||||
if driveitem.id is None:
|
||||
raise ValueError("DriveItem ID is required")
|
||||
@@ -787,11 +788,15 @@ def _convert_driveitem_to_slim_document(
|
||||
return SlimDocument(
|
||||
id=driveitem.id,
|
||||
external_access=external_access,
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
|
||||
)
|
||||
|
||||
|
||||
def _convert_sitepage_to_slim_document(
|
||||
site_page: dict[str, Any], ctx: ClientContext | None, graph_client: GraphClient
|
||||
site_page: dict[str, Any],
|
||||
ctx: ClientContext | None,
|
||||
graph_client: GraphClient,
|
||||
parent_hierarchy_raw_node_id: str | None = None,
|
||||
) -> SlimDocument:
|
||||
"""Convert a SharePoint site page to a SlimDocument object."""
|
||||
if site_page.get("id") is None:
|
||||
@@ -808,6 +813,7 @@ def _convert_sitepage_to_slim_document(
|
||||
return SlimDocument(
|
||||
id=id,
|
||||
external_access=external_access,
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -1594,12 +1600,22 @@ class SharepointConnector(
|
||||
)
|
||||
)
|
||||
|
||||
parent_hierarchy_url: str | None = None
|
||||
if drive_web_url:
|
||||
parent_hierarchy_url = self._get_parent_hierarchy_url(
|
||||
site_url, drive_web_url, drive_name, driveitem
|
||||
)
|
||||
|
||||
try:
|
||||
logger.debug(f"Processing: {driveitem.web_url}")
|
||||
ctx = self._create_rest_client_context(site_descriptor.url)
|
||||
doc_batch.append(
|
||||
_convert_driveitem_to_slim_document(
|
||||
driveitem, drive_name, ctx, self.graph_client
|
||||
driveitem,
|
||||
drive_name,
|
||||
ctx,
|
||||
self.graph_client,
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_url,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -1619,7 +1635,10 @@ class SharepointConnector(
|
||||
ctx = self._create_rest_client_context(site_descriptor.url)
|
||||
doc_batch.append(
|
||||
_convert_sitepage_to_slim_document(
|
||||
site_page, ctx, self.graph_client
|
||||
site_page,
|
||||
ctx,
|
||||
self.graph_client,
|
||||
parent_hierarchy_raw_node_id=site_descriptor.url,
|
||||
)
|
||||
)
|
||||
if len(doc_batch) >= SLIM_BATCH_SIZE:
|
||||
|
||||
@@ -565,6 +565,7 @@ def _get_all_doc_ids(
|
||||
channel_id=channel_id, thread_ts=message["ts"]
|
||||
),
|
||||
external_access=external_access,
|
||||
parent_hierarchy_raw_node_id=channel_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.override_models import PromptOverride
|
||||
from onyx.server.query_and_chat.models import ChatMessageDetail
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_string
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -675,58 +676,43 @@ def set_as_latest_chat_message(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _sanitize_for_postgres(value: str) -> str:
|
||||
"""Remove NUL (0x00) characters from strings as PostgreSQL doesn't allow them."""
|
||||
sanitized = value.replace("\x00", "")
|
||||
if value and not sanitized:
|
||||
logger.warning("Sanitization removed all characters from string")
|
||||
return sanitized
|
||||
|
||||
|
||||
def _sanitize_list_for_postgres(values: list[str]) -> list[str]:
|
||||
"""Remove NUL (0x00) characters from all strings in a list."""
|
||||
return [_sanitize_for_postgres(v) for v in values]
|
||||
|
||||
|
||||
def create_db_search_doc(
|
||||
server_search_doc: ServerSearchDoc,
|
||||
db_session: Session,
|
||||
commit: bool = True,
|
||||
) -> DBSearchDoc:
|
||||
# Sanitize string fields to remove NUL characters (PostgreSQL doesn't allow them)
|
||||
db_search_doc = DBSearchDoc(
|
||||
document_id=_sanitize_for_postgres(server_search_doc.document_id),
|
||||
document_id=sanitize_string(server_search_doc.document_id),
|
||||
chunk_ind=server_search_doc.chunk_ind,
|
||||
semantic_id=_sanitize_for_postgres(server_search_doc.semantic_identifier),
|
||||
semantic_id=sanitize_string(server_search_doc.semantic_identifier),
|
||||
link=(
|
||||
_sanitize_for_postgres(server_search_doc.link)
|
||||
sanitize_string(server_search_doc.link)
|
||||
if server_search_doc.link is not None
|
||||
else None
|
||||
),
|
||||
blurb=_sanitize_for_postgres(server_search_doc.blurb),
|
||||
blurb=sanitize_string(server_search_doc.blurb),
|
||||
source_type=server_search_doc.source_type,
|
||||
boost=server_search_doc.boost,
|
||||
hidden=server_search_doc.hidden,
|
||||
doc_metadata=server_search_doc.metadata,
|
||||
is_relevant=server_search_doc.is_relevant,
|
||||
relevance_explanation=(
|
||||
_sanitize_for_postgres(server_search_doc.relevance_explanation)
|
||||
sanitize_string(server_search_doc.relevance_explanation)
|
||||
if server_search_doc.relevance_explanation is not None
|
||||
else None
|
||||
),
|
||||
# For docs further down that aren't reranked, we can't use the retrieval score
|
||||
score=server_search_doc.score or 0.0,
|
||||
match_highlights=_sanitize_list_for_postgres(
|
||||
server_search_doc.match_highlights
|
||||
),
|
||||
match_highlights=[
|
||||
sanitize_string(h) for h in server_search_doc.match_highlights
|
||||
],
|
||||
updated_at=server_search_doc.updated_at,
|
||||
primary_owners=(
|
||||
_sanitize_list_for_postgres(server_search_doc.primary_owners)
|
||||
[sanitize_string(o) for o in server_search_doc.primary_owners]
|
||||
if server_search_doc.primary_owners is not None
|
||||
else None
|
||||
),
|
||||
secondary_owners=(
|
||||
_sanitize_list_for_postgres(server_search_doc.secondary_owners)
|
||||
[sanitize_string(o) for o in server_search_doc.secondary_owners]
|
||||
if server_search_doc.secondary_owners is not None
|
||||
else None
|
||||
),
|
||||
|
||||
@@ -13,6 +13,7 @@ from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.enums import AccessType
|
||||
@@ -246,6 +247,7 @@ def insert_document_set(
|
||||
description=document_set_creation_request.description,
|
||||
user_id=user_id,
|
||||
is_public=document_set_creation_request.is_public,
|
||||
is_up_to_date=DISABLE_VECTOR_DB,
|
||||
time_last_modified_by_user=func.now(),
|
||||
)
|
||||
db_session.add(new_document_set_row)
|
||||
@@ -336,7 +338,8 @@ def update_document_set(
|
||||
)
|
||||
|
||||
document_set_row.description = document_set_update_request.description
|
||||
document_set_row.is_up_to_date = False
|
||||
if not DISABLE_VECTOR_DB:
|
||||
document_set_row.is_up_to_date = False
|
||||
document_set_row.is_public = document_set_update_request.is_public
|
||||
document_set_row.time_last_modified_by_user = func.now()
|
||||
versioned_private_doc_set_fn = fetch_versioned_implementation(
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
"""CRUD operations for HierarchyNode."""
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -8,6 +13,7 @@ from onyx.connectors.models import HierarchyNode as PydanticHierarchyNode
|
||||
from onyx.db.enums import HierarchyNodeType
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import HierarchyNode
|
||||
from onyx.db.models import HierarchyNodeByConnectorCredentialPair
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
@@ -456,7 +462,7 @@ def get_all_hierarchy_nodes_for_source(
|
||||
def _get_accessible_hierarchy_nodes_for_source(
|
||||
db_session: Session,
|
||||
source: DocumentSource,
|
||||
user_email: str | None, # noqa: ARG001
|
||||
user_email: str, # noqa: ARG001
|
||||
external_group_ids: list[str], # noqa: ARG001
|
||||
) -> list[HierarchyNode]:
|
||||
"""
|
||||
@@ -483,7 +489,7 @@ def _get_accessible_hierarchy_nodes_for_source(
|
||||
def get_accessible_hierarchy_nodes_for_source(
|
||||
db_session: Session,
|
||||
source: DocumentSource,
|
||||
user_email: str | None,
|
||||
user_email: str,
|
||||
external_group_ids: list[str],
|
||||
) -> list[HierarchyNode]:
|
||||
"""
|
||||
@@ -525,6 +531,53 @@ def get_document_parent_hierarchy_node_ids(
|
||||
return {doc_id: parent_id for doc_id, parent_id in results}
|
||||
|
||||
|
||||
def update_document_parent_hierarchy_nodes(
|
||||
db_session: Session,
|
||||
doc_parent_map: dict[str, int | None],
|
||||
commit: bool = True,
|
||||
) -> int:
|
||||
"""Bulk-update Document.parent_hierarchy_node_id for multiple documents.
|
||||
|
||||
Only updates rows whose current value differs from the desired value to
|
||||
avoid unnecessary writes.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
doc_parent_map: Mapping of document_id → desired parent_hierarchy_node_id
|
||||
commit: Whether to commit the transaction
|
||||
|
||||
Returns:
|
||||
Number of documents actually updated
|
||||
"""
|
||||
if not doc_parent_map:
|
||||
return 0
|
||||
|
||||
doc_ids = list(doc_parent_map.keys())
|
||||
existing = get_document_parent_hierarchy_node_ids(db_session, doc_ids)
|
||||
|
||||
by_parent: dict[int | None, list[str]] = defaultdict(list)
|
||||
for doc_id, desired_parent_id in doc_parent_map.items():
|
||||
current = existing.get(doc_id)
|
||||
if current == desired_parent_id or doc_id not in existing:
|
||||
continue
|
||||
by_parent[desired_parent_id].append(doc_id)
|
||||
|
||||
updated = 0
|
||||
for desired_parent_id, ids in by_parent.items():
|
||||
db_session.query(Document).filter(Document.id.in_(ids)).update(
|
||||
{Document.parent_hierarchy_node_id: desired_parent_id},
|
||||
synchronize_session=False,
|
||||
)
|
||||
updated += len(ids)
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
elif updated:
|
||||
db_session.flush()
|
||||
|
||||
return updated
|
||||
|
||||
|
||||
def update_hierarchy_node_permissions(
|
||||
db_session: Session,
|
||||
raw_node_id: str,
|
||||
@@ -571,3 +624,154 @@ def update_hierarchy_node_permissions(
|
||||
db_session.flush()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def upsert_hierarchy_node_cc_pair_entries(
|
||||
db_session: Session,
|
||||
hierarchy_node_ids: list[int],
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
commit: bool = True,
|
||||
) -> None:
|
||||
"""Insert rows into HierarchyNodeByConnectorCredentialPair, ignoring conflicts.
|
||||
|
||||
This records that the given cc_pair "owns" these hierarchy nodes. Used by
|
||||
indexing, pruning, and hierarchy-fetching paths.
|
||||
"""
|
||||
if not hierarchy_node_ids:
|
||||
return
|
||||
|
||||
_M = HierarchyNodeByConnectorCredentialPair
|
||||
stmt = pg_insert(_M).values(
|
||||
[
|
||||
{
|
||||
_M.hierarchy_node_id: node_id,
|
||||
_M.connector_id: connector_id,
|
||||
_M.credential_id: credential_id,
|
||||
}
|
||||
for node_id in hierarchy_node_ids
|
||||
]
|
||||
)
|
||||
stmt = stmt.on_conflict_do_nothing()
|
||||
db_session.execute(stmt)
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
else:
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def remove_stale_hierarchy_node_cc_pair_entries(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
live_hierarchy_node_ids: set[int],
|
||||
commit: bool = True,
|
||||
) -> int:
|
||||
"""Delete join-table rows for this cc_pair that are NOT in the live set.
|
||||
|
||||
If ``live_hierarchy_node_ids`` is empty ALL rows for the cc_pair are deleted
|
||||
(i.e. the connector no longer has any hierarchy nodes). Callers that want a
|
||||
no-op when there are no live nodes must guard before calling.
|
||||
|
||||
Returns the number of deleted rows.
|
||||
"""
|
||||
stmt = delete(HierarchyNodeByConnectorCredentialPair).where(
|
||||
HierarchyNodeByConnectorCredentialPair.connector_id == connector_id,
|
||||
HierarchyNodeByConnectorCredentialPair.credential_id == credential_id,
|
||||
)
|
||||
if live_hierarchy_node_ids:
|
||||
stmt = stmt.where(
|
||||
HierarchyNodeByConnectorCredentialPair.hierarchy_node_id.notin_(
|
||||
live_hierarchy_node_ids
|
||||
)
|
||||
)
|
||||
|
||||
result: CursorResult = db_session.execute(stmt) # type: ignore[assignment]
|
||||
deleted = result.rowcount
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
elif deleted:
|
||||
db_session.flush()
|
||||
|
||||
return deleted
|
||||
|
||||
|
||||
def delete_orphaned_hierarchy_nodes(
|
||||
db_session: Session,
|
||||
source: DocumentSource,
|
||||
commit: bool = True,
|
||||
) -> list[str]:
|
||||
"""Delete hierarchy nodes for a source that have zero cc_pair associations.
|
||||
|
||||
SOURCE-type nodes are excluded (they are synthetic roots).
|
||||
|
||||
Returns the list of raw_node_ids that were deleted (for cache eviction).
|
||||
"""
|
||||
# Find orphaned nodes: no rows in the join table
|
||||
orphan_stmt = (
|
||||
select(HierarchyNode.id, HierarchyNode.raw_node_id)
|
||||
.outerjoin(
|
||||
HierarchyNodeByConnectorCredentialPair,
|
||||
HierarchyNode.id
|
||||
== HierarchyNodeByConnectorCredentialPair.hierarchy_node_id,
|
||||
)
|
||||
.where(
|
||||
HierarchyNode.source == source,
|
||||
HierarchyNode.node_type != HierarchyNodeType.SOURCE,
|
||||
HierarchyNodeByConnectorCredentialPair.hierarchy_node_id.is_(None),
|
||||
)
|
||||
)
|
||||
orphans = db_session.execute(orphan_stmt).all()
|
||||
if not orphans:
|
||||
return []
|
||||
|
||||
orphan_ids = [row[0] for row in orphans]
|
||||
deleted_raw_ids = [row[1] for row in orphans]
|
||||
|
||||
db_session.execute(delete(HierarchyNode).where(HierarchyNode.id.in_(orphan_ids)))
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
else:
|
||||
db_session.flush()
|
||||
|
||||
return deleted_raw_ids
|
||||
|
||||
|
||||
def reparent_orphaned_hierarchy_nodes(
|
||||
db_session: Session,
|
||||
source: DocumentSource,
|
||||
commit: bool = True,
|
||||
) -> list[HierarchyNode]:
|
||||
"""Re-parent hierarchy nodes whose parent_id is NULL to the SOURCE node.
|
||||
|
||||
After pruning deletes stale nodes, their former children get parent_id=NULL
|
||||
via the SET NULL cascade. This function points them back to the SOURCE root.
|
||||
|
||||
Returns the reparented HierarchyNode objects (with updated parent_id)
|
||||
so callers can refresh downstream caches.
|
||||
"""
|
||||
source_node = get_source_hierarchy_node(db_session, source)
|
||||
if not source_node:
|
||||
return []
|
||||
|
||||
stmt = select(HierarchyNode).where(
|
||||
HierarchyNode.source == source,
|
||||
HierarchyNode.parent_id.is_(None),
|
||||
HierarchyNode.node_type != HierarchyNodeType.SOURCE,
|
||||
)
|
||||
orphans = list(db_session.execute(stmt).scalars().all())
|
||||
if not orphans:
|
||||
return []
|
||||
|
||||
for node in orphans:
|
||||
node.parent_id = source_node.id
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
else:
|
||||
db_session.flush()
|
||||
|
||||
return orphans
|
||||
|
||||
@@ -25,8 +25,11 @@ from onyx.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def update_group_llm_provider_relationships__no_commit(
|
||||
llm_provider_id: int,
|
||||
@@ -267,10 +270,35 @@ def upsert_llm_provider(
|
||||
mc.name for mc in llm_provider_upsert_request.model_configurations
|
||||
}
|
||||
|
||||
# Build a lookup of requested visibility by model name
|
||||
requested_visibility = {
|
||||
mc.name: mc.is_visible
|
||||
for mc in llm_provider_upsert_request.model_configurations
|
||||
}
|
||||
|
||||
# Delete removed models
|
||||
removed_ids = [
|
||||
mc.id for name, mc in existing_by_name.items() if name not in models_to_exist
|
||||
]
|
||||
|
||||
default_model = fetch_default_llm_model(db_session)
|
||||
|
||||
# Prevent removing and hiding the default model
|
||||
if default_model:
|
||||
for name, mc in existing_by_name.items():
|
||||
if mc.id == default_model.id:
|
||||
if default_model.id in removed_ids:
|
||||
raise ValueError(
|
||||
f"Cannot remove the default model '{name}'. "
|
||||
"Please change the default model before removing."
|
||||
)
|
||||
if not requested_visibility.get(name, True):
|
||||
raise ValueError(
|
||||
f"Cannot hide the default model '{name}'. "
|
||||
"Please change the default model before hiding."
|
||||
)
|
||||
break
|
||||
|
||||
if removed_ids:
|
||||
db_session.query(ModelConfiguration).filter(
|
||||
ModelConfiguration.id.in_(removed_ids)
|
||||
@@ -532,9 +560,9 @@ def fetch_default_model(
|
||||
) -> ModelConfiguration | None:
|
||||
model_config = db_session.scalar(
|
||||
select(ModelConfiguration)
|
||||
.options(selectinload(ModelConfiguration.llm_provider))
|
||||
.join(LLMModelFlow)
|
||||
.where(
|
||||
ModelConfiguration.is_visible == True, # noqa: E712
|
||||
LLMModelFlow.llm_model_flow_type == flow_type,
|
||||
LLMModelFlow.is_default == True, # noqa: E712
|
||||
)
|
||||
@@ -810,6 +838,29 @@ def sync_auto_mode_models(
|
||||
)
|
||||
changes += 1
|
||||
|
||||
# Update the default if this provider currently holds the global CHAT default.
|
||||
# We flush (but don't commit) so that _update_default_model can see the new
|
||||
# model rows, then commit everything atomically to avoid a window where the
|
||||
# old default is invisible but still pointed-to.
|
||||
db_session.flush()
|
||||
|
||||
recommended_default = llm_recommendations.get_default_model(provider.provider)
|
||||
if recommended_default:
|
||||
current_default = fetch_default_llm_model(db_session)
|
||||
|
||||
if (
|
||||
current_default
|
||||
and current_default.llm_provider_id == provider.id
|
||||
and current_default.name != recommended_default.name
|
||||
):
|
||||
_update_default_model__no_commit(
|
||||
db_session=db_session,
|
||||
provider_id=provider.id,
|
||||
model=recommended_default.name,
|
||||
flow_type=LLMModelFlowType.CHAT,
|
||||
)
|
||||
changes += 1
|
||||
|
||||
db_session.commit()
|
||||
return changes
|
||||
|
||||
@@ -941,7 +992,7 @@ def update_model_configuration__no_commit(
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def _update_default_model(
|
||||
def _update_default_model__no_commit(
|
||||
db_session: Session,
|
||||
provider_id: int,
|
||||
model: str,
|
||||
@@ -979,6 +1030,14 @@ def _update_default_model(
|
||||
new_default.is_default = True
|
||||
model_config.is_visible = True
|
||||
|
||||
|
||||
def _update_default_model(
|
||||
db_session: Session,
|
||||
provider_id: int,
|
||||
model: str,
|
||||
flow_type: LLMModelFlowType,
|
||||
) -> None:
|
||||
_update_default_model__no_commit(db_session, provider_id, model, flow_type)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ from sqlalchemy import desc
|
||||
from sqlalchemy import Enum
|
||||
from sqlalchemy import Float
|
||||
from sqlalchemy import ForeignKey
|
||||
from sqlalchemy import ForeignKeyConstraint
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import Index
|
||||
from sqlalchemy import Integer
|
||||
@@ -36,9 +37,11 @@ from sqlalchemy import Text
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import UniqueConstraint
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.engine.interfaces import Dialect
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy.orm import Mapped
|
||||
from sqlalchemy.orm import Mapper
|
||||
from sqlalchemy.orm import mapped_column
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.types import LargeBinary
|
||||
@@ -117,10 +120,50 @@ class Base(DeclarativeBase):
|
||||
__abstract__ = True
|
||||
|
||||
|
||||
class EncryptedString(TypeDecorator):
|
||||
class _EncryptedBase(TypeDecorator):
|
||||
"""Base for encrypted column types that wrap values in SensitiveValue."""
|
||||
|
||||
impl = LargeBinary
|
||||
# This type's behavior is fully deterministic and doesn't depend on any external factors.
|
||||
cache_ok = True
|
||||
_is_json: bool = False
|
||||
|
||||
def wrap_raw(self, value: Any) -> SensitiveValue:
|
||||
"""Encrypt a raw value and wrap it in SensitiveValue.
|
||||
|
||||
Called by the attribute set event so the Python-side type is always
|
||||
SensitiveValue, regardless of whether the value was loaded from the DB
|
||||
or assigned in application code.
|
||||
"""
|
||||
if self._is_json:
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError(
|
||||
f"EncryptedJson column expected dict, got {type(value).__name__}"
|
||||
)
|
||||
raw_str = json.dumps(value)
|
||||
else:
|
||||
if not isinstance(value, str):
|
||||
raise TypeError(
|
||||
f"EncryptedString column expected str, got {type(value).__name__}"
|
||||
)
|
||||
raw_str = value
|
||||
return SensitiveValue(
|
||||
encrypted_bytes=encrypt_string_to_bytes(raw_str),
|
||||
decrypt_fn=decrypt_bytes_to_string,
|
||||
is_json=self._is_json,
|
||||
)
|
||||
|
||||
def compare_values(self, x: Any, y: Any) -> bool:
|
||||
if x is None or y is None:
|
||||
return x == y
|
||||
if isinstance(x, SensitiveValue):
|
||||
x = x.get_value(apply_mask=False)
|
||||
if isinstance(y, SensitiveValue):
|
||||
y = y.get_value(apply_mask=False)
|
||||
return x == y
|
||||
|
||||
|
||||
class EncryptedString(_EncryptedBase):
|
||||
_is_json: bool = False
|
||||
|
||||
def process_bind_param(
|
||||
self, value: str | SensitiveValue[str] | None, dialect: Dialect # noqa: ARG002
|
||||
@@ -144,20 +187,9 @@ class EncryptedString(TypeDecorator):
|
||||
)
|
||||
return None
|
||||
|
||||
def compare_values(self, x: Any, y: Any) -> bool:
|
||||
if x is None or y is None:
|
||||
return x == y
|
||||
if isinstance(x, SensitiveValue):
|
||||
x = x.get_value(apply_mask=False)
|
||||
if isinstance(y, SensitiveValue):
|
||||
y = y.get_value(apply_mask=False)
|
||||
return x == y
|
||||
|
||||
|
||||
class EncryptedJson(TypeDecorator):
|
||||
impl = LargeBinary
|
||||
# This type's behavior is fully deterministic and doesn't depend on any external factors.
|
||||
cache_ok = True
|
||||
class EncryptedJson(_EncryptedBase):
|
||||
_is_json: bool = True
|
||||
|
||||
def process_bind_param(
|
||||
self,
|
||||
@@ -165,9 +197,7 @@ class EncryptedJson(TypeDecorator):
|
||||
dialect: Dialect, # noqa: ARG002
|
||||
) -> bytes | None:
|
||||
if value is not None:
|
||||
# Handle both raw dicts and SensitiveValue wrappers
|
||||
if isinstance(value, SensitiveValue):
|
||||
# Get raw value for storage
|
||||
value = value.get_value(apply_mask=False)
|
||||
json_str = json.dumps(value)
|
||||
return encrypt_string_to_bytes(json_str)
|
||||
@@ -184,14 +214,40 @@ class EncryptedJson(TypeDecorator):
|
||||
)
|
||||
return None
|
||||
|
||||
def compare_values(self, x: Any, y: Any) -> bool:
|
||||
if x is None or y is None:
|
||||
return x == y
|
||||
if isinstance(x, SensitiveValue):
|
||||
x = x.get_value(apply_mask=False)
|
||||
if isinstance(y, SensitiveValue):
|
||||
y = y.get_value(apply_mask=False)
|
||||
return x == y
|
||||
|
||||
_REGISTERED_ATTRS: set[str] = set()
|
||||
|
||||
|
||||
@event.listens_for(Mapper, "mapper_configured")
|
||||
def _register_sensitive_value_set_events(
|
||||
mapper: Mapper,
|
||||
class_: type,
|
||||
) -> None:
|
||||
"""Auto-wrap raw values in SensitiveValue when assigned to encrypted columns."""
|
||||
for prop in mapper.column_attrs:
|
||||
for col in prop.columns:
|
||||
if isinstance(col.type, _EncryptedBase):
|
||||
col_type = col.type
|
||||
attr = getattr(class_, prop.key)
|
||||
|
||||
# Guard against double-registration (e.g. if mapper is
|
||||
# re-configured in test setups)
|
||||
attr_key = f"{class_.__qualname__}.{prop.key}"
|
||||
if attr_key in _REGISTERED_ATTRS:
|
||||
continue
|
||||
_REGISTERED_ATTRS.add(attr_key)
|
||||
|
||||
@event.listens_for(attr, "set", retval=True)
|
||||
def _wrap_value(
|
||||
target: Any, # noqa: ARG001
|
||||
value: Any,
|
||||
oldvalue: Any, # noqa: ARG001
|
||||
initiator: Any, # noqa: ARG001
|
||||
_col_type: _EncryptedBase = col_type,
|
||||
) -> Any:
|
||||
if value is not None and not isinstance(value, SensitiveValue):
|
||||
return _col_type.wrap_raw(value)
|
||||
return value
|
||||
|
||||
|
||||
class NullFilteredString(TypeDecorator):
|
||||
@@ -2370,6 +2426,38 @@ class SyncRecord(Base):
|
||||
)
|
||||
|
||||
|
||||
class HierarchyNodeByConnectorCredentialPair(Base):
|
||||
"""Tracks which cc_pairs reference each hierarchy node.
|
||||
|
||||
During pruning, stale entries are removed for the current cc_pair.
|
||||
Hierarchy nodes with zero remaining entries are then deleted.
|
||||
"""
|
||||
|
||||
__tablename__ = "hierarchy_node_by_connector_credential_pair"
|
||||
|
||||
hierarchy_node_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("hierarchy_node.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
connector_id: Mapped[int] = mapped_column(primary_key=True)
|
||||
credential_id: Mapped[int] = mapped_column(primary_key=True)
|
||||
|
||||
__table_args__ = (
|
||||
ForeignKeyConstraint(
|
||||
["connector_id", "credential_id"],
|
||||
[
|
||||
"connector_credential_pair.connector_id",
|
||||
"connector_credential_pair.credential_id",
|
||||
],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
Index(
|
||||
"ix_hierarchy_node_cc_pair_connector_credential",
|
||||
"connector_id",
|
||||
"credential_id",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class DocumentByConnectorCredentialPair(Base):
|
||||
"""Represents an indexing of a document by a specific connector / credential pair"""
|
||||
|
||||
|
||||
@@ -205,7 +205,9 @@ def update_persona_access(
|
||||
|
||||
NOTE: Callers are responsible for committing."""
|
||||
|
||||
needs_sync = False
|
||||
if is_public is not None:
|
||||
needs_sync = True
|
||||
persona = db_session.query(Persona).filter(Persona.id == persona_id).first()
|
||||
if persona:
|
||||
persona.is_public = is_public
|
||||
@@ -213,6 +215,7 @@ def update_persona_access(
|
||||
# NOTE: For user-ids and group-ids, `None` means "leave unchanged", `[]` means "clear all shares",
|
||||
# and a non-empty list means "replace with these shares".
|
||||
if user_ids is not None:
|
||||
needs_sync = True
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
@@ -233,6 +236,7 @@ def update_persona_access(
|
||||
# MIT doesn't support group-based sharing, so we allow clearing (no-op since
|
||||
# there shouldn't be any) but raise an error if trying to add actual groups.
|
||||
if group_ids is not None:
|
||||
needs_sync = True
|
||||
db_session.query(Persona__UserGroup).filter(
|
||||
Persona__UserGroup.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
@@ -240,6 +244,10 @@ def update_persona_access(
|
||||
if group_ids:
|
||||
raise NotImplementedError("Onyx MIT does not support group-based sharing")
|
||||
|
||||
# When sharing changes, user file ACLs need to be updated in the vector DB
|
||||
if needs_sync:
|
||||
mark_persona_user_files_for_sync(persona_id, db_session)
|
||||
|
||||
|
||||
def create_update_persona(
|
||||
persona_id: int | None,
|
||||
@@ -851,6 +859,24 @@ def update_personas_display_priority(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def mark_persona_user_files_for_sync(
|
||||
persona_id: int,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""When persona sharing changes, mark all of its user files for sync
|
||||
so that their ACLs get updated in the vector DB."""
|
||||
persona = (
|
||||
db_session.query(Persona)
|
||||
.options(selectinload(Persona.user_files))
|
||||
.filter(Persona.id == persona_id)
|
||||
.first()
|
||||
)
|
||||
if not persona:
|
||||
return
|
||||
file_ids = [uf.id for uf in persona.user_files]
|
||||
_mark_files_need_persona_sync(db_session, file_ids)
|
||||
|
||||
|
||||
def _mark_files_need_persona_sync(
|
||||
db_session: Session,
|
||||
user_file_ids: list[UUID],
|
||||
|
||||
@@ -52,7 +52,7 @@ def create_user_files(
|
||||
) -> CategorizedFilesResult:
|
||||
|
||||
# Categorize the files
|
||||
categorized_files = categorize_uploaded_files(files)
|
||||
categorized_files = categorize_uploaded_files(files, db_session)
|
||||
# NOTE: At the moment, zip metadata is not used for user files.
|
||||
# Should revisit to decide whether this should be a feature.
|
||||
upload_response = upload_files(categorized_files.acceptable, FileOrigin.USER_FILE)
|
||||
|
||||
161
backend/onyx/db/rotate_encryption_key.py
Normal file
161
backend/onyx/db/rotate_encryption_key.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Rotate encryption key for all encrypted columns.
|
||||
|
||||
Dynamically discovers all columns using EncryptedString / EncryptedJson,
|
||||
decrypts each value with the old key, and re-encrypts with the current
|
||||
ENCRYPTION_KEY_SECRET.
|
||||
|
||||
The operation is idempotent: rows already encrypted with the current key
|
||||
are skipped. Commits are made in batches so a crash mid-rotation can be
|
||||
safely resumed by re-running.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import LargeBinary
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import ENCRYPTION_KEY_SECRET
|
||||
from onyx.db.models import Base
|
||||
from onyx.db.models import EncryptedJson
|
||||
from onyx.db.models import EncryptedString
|
||||
from onyx.utils.encryption import decrypt_bytes_to_string
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_BATCH_SIZE = 500
|
||||
|
||||
|
||||
def _can_decrypt_with_current_key(data: bytes) -> bool:
|
||||
"""Check if data is already encrypted with the current key.
|
||||
|
||||
Passes the key explicitly so the fallback-to-raw-decode path in
|
||||
_decrypt_bytes is NOT triggered — a clean success/failure signal.
|
||||
"""
|
||||
try:
|
||||
decrypt_bytes_to_string(data, key=ENCRYPTION_KEY_SECRET)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _discover_encrypted_columns() -> list[tuple[type, str, list[str], bool]]:
|
||||
"""Walk all ORM models and find columns using EncryptedString/EncryptedJson.
|
||||
|
||||
Returns list of (ModelClass, column_attr_name, [pk_attr_names], is_json).
|
||||
"""
|
||||
results: list[tuple[type, str, list[str], bool]] = []
|
||||
|
||||
for mapper in Base.registry.mappers:
|
||||
model_cls = mapper.class_
|
||||
pk_names = [col.key for col in mapper.primary_key]
|
||||
|
||||
for prop in mapper.column_attrs:
|
||||
for col in prop.columns:
|
||||
if isinstance(col.type, EncryptedJson):
|
||||
results.append((model_cls, prop.key, pk_names, True))
|
||||
elif isinstance(col.type, EncryptedString):
|
||||
results.append((model_cls, prop.key, pk_names, False))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def rotate_encryption_key(
|
||||
db_session: Session,
|
||||
old_key: str | None,
|
||||
dry_run: bool = False,
|
||||
) -> dict[str, int]:
|
||||
"""Decrypt all encrypted columns with old_key and re-encrypt with the current key.
|
||||
|
||||
Args:
|
||||
db_session: Active database session.
|
||||
old_key: The previous encryption key. Pass None or "" if values were
|
||||
not previously encrypted with a key.
|
||||
dry_run: If True, count rows that need rotation without modifying data.
|
||||
|
||||
Returns:
|
||||
Dict of "table.column" -> number of rows re-encrypted (or would be).
|
||||
|
||||
Commits every _BATCH_SIZE rows so that locks are held briefly and progress
|
||||
is preserved on crash. Already-rotated rows are detected and skipped,
|
||||
making the operation safe to re-run.
|
||||
"""
|
||||
if not global_version.is_ee_version():
|
||||
raise RuntimeError("EE mode is not enabled — rotation requires EE encryption.")
|
||||
|
||||
if not ENCRYPTION_KEY_SECRET:
|
||||
raise RuntimeError(
|
||||
"ENCRYPTION_KEY_SECRET is not set — cannot rotate. "
|
||||
"Set the target encryption key in the environment before running."
|
||||
)
|
||||
|
||||
encrypted_columns = _discover_encrypted_columns()
|
||||
totals: dict[str, int] = {}
|
||||
|
||||
for model_cls, col_name, pk_names, is_json in encrypted_columns:
|
||||
table_name: str = model_cls.__tablename__ # type: ignore[attr-defined]
|
||||
col_attr = getattr(model_cls, col_name)
|
||||
pk_attrs = [getattr(model_cls, pk) for pk in pk_names]
|
||||
|
||||
# Read raw bytes directly, bypassing the TypeDecorator
|
||||
raw_col = col_attr.property.columns[0]
|
||||
|
||||
stmt = select(*pk_attrs, raw_col.cast(LargeBinary)).where(col_attr.is_not(None))
|
||||
rows = db_session.execute(stmt).all()
|
||||
|
||||
reencrypted = 0
|
||||
batch_pending = 0
|
||||
for row in rows:
|
||||
raw_bytes: bytes | None = row[-1]
|
||||
if raw_bytes is None:
|
||||
continue
|
||||
|
||||
if _can_decrypt_with_current_key(raw_bytes):
|
||||
continue
|
||||
|
||||
try:
|
||||
if not old_key:
|
||||
decrypted_str = raw_bytes.decode("utf-8")
|
||||
else:
|
||||
decrypted_str = decrypt_bytes_to_string(raw_bytes, key=old_key)
|
||||
|
||||
# For EncryptedJson, parse back to dict so the TypeDecorator
|
||||
# can json.dumps() it cleanly (avoids double-encoding).
|
||||
value: Any = json.loads(decrypted_str) if is_json else decrypted_str
|
||||
except (ValueError, UnicodeDecodeError) as e:
|
||||
pk_vals = [row[i] for i in range(len(pk_names))]
|
||||
logger.warning(
|
||||
f"Could not decrypt/parse {table_name}.{col_name} "
|
||||
f"row {pk_vals} — skipping: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not dry_run:
|
||||
pk_filters = [pk_attr == row[i] for i, pk_attr in enumerate(pk_attrs)]
|
||||
update_stmt = (
|
||||
update(model_cls).where(*pk_filters).values({col_name: value})
|
||||
)
|
||||
db_session.execute(update_stmt)
|
||||
batch_pending += 1
|
||||
|
||||
if batch_pending >= _BATCH_SIZE:
|
||||
db_session.commit()
|
||||
batch_pending = 0
|
||||
reencrypted += 1
|
||||
|
||||
# Flush remaining rows in this column
|
||||
if batch_pending > 0:
|
||||
db_session.commit()
|
||||
|
||||
if reencrypted > 0:
|
||||
totals[f"{table_name}.{col_name}"] = reencrypted
|
||||
logger.info(
|
||||
f"{'[DRY RUN] Would re-encrypt' if dry_run else 'Re-encrypted'} "
|
||||
f"{reencrypted} value(s) in {table_name}.{col_name}"
|
||||
)
|
||||
|
||||
return totals
|
||||
@@ -129,7 +129,7 @@ def get_current_search_settings(db_session: Session) -> SearchSettings:
|
||||
latest_settings = result.scalars().first()
|
||||
|
||||
if not latest_settings:
|
||||
raise RuntimeError("No search settings specified, DB is not in a valid state")
|
||||
raise RuntimeError("No search settings specified; DB is not in a valid state.")
|
||||
return latest_settings
|
||||
|
||||
|
||||
|
||||
@@ -13,12 +13,15 @@ from onyx.db.constants import UNSET
|
||||
from onyx.db.constants import UnsetType
|
||||
from onyx.db.enums import MCPServerStatus
|
||||
from onyx.db.models import MCPServer
|
||||
from onyx.db.models import OAuthConfig
|
||||
from onyx.db.models import Tool
|
||||
from onyx.db.models import ToolCall
|
||||
from onyx.server.features.tool.models import Header
|
||||
from onyx.tools.built_in_tools import BUILT_IN_TOOL_TYPES
|
||||
from onyx.utils.headers import HeaderItemDict
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_json_like
|
||||
from onyx.utils.postgres_sanitization import sanitize_string
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
@@ -159,10 +162,26 @@ def update_tool(
|
||||
]
|
||||
if passthrough_auth is not None:
|
||||
tool.passthrough_auth = passthrough_auth
|
||||
old_oauth_config_id = tool.oauth_config_id
|
||||
if not isinstance(oauth_config_id, UnsetType):
|
||||
tool.oauth_config_id = oauth_config_id
|
||||
db_session.commit()
|
||||
db_session.flush()
|
||||
|
||||
# Clean up orphaned OAuthConfig if the oauth_config_id was changed
|
||||
if (
|
||||
old_oauth_config_id is not None
|
||||
and not isinstance(oauth_config_id, UnsetType)
|
||||
and old_oauth_config_id != oauth_config_id
|
||||
):
|
||||
other_tools = db_session.scalars(
|
||||
select(Tool).where(Tool.oauth_config_id == old_oauth_config_id)
|
||||
).all()
|
||||
if not other_tools:
|
||||
oauth_config = db_session.get(OAuthConfig, old_oauth_config_id)
|
||||
if oauth_config:
|
||||
db_session.delete(oauth_config)
|
||||
|
||||
db_session.commit()
|
||||
return tool
|
||||
|
||||
|
||||
@@ -171,8 +190,21 @@ def delete_tool__no_commit(tool_id: int, db_session: Session) -> None:
|
||||
if tool is None:
|
||||
raise ValueError(f"Tool with ID {tool_id} does not exist")
|
||||
|
||||
oauth_config_id = tool.oauth_config_id
|
||||
|
||||
db_session.delete(tool)
|
||||
db_session.flush() # Don't commit yet, let caller decide when to commit
|
||||
db_session.flush()
|
||||
|
||||
# Clean up orphaned OAuthConfig if no other tools reference it
|
||||
if oauth_config_id is not None:
|
||||
other_tools = db_session.scalars(
|
||||
select(Tool).where(Tool.oauth_config_id == oauth_config_id)
|
||||
).all()
|
||||
if not other_tools:
|
||||
oauth_config = db_session.get(OAuthConfig, oauth_config_id)
|
||||
if oauth_config:
|
||||
db_session.delete(oauth_config)
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def get_builtin_tool(
|
||||
@@ -256,11 +288,13 @@ def create_tool_call_no_commit(
|
||||
tab_index=tab_index,
|
||||
tool_id=tool_id,
|
||||
tool_call_id=tool_call_id,
|
||||
reasoning_tokens=reasoning_tokens,
|
||||
tool_call_arguments=tool_call_arguments,
|
||||
tool_call_response=tool_call_response,
|
||||
reasoning_tokens=(
|
||||
sanitize_string(reasoning_tokens) if reasoning_tokens else reasoning_tokens
|
||||
),
|
||||
tool_call_arguments=sanitize_json_like(tool_call_arguments),
|
||||
tool_call_response=sanitize_json_like(tool_call_response),
|
||||
tool_call_tokens=tool_call_tokens,
|
||||
generated_images=generated_images,
|
||||
generated_images=sanitize_json_like(generated_images),
|
||||
)
|
||||
|
||||
db_session.add(tool_call)
|
||||
|
||||
@@ -3,9 +3,11 @@ from uuid import UUID
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Project__UserFile
|
||||
from onyx.db.models import UserFile
|
||||
|
||||
@@ -118,3 +120,31 @@ def get_file_ids_by_user_file_ids(
|
||||
) -> list[str]:
|
||||
user_files = db_session.query(UserFile).filter(UserFile.id.in_(user_file_ids)).all()
|
||||
return [user_file.file_id for user_file in user_files]
|
||||
|
||||
|
||||
def fetch_user_files_with_access_relationships(
|
||||
user_file_ids: list[str],
|
||||
db_session: Session,
|
||||
eager_load_groups: bool = False,
|
||||
) -> list[UserFile]:
|
||||
"""Fetch user files with the owner and assistant relationships
|
||||
eagerly loaded (needed for computing access control).
|
||||
|
||||
When eager_load_groups is True, Persona.groups is also loaded so that
|
||||
callers can extract user-group names without a second DB round-trip."""
|
||||
persona_sub_options = [
|
||||
selectinload(Persona.users),
|
||||
selectinload(Persona.user),
|
||||
]
|
||||
if eager_load_groups:
|
||||
persona_sub_options.append(selectinload(Persona.groups))
|
||||
|
||||
return (
|
||||
db_session.query(UserFile)
|
||||
.options(
|
||||
joinedload(UserFile.user),
|
||||
selectinload(UserFile.assistants).options(*persona_sub_options),
|
||||
)
|
||||
.filter(UserFile.id.in_(user_file_ids))
|
||||
.all()
|
||||
)
|
||||
|
||||
103
backend/onyx/document_index/FILTER_SEMANTICS.md
Normal file
103
backend/onyx/document_index/FILTER_SEMANTICS.md
Normal file
@@ -0,0 +1,103 @@
|
||||
# Vector DB Filter Semantics
|
||||
|
||||
How `IndexFilters` fields combine into the final query filter. Applies to both Vespa and OpenSearch.
|
||||
|
||||
## Filter categories
|
||||
|
||||
| Category | Fields | Join logic |
|
||||
|---|---|---|
|
||||
| **Visibility** | `hidden` | Always applied (unless `include_hidden`) |
|
||||
| **Tenant** | `tenant_id` | AND (multi-tenant only) |
|
||||
| **ACL** | `access_control_list` | OR within, AND with rest |
|
||||
| **Narrowing** | `source_type`, `tags`, `time_cutoff` | Each OR within, AND with rest |
|
||||
| **Knowledge scope** | `document_set`, `user_file_ids`, `attached_document_ids`, `hierarchy_node_ids` | OR within group, AND with rest |
|
||||
| **Additive scope** | `project_id`, `persona_id` | OR'd into knowledge scope **only when** a knowledge scope filter already exists |
|
||||
|
||||
## How filters combine
|
||||
|
||||
All categories are AND'd together. Within the knowledge scope category, individual filters are OR'd.
|
||||
|
||||
```
|
||||
NOT hidden
|
||||
AND tenant = T -- if multi-tenant
|
||||
AND (acl contains A1 OR acl contains A2)
|
||||
AND (source_type = S1 OR ...) -- if set
|
||||
AND (tag = T1 OR ...) -- if set
|
||||
AND <knowledge scope> -- see below
|
||||
AND time >= cutoff -- if set
|
||||
```
|
||||
|
||||
## Knowledge scope rules
|
||||
|
||||
The knowledge scope filter controls **what knowledge an assistant can access**.
|
||||
|
||||
### No explicit knowledge attached
|
||||
|
||||
When `document_set`, `user_file_ids`, `attached_document_ids`, and `hierarchy_node_ids` are all empty/None:
|
||||
|
||||
- **No knowledge scope filter is applied.** The assistant can see everything (subject to ACL).
|
||||
- `project_id` and `persona_id` are ignored — they never restrict on their own.
|
||||
|
||||
### One explicit knowledge type
|
||||
|
||||
```
|
||||
-- Only document sets
|
||||
AND (document_sets contains "Engineering" OR document_sets contains "Legal")
|
||||
|
||||
-- Only user files
|
||||
AND (document_id = "uuid-1" OR document_id = "uuid-2")
|
||||
```
|
||||
|
||||
### Multiple explicit knowledge types (OR'd)
|
||||
|
||||
```
|
||||
-- Document sets + user files
|
||||
AND (
|
||||
document_sets contains "Engineering"
|
||||
OR document_id = "uuid-1"
|
||||
)
|
||||
```
|
||||
|
||||
### Explicit knowledge + overflowing user files
|
||||
|
||||
When an explicit knowledge restriction is in effect **and** `project_id` or `persona_id` is set (user files overflowed the LLM context window), the additive scopes widen the filter:
|
||||
|
||||
```
|
||||
-- Document sets + persona user files overflowed
|
||||
AND (
|
||||
document_sets contains "Engineering"
|
||||
OR personas contains 42
|
||||
)
|
||||
|
||||
-- User files + project files overflowed
|
||||
AND (
|
||||
document_id = "uuid-1"
|
||||
OR user_project contains 7
|
||||
)
|
||||
```
|
||||
|
||||
### Only project_id or persona_id (no explicit knowledge)
|
||||
|
||||
No knowledge scope filter. The assistant searches everything.
|
||||
|
||||
```
|
||||
-- Just ACL, no restriction
|
||||
NOT hidden
|
||||
AND (acl contains ...)
|
||||
```
|
||||
|
||||
## Field reference
|
||||
|
||||
| Filter field | Vespa field | Vespa type | Purpose |
|
||||
|---|---|---|---|
|
||||
| `document_set` | `document_sets` | `weightedset<string>` | Connector doc sets attached to assistant |
|
||||
| `user_file_ids` | `document_id` | `string` | User files uploaded to assistant |
|
||||
| `attached_document_ids` | `document_id` | `string` | Documents explicitly attached (OpenSearch only) |
|
||||
| `hierarchy_node_ids` | `ancestor_hierarchy_node_ids` | `array<int>` | Folder/space nodes (OpenSearch only) |
|
||||
| `project_id` | `user_project` | `array<int>` | Project tag for overflowing user files |
|
||||
| `persona_id` | `personas` | `array<int>` | Persona tag for overflowing user files |
|
||||
| `access_control_list` | `access_control_list` | `weightedset<string>` | ACL entries for the requesting user |
|
||||
| `source_type` | `source_type` | `string` | Connector source type (e.g. `web`, `jira`) |
|
||||
| `tags` | `metadata_list` | `array<string>` | Document metadata tags |
|
||||
| `time_cutoff` | `doc_updated_at` | `long` | Minimum document update timestamp |
|
||||
| `tenant_id` | `tenant_id` | `string` | Tenant isolation (multi-tenant) |
|
||||
@@ -32,9 +32,6 @@ def get_multipass_config(search_settings: SearchSettings) -> MultipassConfig:
|
||||
Determines whether to enable multipass and large chunks by examining
|
||||
the current search settings and the embedder configuration.
|
||||
"""
|
||||
if not search_settings:
|
||||
return MultipassConfig(multipass_indexing=False, enable_large_chunks=False)
|
||||
|
||||
multipass = should_use_multipass(search_settings)
|
||||
enable_large_chunks = SearchSettings.can_use_large_chunks(
|
||||
multipass, search_settings.model_name, search_settings.provider_type
|
||||
|
||||
@@ -26,11 +26,10 @@ def get_default_document_index(
|
||||
To be used for retrieval only. Indexing should be done through both indices
|
||||
until Vespa is deprecated.
|
||||
|
||||
Pre-existing docstring for this function, although secondary indices are not
|
||||
currently supported:
|
||||
Primary index is the index that is used for querying/updating etc. Secondary
|
||||
index is for when both the currently used index and the upcoming index both
|
||||
need to be updated, updates are applied to both indices.
|
||||
need to be updated. Updates are applied to both indices.
|
||||
WARNING: In that case, get_all_document_indices should be used.
|
||||
"""
|
||||
if DISABLE_VECTOR_DB:
|
||||
return DisabledDocumentIndex(
|
||||
@@ -51,11 +50,26 @@ def get_default_document_index(
|
||||
opensearch_retrieval_enabled = get_opensearch_retrieval_state(db_session)
|
||||
if opensearch_retrieval_enabled:
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
secondary_indexing_setting = (
|
||||
IndexingSetting.from_db_model(secondary_search_settings)
|
||||
if secondary_search_settings
|
||||
else None
|
||||
)
|
||||
return OpenSearchOldDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
embedding_dim=indexing_setting.final_embedding_dim,
|
||||
embedding_precision=indexing_setting.embedding_precision,
|
||||
secondary_index_name=secondary_index_name,
|
||||
secondary_embedding_dim=(
|
||||
secondary_indexing_setting.final_embedding_dim
|
||||
if secondary_indexing_setting
|
||||
else None
|
||||
),
|
||||
secondary_embedding_precision=(
|
||||
secondary_indexing_setting.embedding_precision
|
||||
if secondary_indexing_setting
|
||||
else None
|
||||
),
|
||||
large_chunks_enabled=search_settings.large_chunks_enabled,
|
||||
secondary_large_chunks_enabled=secondary_large_chunks_enabled,
|
||||
multitenant=MULTI_TENANT,
|
||||
@@ -86,8 +100,7 @@ def get_all_document_indices(
|
||||
Used for indexing only. Until Vespa is deprecated we will index into both
|
||||
document indices. Retrieval is done through only one index however.
|
||||
|
||||
Large chunks and secondary indices are not currently supported so we
|
||||
hardcode appropriate values.
|
||||
Large chunks are not currently supported so we hardcode appropriate values.
|
||||
|
||||
NOTE: Make sure the Vespa index object is returned first. In the rare event
|
||||
that there is some conflict between indexing and the migration task, it is
|
||||
@@ -123,13 +136,36 @@ def get_all_document_indices(
|
||||
opensearch_document_index: OpenSearchOldDocumentIndex | None = None
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
secondary_indexing_setting = (
|
||||
IndexingSetting.from_db_model(secondary_search_settings)
|
||||
if secondary_search_settings
|
||||
else None
|
||||
)
|
||||
opensearch_document_index = OpenSearchOldDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
embedding_dim=indexing_setting.final_embedding_dim,
|
||||
embedding_precision=indexing_setting.embedding_precision,
|
||||
secondary_index_name=None,
|
||||
large_chunks_enabled=False,
|
||||
secondary_large_chunks_enabled=None,
|
||||
secondary_index_name=(
|
||||
secondary_search_settings.index_name
|
||||
if secondary_search_settings
|
||||
else None
|
||||
),
|
||||
secondary_embedding_dim=(
|
||||
secondary_indexing_setting.final_embedding_dim
|
||||
if secondary_indexing_setting
|
||||
else None
|
||||
),
|
||||
secondary_embedding_precision=(
|
||||
secondary_indexing_setting.embedding_precision
|
||||
if secondary_indexing_setting
|
||||
else None
|
||||
),
|
||||
large_chunks_enabled=search_settings.large_chunks_enabled,
|
||||
secondary_large_chunks_enabled=(
|
||||
secondary_search_settings.large_chunks_enabled
|
||||
if secondary_search_settings
|
||||
else None
|
||||
),
|
||||
multitenant=MULTI_TENANT,
|
||||
httpx_client=httpx_client,
|
||||
)
|
||||
|
||||
@@ -61,6 +61,25 @@ class SearchHit(BaseModel, Generic[SchemaDocumentModel]):
|
||||
explanation: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class IndexInfo(BaseModel):
|
||||
"""
|
||||
Represents information about an OpenSearch index.
|
||||
"""
|
||||
|
||||
model_config = {"frozen": True}
|
||||
|
||||
name: str
|
||||
health: str
|
||||
status: str
|
||||
num_primary_shards: str
|
||||
num_replica_shards: str
|
||||
docs_count: str
|
||||
docs_deleted: str
|
||||
created_at: str
|
||||
total_size: str
|
||||
primary_shards_size: str
|
||||
|
||||
|
||||
def get_new_body_without_vectors(body: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Recursively replaces vectors in the body with their length.
|
||||
|
||||
@@ -159,8 +178,8 @@ class OpenSearchClient(AbstractContextManager):
|
||||
Raises:
|
||||
Exception: There was an error creating the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
|
||||
if not result.get("acknowledged", False):
|
||||
response = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
|
||||
if not response.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to create search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
@@ -173,8 +192,8 @@ class OpenSearchClient(AbstractContextManager):
|
||||
Raises:
|
||||
Exception: There was an error deleting the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.delete(id=pipeline_id)
|
||||
if not result.get("acknowledged", False):
|
||||
response = self._client.search_pipeline.delete(id=pipeline_id)
|
||||
if not response.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to delete search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
@@ -198,6 +217,34 @@ class OpenSearchClient(AbstractContextManager):
|
||||
logger.error(f"Failed to put cluster settings: {response}.")
|
||||
return False
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def list_indices_with_info(self) -> list[IndexInfo]:
|
||||
"""
|
||||
Lists the indices in the OpenSearch cluster with information about each
|
||||
index.
|
||||
|
||||
Returns:
|
||||
A list of IndexInfo objects for each index.
|
||||
"""
|
||||
response = self._client.cat.indices(format="json")
|
||||
indices: list[IndexInfo] = []
|
||||
for raw_index_info in response:
|
||||
indices.append(
|
||||
IndexInfo(
|
||||
name=raw_index_info.get("index", ""),
|
||||
health=raw_index_info.get("health", ""),
|
||||
status=raw_index_info.get("status", ""),
|
||||
num_primary_shards=raw_index_info.get("pri", ""),
|
||||
num_replica_shards=raw_index_info.get("rep", ""),
|
||||
docs_count=raw_index_info.get("docs.count", ""),
|
||||
docs_deleted=raw_index_info.get("docs.deleted", ""),
|
||||
created_at=raw_index_info.get("creation.date.string", ""),
|
||||
total_size=raw_index_info.get("store.size", ""),
|
||||
primary_shards_size=raw_index_info.get("pri.store.size", ""),
|
||||
)
|
||||
)
|
||||
return indices
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def ping(self) -> bool:
|
||||
"""Pings the OpenSearch cluster.
|
||||
|
||||
@@ -271,6 +271,9 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
embedding_dim: int,
|
||||
embedding_precision: EmbeddingPrecision,
|
||||
secondary_index_name: str | None,
|
||||
secondary_embedding_dim: int | None,
|
||||
secondary_embedding_precision: EmbeddingPrecision | None,
|
||||
# NOTE: We do not support large chunks right now.
|
||||
large_chunks_enabled: bool, # noqa: ARG002
|
||||
secondary_large_chunks_enabled: bool | None, # noqa: ARG002
|
||||
multitenant: bool = False,
|
||||
@@ -286,12 +289,25 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
f"Expected {MULTI_TENANT}, got {multitenant}."
|
||||
)
|
||||
tenant_id = get_current_tenant_id()
|
||||
tenant_state = TenantState(tenant_id=tenant_id, multitenant=multitenant)
|
||||
self._real_index = OpenSearchDocumentIndex(
|
||||
tenant_state=TenantState(tenant_id=tenant_id, multitenant=multitenant),
|
||||
tenant_state=tenant_state,
|
||||
index_name=index_name,
|
||||
embedding_dim=embedding_dim,
|
||||
embedding_precision=embedding_precision,
|
||||
)
|
||||
self._secondary_real_index: OpenSearchDocumentIndex | None = None
|
||||
if self.secondary_index_name:
|
||||
if secondary_embedding_dim is None or secondary_embedding_precision is None:
|
||||
raise ValueError(
|
||||
"Bug: Secondary index embedding dimension and precision are not set."
|
||||
)
|
||||
self._secondary_real_index = OpenSearchDocumentIndex(
|
||||
tenant_state=tenant_state,
|
||||
index_name=self.secondary_index_name,
|
||||
embedding_dim=secondary_embedding_dim,
|
||||
embedding_precision=secondary_embedding_precision,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def register_multitenant_indices(
|
||||
@@ -307,19 +323,38 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
self,
|
||||
primary_embedding_dim: int,
|
||||
primary_embedding_precision: EmbeddingPrecision,
|
||||
secondary_index_embedding_dim: int | None, # noqa: ARG002
|
||||
secondary_index_embedding_precision: EmbeddingPrecision | None, # noqa: ARG002
|
||||
secondary_index_embedding_dim: int | None,
|
||||
secondary_index_embedding_precision: EmbeddingPrecision | None,
|
||||
) -> None:
|
||||
# Only handle primary index for now, ignore secondary.
|
||||
return self._real_index.verify_and_create_index_if_necessary(
|
||||
self._real_index.verify_and_create_index_if_necessary(
|
||||
primary_embedding_dim, primary_embedding_precision
|
||||
)
|
||||
if self.secondary_index_name:
|
||||
if (
|
||||
secondary_index_embedding_dim is None
|
||||
or secondary_index_embedding_precision is None
|
||||
):
|
||||
raise ValueError(
|
||||
"Bug: Secondary index embedding dimension and precision are not set."
|
||||
)
|
||||
assert (
|
||||
self._secondary_real_index is not None
|
||||
), "Bug: Secondary index is not initialized."
|
||||
self._secondary_real_index.verify_and_create_index_if_necessary(
|
||||
secondary_index_embedding_dim, secondary_index_embedding_precision
|
||||
)
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[OldDocumentInsertionRecord]:
|
||||
"""
|
||||
NOTE: Do NOT consider the secondary index here. A separate indexing
|
||||
pipeline will be responsible for indexing to the secondary index. This
|
||||
design is not ideal and we should reconsider this when revamping index
|
||||
swapping.
|
||||
"""
|
||||
# Convert IndexBatchParams to IndexingMetadata.
|
||||
chunk_counts: dict[str, IndexingMetadata.ChunkCounts] = {}
|
||||
for doc_id in index_batch_params.doc_id_to_new_chunk_cnt:
|
||||
@@ -351,7 +386,20 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
tenant_id: str, # noqa: ARG002
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
return self._real_index.delete(doc_id, chunk_count)
|
||||
"""
|
||||
NOTE: Remember to handle the secondary index here. There is no separate
|
||||
pipeline for deleting chunks in the secondary index. This design is not
|
||||
ideal and we should reconsider this when revamping index swapping.
|
||||
"""
|
||||
total_chunks_deleted = self._real_index.delete(doc_id, chunk_count)
|
||||
if self.secondary_index_name:
|
||||
assert (
|
||||
self._secondary_real_index is not None
|
||||
), "Bug: Secondary index is not initialized."
|
||||
total_chunks_deleted += self._secondary_real_index.delete(
|
||||
doc_id, chunk_count
|
||||
)
|
||||
return total_chunks_deleted
|
||||
|
||||
def update_single(
|
||||
self,
|
||||
@@ -362,6 +410,11 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
fields: VespaDocumentFields | None,
|
||||
user_fields: VespaDocumentUserFields | None,
|
||||
) -> None:
|
||||
"""
|
||||
NOTE: Remember to handle the secondary index here. There is no separate
|
||||
pipeline for updating chunks in the secondary index. This design is not
|
||||
ideal and we should reconsider this when revamping index swapping.
|
||||
"""
|
||||
if fields is None and user_fields is None:
|
||||
logger.warning(
|
||||
f"Tried to update document {doc_id} with no updated fields or user fields."
|
||||
@@ -392,6 +445,11 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
|
||||
try:
|
||||
self._real_index.update([update_request])
|
||||
if self.secondary_index_name:
|
||||
assert (
|
||||
self._secondary_real_index is not None
|
||||
), "Bug: Secondary index is not initialized."
|
||||
self._secondary_real_index.update([update_request])
|
||||
except NotFoundError:
|
||||
logger.exception(
|
||||
f"Tried to update document {doc_id} but at least one of its chunks was not found in OpenSearch. "
|
||||
@@ -681,7 +739,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
The number of chunks successfully deleted.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Deleting document {document_id} from index {self._index_name}."
|
||||
f"[OpenSearchDocumentIndex] Deleting document {document_id} from index "
|
||||
f"{self._index_name}."
|
||||
)
|
||||
query_body = DocumentQuery.delete_from_document_id_query(
|
||||
document_id=document_id,
|
||||
@@ -717,7 +776,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
specified documents.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Updating {len(update_requests)} chunks for index {self._index_name}."
|
||||
f"[OpenSearchDocumentIndex] Updating {len(update_requests)} chunks for index "
|
||||
f"{self._index_name}."
|
||||
)
|
||||
for update_request in update_requests:
|
||||
properties_to_update: dict[str, Any] = dict()
|
||||
@@ -773,9 +833,11 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
# here.
|
||||
# TODO(andrei): Fix the aforementioned race condition.
|
||||
raise ChunkCountNotFoundError(
|
||||
f"Tried to update document {doc_id} but its chunk count is not known. Older versions of the "
|
||||
"application used to permit this but is not a supported state for a document when using OpenSearch. "
|
||||
"The document was likely just added to the indexing pipeline and the chunk count will be updated shortly."
|
||||
f"Tried to update document {doc_id} but its chunk count is not known. "
|
||||
"Older versions of the application used to permit this but is not a "
|
||||
"supported state for a document when using OpenSearch. The document was "
|
||||
"likely just added to the indexing pipeline and the chunk count will be "
|
||||
"updated shortly."
|
||||
)
|
||||
if doc_chunk_count == 0:
|
||||
raise ValueError(
|
||||
@@ -807,7 +869,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
chunk IDs vs querying for matching document chunks.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Retrieving {len(chunk_requests)} chunks for index {self._index_name}."
|
||||
f"[OpenSearchDocumentIndex] Retrieving {len(chunk_requests)} chunks for index "
|
||||
f"{self._index_name}."
|
||||
)
|
||||
results: list[InferenceChunk] = []
|
||||
for chunk_request in chunk_requests:
|
||||
@@ -854,7 +917,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Hybrid retrieving {num_to_retrieve} chunks for index {self._index_name}."
|
||||
f"[OpenSearchDocumentIndex] Hybrid retrieving {num_to_retrieve} chunks for index "
|
||||
f"{self._index_name}."
|
||||
)
|
||||
# TODO(andrei): This could be better, the caller should just make this
|
||||
# decision when passing in the query param. See the above comment in the
|
||||
@@ -874,8 +938,10 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
index_filters=filters,
|
||||
include_hidden=False,
|
||||
)
|
||||
# NOTE: Using z-score normalization here because it's better for hybrid search from a theoretical standpoint.
|
||||
# Empirically on a small dataset of up to 10K docs, it's not very different. Likely more impactful at scale.
|
||||
# NOTE: Using z-score normalization here because it's better for hybrid
|
||||
# search from a theoretical standpoint. Empirically on a small dataset
|
||||
# of up to 10K docs, it's not very different. Likely more impactful at
|
||||
# scale.
|
||||
# https://opensearch.org/blog/introducing-the-z-score-normalization-technique-for-hybrid-search/
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
|
||||
body=query_body,
|
||||
@@ -902,7 +968,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
dirty: bool | None = None, # noqa: ARG002
|
||||
) -> list[InferenceChunk]:
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Randomly retrieving {num_to_retrieve} chunks for index {self._index_name}."
|
||||
f"[OpenSearchDocumentIndex] Randomly retrieving {num_to_retrieve} chunks for index "
|
||||
f"{self._index_name}."
|
||||
)
|
||||
query_body = DocumentQuery.get_random_search_query(
|
||||
tenant_state=self._tenant_state,
|
||||
@@ -932,7 +999,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
complete.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Indexing {len(chunks)} raw chunks for index {self._index_name}."
|
||||
f"[OpenSearchDocumentIndex] Indexing {len(chunks)} raw chunks for index "
|
||||
f"{self._index_name}."
|
||||
)
|
||||
# Do not raise if the document already exists, just update. This is
|
||||
# because the document may already have been indexed during the
|
||||
|
||||
@@ -243,7 +243,8 @@ class DocumentChunk(BaseModel):
|
||||
return value
|
||||
if not isinstance(value, int):
|
||||
raise ValueError(
|
||||
f"Bug: Expected an int for the last_updated property from OpenSearch, got {type(value)} instead."
|
||||
f"Bug: Expected an int for the last_updated property from OpenSearch, got "
|
||||
f"{type(value)} instead."
|
||||
)
|
||||
return datetime.fromtimestamp(value, tz=timezone.utc)
|
||||
|
||||
@@ -284,19 +285,22 @@ class DocumentChunk(BaseModel):
|
||||
elif isinstance(value, TenantState):
|
||||
if MULTI_TENANT != value.multitenant:
|
||||
raise ValueError(
|
||||
f"Bug: An existing TenantState object was supplied to the DocumentChunk model but its multi-tenant mode "
|
||||
f"({value.multitenant}) does not match the program's current global tenancy state."
|
||||
f"Bug: An existing TenantState object was supplied to the DocumentChunk model "
|
||||
f"but its multi-tenant mode ({value.multitenant}) does not match the program's "
|
||||
"current global tenancy state."
|
||||
)
|
||||
return value
|
||||
elif not isinstance(value, str):
|
||||
raise ValueError(
|
||||
f"Bug: Expected a str for the tenant_id property from OpenSearch, got {type(value)} instead."
|
||||
f"Bug: Expected a str for the tenant_id property from OpenSearch, got "
|
||||
f"{type(value)} instead."
|
||||
)
|
||||
else:
|
||||
if not MULTI_TENANT:
|
||||
raise ValueError(
|
||||
"Bug: Got a non-null str for the tenant_id property from OpenSearch but multi-tenant mode is not enabled. "
|
||||
"This is unexpected because in single-tenant mode we don't expect to see a tenant_id."
|
||||
"Bug: Got a non-null str for the tenant_id property from OpenSearch but "
|
||||
"multi-tenant mode is not enabled. This is unexpected because in single-tenant "
|
||||
"mode we don't expect to see a tenant_id."
|
||||
)
|
||||
return TenantState(tenant_id=value, multitenant=MULTI_TENANT)
|
||||
|
||||
@@ -352,8 +356,10 @@ class DocumentSchema:
|
||||
"properties": {
|
||||
TITLE_FIELD_NAME: {
|
||||
"type": "text",
|
||||
# Language analyzer (e.g. english) stems at index and search time for variant matching.
|
||||
# Configure via OPENSEARCH_TEXT_ANALYZER. Existing indices need reindexing after a change.
|
||||
# Language analyzer (e.g. english) stems at index and search
|
||||
# time for variant matching. Configure via
|
||||
# OPENSEARCH_TEXT_ANALYZER. Existing indices need reindexing
|
||||
# after a change.
|
||||
"analyzer": OPENSEARCH_TEXT_ANALYZER,
|
||||
"fields": {
|
||||
# Subfield accessed as title.keyword. Not indexed for
|
||||
|
||||
@@ -698,41 +698,6 @@ class DocumentQuery:
|
||||
"""
|
||||
return {"terms": {ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME: node_ids}}
|
||||
|
||||
def _get_assistant_knowledge_filter(
|
||||
attached_doc_ids: list[str] | None,
|
||||
node_ids: list[int] | None,
|
||||
file_ids: list[UUID] | None,
|
||||
document_sets: list[str] | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Combined filter for assistant knowledge.
|
||||
|
||||
When an assistant has attached knowledge, search should be scoped to:
|
||||
- Documents explicitly attached (by document ID), OR
|
||||
- Documents under attached hierarchy nodes (by ancestor node IDs), OR
|
||||
- User-uploaded files attached to the assistant, OR
|
||||
- Documents in the assistant's document sets (if any)
|
||||
"""
|
||||
knowledge_filter: dict[str, Any] = {
|
||||
"bool": {"should": [], "minimum_should_match": 1}
|
||||
}
|
||||
if attached_doc_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_attached_document_id_filter(attached_doc_ids)
|
||||
)
|
||||
if node_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_hierarchy_node_filter(node_ids)
|
||||
)
|
||||
if file_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_file_id_filter(file_ids)
|
||||
)
|
||||
if document_sets:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_document_set_filter(document_sets)
|
||||
)
|
||||
return knowledge_filter
|
||||
|
||||
filter_clauses: list[dict[str, Any]] = []
|
||||
|
||||
if not include_hidden:
|
||||
@@ -758,41 +723,53 @@ class DocumentQuery:
|
||||
# document's metadata list.
|
||||
filter_clauses.append(_get_tag_filter(tags))
|
||||
|
||||
# Check if this is an assistant knowledge search (has any assistant-scoped knowledge)
|
||||
has_assistant_knowledge = (
|
||||
# Knowledge scope: explicit knowledge attachments restrict what
|
||||
# an assistant can see. When none are set the assistant
|
||||
# searches everything.
|
||||
#
|
||||
# project_id / persona_id are additive: they make overflowing
|
||||
# user files findable but must NOT trigger the restriction on
|
||||
# their own (an agent with no explicit knowledge should search
|
||||
# everything).
|
||||
has_knowledge_scope = (
|
||||
attached_document_ids
|
||||
or hierarchy_node_ids
|
||||
or user_file_ids
|
||||
or document_sets
|
||||
)
|
||||
|
||||
if has_assistant_knowledge:
|
||||
# If assistant has attached knowledge, scope search to that knowledge.
|
||||
# Document sets are included in the OR filter so directly attached
|
||||
# docs are always findable even if not in the document sets.
|
||||
filter_clauses.append(
|
||||
_get_assistant_knowledge_filter(
|
||||
attached_document_ids,
|
||||
hierarchy_node_ids,
|
||||
user_file_ids,
|
||||
document_sets,
|
||||
if has_knowledge_scope:
|
||||
knowledge_filter: dict[str, Any] = {
|
||||
"bool": {"should": [], "minimum_should_match": 1}
|
||||
}
|
||||
if attached_document_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_attached_document_id_filter(attached_document_ids)
|
||||
)
|
||||
)
|
||||
elif user_file_ids:
|
||||
# Fallback for non-assistant user file searches (e.g., project searches)
|
||||
# If at least one user file ID is provided, the caller will only
|
||||
# retrieve documents where the document ID is in this input list of
|
||||
# file IDs.
|
||||
filter_clauses.append(_get_user_file_id_filter(user_file_ids))
|
||||
|
||||
if project_id is not None:
|
||||
# If a project ID is provided, the caller will only retrieve
|
||||
# documents where the project ID provided here is present in the
|
||||
# document's user projects list.
|
||||
filter_clauses.append(_get_user_project_filter(project_id))
|
||||
|
||||
if persona_id is not None:
|
||||
filter_clauses.append(_get_persona_filter(persona_id))
|
||||
if hierarchy_node_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_hierarchy_node_filter(hierarchy_node_ids)
|
||||
)
|
||||
if user_file_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_file_id_filter(user_file_ids)
|
||||
)
|
||||
if document_sets:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_document_set_filter(document_sets)
|
||||
)
|
||||
# Additive: widen scope to also cover overflowing user
|
||||
# files, but only when an explicit restriction is already
|
||||
# in effect.
|
||||
if project_id is not None:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_project_filter(project_id)
|
||||
)
|
||||
if persona_id is not None:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_persona_filter(persona_id)
|
||||
)
|
||||
filter_clauses.append(knowledge_filter)
|
||||
|
||||
if time_cutoff is not None:
|
||||
# If a time cutoff is provided, the caller will only retrieve
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import string
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
@@ -18,6 +19,7 @@ from onyx.background.celery.tasks.opensearch_migration.transformer import (
|
||||
)
|
||||
from onyx.configs.app_configs import LOG_VESPA_TIMING_INFORMATION
|
||||
from onyx.configs.app_configs import VESPA_LANGUAGE_OVERRIDE
|
||||
from onyx.configs.app_configs import VESPA_MIGRATION_REQUEST_TIMEOUT_S
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunkUncleaned
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
@@ -338,12 +340,18 @@ def get_all_chunks_paginated(
|
||||
params["continuation"] = continuation_token
|
||||
|
||||
response: httpx.Response | None = None
|
||||
start_time = time.monotonic()
|
||||
try:
|
||||
with get_vespa_http_client() as http_client:
|
||||
with get_vespa_http_client(
|
||||
timeout=VESPA_MIGRATION_REQUEST_TIMEOUT_S
|
||||
) as http_client:
|
||||
response = http_client.get(url, params=params)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPError as e:
|
||||
error_base = f"Failed to get chunks from Vespa slice {slice_id} with continuation token {continuation_token}."
|
||||
error_base = (
|
||||
f"Failed to get chunks from Vespa slice {slice_id} with continuation token "
|
||||
f"{continuation_token} in {time.monotonic() - start_time:.3f} seconds."
|
||||
)
|
||||
logger.exception(
|
||||
f"Request URL: {e.request.url}\n"
|
||||
f"Request Headers: {e.request.headers}\n"
|
||||
|
||||
@@ -465,6 +465,12 @@ class VespaIndex(DocumentIndex):
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[OldDocumentInsertionRecord]:
|
||||
"""
|
||||
NOTE: Do NOT consider the secondary index here. A separate indexing
|
||||
pipeline will be responsible for indexing to the secondary index. This
|
||||
design is not ideal and we should reconsider this when revamping index
|
||||
swapping.
|
||||
"""
|
||||
if len(index_batch_params.doc_id_to_previous_chunk_cnt) != len(
|
||||
index_batch_params.doc_id_to_new_chunk_cnt
|
||||
):
|
||||
@@ -659,6 +665,10 @@ class VespaIndex(DocumentIndex):
|
||||
"""Note: if the document id does not exist, the update will be a no-op and the
|
||||
function will complete with no errors or exceptions.
|
||||
Handle other exceptions if you wish to implement retry behavior
|
||||
|
||||
NOTE: Remember to handle the secondary index here. There is no separate
|
||||
pipeline for updating chunks in the secondary index. This design is not
|
||||
ideal and we should reconsider this when revamping index swapping.
|
||||
"""
|
||||
if fields is None and user_fields is None:
|
||||
logger.warning(
|
||||
@@ -679,13 +689,6 @@ class VespaIndex(DocumentIndex):
|
||||
f"Bug: Tenant ID mismatch. Expected {tenant_state.tenant_id}, got {tenant_id}."
|
||||
)
|
||||
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=self.index_name,
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=self.large_chunks_enabled,
|
||||
httpx_client=self.httpx_client,
|
||||
)
|
||||
|
||||
project_ids: set[int] | None = None
|
||||
if user_fields is not None and user_fields.user_projects is not None:
|
||||
project_ids = set(user_fields.user_projects)
|
||||
@@ -705,7 +708,20 @@ class VespaIndex(DocumentIndex):
|
||||
persona_ids=persona_ids,
|
||||
)
|
||||
|
||||
vespa_document_index.update([update_request])
|
||||
indices = [self.index_name]
|
||||
if self.secondary_index_name:
|
||||
indices.append(self.secondary_index_name)
|
||||
|
||||
for index_name in indices:
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=index_name,
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=self.index_to_large_chunks_enabled.get(
|
||||
index_name, False
|
||||
),
|
||||
httpx_client=self.httpx_client,
|
||||
)
|
||||
vespa_document_index.update([update_request])
|
||||
|
||||
def delete_single(
|
||||
self,
|
||||
@@ -714,6 +730,11 @@ class VespaIndex(DocumentIndex):
|
||||
tenant_id: str,
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
"""
|
||||
NOTE: Remember to handle the secondary index here. There is no separate
|
||||
pipeline for deleting chunks in the secondary index. This design is not
|
||||
ideal and we should reconsider this when revamping index swapping.
|
||||
"""
|
||||
tenant_state = TenantState(
|
||||
tenant_id=get_current_tenant_id(),
|
||||
multitenant=MULTI_TENANT,
|
||||
@@ -726,13 +747,25 @@ class VespaIndex(DocumentIndex):
|
||||
raise ValueError(
|
||||
f"Bug: Tenant ID mismatch. Expected {tenant_state.tenant_id}, got {tenant_id}."
|
||||
)
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=self.index_name,
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=self.large_chunks_enabled,
|
||||
httpx_client=self.httpx_client,
|
||||
)
|
||||
return vespa_document_index.delete(document_id=doc_id, chunk_count=chunk_count)
|
||||
indices = [self.index_name]
|
||||
if self.secondary_index_name:
|
||||
indices.append(self.secondary_index_name)
|
||||
|
||||
total_chunks_deleted = 0
|
||||
for index_name in indices:
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=index_name,
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=self.index_to_large_chunks_enabled.get(
|
||||
index_name, False
|
||||
),
|
||||
httpx_client=self.httpx_client,
|
||||
)
|
||||
total_chunks_deleted += vespa_document_index.delete(
|
||||
document_id=doc_id, chunk_count=chunk_count
|
||||
)
|
||||
|
||||
return total_chunks_deleted
|
||||
|
||||
def id_based_retrieval(
|
||||
self,
|
||||
|
||||
@@ -52,7 +52,9 @@ def replace_invalid_doc_id_characters(text: str) -> str:
|
||||
return text.replace("'", "_")
|
||||
|
||||
|
||||
def get_vespa_http_client(no_timeout: bool = False, http2: bool = True) -> httpx.Client:
|
||||
def get_vespa_http_client(
|
||||
no_timeout: bool = False, http2: bool = True, timeout: int | None = None
|
||||
) -> httpx.Client:
|
||||
"""
|
||||
Configures and returns an HTTP client for communicating with Vespa,
|
||||
including authentication if needed.
|
||||
@@ -64,7 +66,7 @@ def get_vespa_http_client(no_timeout: bool = False, http2: bool = True) -> httpx
|
||||
else None
|
||||
),
|
||||
verify=False if not MANAGED_VESPA else True,
|
||||
timeout=None if no_timeout else VESPA_REQUEST_TIMEOUT,
|
||||
timeout=None if no_timeout else (timeout or VESPA_REQUEST_TIMEOUT),
|
||||
http2=http2,
|
||||
)
|
||||
|
||||
|
||||
@@ -23,11 +23,8 @@ from shared_configs.configs import MULTI_TENANT
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def build_tenant_id_filter(tenant_id: str, include_trailing_and: bool = False) -> str:
|
||||
filter_str = f'({TENANT_ID} contains "{tenant_id}")'
|
||||
if include_trailing_and:
|
||||
filter_str += " and "
|
||||
return filter_str
|
||||
def build_tenant_id_filter(tenant_id: str) -> str:
|
||||
return f'({TENANT_ID} contains "{tenant_id}")'
|
||||
|
||||
|
||||
def build_vespa_filters(
|
||||
@@ -37,30 +34,22 @@ def build_vespa_filters(
|
||||
remove_trailing_and: bool = False, # Set to True when using as a complete Vespa query
|
||||
) -> str:
|
||||
def _build_or_filters(key: str, vals: list[str] | None) -> str:
|
||||
"""For string-based 'contains' filters, e.g. WSET fields or array<string> fields."""
|
||||
"""For string-based 'contains' filters, e.g. WSET fields or array<string> fields.
|
||||
Returns a bare clause like '(key contains "v1" or key contains "v2")' or ""."""
|
||||
if not key or not vals:
|
||||
return ""
|
||||
eq_elems = [f'{key} contains "{val}"' for val in vals if val]
|
||||
if not eq_elems:
|
||||
return ""
|
||||
or_clause = " or ".join(eq_elems)
|
||||
return f"({or_clause}) and "
|
||||
return f"({' or '.join(eq_elems)})"
|
||||
|
||||
def _build_int_or_filters(key: str, vals: list[int] | None) -> str:
|
||||
"""
|
||||
For an integer field filter.
|
||||
If vals is not None, we want *only* docs whose key matches one of vals.
|
||||
"""
|
||||
# If `vals` is None => skip the filter entirely
|
||||
"""For an integer field filter.
|
||||
Returns a bare clause or ""."""
|
||||
if vals is None or not vals:
|
||||
return ""
|
||||
|
||||
# Otherwise build the OR filter
|
||||
eq_elems = [f"{key} = {val}" for val in vals]
|
||||
or_clause = " or ".join(eq_elems)
|
||||
result = f"({or_clause}) and "
|
||||
|
||||
return result
|
||||
return f"({' or '.join(eq_elems)})"
|
||||
|
||||
def _build_kg_filter(
|
||||
kg_entities: list[str] | None,
|
||||
@@ -73,16 +62,12 @@ def build_vespa_filters(
|
||||
combined_filter_parts = []
|
||||
|
||||
def _build_kge(entity: str) -> str:
|
||||
# TYPE-SUBTYPE::ID -> "TYPE-SUBTYPE::ID"
|
||||
# TYPE-SUBTYPE::* -> ({prefix: true}"TYPE-SUBTYPE")
|
||||
# TYPE::* -> ({prefix: true}"TYPE")
|
||||
GENERAL = "::*"
|
||||
if entity.endswith(GENERAL):
|
||||
return f'({{prefix: true}}"{entity.split(GENERAL, 1)[0]}")'
|
||||
else:
|
||||
return f'"{entity}"'
|
||||
|
||||
# OR the entities (give new design)
|
||||
if kg_entities:
|
||||
filter_parts = []
|
||||
for kg_entity in kg_entities:
|
||||
@@ -104,8 +89,7 @@ def build_vespa_filters(
|
||||
|
||||
# TODO: remove kg terms entirely from prompts and codebase
|
||||
|
||||
# AND the combined filter parts
|
||||
return f"({' and '.join(combined_filter_parts)}) and "
|
||||
return f"({' and '.join(combined_filter_parts)})"
|
||||
|
||||
def _build_kg_source_filters(
|
||||
kg_sources: list[str] | None,
|
||||
@@ -114,16 +98,14 @@ def build_vespa_filters(
|
||||
return ""
|
||||
|
||||
source_phrases = [f'{DOCUMENT_ID} contains "{source}"' for source in kg_sources]
|
||||
|
||||
return f"({' or '.join(source_phrases)}) and "
|
||||
return f"({' or '.join(source_phrases)})"
|
||||
|
||||
def _build_kg_chunk_id_zero_only_filter(
|
||||
kg_chunk_id_zero_only: bool,
|
||||
) -> str:
|
||||
if not kg_chunk_id_zero_only:
|
||||
return ""
|
||||
|
||||
return "(chunk_id = 0 ) and "
|
||||
return "(chunk_id = 0)"
|
||||
|
||||
def _build_time_filter(
|
||||
cutoff: datetime | None,
|
||||
@@ -135,8 +117,8 @@ def build_vespa_filters(
|
||||
cutoff_secs = int(cutoff.timestamp())
|
||||
|
||||
if include_untimed:
|
||||
return f"!({DOC_UPDATED_AT} < {cutoff_secs}) and "
|
||||
return f"({DOC_UPDATED_AT} >= {cutoff_secs}) and "
|
||||
return f"!({DOC_UPDATED_AT} < {cutoff_secs})"
|
||||
return f"({DOC_UPDATED_AT} >= {cutoff_secs})"
|
||||
|
||||
def _build_user_project_filter(
|
||||
project_id: int | None,
|
||||
@@ -147,8 +129,7 @@ def build_vespa_filters(
|
||||
pid = int(project_id)
|
||||
except Exception:
|
||||
return ""
|
||||
# Vespa YQL 'contains' expects a string literal; quote the integer
|
||||
return f'({USER_PROJECT} contains "{pid}") and '
|
||||
return f'({USER_PROJECT} contains "{pid}")'
|
||||
|
||||
def _build_persona_filter(
|
||||
persona_id: int | None,
|
||||
@@ -160,73 +141,94 @@ def build_vespa_filters(
|
||||
except Exception:
|
||||
logger.warning(f"Invalid persona ID: {persona_id}")
|
||||
return ""
|
||||
return f'({PERSONAS} contains "{pid}") and '
|
||||
return f'({PERSONAS} contains "{pid}")'
|
||||
|
||||
# Start building the filter string
|
||||
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
|
||||
def _append(parts: list[str], clause: str) -> None:
|
||||
if clause:
|
||||
parts.append(clause)
|
||||
|
||||
# Collect all top-level filter clauses, then join with " and " at the end.
|
||||
filter_parts: list[str] = []
|
||||
|
||||
if not include_hidden:
|
||||
filter_parts.append(f"!({HIDDEN}=true)")
|
||||
|
||||
# TODO: add error condition if MULTI_TENANT and no tenant_id filter is set
|
||||
# If running in multi-tenant mode
|
||||
if filters.tenant_id and MULTI_TENANT:
|
||||
filter_str += build_tenant_id_filter(
|
||||
filters.tenant_id, include_trailing_and=True
|
||||
)
|
||||
filter_parts.append(build_tenant_id_filter(filters.tenant_id))
|
||||
|
||||
# ACL filters
|
||||
if filters.access_control_list is not None:
|
||||
filter_str += _build_or_filters(
|
||||
ACCESS_CONTROL_LIST, filters.access_control_list
|
||||
_append(
|
||||
filter_parts,
|
||||
_build_or_filters(ACCESS_CONTROL_LIST, filters.access_control_list),
|
||||
)
|
||||
|
||||
# Source type filters
|
||||
source_strs = (
|
||||
[s.value for s in filters.source_type] if filters.source_type else None
|
||||
)
|
||||
filter_str += _build_or_filters(SOURCE_TYPE, source_strs)
|
||||
_append(filter_parts, _build_or_filters(SOURCE_TYPE, source_strs))
|
||||
|
||||
# Tag filters
|
||||
tag_attributes = None
|
||||
if filters.tags:
|
||||
# build e.g. "tag_key|tag_value"
|
||||
tag_attributes = [
|
||||
f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}" for tag in filters.tags
|
||||
]
|
||||
filter_str += _build_or_filters(METADATA_LIST, tag_attributes)
|
||||
_append(filter_parts, _build_or_filters(METADATA_LIST, tag_attributes))
|
||||
|
||||
# Document sets
|
||||
filter_str += _build_or_filters(DOCUMENT_SETS, filters.document_set)
|
||||
# Knowledge scope: explicit knowledge attachments (document_sets,
|
||||
# user_file_ids) restrict what an assistant can see. When none are
|
||||
# set, the assistant can see everything.
|
||||
#
|
||||
# project_id / persona_id are additive: they make overflowing user
|
||||
# files findable in Vespa but must NOT trigger the restriction on
|
||||
# their own (an agent with no explicit knowledge should search
|
||||
# everything).
|
||||
knowledge_scope_parts: list[str] = []
|
||||
|
||||
_append(
|
||||
knowledge_scope_parts, _build_or_filters(DOCUMENT_SETS, filters.document_set)
|
||||
)
|
||||
|
||||
# Convert UUIDs to strings for user_file_ids
|
||||
user_file_ids_str = (
|
||||
[str(uuid) for uuid in filters.user_file_ids] if filters.user_file_ids else None
|
||||
)
|
||||
filter_str += _build_or_filters(DOCUMENT_ID, user_file_ids_str)
|
||||
_append(knowledge_scope_parts, _build_or_filters(DOCUMENT_ID, user_file_ids_str))
|
||||
|
||||
# User project filter (array<int> attribute membership)
|
||||
filter_str += _build_user_project_filter(filters.project_id)
|
||||
# Only include project/persona scopes when an explicit knowledge
|
||||
# restriction is already in effect — they widen the scope to also
|
||||
# cover overflowing user files but never restrict on their own.
|
||||
if knowledge_scope_parts:
|
||||
_append(knowledge_scope_parts, _build_user_project_filter(filters.project_id))
|
||||
_append(knowledge_scope_parts, _build_persona_filter(filters.persona_id))
|
||||
|
||||
# Persona filter (array<int> attribute membership)
|
||||
filter_str += _build_persona_filter(filters.persona_id)
|
||||
if len(knowledge_scope_parts) > 1:
|
||||
filter_parts.append("(" + " or ".join(knowledge_scope_parts) + ")")
|
||||
elif len(knowledge_scope_parts) == 1:
|
||||
filter_parts.append(knowledge_scope_parts[0])
|
||||
|
||||
# Time filter
|
||||
filter_str += _build_time_filter(filters.time_cutoff)
|
||||
_append(filter_parts, _build_time_filter(filters.time_cutoff))
|
||||
|
||||
# # Knowledge Graph Filters
|
||||
# filter_str += _build_kg_filter(
|
||||
# _append(filter_parts, _build_kg_filter(
|
||||
# kg_entities=filters.kg_entities,
|
||||
# kg_relationships=filters.kg_relationships,
|
||||
# kg_terms=filters.kg_terms,
|
||||
# )
|
||||
# ))
|
||||
|
||||
# filter_str += _build_kg_source_filters(filters.kg_sources)
|
||||
# _append(filter_parts, _build_kg_source_filters(filters.kg_sources))
|
||||
|
||||
# filter_str += _build_kg_chunk_id_zero_only_filter(
|
||||
# _append(filter_parts, _build_kg_chunk_id_zero_only_filter(
|
||||
# filters.kg_chunk_id_zero_only or False
|
||||
# )
|
||||
# ))
|
||||
|
||||
# Trim trailing " and "
|
||||
if remove_trailing_and and filter_str.endswith(" and "):
|
||||
filter_str = filter_str[:-5]
|
||||
filter_str = " and ".join(filter_parts)
|
||||
|
||||
if filter_str and not remove_trailing_and:
|
||||
filter_str += " and "
|
||||
|
||||
return filter_str
|
||||
|
||||
|
||||
0
backend/onyx/error_handling/__init__.py
Normal file
0
backend/onyx/error_handling/__init__.py
Normal file
101
backend/onyx/error_handling/error_codes.py
Normal file
101
backend/onyx/error_handling/error_codes.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Standardized error codes for the Onyx backend.
|
||||
|
||||
Usage:
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
raise OnyxError(OnyxErrorCode.UNAUTHENTICATED, "Token expired")
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class OnyxErrorCode(Enum):
|
||||
"""
|
||||
Each member is a tuple of (error_code_string, http_status_code).
|
||||
|
||||
The error_code_string is a stable, machine-readable identifier that
|
||||
API consumers can match on. The http_status_code is the default HTTP
|
||||
status to return.
|
||||
"""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Authentication (401)
|
||||
# ------------------------------------------------------------------
|
||||
UNAUTHENTICATED = ("UNAUTHENTICATED", 401)
|
||||
INVALID_TOKEN = ("INVALID_TOKEN", 401)
|
||||
TOKEN_EXPIRED = ("TOKEN_EXPIRED", 401)
|
||||
CSRF_FAILURE = ("CSRF_FAILURE", 403)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Authorization (403)
|
||||
# ------------------------------------------------------------------
|
||||
UNAUTHORIZED = ("UNAUTHORIZED", 403)
|
||||
INSUFFICIENT_PERMISSIONS = ("INSUFFICIENT_PERMISSIONS", 403)
|
||||
ADMIN_ONLY = ("ADMIN_ONLY", 403)
|
||||
EE_REQUIRED = ("EE_REQUIRED", 403)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Validation / Bad Request (400)
|
||||
# ------------------------------------------------------------------
|
||||
VALIDATION_ERROR = ("VALIDATION_ERROR", 400)
|
||||
INVALID_INPUT = ("INVALID_INPUT", 400)
|
||||
MISSING_REQUIRED_FIELD = ("MISSING_REQUIRED_FIELD", 400)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Not Found (404)
|
||||
# ------------------------------------------------------------------
|
||||
NOT_FOUND = ("NOT_FOUND", 404)
|
||||
CONNECTOR_NOT_FOUND = ("CONNECTOR_NOT_FOUND", 404)
|
||||
CREDENTIAL_NOT_FOUND = ("CREDENTIAL_NOT_FOUND", 404)
|
||||
PERSONA_NOT_FOUND = ("PERSONA_NOT_FOUND", 404)
|
||||
DOCUMENT_NOT_FOUND = ("DOCUMENT_NOT_FOUND", 404)
|
||||
SESSION_NOT_FOUND = ("SESSION_NOT_FOUND", 404)
|
||||
USER_NOT_FOUND = ("USER_NOT_FOUND", 404)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Conflict (409)
|
||||
# ------------------------------------------------------------------
|
||||
CONFLICT = ("CONFLICT", 409)
|
||||
DUPLICATE_RESOURCE = ("DUPLICATE_RESOURCE", 409)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Rate Limiting / Quotas (429 / 402)
|
||||
# ------------------------------------------------------------------
|
||||
RATE_LIMITED = ("RATE_LIMITED", 429)
|
||||
SEAT_LIMIT_EXCEEDED = ("SEAT_LIMIT_EXCEEDED", 402)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Connector / Credential Errors (400-range)
|
||||
# ------------------------------------------------------------------
|
||||
CONNECTOR_VALIDATION_FAILED = ("CONNECTOR_VALIDATION_FAILED", 400)
|
||||
CREDENTIAL_INVALID = ("CREDENTIAL_INVALID", 400)
|
||||
CREDENTIAL_EXPIRED = ("CREDENTIAL_EXPIRED", 401)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Server Errors (5xx)
|
||||
# ------------------------------------------------------------------
|
||||
INTERNAL_ERROR = ("INTERNAL_ERROR", 500)
|
||||
NOT_IMPLEMENTED = ("NOT_IMPLEMENTED", 501)
|
||||
SERVICE_UNAVAILABLE = ("SERVICE_UNAVAILABLE", 503)
|
||||
BAD_GATEWAY = ("BAD_GATEWAY", 502)
|
||||
LLM_PROVIDER_ERROR = ("LLM_PROVIDER_ERROR", 502)
|
||||
GATEWAY_TIMEOUT = ("GATEWAY_TIMEOUT", 504)
|
||||
|
||||
def __init__(self, code: str, status_code: int) -> None:
|
||||
self.code = code
|
||||
self.status_code = status_code
|
||||
|
||||
def detail(self, message: str | None = None) -> dict[str, str]:
|
||||
"""Build a structured error detail dict.
|
||||
|
||||
Returns a dict like:
|
||||
{"error_code": "UNAUTHENTICATED", "detail": "Token expired"}
|
||||
|
||||
If no message is supplied, the error code itself is used as the detail.
|
||||
"""
|
||||
return {
|
||||
"error_code": self.code,
|
||||
"detail": message or self.code,
|
||||
}
|
||||
83
backend/onyx/error_handling/exceptions.py
Normal file
83
backend/onyx/error_handling/exceptions.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""OnyxError — the single exception type for all Onyx business errors.
|
||||
|
||||
Raise ``OnyxError`` instead of ``HTTPException`` in business code. A global
|
||||
FastAPI exception handler (registered via ``register_onyx_exception_handlers``)
|
||||
converts it into a JSON response with the standard
|
||||
``{"error_code": "...", "detail": "..."}`` shape.
|
||||
|
||||
Usage::
|
||||
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Session not found")
|
||||
|
||||
For upstream errors with a dynamic HTTP status (e.g. billing service),
|
||||
use ``status_code_override``::
|
||||
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
detail,
|
||||
status_code_override=upstream_status,
|
||||
)
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class OnyxError(Exception):
|
||||
"""Structured error that maps to a specific ``OnyxErrorCode``.
|
||||
|
||||
Attributes:
|
||||
error_code: The ``OnyxErrorCode`` enum member.
|
||||
detail: Human-readable detail (defaults to the error code string).
|
||||
status_code: HTTP status — either overridden or from the error code.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_code: OnyxErrorCode,
|
||||
detail: str | None = None,
|
||||
*,
|
||||
status_code_override: int | None = None,
|
||||
) -> None:
|
||||
resolved_detail = detail or error_code.code
|
||||
super().__init__(resolved_detail)
|
||||
self.error_code = error_code
|
||||
self.detail = resolved_detail
|
||||
self._status_code_override = status_code_override
|
||||
|
||||
@property
|
||||
def status_code(self) -> int:
|
||||
return self._status_code_override or self.error_code.status_code
|
||||
|
||||
|
||||
def register_onyx_exception_handlers(app: FastAPI) -> None:
|
||||
"""Register a global handler that converts ``OnyxError`` to JSON responses.
|
||||
|
||||
Must be called *after* the app is created but *before* it starts serving.
|
||||
The handler logs at WARNING for 4xx and ERROR for 5xx.
|
||||
"""
|
||||
|
||||
@app.exception_handler(OnyxError)
|
||||
async def _handle_onyx_error(
|
||||
request: Request, # noqa: ARG001
|
||||
exc: OnyxError,
|
||||
) -> JSONResponse:
|
||||
status_code = exc.status_code
|
||||
if status_code >= 500:
|
||||
logger.error(f"OnyxError {exc.error_code.code}: {exc.detail}")
|
||||
elif status_code >= 400:
|
||||
logger.warning(f"OnyxError {exc.error_code.code}: {exc.detail}")
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content=exc.error_code.detail(exc.detail),
|
||||
)
|
||||
@@ -19,12 +19,16 @@ class OnyxMimeTypes:
|
||||
PLAIN_TEXT_MIME_TYPE,
|
||||
"text/markdown",
|
||||
"text/x-markdown",
|
||||
"text/x-log",
|
||||
"text/x-config",
|
||||
"text/tab-separated-values",
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"text/xml",
|
||||
"application/x-yaml",
|
||||
"application/yaml",
|
||||
"text/yaml",
|
||||
"text/x-yaml",
|
||||
}
|
||||
DOCUMENT_MIME_TYPES = {
|
||||
PDF_MIME_TYPE,
|
||||
|
||||
@@ -49,7 +49,6 @@ from onyx.indexing.embedder import IndexingEmbedder
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import IndexingBatchAdapter
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
from onyx.indexing.postgres_sanitization import sanitize_documents_for_postgres
|
||||
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
|
||||
from onyx.llm.factory import get_default_llm_with_vision
|
||||
from onyx.llm.factory import get_llm_for_contextual_rag
|
||||
@@ -65,6 +64,7 @@ from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT1
|
||||
from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT2
|
||||
from onyx.prompts.contextual_retrieval import DOCUMENT_SUMMARY_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_documents_for_postgres
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import abc
|
||||
from typing import cast
|
||||
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
|
||||
@@ -7,6 +8,19 @@ class KvKeyNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def unwrap_str(val: JSON_ro) -> str:
|
||||
"""Unwrap a string stored as {"value": str} in the encrypted KV store.
|
||||
Also handles legacy plain-string values cached in Redis."""
|
||||
if isinstance(val, dict):
|
||||
try:
|
||||
return cast(str, val["value"])
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Expected dict with 'value' key, got keys: {list(val.keys())}"
|
||||
)
|
||||
return cast(str, val)
|
||||
|
||||
|
||||
class KeyValueStore:
|
||||
# In the Multi Tenant case, the tenant context is picked up automatically, it does not need to be passed in
|
||||
# It's read from the global thread level variable
|
||||
|
||||
@@ -22,6 +22,7 @@ class LlmProviderNames(str, Enum):
|
||||
OPENROUTER = "openrouter"
|
||||
AZURE = "azure"
|
||||
OLLAMA_CHAT = "ollama_chat"
|
||||
LM_STUDIO = "lm_studio"
|
||||
MISTRAL = "mistral"
|
||||
LITELLM_PROXY = "litellm_proxy"
|
||||
|
||||
@@ -41,6 +42,7 @@ WELL_KNOWN_PROVIDER_NAMES = [
|
||||
LlmProviderNames.OPENROUTER,
|
||||
LlmProviderNames.AZURE,
|
||||
LlmProviderNames.OLLAMA_CHAT,
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
]
|
||||
|
||||
|
||||
@@ -56,6 +58,7 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
|
||||
LlmProviderNames.AZURE: "Azure",
|
||||
"ollama": "Ollama",
|
||||
LlmProviderNames.OLLAMA_CHAT: "Ollama",
|
||||
LlmProviderNames.LM_STUDIO: "LM Studio",
|
||||
"groq": "Groq",
|
||||
"anyscale": "Anyscale",
|
||||
"deepseek": "DeepSeek",
|
||||
@@ -103,6 +106,7 @@ AGGREGATOR_PROVIDERS: set[str] = {
|
||||
LlmProviderNames.BEDROCK_CONVERSE,
|
||||
LlmProviderNames.OPENROUTER,
|
||||
LlmProviderNames.OLLAMA_CHAT,
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
LlmProviderNames.VERTEX_AI,
|
||||
LlmProviderNames.AZURE,
|
||||
}
|
||||
|
||||
@@ -20,7 +20,9 @@ from onyx.llm.multi_llm import LitellmLLM
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.utils import get_max_input_tokens_from_llm_provider
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.llm.well_known_providers.constants import OLLAMA_API_KEY_CONFIG_KEY
|
||||
from onyx.llm.well_known_providers.constants import (
|
||||
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING,
|
||||
)
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.utils.headers import build_llm_extra_headers
|
||||
@@ -32,14 +34,18 @@ logger = setup_logger()
|
||||
def _build_provider_extra_headers(
|
||||
provider: str, custom_config: dict[str, str] | None
|
||||
) -> dict[str, str]:
|
||||
if provider == LlmProviderNames.OLLAMA_CHAT and custom_config:
|
||||
raw_api_key = custom_config.get(OLLAMA_API_KEY_CONFIG_KEY)
|
||||
api_key = raw_api_key.strip() if raw_api_key else None
|
||||
if provider in PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING and custom_config:
|
||||
raw = custom_config.get(PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING[provider])
|
||||
api_key = raw.strip() if raw else None
|
||||
if not api_key:
|
||||
return {}
|
||||
if not api_key.lower().startswith("bearer "):
|
||||
api_key = f"Bearer {api_key}"
|
||||
return {"Authorization": api_key}
|
||||
return {
|
||||
"Authorization": (
|
||||
api_key
|
||||
if api_key.lower().startswith("bearer ")
|
||||
else f"Bearer {api_key}"
|
||||
)
|
||||
}
|
||||
|
||||
# Passing these will put Onyx on the OpenRouter leaderboard
|
||||
elif provider == LlmProviderNames.OPENROUTER:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user