mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-03 22:55:46 +00:00
Compare commits
19 Commits
action_too
...
worktree-o
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0f92287880 | ||
|
|
491c3bf906 | ||
|
|
1d898e343f | ||
|
|
0b129e24ea | ||
|
|
267ac5ac60 | ||
|
|
bf80211eae | ||
|
|
135385e57b | ||
|
|
f06630bc1b | ||
|
|
4495df98cf | ||
|
|
0124937aa8 | ||
|
|
aec2d24706 | ||
|
|
16ebb55362 | ||
|
|
ab6c11319e | ||
|
|
05f5b96964 | ||
|
|
f525aa175b | ||
|
|
4ba6e5f735 | ||
|
|
992ad3b8d4 | ||
|
|
a6404f8b3e | ||
|
|
efc49c9f6b |
154
.cursor/skills/onyx-cli/SKILL.md
Normal file
154
.cursor/skills/onyx-cli/SKILL.md
Normal file
@@ -0,0 +1,154 @@
|
||||
---
|
||||
name: onyx-cli
|
||||
description: Query the Onyx knowledge base using the onyx command. Use when the user wants to search company documents, ask questions about internal knowledge, query connected data sources, or look up information stored in Onyx.
|
||||
---
|
||||
|
||||
# Onyx CLI — Agent Tool
|
||||
|
||||
Onyx is an enterprise search and Gen-AI platform that connects to company documents, apps, and people. The `onyx` CLI provides non-interactive commands to query the Onyx knowledge base and list available agents.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### 1. Check if installed
|
||||
|
||||
```bash
|
||||
which onyx
|
||||
```
|
||||
|
||||
### 2. Install (if needed)
|
||||
|
||||
**Primary — pip:**
|
||||
|
||||
```bash
|
||||
pip install onyx-cli
|
||||
```
|
||||
|
||||
**From source (Go):**
|
||||
|
||||
```bash
|
||||
cd cli && go build -o onyx . && sudo mv onyx /usr/local/bin/
|
||||
```
|
||||
|
||||
### 3. Check if configured
|
||||
|
||||
The CLI is configured when `~/.config/onyx-cli/config.json` exists and contains an `api_key`. Check with:
|
||||
|
||||
```bash
|
||||
test -s ~/.config/onyx-cli/config.json && echo "configured" || echo "not configured"
|
||||
```
|
||||
|
||||
If unconfigured, you have two options:
|
||||
|
||||
**Option A — Interactive setup (requires user input):**
|
||||
|
||||
```bash
|
||||
onyx configure
|
||||
```
|
||||
|
||||
This prompts for the Onyx server URL and API key, tests the connection, and saves config.
|
||||
|
||||
**Option B — Environment variables (non-interactive, preferred for agents):**
|
||||
|
||||
```bash
|
||||
export ONYX_SERVER_URL="https://your-onyx-server.com" # default: http://localhost:3000
|
||||
export ONYX_API_KEY="your-api-key"
|
||||
```
|
||||
|
||||
Environment variables override the config file. If these are set, no config file is needed.
|
||||
|
||||
| Variable | Required | Description |
|
||||
|----------|----------|-------------|
|
||||
| `ONYX_SERVER_URL` | No | Onyx server base URL (default: `http://localhost:3000`) |
|
||||
| `ONYX_API_KEY` | Yes | API key for authentication |
|
||||
| `DANSWER_API_KEY` | No | Legacy fallback for `ONYX_API_KEY` |
|
||||
| `ONYX_PERSONA_ID` | No | Default agent/persona ID |
|
||||
|
||||
If neither the config file nor environment variables are set, tell the user that `onyx` needs to be configured and ask them to either:
|
||||
- Run `onyx configure` interactively, or
|
||||
- Set `ONYX_SERVER_URL` and `ONYX_API_KEY` environment variables
|
||||
|
||||
## Commands
|
||||
|
||||
### List available agents
|
||||
|
||||
```bash
|
||||
onyx agents
|
||||
```
|
||||
|
||||
Prints a table of agent IDs, names, and descriptions. Use `--json` for structured output:
|
||||
|
||||
```bash
|
||||
onyx agents --json
|
||||
```
|
||||
|
||||
Use agent IDs with `ask --agent-id` to query a specific agent.
|
||||
|
||||
### Basic query (plain text output)
|
||||
|
||||
```bash
|
||||
onyx ask "What is our company's PTO policy?"
|
||||
```
|
||||
|
||||
Streams the answer as plain text to stdout. Exit code 0 on success, non-zero on error.
|
||||
|
||||
### JSON output (structured events)
|
||||
|
||||
```bash
|
||||
onyx ask --json "What authentication methods do we support?"
|
||||
```
|
||||
|
||||
Outputs JSON-encoded parsed stream events (one object per line). Key event objects include message deltas, stop, errors, search-start, and citation payloads.
|
||||
|
||||
| Event Type | Description |
|
||||
|------------|-------------|
|
||||
| `MessageDeltaEvent` | Content token — concatenate all `content` fields for the full answer |
|
||||
| `StopEvent` | Stream complete |
|
||||
| `ErrorEvent` | Error with `error` message field |
|
||||
| `SearchStartEvent` | Onyx started searching documents |
|
||||
| `CitationEvent` | Source citation with `citation_number` and `document_id` |
|
||||
|
||||
### Specify an agent
|
||||
|
||||
```bash
|
||||
onyx ask --agent-id 5 "Summarize our Q4 roadmap"
|
||||
```
|
||||
|
||||
Uses a specific Onyx agent/persona instead of the default.
|
||||
|
||||
### All flags
|
||||
|
||||
| Flag | Type | Description |
|
||||
|------|------|-------------|
|
||||
| `--agent-id` | int | Agent ID to use (overrides default) |
|
||||
| `--json` | bool | Output raw NDJSON events instead of plain text |
|
||||
|
||||
## When to Use
|
||||
|
||||
Use `onyx ask` when:
|
||||
|
||||
- The user asks about company-specific information (policies, docs, processes)
|
||||
- You need to search internal knowledge bases or connected data sources
|
||||
- The user references Onyx, asks you to "search Onyx", or wants to query their documents
|
||||
- You need context from company wikis, Confluence, Google Drive, Slack, or other connected sources
|
||||
|
||||
Do NOT use when:
|
||||
|
||||
- The question is about general programming knowledge (use your own knowledge)
|
||||
- The user is asking about code in the current repository (use grep/read tools)
|
||||
- The user hasn't mentioned Onyx and the question doesn't require internal company data
|
||||
|
||||
## Examples
|
||||
|
||||
```bash
|
||||
# Simple question
|
||||
onyx ask "What are the steps to deploy to production?"
|
||||
|
||||
# Get structured output for parsing
|
||||
onyx ask --json "List all active API integrations"
|
||||
|
||||
# Use a specialized agent
|
||||
onyx ask --agent-id 3 "What were the action items from last week's standup?"
|
||||
|
||||
# Pipe the answer into another command
|
||||
onyx ask "What is the database schema for users?" | head -20
|
||||
```
|
||||
@@ -54,7 +54,6 @@ runs:
|
||||
shell: bash
|
||||
env:
|
||||
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
|
||||
INTEGRATION_REPOSITORY: ${{ inputs.runs-on-ecr-cache }}
|
||||
TAG: nightly-llm-it-${{ inputs.run-id }}
|
||||
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
|
||||
HEAD_SHA: ${{ inputs.github-sha }}
|
||||
|
||||
26
.github/workflows/deployment.yml
vendored
26
.github/workflows/deployment.yml
vendored
@@ -426,9 +426,8 @@ jobs:
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64,mode=max
|
||||
@@ -500,9 +499,8 @@ jobs:
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64,mode=max
|
||||
@@ -648,8 +646,8 @@ jobs:
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64,mode=max
|
||||
@@ -730,8 +728,8 @@ jobs:
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64,mode=max
|
||||
@@ -864,9 +862,8 @@ jobs:
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64,mode=max
|
||||
@@ -937,9 +934,8 @@ jobs:
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64,mode=max
|
||||
@@ -1076,8 +1072,8 @@ jobs:
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
ENABLE_CRAFT=true
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64,mode=max
|
||||
@@ -1149,8 +1145,8 @@ jobs:
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
ENABLE_CRAFT=true
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64,mode=max
|
||||
@@ -1291,9 +1287,8 @@ jobs:
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64,mode=max
|
||||
@@ -1371,9 +1366,8 @@ jobs:
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64,mode=max
|
||||
|
||||
14
.github/workflows/nightly-llm-provider-chat.yml
vendored
14
.github/workflows/nightly-llm-provider-chat.yml
vendored
@@ -15,10 +15,6 @@ permissions:
|
||||
jobs:
|
||||
provider-chat-test:
|
||||
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
|
||||
secrets: inherit
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
with:
|
||||
openai_models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
|
||||
anthropic_models: ${{ vars.NIGHTLY_LLM_ANTHROPIC_MODELS }}
|
||||
@@ -29,6 +25,16 @@ jobs:
|
||||
ollama_models: ${{ vars.NIGHTLY_LLM_OLLAMA_MODELS }}
|
||||
openrouter_models: ${{ vars.NIGHTLY_LLM_OPENROUTER_MODELS }}
|
||||
strict: true
|
||||
secrets:
|
||||
openai_api_key: ${{ secrets.OPENAI_API_KEY }}
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
bedrock_api_key: ${{ secrets.BEDROCK_API_KEY }}
|
||||
vertex_ai_custom_config_json: ${{ secrets.NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON }}
|
||||
azure_api_key: ${{ secrets.AZURE_API_KEY }}
|
||||
ollama_api_key: ${{ secrets.OLLAMA_API_KEY }}
|
||||
openrouter_api_key: ${{ secrets.OPENROUTER_API_KEY }}
|
||||
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs: [provider-chat-test]
|
||||
|
||||
@@ -114,10 +114,8 @@ jobs:
|
||||
|
||||
- name: Mark workflow as failed if cherry-pick failed
|
||||
if: steps.gate.outputs.should_cherrypick == 'true' && steps.run_cherry_pick.outputs.status == 'failure'
|
||||
env:
|
||||
CHERRY_PICK_REASON: ${{ steps.run_cherry_pick.outputs.reason }}
|
||||
run: |
|
||||
echo "::error::Automated cherry-pick failed (${CHERRY_PICK_REASON})."
|
||||
echo "::error::Automated cherry-pick failed (${{ steps.run_cherry_pick.outputs.reason }})."
|
||||
exit 1
|
||||
|
||||
notify-slack-on-cherry-pick-failure:
|
||||
|
||||
@@ -160,7 +160,7 @@ jobs:
|
||||
cd deployment/docker_compose
|
||||
|
||||
# Get list of running containers
|
||||
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml ps -q)
|
||||
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml -f docker-compose.opensearch.yml ps -q)
|
||||
|
||||
# Collect logs from each container
|
||||
for container in $containers; do
|
||||
|
||||
21
.github/workflows/pr-python-checks.yml
vendored
21
.github/workflows/pr-python-checks.yml
vendored
@@ -8,7 +8,7 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- "release/**"
|
||||
- 'release/**'
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
@@ -21,13 +21,7 @@ jobs:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
# Note: Mypy seems quite optimized for x64 compared to arm64.
|
||||
# Similarly, mypy is single-threaded and incremental, so 2cpu is sufficient.
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=2cpu-linux-x64,
|
||||
"run-id=${{ github.run_id }}-mypy-check",
|
||||
"extras=s3-cache",
|
||||
]
|
||||
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}-mypy-check", "extras=s3-cache"]
|
||||
timeout-minutes: 45
|
||||
|
||||
steps:
|
||||
@@ -58,14 +52,21 @@ jobs:
|
||||
if: ${{ vars.DISABLE_MYPY_CACHE != 'true' }}
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: .mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'pyproject.toml') }}
|
||||
path: backend/.mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
restore-keys: |
|
||||
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
|
||||
mypy-${{ runner.os }}-
|
||||
|
||||
- name: Run MyPy
|
||||
working-directory: ./backend
|
||||
env:
|
||||
MYPY_FORCE_COLOR: 1
|
||||
TERM: xterm-256color
|
||||
run: mypy .
|
||||
|
||||
- name: Run MyPy (tools/)
|
||||
env:
|
||||
MYPY_FORCE_COLOR: 1
|
||||
TERM: xterm-256color
|
||||
run: mypy tools/
|
||||
|
||||
@@ -48,10 +48,28 @@ on:
|
||||
required: false
|
||||
default: true
|
||||
type: boolean
|
||||
secrets:
|
||||
openai_api_key:
|
||||
required: false
|
||||
anthropic_api_key:
|
||||
required: false
|
||||
bedrock_api_key:
|
||||
required: false
|
||||
vertex_ai_custom_config_json:
|
||||
required: false
|
||||
azure_api_key:
|
||||
required: false
|
||||
ollama_api_key:
|
||||
required: false
|
||||
openrouter_api_key:
|
||||
required: false
|
||||
DOCKER_USERNAME:
|
||||
required: true
|
||||
DOCKER_TOKEN:
|
||||
required: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
|
||||
jobs:
|
||||
build-backend-image:
|
||||
@@ -63,7 +81,6 @@ jobs:
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
environment: ci-protected
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
@@ -72,19 +89,6 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, test/docker-username
|
||||
DOCKER_TOKEN, test/docker-token
|
||||
|
||||
- name: Build backend image
|
||||
uses: ./.github/actions/build-backend-image
|
||||
with:
|
||||
@@ -93,8 +97,8 @@ jobs:
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ env.DOCKER_USERNAME }}
|
||||
docker-token: ${{ env.DOCKER_TOKEN }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
docker-no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' && 'true' || 'false' }}
|
||||
|
||||
build-model-server-image:
|
||||
@@ -106,7 +110,6 @@ jobs:
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
environment: ci-protected
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
@@ -115,19 +118,6 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, test/docker-username
|
||||
DOCKER_TOKEN, test/docker-token
|
||||
|
||||
- name: Build model server image
|
||||
uses: ./.github/actions/build-model-server-image
|
||||
with:
|
||||
@@ -136,8 +126,8 @@ jobs:
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ env.DOCKER_USERNAME }}
|
||||
docker-token: ${{ env.DOCKER_TOKEN }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
build-integration-image:
|
||||
runs-on:
|
||||
@@ -148,7 +138,6 @@ jobs:
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
environment: ci-protected
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
@@ -157,19 +146,6 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, test/docker-username
|
||||
DOCKER_TOKEN, test/docker-token
|
||||
|
||||
- name: Build integration image
|
||||
uses: ./.github/actions/build-integration-image
|
||||
with:
|
||||
@@ -178,8 +154,8 @@ jobs:
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ env.DOCKER_USERNAME }}
|
||||
docker-token: ${{ env.DOCKER_TOKEN }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
provider-chat-test:
|
||||
needs:
|
||||
@@ -194,56 +170,56 @@ jobs:
|
||||
include:
|
||||
- provider: openai
|
||||
models: ${{ inputs.openai_models }}
|
||||
api_key_env: OPENAI_API_KEY
|
||||
custom_config_env: ""
|
||||
api_key_secret: openai_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: true
|
||||
- provider: anthropic
|
||||
models: ${{ inputs.anthropic_models }}
|
||||
api_key_env: ANTHROPIC_API_KEY
|
||||
custom_config_env: ""
|
||||
api_key_secret: anthropic_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: true
|
||||
- provider: bedrock
|
||||
models: ${{ inputs.bedrock_models }}
|
||||
api_key_env: BEDROCK_API_KEY
|
||||
custom_config_env: ""
|
||||
api_key_secret: bedrock_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: vertex_ai
|
||||
models: ${{ inputs.vertex_ai_models }}
|
||||
api_key_env: ""
|
||||
custom_config_env: NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON
|
||||
api_key_secret: ""
|
||||
custom_config_secret: vertex_ai_custom_config_json
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: azure
|
||||
models: ${{ inputs.azure_models }}
|
||||
api_key_env: AZURE_API_KEY
|
||||
custom_config_env: ""
|
||||
api_key_secret: azure_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: ${{ inputs.azure_api_base }}
|
||||
api_version: "2025-04-01-preview"
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: ollama_chat
|
||||
models: ${{ inputs.ollama_models }}
|
||||
api_key_env: OLLAMA_API_KEY
|
||||
custom_config_env: ""
|
||||
api_key_secret: ollama_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: "https://ollama.com"
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: openrouter
|
||||
models: ${{ inputs.openrouter_models }}
|
||||
api_key_env: OPENROUTER_API_KEY
|
||||
custom_config_env: ""
|
||||
api_key_secret: openrouter_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: "https://openrouter.ai/api/v1"
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
@@ -254,7 +230,6 @@ jobs:
|
||||
- "run-id=${{ github.run_id }}-nightly-${{ matrix.provider }}-provider-chat-test"
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 45
|
||||
environment: ci-protected
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
@@ -263,43 +238,21 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
# Keep JSON values unparsed so vertex custom config is passed as raw JSON.
|
||||
parse-json-secrets: false
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, test/docker-username
|
||||
DOCKER_TOKEN, test/docker-token
|
||||
OPENAI_API_KEY, test/openai-api-key
|
||||
ANTHROPIC_API_KEY, test/anthropic-api-key
|
||||
BEDROCK_API_KEY, test/bedrock-api-key
|
||||
NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON, test/nightly-llm-vertex-ai-custom-config-json
|
||||
AZURE_API_KEY, test/azure-api-key
|
||||
OLLAMA_API_KEY, test/ollama-api-key
|
||||
OPENROUTER_API_KEY, test/openrouter-api-key
|
||||
|
||||
- name: Run nightly provider chat test
|
||||
uses: ./.github/actions/run-nightly-provider-chat-test
|
||||
with:
|
||||
provider: ${{ matrix.provider }}
|
||||
models: ${{ matrix.models }}
|
||||
provider-api-key: ${{ matrix.api_key_env && env[matrix.api_key_env] || '' }}
|
||||
provider-api-key: ${{ matrix.api_key_secret && secrets[matrix.api_key_secret] || '' }}
|
||||
strict: ${{ inputs.strict && 'true' || 'false' }}
|
||||
api-base: ${{ matrix.api_base }}
|
||||
api-version: ${{ matrix.api_version }}
|
||||
deployment-name: ${{ matrix.deployment_name }}
|
||||
custom-config-json: ${{ matrix.custom_config_env && env[matrix.custom_config_env] || '' }}
|
||||
custom-config-json: ${{ matrix.custom_config_secret && secrets[matrix.custom_config_secret] || '' }}
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ env.DOCKER_USERNAME }}
|
||||
docker-token: ${{ env.DOCKER_TOKEN }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Dump API server logs
|
||||
if: always()
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
"""add cache_store table
|
||||
|
||||
Revision ID: 2664261bfaab
|
||||
Revises: 4a1e4b1c89d2
|
||||
Create Date: 2026-02-27 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2664261bfaab"
|
||||
down_revision = "4a1e4b1c89d2"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"cache_store",
|
||||
sa.Column("key", sa.String(), nullable=False),
|
||||
sa.Column("value", sa.LargeBinary(), nullable=True),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint("key"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_cache_store_expires",
|
||||
"cache_store",
|
||||
["expires_at"],
|
||||
postgresql_where=sa.text("expires_at IS NOT NULL"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_cache_store_expires", table_name="cache_store")
|
||||
op.drop_table("cache_store")
|
||||
@@ -1,51 +0,0 @@
|
||||
"""Add INDEXING to UserFileStatus
|
||||
|
||||
Revision ID: 4a1e4b1c89d2
|
||||
Revises: 6b3b4083c5aa
|
||||
Create Date: 2026-02-28 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = "4a1e4b1c89d2"
|
||||
down_revision = "6b3b4083c5aa"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
TABLE = "user_file"
|
||||
COLUMN = "status"
|
||||
CONSTRAINT_NAME = "ck_user_file_status"
|
||||
|
||||
OLD_VALUES = ("PROCESSING", "COMPLETED", "FAILED", "CANCELED", "DELETING")
|
||||
NEW_VALUES = ("PROCESSING", "INDEXING", "COMPLETED", "FAILED", "CANCELED", "DELETING")
|
||||
|
||||
|
||||
def _drop_status_check_constraint() -> None:
|
||||
"""Drop the existing CHECK constraint on user_file.status.
|
||||
|
||||
The constraint name is auto-generated by SQLAlchemy and unknown,
|
||||
so we look it up via the inspector.
|
||||
"""
|
||||
inspector = sa.inspect(op.get_bind())
|
||||
for constraint in inspector.get_check_constraints(TABLE):
|
||||
if COLUMN in constraint.get("sqltext", ""):
|
||||
constraint_name = constraint["name"]
|
||||
if constraint_name is not None:
|
||||
op.drop_constraint(constraint_name, TABLE, type_="check")
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
_drop_status_check_constraint()
|
||||
in_clause = ", ".join(f"'{v}'" for v in NEW_VALUES)
|
||||
op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
f"UPDATE {TABLE} SET {COLUMN} = 'PROCESSING' WHERE {COLUMN} = 'INDEXING'"
|
||||
)
|
||||
op.drop_constraint(CONSTRAINT_NAME, TABLE, type_="check")
|
||||
in_clause = ", ".join(f"'{v}'" for v in OLD_VALUES)
|
||||
op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})")
|
||||
@@ -1,112 +0,0 @@
|
||||
"""persona cleanup and featured
|
||||
|
||||
Revision ID: 6b3b4083c5aa
|
||||
Revises: 57122d037335
|
||||
Create Date: 2026-02-26 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "6b3b4083c5aa"
|
||||
down_revision = "57122d037335"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add featured column with nullable=True first
|
||||
op.add_column("persona", sa.Column("featured", sa.Boolean(), nullable=True))
|
||||
|
||||
# Migrate data from is_default_persona to featured
|
||||
op.execute("UPDATE persona SET featured = is_default_persona")
|
||||
|
||||
# Make featured non-nullable with default=False
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"featured",
|
||||
existing_type=sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
)
|
||||
|
||||
# Drop is_default_persona column
|
||||
op.drop_column("persona", "is_default_persona")
|
||||
|
||||
# Drop unused columns
|
||||
op.drop_column("persona", "num_chunks")
|
||||
op.drop_column("persona", "chunks_above")
|
||||
op.drop_column("persona", "chunks_below")
|
||||
op.drop_column("persona", "llm_relevance_filter")
|
||||
op.drop_column("persona", "llm_filter_extraction")
|
||||
op.drop_column("persona", "recency_bias")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Add back recency_bias column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"recency_bias",
|
||||
sa.VARCHAR(),
|
||||
nullable=False,
|
||||
server_default="base_decay",
|
||||
),
|
||||
)
|
||||
|
||||
# Add back llm_filter_extraction column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"llm_filter_extraction",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
# Add back llm_relevance_filter column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"llm_relevance_filter",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
# Add back chunks_below column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("chunks_below", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
|
||||
# Add back chunks_above column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("chunks_above", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
|
||||
# Add back num_chunks column
|
||||
op.add_column("persona", sa.Column("num_chunks", sa.Float(), nullable=True))
|
||||
|
||||
# Add back is_default_persona column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"is_default_persona",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
# Migrate data from featured to is_default_persona
|
||||
op.execute("UPDATE persona SET is_default_persona = featured")
|
||||
|
||||
# Drop featured column
|
||||
op.drop_column("persona", "featured")
|
||||
@@ -1,34 +0,0 @@
|
||||
"""make scim_user_mapping.external_id nullable
|
||||
|
||||
Revision ID: a3b8d9e2f1c4
|
||||
Revises: 2664261bfaab
|
||||
Create Date: 2026-03-02
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a3b8d9e2f1c4"
|
||||
down_revision = "2664261bfaab"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column(
|
||||
"scim_user_mapping",
|
||||
"external_id",
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Delete any rows where external_id is NULL before re-applying NOT NULL
|
||||
op.execute("DELETE FROM scim_user_mapping WHERE external_id IS NULL")
|
||||
op.alter_column(
|
||||
"scim_user_mapping",
|
||||
"external_id",
|
||||
nullable=False,
|
||||
)
|
||||
@@ -126,16 +126,12 @@ class ScimDAL(DAL):
|
||||
|
||||
def create_user_mapping(
|
||||
self,
|
||||
external_id: str | None,
|
||||
external_id: str,
|
||||
user_id: UUID,
|
||||
scim_username: str | None = None,
|
||||
fields: ScimMappingFields | None = None,
|
||||
) -> ScimUserMapping:
|
||||
"""Create a SCIM mapping for a user.
|
||||
|
||||
``external_id`` may be ``None`` when the IdP omits it (RFC 7643
|
||||
allows this). The mapping still marks the user as SCIM-managed.
|
||||
"""
|
||||
"""Create a mapping between a SCIM externalId and an Onyx user."""
|
||||
f = fields or ScimMappingFields()
|
||||
mapping = ScimUserMapping(
|
||||
external_id=external_id,
|
||||
@@ -274,13 +270,8 @@ class ScimDAL(DAL):
|
||||
Raises:
|
||||
ValueError: If the filter uses an unsupported attribute.
|
||||
"""
|
||||
# Inner-join with ScimUserMapping so only SCIM-managed users appear.
|
||||
# Pre-existing system accounts (anonymous, admin, etc.) are excluded
|
||||
# unless they were explicitly linked via SCIM provisioning.
|
||||
query = (
|
||||
select(User)
|
||||
.join(ScimUserMapping, ScimUserMapping.user_id == User.id)
|
||||
.where(User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER]))
|
||||
query = select(User).where(
|
||||
User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER])
|
||||
)
|
||||
|
||||
if scim_filter:
|
||||
@@ -330,37 +321,34 @@ class ScimDAL(DAL):
|
||||
scim_username: str | None = None,
|
||||
fields: ScimMappingFields | None = None,
|
||||
) -> None:
|
||||
"""Sync the SCIM mapping for a user.
|
||||
|
||||
If a mapping already exists, its fields are updated (including
|
||||
setting ``external_id`` to ``None`` when the IdP omits it).
|
||||
If no mapping exists and ``new_external_id`` is provided, a new
|
||||
mapping is created. A mapping is never deleted here — SCIM-managed
|
||||
users must retain their mapping to remain visible in ``GET /Users``.
|
||||
"""Create, update, or delete the external ID mapping for a user.
|
||||
|
||||
When *fields* is provided, all mapping fields are written
|
||||
unconditionally — including ``None`` values — so that a caller can
|
||||
clear a previously-set field (e.g. removing a department).
|
||||
"""
|
||||
mapping = self.get_user_mapping_by_user_id(user_id)
|
||||
if mapping:
|
||||
if mapping.external_id != new_external_id:
|
||||
mapping.external_id = new_external_id
|
||||
if scim_username is not None:
|
||||
mapping.scim_username = scim_username
|
||||
if fields is not None:
|
||||
mapping.department = fields.department
|
||||
mapping.manager = fields.manager
|
||||
mapping.given_name = fields.given_name
|
||||
mapping.family_name = fields.family_name
|
||||
mapping.scim_emails_json = fields.scim_emails_json
|
||||
elif new_external_id:
|
||||
self.create_user_mapping(
|
||||
external_id=new_external_id,
|
||||
user_id=user_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
if new_external_id:
|
||||
if mapping:
|
||||
if mapping.external_id != new_external_id:
|
||||
mapping.external_id = new_external_id
|
||||
if scim_username is not None:
|
||||
mapping.scim_username = scim_username
|
||||
if fields is not None:
|
||||
mapping.department = fields.department
|
||||
mapping.manager = fields.manager
|
||||
mapping.given_name = fields.given_name
|
||||
mapping.family_name = fields.family_name
|
||||
mapping.scim_emails_json = fields.scim_emails_json
|
||||
else:
|
||||
self.create_user_mapping(
|
||||
external_id=new_external_id,
|
||||
user_id=user_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
elif mapping:
|
||||
self.delete_user_mapping(mapping.id)
|
||||
|
||||
def _get_user_mappings_batch(
|
||||
self, user_ids: list[UUID]
|
||||
|
||||
@@ -31,7 +31,6 @@ from ee.onyx.server.query_and_chat.query_backend import (
|
||||
from ee.onyx.server.query_and_chat.search_backend import router as search_router
|
||||
from ee.onyx.server.query_history.api import router as query_history_router
|
||||
from ee.onyx.server.reporting.usage_export_api import router as usage_export_router
|
||||
from ee.onyx.server.scim.api import register_scim_exception_handlers
|
||||
from ee.onyx.server.scim.api import scim_router
|
||||
from ee.onyx.server.seeding import seed_db
|
||||
from ee.onyx.server.tenants.api import router as tenants_router
|
||||
@@ -168,7 +167,6 @@ def get_application() -> FastAPI:
|
||||
# they use their own SCIM bearer token auth).
|
||||
# Not behind APP_API_PREFIX because IdPs expect /scim/v2/... directly.
|
||||
application.include_router(scim_router)
|
||||
register_scim_exception_handlers(application)
|
||||
|
||||
# Ensure all routes have auth enabled or are explicitly marked as public
|
||||
check_ee_router_auth(application)
|
||||
|
||||
@@ -15,9 +15,7 @@ from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import FastAPI
|
||||
from fastapi import Query
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi_users.password import PasswordHelper
|
||||
@@ -26,7 +24,6 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.scim import ScimDAL
|
||||
from ee.onyx.server.scim.auth import ScimAuthError
|
||||
from ee.onyx.server.scim.auth import verify_scim_token
|
||||
from ee.onyx.server.scim.filtering import parse_scim_filter
|
||||
from ee.onyx.server.scim.models import SCIM_LIST_RESPONSE_SCHEMA
|
||||
@@ -80,22 +77,6 @@ scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
|
||||
_pw_helper = PasswordHelper()
|
||||
|
||||
|
||||
def register_scim_exception_handlers(app: FastAPI) -> None:
|
||||
"""Register SCIM-specific exception handlers on the FastAPI app.
|
||||
|
||||
Call this after ``app.include_router(scim_router)`` so that auth
|
||||
failures from ``verify_scim_token`` return RFC 7644 §3.12 error
|
||||
envelopes (with ``schemas`` and ``status`` fields) instead of
|
||||
FastAPI's default ``{"detail": "..."}`` format.
|
||||
"""
|
||||
|
||||
@app.exception_handler(ScimAuthError)
|
||||
async def _handle_scim_auth_error(
|
||||
_request: Request, exc: ScimAuthError
|
||||
) -> ScimJSONResponse:
|
||||
return _scim_error_response(exc.status_code, exc.detail)
|
||||
|
||||
|
||||
def _get_provider(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
) -> ScimProvider:
|
||||
@@ -423,63 +404,21 @@ def create_user(
|
||||
|
||||
email = user_resource.userName.strip()
|
||||
|
||||
# Check for existing user — if they exist but aren't SCIM-managed yet,
|
||||
# link them to the IdP rather than rejecting with 409.
|
||||
external_id: str | None = user_resource.externalId
|
||||
scim_username: str = user_resource.userName.strip()
|
||||
fields: ScimMappingFields = _fields_from_resource(user_resource)
|
||||
# externalId is how the IdP correlates this user on subsequent requests.
|
||||
# Without it, the IdP can't find the user and will try to re-create,
|
||||
# hitting a 409 conflict — so we require it up front.
|
||||
if not user_resource.externalId:
|
||||
return _scim_error_response(400, "externalId is required")
|
||||
|
||||
existing_user = dal.get_user_by_email(email)
|
||||
if existing_user:
|
||||
existing_mapping = dal.get_user_mapping_by_user_id(existing_user.id)
|
||||
if existing_mapping:
|
||||
return _scim_error_response(409, f"User with email {email} already exists")
|
||||
|
||||
# Adopt pre-existing user into SCIM management.
|
||||
# Reactivating a deactivated user consumes a seat, so enforce the
|
||||
# seat limit the same way replace_user does.
|
||||
if user_resource.active and not existing_user.is_active:
|
||||
seat_error = _check_seat_availability(dal)
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
|
||||
personal_name = _scim_name_to_str(user_resource.name)
|
||||
dal.update_user(
|
||||
existing_user,
|
||||
is_active=user_resource.active,
|
||||
**({"personal_name": personal_name} if personal_name else {}),
|
||||
)
|
||||
|
||||
try:
|
||||
dal.create_user_mapping(
|
||||
external_id=external_id,
|
||||
user_id=existing_user.id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
dal.commit()
|
||||
except IntegrityError:
|
||||
dal.rollback()
|
||||
return _scim_error_response(
|
||||
409, f"User with email {email} already has a SCIM mapping"
|
||||
)
|
||||
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
existing_user,
|
||||
external_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
),
|
||||
status_code=201,
|
||||
)
|
||||
|
||||
# Only enforce seat limit for net-new users — adopting a pre-existing
|
||||
# user doesn't consume a new seat.
|
||||
# Enforce seat limit
|
||||
seat_error = _check_seat_availability(dal)
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
|
||||
# Check for existing user
|
||||
if dal.get_user_by_email(email):
|
||||
return _scim_error_response(409, f"User with email {email} already exists")
|
||||
|
||||
# Create user with a random password (SCIM users authenticate via IdP)
|
||||
personal_name = _scim_name_to_str(user_resource.name)
|
||||
user = User(
|
||||
@@ -497,21 +436,18 @@ def create_user(
|
||||
dal.rollback()
|
||||
return _scim_error_response(409, f"User with email {email} already exists")
|
||||
|
||||
# Always create a SCIM mapping so that the user is marked as
|
||||
# SCIM-managed. externalId may be None (RFC 7643 says it's optional).
|
||||
try:
|
||||
dal.create_user_mapping(
|
||||
external_id=external_id,
|
||||
user_id=user.id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
dal.commit()
|
||||
except IntegrityError:
|
||||
dal.rollback()
|
||||
return _scim_error_response(
|
||||
409, f"User with email {email} already has a SCIM mapping"
|
||||
)
|
||||
# Create SCIM mapping (externalId is validated above, always present)
|
||||
external_id = user_resource.externalId
|
||||
scim_username = user_resource.userName.strip()
|
||||
fields = _fields_from_resource(user_resource)
|
||||
dal.create_user_mapping(
|
||||
external_id=external_id,
|
||||
user_id=user.id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
|
||||
@@ -19,6 +19,7 @@ import hashlib
|
||||
import secrets
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -27,21 +28,6 @@ from onyx.auth.utils import get_hashed_bearer_token_from_request
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import ScimToken
|
||||
|
||||
|
||||
class ScimAuthError(Exception):
|
||||
"""Raised when SCIM bearer token authentication fails.
|
||||
|
||||
Unlike HTTPException, this carries the status and detail so the SCIM
|
||||
exception handler can wrap them in an RFC 7644 §3.12 error envelope
|
||||
with ``schemas`` and ``status`` fields.
|
||||
"""
|
||||
|
||||
def __init__(self, status_code: int, detail: str) -> None:
|
||||
self.status_code = status_code
|
||||
self.detail = detail
|
||||
super().__init__(detail)
|
||||
|
||||
|
||||
SCIM_TOKEN_PREFIX = "onyx_scim_"
|
||||
SCIM_TOKEN_LENGTH = 48
|
||||
|
||||
@@ -96,14 +82,23 @@ def verify_scim_token(
|
||||
"""
|
||||
hashed = _get_hashed_scim_token_from_request(request)
|
||||
if not hashed:
|
||||
raise ScimAuthError(401, "Missing or invalid SCIM bearer token")
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing or invalid SCIM bearer token",
|
||||
)
|
||||
|
||||
token = dal.get_token_by_hash(hashed)
|
||||
|
||||
if not token:
|
||||
raise ScimAuthError(401, "Invalid SCIM bearer token")
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid SCIM bearer token",
|
||||
)
|
||||
|
||||
if not token.is_active:
|
||||
raise ScimAuthError(401, "SCIM token has been revoked")
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="SCIM token has been revoked",
|
||||
)
|
||||
|
||||
return token
|
||||
|
||||
@@ -153,31 +153,26 @@ class ScimProvider(ABC):
|
||||
self,
|
||||
user: User,
|
||||
fields: ScimMappingFields,
|
||||
) -> ScimName:
|
||||
) -> ScimName | None:
|
||||
"""Build SCIM name components for the response.
|
||||
|
||||
Round-trips stored ``given_name``/``family_name`` when available (so
|
||||
the IdP gets back what it sent). Falls back to splitting
|
||||
``personal_name`` for users provisioned before we stored components.
|
||||
Always returns a ScimName — Okta's spec tests expect ``name``
|
||||
(with ``givenName``/``familyName``) on every user resource.
|
||||
Providers may override for custom behavior.
|
||||
"""
|
||||
if fields.given_name is not None or fields.family_name is not None:
|
||||
return ScimName(
|
||||
givenName=fields.given_name or "",
|
||||
familyName=fields.family_name or "",
|
||||
formatted=user.personal_name or "",
|
||||
givenName=fields.given_name,
|
||||
familyName=fields.family_name,
|
||||
formatted=user.personal_name,
|
||||
)
|
||||
if not user.personal_name:
|
||||
# Derive a reasonable name from the email so that SCIM spec tests
|
||||
# see non-empty givenName / familyName for every user resource.
|
||||
local = user.email.split("@")[0] if user.email else ""
|
||||
return ScimName(givenName=local, familyName="", formatted=local)
|
||||
return None
|
||||
parts = user.personal_name.split(" ", 1)
|
||||
return ScimName(
|
||||
givenName=parts[0],
|
||||
familyName=parts[1] if len(parts) > 1 else "",
|
||||
familyName=parts[1] if len(parts) > 1 else None,
|
||||
formatted=user.personal_name,
|
||||
)
|
||||
|
||||
|
||||
@@ -18,8 +18,8 @@ from ee.onyx.server.enterprise_settings.store import (
|
||||
store_settings as store_ee_settings,
|
||||
)
|
||||
from ee.onyx.server.enterprise_settings.store import upload_logo
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import Tool
|
||||
@@ -117,38 +117,15 @@ def _seed_custom_tools(db_session: Session, tools: List[CustomToolSeed]) -> None
|
||||
def _seed_llms(
|
||||
db_session: Session, llm_upsert_requests: list[LLMProviderUpsertRequest]
|
||||
) -> None:
|
||||
if not llm_upsert_requests:
|
||||
return
|
||||
|
||||
logger.notice("Seeding LLMs")
|
||||
for request in llm_upsert_requests:
|
||||
existing = fetch_existing_llm_provider(name=request.name, db_session=db_session)
|
||||
if existing:
|
||||
request.id = existing.id
|
||||
seeded_providers = [
|
||||
upsert_llm_provider(llm_upsert_request, db_session)
|
||||
for llm_upsert_request in llm_upsert_requests
|
||||
]
|
||||
|
||||
default_provider = next(
|
||||
(p for p in seeded_providers if p.model_configurations), None
|
||||
)
|
||||
if not default_provider:
|
||||
return
|
||||
|
||||
visible_configs = [
|
||||
mc for mc in default_provider.model_configurations if mc.is_visible
|
||||
]
|
||||
default_config = (
|
||||
visible_configs[0]
|
||||
if visible_configs
|
||||
else default_provider.model_configurations[0]
|
||||
)
|
||||
update_default_provider(
|
||||
provider_id=default_provider.id,
|
||||
model_name=default_config.name,
|
||||
db_session=db_session,
|
||||
)
|
||||
if llm_upsert_requests:
|
||||
logger.notice("Seeding LLMs")
|
||||
seeded_providers = [
|
||||
upsert_llm_provider(llm_upsert_request, db_session)
|
||||
for llm_upsert_request in llm_upsert_requests
|
||||
]
|
||||
update_default_provider(
|
||||
provider_id=seeded_providers[0].id, db_session=db_session
|
||||
)
|
||||
|
||||
|
||||
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:
|
||||
@@ -160,6 +137,12 @@ def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) ->
|
||||
user=None, # Seeding is done as admin
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=(
|
||||
persona.num_chunks if persona.num_chunks is not None else 0.0
|
||||
),
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
document_set_ids=persona.document_set_ids,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
@@ -171,7 +154,6 @@ def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) ->
|
||||
system_prompt=persona.system_prompt,
|
||||
task_prompt=persona.task_prompt,
|
||||
datetime_aware=persona.datetime_aware,
|
||||
featured=persona.featured,
|
||||
commit=False,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
@@ -109,12 +109,6 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
if metadata.status == _BLOCKING_STATUS:
|
||||
settings.application_status = metadata.status
|
||||
settings.ee_features_enabled = False
|
||||
elif metadata.used_seats > metadata.seats:
|
||||
# License is valid but seat limit exceeded
|
||||
settings.application_status = ApplicationStatus.SEAT_LIMIT_EXCEEDED
|
||||
settings.seat_count = metadata.seats
|
||||
settings.used_seats = metadata.used_seats
|
||||
settings.ee_features_enabled = True
|
||||
else:
|
||||
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
|
||||
settings.ee_features_enabled = True
|
||||
|
||||
@@ -33,7 +33,6 @@ from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.engine.sql_engine import get_session_with_shared_schema
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.image_generation import create_default_image_gen_config_from_api_key
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_cloud_embedding_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
@@ -303,17 +302,12 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
|
||||
has_set_default_provider = False
|
||||
|
||||
def _upsert(request: LLMProviderUpsertRequest, default_model: str) -> None:
|
||||
def _upsert(request: LLMProviderUpsertRequest) -> None:
|
||||
nonlocal has_set_default_provider
|
||||
try:
|
||||
existing = fetch_existing_llm_provider(
|
||||
name=request.name, db_session=db_session
|
||||
)
|
||||
if existing:
|
||||
request.id = existing.id
|
||||
provider = upsert_llm_provider(request, db_session)
|
||||
if not has_set_default_provider:
|
||||
update_default_provider(provider.id, default_model, db_session)
|
||||
update_default_provider(provider.id, db_session)
|
||||
has_set_default_provider = True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure {request.provider} provider: {e}")
|
||||
@@ -331,13 +325,14 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
OPENAI_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(openai_provider, default_model_name)
|
||||
_upsert(openai_provider)
|
||||
|
||||
# Create default image generation config using the OpenAI API key
|
||||
try:
|
||||
@@ -366,13 +361,14 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="Anthropic",
|
||||
provider=ANTHROPIC_PROVIDER_NAME,
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
ANTHROPIC_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(anthropic_provider, default_model_name)
|
||||
_upsert(anthropic_provider)
|
||||
else:
|
||||
logger.info(
|
||||
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
|
||||
@@ -397,13 +393,14 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="Google Vertex AI",
|
||||
provider=VERTEXAI_PROVIDER_NAME,
|
||||
custom_config=custom_config,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
VERTEXAI_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(vertexai_provider, default_model_name)
|
||||
_upsert(vertexai_provider)
|
||||
else:
|
||||
logger.info(
|
||||
"VERTEXAI_DEFAULT_CREDENTIALS not set, skipping Vertex AI provider configuration"
|
||||
@@ -435,11 +432,12 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="OpenRouter",
|
||||
provider=OPENROUTER_PROVIDER_NAME,
|
||||
api_key=OPENROUTER_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=model_configurations,
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(openrouter_provider, default_model_name)
|
||||
_upsert(openrouter_provider)
|
||||
else:
|
||||
logger.info(
|
||||
"OPENROUTER_DEFAULT_API_KEY not set, skipping OpenRouter provider configuration"
|
||||
|
||||
@@ -543,7 +543,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
result = await db_session.execute(
|
||||
select(Persona.id)
|
||||
.where(
|
||||
Persona.featured.is_(True),
|
||||
Persona.is_default_persona.is_(True),
|
||||
Persona.is_public.is_(True),
|
||||
Persona.is_visible.is_(True),
|
||||
Persona.deleted.is_(False),
|
||||
@@ -725,19 +725,11 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
if user_by_session:
|
||||
user = user_by_session
|
||||
|
||||
# If the user is inactive, check seat availability before
|
||||
# upgrading role — otherwise they'd become an inactive BASIC
|
||||
# user who still can't log in.
|
||||
if not user.is_active:
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
enforce_seat_limit(sync_db)
|
||||
|
||||
await self.user_db.update(
|
||||
user,
|
||||
{
|
||||
"is_verified": is_verified_by_default,
|
||||
"role": UserRole.BASIC,
|
||||
**({"is_active": True} if not user.is_active else {}),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -241,7 +241,8 @@ _VECTOR_DB_BEAT_TASK_NAMES: set[str] = {
|
||||
"check-for-index-attempt-cleanup",
|
||||
"check-for-doc-permissions-sync",
|
||||
"check-for-external-group-sync",
|
||||
"migrate-chunks-from-vespa-to-opensearch",
|
||||
"check-for-documents-for-opensearch-migration",
|
||||
"migrate-documents-from-vespa-to-opensearch",
|
||||
}
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
|
||||
@@ -414,31 +414,34 @@ def _process_user_file_with_indexing(
|
||||
raise RuntimeError(f"Indexing pipeline failed for user file {user_file_id}")
|
||||
|
||||
|
||||
def process_user_file_impl(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
"""Core implementation for processing a single user file.
|
||||
|
||||
When redis_locking=True, acquires a per-file Redis lock and clears the
|
||||
queued-key guard (Celery path). When redis_locking=False, skips all Redis
|
||||
operations (BackgroundTask path).
|
||||
"""
|
||||
task_logger.info(f"process_user_file_impl - Starting id={user_file_id}")
|
||||
task_logger.info(f"process_single_user_file - Starting id={user_file_id}")
|
||||
start = time.monotonic()
|
||||
|
||||
file_lock: RedisLock | None = None
|
||||
if redis_locking:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client.delete(_user_file_queued_key(user_file_id))
|
||||
file_lock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Clear the "queued" guard set by the beat generator so that the next beat
|
||||
# cycle can re-enqueue this file if it is still in PROCESSING state after
|
||||
# this task completes or fails.
|
||||
redis_client.delete(_user_file_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
if not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"process_single_user_file - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
if file_lock is not None and not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"process_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
return
|
||||
return None
|
||||
|
||||
documents: list[Document] = []
|
||||
try:
|
||||
@@ -446,18 +449,15 @@ def process_user_file_impl(
|
||||
uf = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if not uf:
|
||||
task_logger.warning(
|
||||
f"process_user_file_impl - UserFile not found id={user_file_id}"
|
||||
f"process_single_user_file - UserFile not found id={user_file_id}"
|
||||
)
|
||||
return
|
||||
return None
|
||||
|
||||
if uf.status not in (
|
||||
UserFileStatus.PROCESSING,
|
||||
UserFileStatus.INDEXING,
|
||||
):
|
||||
if uf.status != UserFileStatus.PROCESSING:
|
||||
task_logger.info(
|
||||
f"process_user_file_impl - Skipping id={user_file_id} status={uf.status}"
|
||||
f"process_single_user_file - Skipping id={user_file_id} status={uf.status}"
|
||||
)
|
||||
return
|
||||
return None
|
||||
|
||||
connector = LocalFileConnector(
|
||||
file_locations=[uf.file_id],
|
||||
@@ -471,6 +471,7 @@ def process_user_file_impl(
|
||||
[doc for doc in batch if not isinstance(doc, HierarchyNode)]
|
||||
)
|
||||
|
||||
# update the document id to userfile id in the documents
|
||||
for document in documents:
|
||||
document.id = str(user_file_id)
|
||||
document.source = DocumentSource.USER_FILE
|
||||
@@ -492,8 +493,9 @@ def process_user_file_impl(
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"process_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
f"process_single_user_file - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
# don't update the status if the user file is being deleted
|
||||
current_user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if (
|
||||
current_user_file
|
||||
@@ -502,43 +504,33 @@ def process_user_file_impl(
|
||||
uf.status = UserFileStatus.FAILED
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
return
|
||||
return None
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"process_user_file_impl - Finished id={user_file_id} docs={len(documents)} elapsed={elapsed:.2f}s"
|
||||
f"process_single_user_file - Finished id={user_file_id} docs={len(documents)} elapsed={elapsed:.2f}s"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
# Attempt to mark the file as failed
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
uf = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if uf:
|
||||
# don't update the status if the user file is being deleted
|
||||
if uf.status != UserFileStatus.DELETING:
|
||||
uf.status = UserFileStatus.FAILED
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
|
||||
task_logger.exception(
|
||||
f"process_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
f"process_single_user_file - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
raise
|
||||
return None
|
||||
finally:
|
||||
if file_lock is not None and file_lock.owned():
|
||||
if file_lock.owned():
|
||||
file_lock.release()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
process_user_file_impl(
|
||||
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
|
||||
)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_USER_FILE_DELETE,
|
||||
soft_time_limit=300,
|
||||
@@ -589,38 +581,36 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def delete_user_file_impl(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file_delete(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
"""Core implementation for deleting a single user file.
|
||||
|
||||
When redis_locking=True, acquires a per-file Redis lock (Celery path).
|
||||
When redis_locking=False, skips Redis operations (BackgroundTask path).
|
||||
"""
|
||||
task_logger.info(f"delete_user_file_impl - Starting id={user_file_id}")
|
||||
|
||||
file_lock: RedisLock | None = None
|
||||
if redis_locking:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
file_lock = redis_client.lock(
|
||||
_user_file_delete_lock_key(user_file_id),
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
"""Process a single user file delete."""
|
||||
task_logger.info(f"process_single_user_file_delete - Starting id={user_file_id}")
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_delete_lock_key(user_file_id),
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
if not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"process_single_user_file_delete - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
if file_lock is not None and not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"delete_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
return
|
||||
|
||||
return None
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if not user_file:
|
||||
task_logger.info(
|
||||
f"delete_user_file_impl - User file not found id={user_file_id}"
|
||||
f"process_single_user_file_delete - User file not found id={user_file_id}"
|
||||
)
|
||||
return
|
||||
return None
|
||||
|
||||
# 1) Delete vector DB chunks (skip when disabled)
|
||||
if not DISABLE_VECTOR_DB:
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
@@ -658,6 +648,7 @@ def delete_user_file_impl(
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
|
||||
# 2) Delete the user-uploaded file content from filestore (blob + metadata)
|
||||
file_store = get_default_file_store()
|
||||
try:
|
||||
file_store.delete_file(user_file.file_id)
|
||||
@@ -665,34 +656,26 @@ def delete_user_file_impl(
|
||||
user_file_id_to_plaintext_file_name(user_file.id)
|
||||
)
|
||||
except Exception as e:
|
||||
# This block executed only if the file is not found in the filestore
|
||||
task_logger.exception(
|
||||
f"delete_user_file_impl - Error deleting file id={user_file.id} - {e.__class__.__name__}"
|
||||
f"process_single_user_file_delete - Error deleting file id={user_file.id} - {e.__class__.__name__}"
|
||||
)
|
||||
|
||||
# 3) Finally, delete the UserFile row
|
||||
db_session.delete(user_file)
|
||||
db_session.commit()
|
||||
task_logger.info(f"delete_user_file_impl - Completed id={user_file_id}")
|
||||
task_logger.info(
|
||||
f"process_single_user_file_delete - Completed id={user_file_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"delete_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
f"process_single_user_file_delete - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
raise
|
||||
return None
|
||||
finally:
|
||||
if file_lock is not None and file_lock.owned():
|
||||
if file_lock.owned():
|
||||
file_lock.release()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file_delete(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
delete_user_file_impl(
|
||||
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -764,30 +747,32 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def project_sync_user_file_impl(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file_project_sync(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
"""Core implementation for syncing a user file's project/persona metadata.
|
||||
"""Process a single user file project sync."""
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - Starting id={user_file_id}"
|
||||
)
|
||||
|
||||
When redis_locking=True, acquires a per-file Redis lock and clears the
|
||||
queued-key guard (Celery path). When redis_locking=False, skips Redis
|
||||
operations (BackgroundTask path).
|
||||
"""
|
||||
task_logger.info(f"project_sync_user_file_impl - Starting id={user_file_id}")
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock | None = None
|
||||
if redis_locking:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
|
||||
file_lock = redis_client.lock(
|
||||
user_file_project_sync_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
user_file_project_sync_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
if not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
if file_lock is not None and not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"project_sync_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
return
|
||||
return None
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -798,10 +783,11 @@ def project_sync_user_file_impl(
|
||||
).scalar_one_or_none()
|
||||
if not user_file:
|
||||
task_logger.info(
|
||||
f"project_sync_user_file_impl - User file not found id={user_file_id}"
|
||||
f"process_single_user_file_project_sync - User file not found id={user_file_id}"
|
||||
)
|
||||
return
|
||||
return None
|
||||
|
||||
# Sync project metadata to vector DB (skip when disabled)
|
||||
if not DISABLE_VECTOR_DB:
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
@@ -836,7 +822,7 @@ def project_sync_user_file_impl(
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"project_sync_user_file_impl - User file id={user_file_id}"
|
||||
f"process_single_user_file_project_sync - User file id={user_file_id}"
|
||||
)
|
||||
|
||||
user_file.needs_project_sync = False
|
||||
@@ -849,22 +835,11 @@ def project_sync_user_file_impl(
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"project_sync_user_file_impl - Error syncing project for file id={user_file_id} - {e.__class__.__name__}"
|
||||
f"process_single_user_file_project_sync - Error syncing project for file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
raise
|
||||
return None
|
||||
finally:
|
||||
if file_lock is not None and file_lock.owned():
|
||||
if file_lock.owned():
|
||||
file_lock.release()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file_project_sync(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -1,307 +0,0 @@
|
||||
"""Periodic poller for NO_VECTOR_DB deployments.
|
||||
|
||||
Replaces Celery Beat and background workers with a lightweight daemon thread
|
||||
that runs from the API server process. Two responsibilities:
|
||||
|
||||
1. Recovery polling (every 30 s): re-processes user files stuck in
|
||||
PROCESSING / DELETING / needs_sync states via the drain loops defined
|
||||
in ``task_utils.py``.
|
||||
|
||||
2. Periodic task execution (configurable intervals): runs LLM model updates
|
||||
and scheduled evals at their configured cadences, with Postgres advisory
|
||||
lock deduplication across multiple API server instances.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
RECOVERY_INTERVAL_SECONDS = 30
|
||||
PERIODIC_TASK_LOCK_BASE = 20_000
|
||||
PERIODIC_TASK_KV_PREFIX = "periodic_poller:last_claimed:"
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Periodic task definitions
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
_NEVER_RAN: float = -1e18
|
||||
|
||||
|
||||
@dataclass
|
||||
class _PeriodicTaskDef:
|
||||
name: str
|
||||
interval_seconds: float
|
||||
lock_id: int
|
||||
run_fn: Callable[[], None]
|
||||
last_run_at: float = field(default=_NEVER_RAN)
|
||||
|
||||
|
||||
def _run_auto_llm_update() -> None:
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
|
||||
if not AUTO_LLM_CONFIG_URL:
|
||||
return
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.llm.well_known_providers.auto_update_service import (
|
||||
sync_llm_models_from_github,
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
sync_llm_models_from_github(db_session)
|
||||
|
||||
|
||||
def _run_cache_cleanup() -> None:
|
||||
from onyx.cache.postgres_backend import cleanup_expired_cache_entries
|
||||
|
||||
cleanup_expired_cache_entries()
|
||||
|
||||
|
||||
def _run_scheduled_eval() -> None:
|
||||
from onyx.configs.app_configs import BRAINTRUST_API_KEY
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_PERMISSIONS_EMAIL
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_PROJECT
|
||||
|
||||
if not all(
|
||||
[
|
||||
BRAINTRUST_API_KEY,
|
||||
SCHEDULED_EVAL_PROJECT,
|
||||
SCHEDULED_EVAL_DATASET_NAMES,
|
||||
SCHEDULED_EVAL_PERMISSIONS_EMAIL,
|
||||
]
|
||||
):
|
||||
return
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from onyx.evals.eval import run_eval
|
||||
from onyx.evals.models import EvalConfigurationOptions
|
||||
|
||||
run_timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
for dataset_name in SCHEDULED_EVAL_DATASET_NAMES:
|
||||
try:
|
||||
run_eval(
|
||||
configuration=EvalConfigurationOptions(
|
||||
search_permissions_email=SCHEDULED_EVAL_PERMISSIONS_EMAIL,
|
||||
dataset_name=dataset_name,
|
||||
no_send_logs=False,
|
||||
braintrust_project=SCHEDULED_EVAL_PROJECT,
|
||||
experiment_name=f"{dataset_name} - {run_timestamp}",
|
||||
),
|
||||
remote_dataset_name=dataset_name,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Periodic poller - Failed scheduled eval for dataset {dataset_name}"
|
||||
)
|
||||
|
||||
|
||||
_CACHE_CLEANUP_INTERVAL_SECONDS = 300
|
||||
|
||||
|
||||
def _build_periodic_tasks() -> list[_PeriodicTaskDef]:
|
||||
from onyx.cache.interface import CacheBackendType
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
|
||||
from onyx.configs.app_configs import CACHE_BACKEND
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
|
||||
tasks: list[_PeriodicTaskDef] = []
|
||||
if CACHE_BACKEND == CacheBackendType.POSTGRES:
|
||||
tasks.append(
|
||||
_PeriodicTaskDef(
|
||||
name="cache-cleanup",
|
||||
interval_seconds=_CACHE_CLEANUP_INTERVAL_SECONDS,
|
||||
lock_id=PERIODIC_TASK_LOCK_BASE + 2,
|
||||
run_fn=_run_cache_cleanup,
|
||||
)
|
||||
)
|
||||
if AUTO_LLM_CONFIG_URL:
|
||||
tasks.append(
|
||||
_PeriodicTaskDef(
|
||||
name="auto-llm-update",
|
||||
interval_seconds=AUTO_LLM_UPDATE_INTERVAL_SECONDS,
|
||||
lock_id=PERIODIC_TASK_LOCK_BASE,
|
||||
run_fn=_run_auto_llm_update,
|
||||
)
|
||||
)
|
||||
if SCHEDULED_EVAL_DATASET_NAMES:
|
||||
tasks.append(
|
||||
_PeriodicTaskDef(
|
||||
name="scheduled-eval",
|
||||
interval_seconds=7 * 24 * 3600,
|
||||
lock_id=PERIODIC_TASK_LOCK_BASE + 1,
|
||||
run_fn=_run_scheduled_eval,
|
||||
)
|
||||
)
|
||||
return tasks
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Periodic task runner with advisory-lock-guarded claim
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def _try_claim_task(task_def: _PeriodicTaskDef) -> bool:
|
||||
"""Atomically check whether *task_def* should run and record a claim.
|
||||
|
||||
Uses a transaction-scoped advisory lock for atomicity combined with a
|
||||
``KVStore`` timestamp for cross-instance dedup. The DB session is held
|
||||
only for this brief claim transaction, not during task execution.
|
||||
"""
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import KVStore
|
||||
|
||||
kv_key = PERIODIC_TASK_KV_PREFIX + task_def.name
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
acquired = db_session.execute(
|
||||
text("SELECT pg_try_advisory_xact_lock(:id)"),
|
||||
{"id": task_def.lock_id},
|
||||
).scalar()
|
||||
if not acquired:
|
||||
return False
|
||||
|
||||
row = db_session.query(KVStore).filter_by(key=kv_key).first()
|
||||
if row and row.value is not None:
|
||||
last_claimed = datetime.fromisoformat(str(row.value))
|
||||
elapsed = (datetime.now(timezone.utc) - last_claimed).total_seconds()
|
||||
if elapsed < task_def.interval_seconds:
|
||||
return False
|
||||
|
||||
now_ts = datetime.now(timezone.utc).isoformat()
|
||||
if row:
|
||||
row.value = now_ts
|
||||
else:
|
||||
db_session.add(KVStore(key=kv_key, value=now_ts))
|
||||
db_session.commit()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _try_run_periodic_task(task_def: _PeriodicTaskDef) -> None:
|
||||
"""Run *task_def* if its interval has elapsed and no peer holds the lock."""
|
||||
now = time.monotonic()
|
||||
if now - task_def.last_run_at < task_def.interval_seconds:
|
||||
return
|
||||
|
||||
if not _try_claim_task(task_def):
|
||||
return
|
||||
|
||||
try:
|
||||
task_def.run_fn()
|
||||
task_def.last_run_at = now
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Periodic poller - Error running periodic task {task_def.name}"
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Recovery / drain loop runner
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def _run_drain_loops(tenant_id: str) -> None:
|
||||
from onyx.background.task_utils import drain_delete_loop
|
||||
from onyx.background.task_utils import drain_processing_loop
|
||||
from onyx.background.task_utils import drain_project_sync_loop
|
||||
|
||||
drain_processing_loop(tenant_id)
|
||||
drain_delete_loop(tenant_id)
|
||||
drain_project_sync_loop(tenant_id)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Startup recovery (10g)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def recover_stuck_user_files(tenant_id: str) -> None:
|
||||
"""Run all drain loops once to re-process files left in intermediate states.
|
||||
|
||||
Called from ``lifespan()`` on startup when ``DISABLE_VECTOR_DB`` is set.
|
||||
"""
|
||||
logger.info("recover_stuck_user_files - Checking for stuck user files")
|
||||
try:
|
||||
_run_drain_loops(tenant_id)
|
||||
except Exception:
|
||||
logger.exception("recover_stuck_user_files - Error during recovery")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Daemon thread (10f)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
_shutdown_event = threading.Event()
|
||||
_poller_thread: threading.Thread | None = None
|
||||
|
||||
|
||||
def _poller_loop(tenant_id: str) -> None:
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
periodic_tasks = _build_periodic_tasks()
|
||||
logger.info(
|
||||
f"Periodic poller started with {len(periodic_tasks)} periodic task(s): "
|
||||
f"{[t.name for t in periodic_tasks]}"
|
||||
)
|
||||
|
||||
while not _shutdown_event.is_set():
|
||||
try:
|
||||
_run_drain_loops(tenant_id)
|
||||
except Exception:
|
||||
logger.exception("Periodic poller - Error in recovery polling")
|
||||
|
||||
for task_def in periodic_tasks:
|
||||
try:
|
||||
_try_run_periodic_task(task_def)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Periodic poller - Unhandled error checking task {task_def.name}"
|
||||
)
|
||||
|
||||
_shutdown_event.wait(RECOVERY_INTERVAL_SECONDS)
|
||||
|
||||
|
||||
def start_periodic_poller(tenant_id: str) -> None:
|
||||
"""Start the periodic poller daemon thread."""
|
||||
global _poller_thread # noqa: PLW0603
|
||||
_shutdown_event.clear()
|
||||
_poller_thread = threading.Thread(
|
||||
target=_poller_loop,
|
||||
args=(tenant_id,),
|
||||
daemon=True,
|
||||
name="no-vectordb-periodic-poller",
|
||||
)
|
||||
_poller_thread.start()
|
||||
logger.info("Periodic poller thread started")
|
||||
|
||||
|
||||
def stop_periodic_poller() -> None:
|
||||
"""Signal the periodic poller to stop and wait for it to exit."""
|
||||
global _poller_thread # noqa: PLW0603
|
||||
if _poller_thread is None:
|
||||
return
|
||||
_shutdown_event.set()
|
||||
_poller_thread.join(timeout=10)
|
||||
if _poller_thread.is_alive():
|
||||
logger.warning("Periodic poller thread did not stop within timeout")
|
||||
_poller_thread = None
|
||||
logger.info("Periodic poller thread stopped")
|
||||
@@ -1,33 +1,3 @@
|
||||
"""Background task utilities.
|
||||
|
||||
Contains query-history report helpers (used by all deployment modes) and
|
||||
in-process background task execution helpers for NO_VECTOR_DB mode:
|
||||
|
||||
- Atomic claim-and-mark helpers that prevent duplicate processing
|
||||
- Drain loops that process all pending user file work
|
||||
|
||||
Each claim function runs a short-lived transaction: SELECT ... FOR UPDATE
|
||||
SKIP LOCKED, UPDATE the row to remove it from future queries, COMMIT.
|
||||
After the commit the row lock is released, but the row is no longer
|
||||
eligible for re-claiming. No long-lived sessions or advisory locks.
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Query-history report helpers (pre-existing, used by all modes)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
QUERY_REPORT_NAME_PREFIX = "query-history"
|
||||
|
||||
|
||||
@@ -39,168 +9,3 @@ def construct_query_history_report_name(
|
||||
|
||||
def extract_task_id_from_query_history_report_name(name: str) -> str:
|
||||
return name.removeprefix(f"{QUERY_REPORT_NAME_PREFIX}-").removesuffix(".csv")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Atomic claim-and-mark helpers
|
||||
# ------------------------------------------------------------------
|
||||
# Each function runs inside a single short-lived session/transaction:
|
||||
# 1. SELECT ... FOR UPDATE SKIP LOCKED (locks one eligible row)
|
||||
# 2. UPDATE the row so it is no longer eligible
|
||||
# 3. COMMIT (releases the row lock)
|
||||
# After the commit, no other drain loop can claim the same row.
|
||||
|
||||
|
||||
def _claim_next_processing_file(db_session: Session) -> UUID | None:
|
||||
"""Claim the next PROCESSING file by transitioning it to INDEXING.
|
||||
|
||||
Returns the file id, or None when no eligible files remain.
|
||||
"""
|
||||
file_id = db_session.execute(
|
||||
select(UserFile.id)
|
||||
.where(UserFile.status == UserFileStatus.PROCESSING)
|
||||
.order_by(UserFile.created_at)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
).scalar_one_or_none()
|
||||
if file_id is None:
|
||||
return None
|
||||
|
||||
db_session.execute(
|
||||
sa.update(UserFile)
|
||||
.where(UserFile.id == file_id)
|
||||
.values(status=UserFileStatus.INDEXING)
|
||||
)
|
||||
db_session.commit()
|
||||
return file_id
|
||||
|
||||
|
||||
def _claim_next_deleting_file(
|
||||
db_session: Session,
|
||||
exclude_ids: set[UUID] | None = None,
|
||||
) -> UUID | None:
|
||||
"""Claim the next DELETING file.
|
||||
|
||||
No status transition needed — the impl deletes the row on success.
|
||||
The short-lived FOR UPDATE lock prevents concurrent claims.
|
||||
*exclude_ids* prevents re-processing the same file if the impl fails.
|
||||
"""
|
||||
stmt = (
|
||||
select(UserFile.id)
|
||||
.where(UserFile.status == UserFileStatus.DELETING)
|
||||
.order_by(UserFile.created_at)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
)
|
||||
if exclude_ids:
|
||||
stmt = stmt.where(UserFile.id.notin_(exclude_ids))
|
||||
file_id = db_session.execute(stmt).scalar_one_or_none()
|
||||
db_session.commit()
|
||||
return file_id
|
||||
|
||||
|
||||
def _claim_next_sync_file(
|
||||
db_session: Session,
|
||||
exclude_ids: set[UUID] | None = None,
|
||||
) -> UUID | None:
|
||||
"""Claim the next file needing project/persona sync.
|
||||
|
||||
No status transition needed — the impl clears the sync flags on
|
||||
success. The short-lived FOR UPDATE lock prevents concurrent claims.
|
||||
*exclude_ids* prevents re-processing the same file if the impl fails.
|
||||
"""
|
||||
stmt = (
|
||||
select(UserFile.id)
|
||||
.where(
|
||||
sa.and_(
|
||||
sa.or_(
|
||||
UserFile.needs_project_sync.is_(True),
|
||||
UserFile.needs_persona_sync.is_(True),
|
||||
),
|
||||
UserFile.status == UserFileStatus.COMPLETED,
|
||||
)
|
||||
)
|
||||
.order_by(UserFile.created_at)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
)
|
||||
if exclude_ids:
|
||||
stmt = stmt.where(UserFile.id.notin_(exclude_ids))
|
||||
file_id = db_session.execute(stmt).scalar_one_or_none()
|
||||
db_session.commit()
|
||||
return file_id
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Drain loops — process *all* pending work of each type
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def drain_processing_loop(tenant_id: str) -> None:
|
||||
"""Process all pending PROCESSING user files."""
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
process_user_file_impl,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
while True:
|
||||
with get_session_with_current_tenant() as session:
|
||||
file_id = _claim_next_processing_file(session)
|
||||
if file_id is None:
|
||||
break
|
||||
try:
|
||||
process_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to process user file {file_id}")
|
||||
|
||||
|
||||
def drain_delete_loop(tenant_id: str) -> None:
|
||||
"""Delete all pending DELETING user files."""
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
delete_user_file_impl,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
failed: set[UUID] = set()
|
||||
while True:
|
||||
with get_session_with_current_tenant() as session:
|
||||
file_id = _claim_next_deleting_file(session, exclude_ids=failed)
|
||||
if file_id is None:
|
||||
break
|
||||
try:
|
||||
delete_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to delete user file {file_id}")
|
||||
failed.add(file_id)
|
||||
|
||||
|
||||
def drain_project_sync_loop(tenant_id: str) -> None:
|
||||
"""Sync all pending project/persona metadata for user files."""
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
project_sync_user_file_impl,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
failed: set[UUID] = set()
|
||||
while True:
|
||||
with get_session_with_current_tenant() as session:
|
||||
file_id = _claim_next_sync_file(session, exclude_ids=failed)
|
||||
if file_id is None:
|
||||
break
|
||||
try:
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to sync user file {file_id}")
|
||||
failed.add(file_id)
|
||||
|
||||
51
backend/onyx/cache/factory.py
vendored
51
backend/onyx/cache/factory.py
vendored
@@ -1,51 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.interface import CacheBackendType
|
||||
from onyx.configs.app_configs import CACHE_BACKEND
|
||||
|
||||
|
||||
def _build_redis_backend(tenant_id: str) -> CacheBackend:
|
||||
from onyx.cache.redis_backend import RedisCacheBackend
|
||||
from onyx.redis.redis_pool import redis_pool
|
||||
|
||||
return RedisCacheBackend(redis_pool.get_client(tenant_id))
|
||||
|
||||
|
||||
def _build_postgres_backend(tenant_id: str) -> CacheBackend:
|
||||
from onyx.cache.postgres_backend import PostgresCacheBackend
|
||||
|
||||
return PostgresCacheBackend(tenant_id)
|
||||
|
||||
|
||||
_BACKEND_BUILDERS: dict[CacheBackendType, Callable[[str], CacheBackend]] = {
|
||||
CacheBackendType.REDIS: _build_redis_backend,
|
||||
CacheBackendType.POSTGRES: _build_postgres_backend,
|
||||
}
|
||||
|
||||
|
||||
def get_cache_backend(*, tenant_id: str | None = None) -> CacheBackend:
|
||||
"""Return a tenant-aware ``CacheBackend``.
|
||||
|
||||
If *tenant_id* is ``None``, the current tenant is read from the
|
||||
thread-local context variable (same behaviour as ``get_redis_client``).
|
||||
"""
|
||||
if tenant_id is None:
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
builder = _BACKEND_BUILDERS.get(CACHE_BACKEND)
|
||||
if builder is None:
|
||||
raise ValueError(
|
||||
f"Unsupported CACHE_BACKEND={CACHE_BACKEND!r}. "
|
||||
f"Supported values: {[t.value for t in CacheBackendType]}"
|
||||
)
|
||||
return builder(tenant_id)
|
||||
|
||||
|
||||
def get_shared_cache_backend() -> CacheBackend:
|
||||
"""Return a ``CacheBackend`` in the shared (cross-tenant) namespace."""
|
||||
from shared_configs.configs import DEFAULT_REDIS_PREFIX
|
||||
|
||||
return get_cache_backend(tenant_id=DEFAULT_REDIS_PREFIX)
|
||||
104
backend/onyx/cache/interface.py
vendored
104
backend/onyx/cache/interface.py
vendored
@@ -1,104 +0,0 @@
|
||||
import abc
|
||||
from enum import Enum
|
||||
|
||||
TTL_KEY_NOT_FOUND = -2
|
||||
TTL_NO_EXPIRY = -1
|
||||
|
||||
|
||||
class CacheBackendType(str, Enum):
|
||||
REDIS = "redis"
|
||||
POSTGRES = "postgres"
|
||||
|
||||
|
||||
class CacheLock(abc.ABC):
|
||||
"""Abstract distributed lock returned by CacheBackend.lock()."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def acquire(
|
||||
self,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: float | None = None,
|
||||
) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def release(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def owned(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def __enter__(self) -> "CacheLock":
|
||||
if not self.acquire():
|
||||
raise RuntimeError("Failed to acquire lock")
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: object) -> None:
|
||||
self.release()
|
||||
|
||||
|
||||
class CacheBackend(abc.ABC):
|
||||
"""Thin abstraction over a key-value cache with TTL, locks, and blocking lists.
|
||||
|
||||
Covers the subset of Redis operations used outside of Celery. When
|
||||
CACHE_BACKEND=postgres, a PostgreSQL-backed implementation is used instead.
|
||||
"""
|
||||
|
||||
# -- basic key/value ---------------------------------------------------
|
||||
|
||||
@abc.abstractmethod
|
||||
def get(self, key: str) -> bytes | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str | bytes | int | float,
|
||||
ex: int | None = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete(self, key: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def exists(self, key: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
# -- TTL ---------------------------------------------------------------
|
||||
|
||||
@abc.abstractmethod
|
||||
def expire(self, key: str, seconds: int) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def ttl(self, key: str) -> int:
|
||||
"""Return remaining TTL in seconds.
|
||||
|
||||
Returns ``TTL_NO_EXPIRY`` (-1) if key exists without expiry,
|
||||
``TTL_KEY_NOT_FOUND`` (-2) if key is missing or expired.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# -- distributed lock --------------------------------------------------
|
||||
|
||||
@abc.abstractmethod
|
||||
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
|
||||
raise NotImplementedError
|
||||
|
||||
# -- blocking list (used by MCP OAuth BLPOP pattern) -------------------
|
||||
|
||||
@abc.abstractmethod
|
||||
def rpush(self, key: str, value: str | bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
|
||||
"""Block until a value is available on one of *keys*, or *timeout* expires.
|
||||
|
||||
Returns ``(key, value)`` or ``None`` on timeout.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
323
backend/onyx/cache/postgres_backend.py
vendored
323
backend/onyx/cache/postgres_backend.py
vendored
@@ -1,323 +0,0 @@
|
||||
"""PostgreSQL-backed ``CacheBackend`` for NO_VECTOR_DB deployments.
|
||||
|
||||
Uses the ``cache_store`` table for key-value storage, PostgreSQL advisory locks
|
||||
for distributed locking, and a polling loop for the BLPOP pattern.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import struct
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import AbstractContextManager
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.interface import CacheLock
|
||||
from onyx.cache.interface import TTL_KEY_NOT_FOUND
|
||||
from onyx.cache.interface import TTL_NO_EXPIRY
|
||||
from onyx.db.models import CacheStore
|
||||
|
||||
_LIST_KEY_PREFIX = "_q:"
|
||||
# ASCII: ':' (0x3A) < ';' (0x3B). Upper bound for range queries so [prefix+, prefix;)
|
||||
# captures all list-item keys (e.g. _q:mylist:123:uuid) without including other
|
||||
# lists whose names share a prefix (e.g. _q:mylist2:...).
|
||||
_LIST_KEY_RANGE_TERMINATOR = ";"
|
||||
_LIST_ITEM_TTL_SECONDS = 3600
|
||||
_LOCK_POLL_INTERVAL = 0.1
|
||||
_BLPOP_POLL_INTERVAL = 0.25
|
||||
|
||||
|
||||
def _list_item_key(key: str) -> str:
|
||||
"""Unique key for a list item. Timestamp for FIFO ordering; UUID prevents
|
||||
collision when concurrent rpush calls occur within the same nanosecond.
|
||||
"""
|
||||
return f"{_LIST_KEY_PREFIX}{key}:{time.time_ns()}:{uuid.uuid4().hex}"
|
||||
|
||||
|
||||
def _to_bytes(value: str | bytes | int | float) -> bytes:
|
||||
if isinstance(value, bytes):
|
||||
return value
|
||||
return str(value).encode()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lock
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class PostgresCacheLock(CacheLock):
|
||||
"""Advisory-lock-based distributed lock.
|
||||
|
||||
Uses ``get_session_with_tenant`` for connection lifecycle. The lock is tied
|
||||
to the session's connection; releasing or closing the session frees it.
|
||||
|
||||
NOTE: Unlike Redis locks, advisory locks do not auto-expire after
|
||||
``timeout`` seconds. They are released when ``release()`` is
|
||||
called or when the session is closed.
|
||||
"""
|
||||
|
||||
def __init__(self, lock_id: int, timeout: float | None, tenant_id: str) -> None:
|
||||
self._lock_id = lock_id
|
||||
self._timeout = timeout
|
||||
self._tenant_id = tenant_id
|
||||
self._session_cm: AbstractContextManager[Session] | None = None
|
||||
self._session: Session | None = None
|
||||
self._acquired = False
|
||||
|
||||
def acquire(
|
||||
self,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: float | None = None,
|
||||
) -> bool:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
self._session_cm = get_session_with_tenant(tenant_id=self._tenant_id)
|
||||
self._session = self._session_cm.__enter__()
|
||||
try:
|
||||
if not blocking:
|
||||
return self._try_lock()
|
||||
|
||||
effective_timeout = blocking_timeout or self._timeout
|
||||
deadline = (
|
||||
(time.monotonic() + effective_timeout) if effective_timeout else None
|
||||
)
|
||||
while True:
|
||||
if self._try_lock():
|
||||
return True
|
||||
if deadline is not None and time.monotonic() >= deadline:
|
||||
return False
|
||||
time.sleep(_LOCK_POLL_INTERVAL)
|
||||
finally:
|
||||
if not self._acquired:
|
||||
self._close_session()
|
||||
|
||||
def release(self) -> None:
|
||||
if not self._acquired or self._session is None:
|
||||
return
|
||||
try:
|
||||
self._session.execute(select(func.pg_advisory_unlock(self._lock_id)))
|
||||
finally:
|
||||
self._acquired = False
|
||||
self._close_session()
|
||||
|
||||
def owned(self) -> bool:
|
||||
return self._acquired
|
||||
|
||||
def _close_session(self) -> None:
|
||||
if self._session_cm is not None:
|
||||
try:
|
||||
self._session_cm.__exit__(None, None, None)
|
||||
finally:
|
||||
self._session_cm = None
|
||||
self._session = None
|
||||
|
||||
def _try_lock(self) -> bool:
|
||||
assert self._session is not None
|
||||
result = self._session.execute(
|
||||
select(func.pg_try_advisory_lock(self._lock_id))
|
||||
).scalar()
|
||||
if result:
|
||||
self._acquired = True
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Backend
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class PostgresCacheBackend(CacheBackend):
|
||||
"""``CacheBackend`` backed by the ``cache_store`` table in PostgreSQL.
|
||||
|
||||
Each operation opens and closes its own database session so the backend
|
||||
is safe to share across threads. Tenant isolation is handled by
|
||||
SQLAlchemy's ``schema_translate_map`` (set by ``get_session_with_tenant``).
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_id: str) -> None:
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
# -- basic key/value ---------------------------------------------------
|
||||
|
||||
def get(self, key: str) -> bytes | None:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
stmt = select(CacheStore.value).where(
|
||||
CacheStore.key == key,
|
||||
or_(CacheStore.expires_at.is_(None), CacheStore.expires_at > func.now()),
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
value = session.execute(stmt).scalar_one_or_none()
|
||||
if value is None:
|
||||
return None
|
||||
return bytes(value)
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str | bytes | int | float,
|
||||
ex: int | None = None,
|
||||
) -> None:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
value_bytes = _to_bytes(value)
|
||||
expires_at = (
|
||||
datetime.now(timezone.utc) + timedelta(seconds=ex)
|
||||
if ex is not None
|
||||
else None
|
||||
)
|
||||
stmt = (
|
||||
pg_insert(CacheStore)
|
||||
.values(key=key, value=value_bytes, expires_at=expires_at)
|
||||
.on_conflict_do_update(
|
||||
index_elements=[CacheStore.key],
|
||||
set_={"value": value_bytes, "expires_at": expires_at},
|
||||
)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
session.execute(delete(CacheStore).where(CacheStore.key == key))
|
||||
session.commit()
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
stmt = (
|
||||
select(CacheStore.key)
|
||||
.where(
|
||||
CacheStore.key == key,
|
||||
or_(
|
||||
CacheStore.expires_at.is_(None),
|
||||
CacheStore.expires_at > func.now(),
|
||||
),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
return session.execute(stmt).first() is not None
|
||||
|
||||
# -- TTL ---------------------------------------------------------------
|
||||
|
||||
def expire(self, key: str, seconds: int) -> None:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
new_exp = datetime.now(timezone.utc) + timedelta(seconds=seconds)
|
||||
stmt = (
|
||||
update(CacheStore).where(CacheStore.key == key).values(expires_at=new_exp)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
def ttl(self, key: str) -> int:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
stmt = select(CacheStore.expires_at).where(CacheStore.key == key)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
result = session.execute(stmt).first()
|
||||
if result is None:
|
||||
return TTL_KEY_NOT_FOUND
|
||||
expires_at: datetime | None = result[0]
|
||||
if expires_at is None:
|
||||
return TTL_NO_EXPIRY
|
||||
remaining = (expires_at - datetime.now(timezone.utc)).total_seconds()
|
||||
if remaining <= 0:
|
||||
return TTL_KEY_NOT_FOUND
|
||||
return int(remaining)
|
||||
|
||||
# -- distributed lock --------------------------------------------------
|
||||
|
||||
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
|
||||
return PostgresCacheLock(
|
||||
self._lock_id_for(name), timeout, tenant_id=self._tenant_id
|
||||
)
|
||||
|
||||
# -- blocking list (MCP OAuth BLPOP pattern) ---------------------------
|
||||
|
||||
def rpush(self, key: str, value: str | bytes) -> None:
|
||||
self.set(_list_item_key(key), value, ex=_LIST_ITEM_TTL_SECONDS)
|
||||
|
||||
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
|
||||
if timeout <= 0:
|
||||
raise ValueError(
|
||||
"PostgresCacheBackend.blpop requires timeout > 0. "
|
||||
"timeout=0 would block the calling thread indefinitely "
|
||||
"with no way to interrupt short of process termination."
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
deadline = time.monotonic() + timeout
|
||||
while True:
|
||||
for key in keys:
|
||||
lower = f"{_LIST_KEY_PREFIX}{key}:"
|
||||
upper = f"{_LIST_KEY_PREFIX}{key}{_LIST_KEY_RANGE_TERMINATOR}"
|
||||
stmt = (
|
||||
select(CacheStore)
|
||||
.where(
|
||||
CacheStore.key >= lower,
|
||||
CacheStore.key < upper,
|
||||
or_(
|
||||
CacheStore.expires_at.is_(None),
|
||||
CacheStore.expires_at > func.now(),
|
||||
),
|
||||
)
|
||||
.order_by(CacheStore.key)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
row = session.execute(stmt).scalars().first()
|
||||
if row is not None:
|
||||
value = bytes(row.value) if row.value else b""
|
||||
session.delete(row)
|
||||
session.commit()
|
||||
return (key.encode(), value)
|
||||
if time.monotonic() >= deadline:
|
||||
return None
|
||||
time.sleep(_BLPOP_POLL_INTERVAL)
|
||||
|
||||
# -- helpers -----------------------------------------------------------
|
||||
|
||||
def _lock_id_for(self, name: str) -> int:
|
||||
"""Map *name* to a 64-bit signed int for ``pg_advisory_lock``."""
|
||||
h = hashlib.md5(f"{self._tenant_id}:{name}".encode()).digest()
|
||||
return struct.unpack("q", h[:8])[0]
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Periodic cleanup
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def cleanup_expired_cache_entries() -> None:
|
||||
"""Delete rows whose ``expires_at`` is in the past.
|
||||
|
||||
Called by the periodic poller every 5 minutes.
|
||||
"""
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as session:
|
||||
session.execute(
|
||||
delete(CacheStore).where(
|
||||
CacheStore.expires_at.is_not(None),
|
||||
CacheStore.expires_at < func.now(),
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
92
backend/onyx/cache/redis_backend.py
vendored
92
backend/onyx/cache/redis_backend.py
vendored
@@ -1,92 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
from redis.client import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.interface import CacheLock
|
||||
|
||||
|
||||
class RedisCacheLock(CacheLock):
|
||||
"""Wraps ``redis.lock.Lock`` behind the ``CacheLock`` interface."""
|
||||
|
||||
def __init__(self, lock: RedisLock) -> None:
|
||||
self._lock = lock
|
||||
|
||||
def acquire(
|
||||
self,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: float | None = None,
|
||||
) -> bool:
|
||||
return bool(
|
||||
self._lock.acquire(
|
||||
blocking=blocking,
|
||||
blocking_timeout=blocking_timeout,
|
||||
)
|
||||
)
|
||||
|
||||
def release(self) -> None:
|
||||
self._lock.release()
|
||||
|
||||
def owned(self) -> bool:
|
||||
return bool(self._lock.owned())
|
||||
|
||||
|
||||
class RedisCacheBackend(CacheBackend):
|
||||
"""``CacheBackend`` implementation that delegates to a ``redis.Redis`` client.
|
||||
|
||||
This is a thin pass-through — every method maps 1-to-1 to the underlying
|
||||
Redis command. ``TenantRedis`` key-prefixing is handled by the client
|
||||
itself (provided by ``get_redis_client``).
|
||||
"""
|
||||
|
||||
def __init__(self, redis_client: Redis) -> None:
|
||||
self._r = redis_client
|
||||
|
||||
# -- basic key/value ---------------------------------------------------
|
||||
|
||||
def get(self, key: str) -> bytes | None:
|
||||
val = self._r.get(key)
|
||||
if val is None:
|
||||
return None
|
||||
if isinstance(val, bytes):
|
||||
return val
|
||||
return str(val).encode()
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str | bytes | int | float,
|
||||
ex: int | None = None,
|
||||
) -> None:
|
||||
self._r.set(key, value, ex=ex)
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
self._r.delete(key)
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
return bool(self._r.exists(key))
|
||||
|
||||
# -- TTL ---------------------------------------------------------------
|
||||
|
||||
def expire(self, key: str, seconds: int) -> None:
|
||||
self._r.expire(key, seconds)
|
||||
|
||||
def ttl(self, key: str) -> int:
|
||||
return cast(int, self._r.ttl(key))
|
||||
|
||||
# -- distributed lock --------------------------------------------------
|
||||
|
||||
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
|
||||
return RedisCacheLock(self._r.lock(name, timeout=timeout))
|
||||
|
||||
# -- blocking list (MCP OAuth BLPOP pattern) ---------------------------
|
||||
|
||||
def rpush(self, key: str, value: str | bytes) -> None:
|
||||
self._r.rpush(key, value)
|
||||
|
||||
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
|
||||
result = cast(list[bytes] | None, self._r.blpop(keys, timeout=timeout))
|
||||
if result is None:
|
||||
return None
|
||||
return (result[0], result[1])
|
||||
@@ -1,52 +1,57 @@
|
||||
from uuid import UUID
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from redis.client import Redis
|
||||
|
||||
# Redis key prefixes for chat message processing
|
||||
PREFIX = "chatprocessing"
|
||||
FENCE_PREFIX = f"{PREFIX}_fence"
|
||||
FENCE_TTL = 30 * 60 # 30 minutes
|
||||
|
||||
|
||||
def _get_fence_key(chat_session_id: UUID) -> str:
|
||||
"""Generate the cache key for a chat session processing fence.
|
||||
"""
|
||||
Generate the Redis key for a chat session processing a message.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
|
||||
Returns:
|
||||
The fence key string. Tenant isolation is handled automatically
|
||||
by the cache backend (Redis key-prefixing or Postgres schema routing).
|
||||
The fence key string (tenant_id is automatically added by the Redis client)
|
||||
"""
|
||||
return f"{FENCE_PREFIX}_{chat_session_id}"
|
||||
|
||||
|
||||
def set_processing_status(
|
||||
chat_session_id: UUID, cache: CacheBackend, value: bool
|
||||
chat_session_id: UUID, redis_client: Redis, value: bool
|
||||
) -> None:
|
||||
"""Set or clear the fence for a chat session processing a message.
|
||||
"""
|
||||
Set or clear the fence for a chat session processing a message.
|
||||
|
||||
If the key exists, a message is being processed.
|
||||
If the key exists, we are processing a message. If the key does not exist, we are not processing a message.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
cache: Tenant-aware cache backend
|
||||
redis_client: The Redis client to use
|
||||
value: True to set the fence, False to clear it
|
||||
"""
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
|
||||
if value:
|
||||
cache.set(fence_key, 0, ex=FENCE_TTL)
|
||||
redis_client.set(fence_key, 0, ex=FENCE_TTL)
|
||||
else:
|
||||
cache.delete(fence_key)
|
||||
redis_client.delete(fence_key)
|
||||
|
||||
|
||||
def is_chat_session_processing(chat_session_id: UUID, cache: CacheBackend) -> bool:
|
||||
"""Check if the chat session is processing a message.
|
||||
def is_chat_session_processing(chat_session_id: UUID, redis_client: Redis) -> bool:
|
||||
"""
|
||||
Check if the chat session is processing a message.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
cache: Tenant-aware cache backend
|
||||
redis_client: The Redis client to use
|
||||
|
||||
Returns:
|
||||
True if the chat session is processing a message, False otherwise
|
||||
"""
|
||||
return cache.exists(_get_fence_key(chat_session_id))
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
return bool(redis_client.exists(fence_key))
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -531,13 +530,11 @@ def _create_file_tool_metadata_message(
|
||||
"""
|
||||
lines = [
|
||||
"You have access to the following files. Use the read_file tool to "
|
||||
"read sections of any file. You MUST pass the file_id UUID (not the "
|
||||
"filename) to read_file:"
|
||||
"read sections of any file:"
|
||||
]
|
||||
for meta in file_metadata:
|
||||
lines.append(
|
||||
f'- file_id="{meta.file_id}" filename="{meta.filename}" '
|
||||
f"(~{meta.approx_char_count:,} chars)"
|
||||
f'- {meta.file_id}: "{meta.filename}" (~{meta.approx_char_count:,} chars)'
|
||||
)
|
||||
|
||||
message_content = "\n".join(lines)
|
||||
@@ -561,16 +558,12 @@ def _create_context_files_message(
|
||||
# Format as documents JSON as described in README
|
||||
documents_list = []
|
||||
for idx, file_text in enumerate(context_files.file_texts, start=1):
|
||||
title = (
|
||||
context_files.file_metadata[idx - 1].filename
|
||||
if idx - 1 < len(context_files.file_metadata)
|
||||
else None
|
||||
documents_list.append(
|
||||
{
|
||||
"document": idx,
|
||||
"contents": file_text,
|
||||
}
|
||||
)
|
||||
entry: dict[str, Any] = {"document": idx}
|
||||
if title:
|
||||
entry["title"] = title
|
||||
entry["contents"] = file_text
|
||||
documents_list.append(entry)
|
||||
|
||||
documents_json = json.dumps({"documents": documents_list}, indent=2)
|
||||
message_content = f"Here are some documents provided for context, they may not all be relevant:\n{documents_json}"
|
||||
|
||||
@@ -11,10 +11,9 @@ from contextvars import Token
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from redis.client import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.chat.chat_processing_checker import set_processing_status
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_state import run_chat_loop_with_state_containers
|
||||
@@ -80,6 +79,7 @@ from onyx.llm.request_context import reset_llm_mock_response
|
||||
from onyx.llm.request_context import set_llm_mock_response
|
||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
@@ -448,7 +448,7 @@ def handle_stream_message_objects(
|
||||
|
||||
llm: LLM | None = None
|
||||
chat_session: ChatSession | None = None
|
||||
cache: CacheBackend | None = None
|
||||
redis_client: Redis | None = None
|
||||
|
||||
user_id = user.id
|
||||
if user.is_anonymous:
|
||||
@@ -809,19 +809,19 @@ def handle_stream_message_objects(
|
||||
)
|
||||
simple_chat_history.insert(0, summary_simple)
|
||||
|
||||
cache = get_cache_backend()
|
||||
redis_client = get_redis_client()
|
||||
|
||||
reset_cancel_status(
|
||||
chat_session.id,
|
||||
cache,
|
||||
redis_client,
|
||||
)
|
||||
|
||||
def check_is_connected() -> bool:
|
||||
return check_stop_signal(chat_session.id, cache)
|
||||
return check_stop_signal(chat_session.id, redis_client)
|
||||
|
||||
set_processing_status(
|
||||
chat_session_id=chat_session.id,
|
||||
cache=cache,
|
||||
redis_client=redis_client,
|
||||
value=True,
|
||||
)
|
||||
|
||||
@@ -968,10 +968,10 @@ def handle_stream_message_objects(
|
||||
reset_llm_mock_response(mock_response_token)
|
||||
|
||||
try:
|
||||
if cache is not None and chat_session is not None:
|
||||
if redis_client is not None and chat_session is not None:
|
||||
set_processing_status(
|
||||
chat_session_id=chat_session.id,
|
||||
cache=cache,
|
||||
redis_client=redis_client,
|
||||
value=False,
|
||||
)
|
||||
except Exception:
|
||||
|
||||
@@ -1,58 +1,65 @@
|
||||
from uuid import UUID
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from redis.client import Redis
|
||||
|
||||
# Redis key prefixes for chat session stop signals
|
||||
PREFIX = "chatsessionstop"
|
||||
FENCE_PREFIX = f"{PREFIX}_fence"
|
||||
FENCE_TTL = 10 * 60 # 10 minutes
|
||||
FENCE_TTL = 10 * 60 # 10 minutes - defensive TTL to prevent memory leaks
|
||||
|
||||
|
||||
def _get_fence_key(chat_session_id: UUID) -> str:
|
||||
"""Generate the cache key for a chat session stop signal fence.
|
||||
"""
|
||||
Generate the Redis key for a chat session stop signal fence.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
|
||||
Returns:
|
||||
The fence key string. Tenant isolation is handled automatically
|
||||
by the cache backend (Redis key-prefixing or Postgres schema routing).
|
||||
The fence key string (tenant_id is automatically added by the Redis client)
|
||||
"""
|
||||
return f"{FENCE_PREFIX}_{chat_session_id}"
|
||||
|
||||
|
||||
def set_fence(chat_session_id: UUID, cache: CacheBackend, value: bool) -> None:
|
||||
"""Set or clear the stop signal fence for a chat session.
|
||||
def set_fence(chat_session_id: UUID, redis_client: Redis, value: bool) -> None:
|
||||
"""
|
||||
Set or clear the stop signal fence for a chat session.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
cache: Tenant-aware cache backend
|
||||
redis_client: Redis client to use (tenant-aware client that auto-prefixes keys)
|
||||
value: True to set the fence (stop signal), False to clear it
|
||||
"""
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
if not value:
|
||||
cache.delete(fence_key)
|
||||
redis_client.delete(fence_key)
|
||||
return
|
||||
cache.set(fence_key, 0, ex=FENCE_TTL)
|
||||
|
||||
redis_client.set(fence_key, 0, ex=FENCE_TTL)
|
||||
|
||||
|
||||
def is_connected(chat_session_id: UUID, cache: CacheBackend) -> bool:
|
||||
"""Check if the chat session should continue (not stopped).
|
||||
def is_connected(chat_session_id: UUID, redis_client: Redis) -> bool:
|
||||
"""
|
||||
Check if the chat session should continue (not stopped).
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session to check
|
||||
cache: Tenant-aware cache backend
|
||||
redis_client: Redis client to use for checking the stop signal (tenant-aware client that auto-prefixes keys)
|
||||
|
||||
Returns:
|
||||
True if the session should continue, False if it should stop
|
||||
"""
|
||||
return not cache.exists(_get_fence_key(chat_session_id))
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
return not bool(redis_client.exists(fence_key))
|
||||
|
||||
|
||||
def reset_cancel_status(chat_session_id: UUID, cache: CacheBackend) -> None:
|
||||
"""Clear the stop signal for a chat session.
|
||||
def reset_cancel_status(chat_session_id: UUID, redis_client: Redis) -> None:
|
||||
"""
|
||||
Clear the stop signal for a chat session.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
cache: Tenant-aware cache backend
|
||||
redis_client: Redis client to use (tenant-aware client that auto-prefixes keys)
|
||||
"""
|
||||
cache.delete(_get_fence_key(chat_session_id))
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
redis_client.delete(fence_key)
|
||||
|
||||
@@ -6,7 +6,6 @@ from datetime import timezone
|
||||
from typing import cast
|
||||
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
from onyx.cache.interface import CacheBackendType
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
|
||||
@@ -55,12 +54,6 @@ DISABLE_USER_KNOWLEDGE = os.environ.get("DISABLE_USER_KNOWLEDGE", "").lower() ==
|
||||
# are disabled but core chat, tools, user file uploads, and Projects still work.
|
||||
DISABLE_VECTOR_DB = os.environ.get("DISABLE_VECTOR_DB", "").lower() == "true"
|
||||
|
||||
# Which backend to use for caching, locks, and ephemeral state.
|
||||
# "redis" (default) or "postgres" (only valid when DISABLE_VECTOR_DB=true).
|
||||
CACHE_BACKEND = CacheBackendType(
|
||||
os.environ.get("CACHE_BACKEND", CacheBackendType.REDIS)
|
||||
)
|
||||
|
||||
# Maximum token count for a single uploaded file. Files exceeding this are rejected.
|
||||
# Defaults to 100k tokens (or 10M when vector DB is disabled).
|
||||
_DEFAULT_FILE_TOKEN_LIMIT = 10_000_000 if DISABLE_VECTOR_DB else 100_000
|
||||
@@ -819,9 +812,7 @@ RERANK_COUNT = int(os.environ.get("RERANK_COUNT") or 1000)
|
||||
# Tool Configs
|
||||
#####
|
||||
# Code Interpreter Service Configuration
|
||||
CODE_INTERPRETER_BASE_URL = os.environ.get(
|
||||
"CODE_INTERPRETER_BASE_URL", "http://localhost:8000"
|
||||
)
|
||||
CODE_INTERPRETER_BASE_URL = os.environ.get("CODE_INTERPRETER_BASE_URL")
|
||||
|
||||
CODE_INTERPRETER_DEFAULT_TIMEOUT_MS = int(
|
||||
os.environ.get("CODE_INTERPRETER_DEFAULT_TIMEOUT_MS") or 60_000
|
||||
|
||||
@@ -98,7 +98,6 @@ def get_chat_sessions_by_user(
|
||||
db_session: Session,
|
||||
include_onyxbot_flows: bool = False,
|
||||
limit: int = 50,
|
||||
before: datetime | None = None,
|
||||
project_id: int | None = None,
|
||||
only_non_project_chats: bool = False,
|
||||
include_failed_chats: bool = False,
|
||||
@@ -113,9 +112,6 @@ def get_chat_sessions_by_user(
|
||||
if deleted is not None:
|
||||
stmt = stmt.where(ChatSession.deleted == deleted)
|
||||
|
||||
if before is not None:
|
||||
stmt = stmt.where(ChatSession.time_updated < before)
|
||||
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
|
||||
@@ -186,7 +186,6 @@ class EmbeddingPrecision(str, PyEnum):
|
||||
|
||||
class UserFileStatus(str, PyEnum):
|
||||
PROCESSING = "PROCESSING"
|
||||
INDEXING = "INDEXING"
|
||||
COMPLETED = "COMPLETED"
|
||||
FAILED = "FAILED"
|
||||
CANCELED = "CANCELED"
|
||||
|
||||
@@ -202,6 +202,7 @@ def create_default_image_gen_config_from_api_key(
|
||||
api_key=api_key,
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
default_model_name=model_name,
|
||||
deployment_name=None,
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
@@ -109,38 +109,45 @@ def can_user_access_llm_provider(
|
||||
is_admin: If True, bypass user group restrictions but still respect persona restrictions
|
||||
|
||||
Access logic:
|
||||
- is_public controls USER access (group bypass): when True, all users can access
|
||||
regardless of group membership. When False, user must be in a whitelisted group
|
||||
(or be admin).
|
||||
- Persona restrictions are ALWAYS enforced when set, regardless of is_public.
|
||||
This allows admins to make a provider available to all users while still
|
||||
restricting which personas (assistants) can use it.
|
||||
|
||||
Decision matrix:
|
||||
1. is_public=True, no personas set → everyone has access
|
||||
2. is_public=True, personas set → all users, but only whitelisted personas
|
||||
3. is_public=False, groups+personas set → must satisfy BOTH (admins bypass groups)
|
||||
4. is_public=False, only groups set → must be in group (admins bypass)
|
||||
5. is_public=False, only personas set → must use whitelisted persona
|
||||
6. is_public=False, neither set → admin-only (locked)
|
||||
1. If is_public=True → everyone has access (public override)
|
||||
2. If is_public=False:
|
||||
- Both groups AND personas set → must satisfy BOTH (AND logic, admins bypass group check)
|
||||
- Only groups set → must be in one of the groups (OR across groups, admins bypass)
|
||||
- Only personas set → must use one of the personas (OR across personas, applies to admins)
|
||||
- Neither set → NOBODY has access unless admin (locked, admin-only)
|
||||
"""
|
||||
provider_group_ids = {g.id for g in (provider.groups or [])}
|
||||
provider_persona_ids = {p.id for p in (provider.personas or [])}
|
||||
has_groups = bool(provider_group_ids)
|
||||
has_personas = bool(provider_persona_ids)
|
||||
|
||||
# Persona restrictions are always enforced when set, regardless of is_public
|
||||
if has_personas and not (persona and persona.id in provider_persona_ids):
|
||||
return False
|
||||
|
||||
# Public override - everyone has access
|
||||
if provider.is_public:
|
||||
return True
|
||||
|
||||
# Extract IDs once to avoid multiple iterations
|
||||
provider_group_ids = (
|
||||
{group.id for group in provider.groups} if provider.groups else set()
|
||||
)
|
||||
provider_persona_ids = (
|
||||
{p.id for p in provider.personas} if provider.personas else set()
|
||||
)
|
||||
|
||||
has_groups = bool(provider_group_ids)
|
||||
has_personas = bool(provider_persona_ids)
|
||||
|
||||
# Both groups AND personas set → AND logic (must satisfy both)
|
||||
if has_groups and has_personas:
|
||||
# Admins bypass group check but still must satisfy persona restrictions
|
||||
user_in_group = is_admin or bool(user_group_ids & provider_group_ids)
|
||||
persona_allowed = persona.id in provider_persona_ids if persona else False
|
||||
return user_in_group and persona_allowed
|
||||
|
||||
# Only groups set → user must be in one of the groups (admins bypass)
|
||||
if has_groups:
|
||||
return is_admin or bool(user_group_ids & provider_group_ids)
|
||||
|
||||
# No groups: either persona-whitelisted (already passed) or admin-only if locked
|
||||
return has_personas or is_admin
|
||||
# Only personas set → persona must be in allowed list (applies to admins too)
|
||||
if has_personas:
|
||||
return persona.id in provider_persona_ids if persona else False
|
||||
|
||||
# Neither groups nor personas set, and not public → admins can access
|
||||
return is_admin
|
||||
|
||||
|
||||
def validate_persona_ids_exist(
|
||||
@@ -206,29 +213,11 @@ def upsert_llm_provider(
|
||||
llm_provider_upsert_request: LLMProviderUpsertRequest,
|
||||
db_session: Session,
|
||||
) -> LLMProviderView:
|
||||
existing_llm_provider: LLMProviderModel | None = None
|
||||
if llm_provider_upsert_request.id:
|
||||
existing_llm_provider = fetch_existing_llm_provider_by_id(
|
||||
id=llm_provider_upsert_request.id, db_session=db_session
|
||||
)
|
||||
if not existing_llm_provider:
|
||||
raise ValueError(
|
||||
f"LLM provider with id {llm_provider_upsert_request.id} not found"
|
||||
)
|
||||
existing_llm_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
)
|
||||
|
||||
if existing_llm_provider.name != llm_provider_upsert_request.name:
|
||||
raise ValueError(
|
||||
f"LLM provider with id {llm_provider_upsert_request.id} name change not allowed"
|
||||
)
|
||||
else:
|
||||
existing_llm_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
)
|
||||
if existing_llm_provider:
|
||||
raise ValueError(
|
||||
f"LLM provider with name '{llm_provider_upsert_request.name}'"
|
||||
" already exists"
|
||||
)
|
||||
if not existing_llm_provider:
|
||||
existing_llm_provider = LLMProviderModel(name=llm_provider_upsert_request.name)
|
||||
db_session.add(existing_llm_provider)
|
||||
|
||||
@@ -249,7 +238,11 @@ def upsert_llm_provider(
|
||||
existing_llm_provider.api_base = api_base
|
||||
existing_llm_provider.api_version = llm_provider_upsert_request.api_version
|
||||
existing_llm_provider.custom_config = custom_config
|
||||
|
||||
# TODO: Remove default model name on api change
|
||||
# Needed due to /provider/{id}/default endpoint not disclosing the default model name
|
||||
existing_llm_provider.default_model_name = (
|
||||
llm_provider_upsert_request.default_model_name
|
||||
)
|
||||
existing_llm_provider.is_public = llm_provider_upsert_request.is_public
|
||||
existing_llm_provider.is_auto_mode = llm_provider_upsert_request.is_auto_mode
|
||||
existing_llm_provider.deployment_name = llm_provider_upsert_request.deployment_name
|
||||
@@ -313,6 +306,15 @@ def upsert_llm_provider(
|
||||
display_name=model_config.display_name,
|
||||
)
|
||||
|
||||
default_model = fetch_default_model(db_session, LLMModelFlowType.CHAT)
|
||||
if default_model and default_model.llm_provider_id == existing_llm_provider.id:
|
||||
_update_default_model(
|
||||
db_session=db_session,
|
||||
provider_id=existing_llm_provider.id,
|
||||
model=existing_llm_provider.default_model_name,
|
||||
flow_type=LLMModelFlowType.CHAT,
|
||||
)
|
||||
|
||||
# Make sure the relationship table stays up to date
|
||||
update_group_llm_provider_relationships__no_commit(
|
||||
llm_provider_id=existing_llm_provider.id,
|
||||
@@ -486,22 +488,6 @@ def fetch_existing_llm_provider(
|
||||
return provider_model
|
||||
|
||||
|
||||
def fetch_existing_llm_provider_by_id(
|
||||
id: int, db_session: Session
|
||||
) -> LLMProviderModel | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel)
|
||||
.where(LLMProviderModel.id == id)
|
||||
.options(
|
||||
selectinload(LLMProviderModel.model_configurations),
|
||||
selectinload(LLMProviderModel.groups),
|
||||
selectinload(LLMProviderModel.personas),
|
||||
)
|
||||
)
|
||||
|
||||
return provider_model
|
||||
|
||||
|
||||
def fetch_embedding_provider(
|
||||
db_session: Session, provider_type: EmbeddingProvider
|
||||
) -> CloudEmbeddingProviderModel | None:
|
||||
@@ -532,7 +518,6 @@ def fetch_default_model(
|
||||
) -> ModelConfiguration | None:
|
||||
model_config = db_session.scalar(
|
||||
select(ModelConfiguration)
|
||||
.options(selectinload(ModelConfiguration.llm_provider))
|
||||
.join(LLMModelFlow)
|
||||
.where(
|
||||
ModelConfiguration.is_visible == True, # noqa: E712
|
||||
@@ -619,13 +604,22 @@ def remove_llm_provider__no_commit(db_session: Session, provider_id: int) -> Non
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def update_default_provider(
|
||||
provider_id: int, model_name: str, db_session: Session
|
||||
) -> None:
|
||||
def update_default_provider(provider_id: int, db_session: Session) -> None:
|
||||
# Attempt to get the default_model_name from the provider first
|
||||
# TODO: Remove default_model_name check
|
||||
provider = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.id == provider_id,
|
||||
)
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f"LLM Provider with id={provider_id} does not exist")
|
||||
|
||||
_update_default_model(
|
||||
db_session,
|
||||
provider_id,
|
||||
model_name,
|
||||
provider.default_model_name, # type: ignore[arg-type]
|
||||
LLMModelFlowType.CHAT,
|
||||
)
|
||||
|
||||
@@ -811,6 +805,12 @@ def sync_auto_mode_models(
|
||||
)
|
||||
changes += 1
|
||||
|
||||
# In Auto mode, default model is always set from GitHub config
|
||||
default_model = llm_recommendations.get_default_model(provider.provider)
|
||||
if default_model and provider.default_model_name != default_model.name:
|
||||
provider.default_model_name = default_model.name
|
||||
changes += 1
|
||||
|
||||
db_session.commit()
|
||||
return changes
|
||||
|
||||
|
||||
@@ -103,6 +103,7 @@ from onyx.utils.encryption import encrypt_string_to_bytes
|
||||
from onyx.utils.sensitive import SensitiveValue
|
||||
from onyx.utils.headers import HeaderItemDict
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
|
||||
# TODO: After anonymous user migration has been deployed, make user_id columns NOT NULL
|
||||
# and update Mapped[User | None] relationships to Mapped[User] where needed.
|
||||
@@ -3264,6 +3265,19 @@ class Persona(Base):
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String)
|
||||
description: Mapped[str] = mapped_column(String)
|
||||
# Number of chunks to pass to the LLM for generation.
|
||||
num_chunks: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
chunks_above: Mapped[int] = mapped_column(Integer)
|
||||
chunks_below: Mapped[int] = mapped_column(Integer)
|
||||
# Pass every chunk through LLM for evaluation, fairly expensive
|
||||
# Can be turned off globally by admin, in which case, this setting is ignored
|
||||
llm_relevance_filter: Mapped[bool] = mapped_column(Boolean)
|
||||
# Enables using LLM to extract time and source type filters
|
||||
# Can also be admin disabled globally
|
||||
llm_filter_extraction: Mapped[bool] = mapped_column(Boolean)
|
||||
recency_bias: Mapped[RecencyBiasSetting] = mapped_column(
|
||||
Enum(RecencyBiasSetting, native_enum=False)
|
||||
)
|
||||
|
||||
# Allows the persona to specify a specific default LLM model
|
||||
# NOTE: only is applied on the actual response generation - is not used for things like
|
||||
@@ -3290,8 +3304,11 @@ class Persona(Base):
|
||||
# Treated specially (cannot be user edited etc.)
|
||||
builtin_persona: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
# Featured personas are highlighted in the UI
|
||||
featured: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
# Default personas are personas created by admins and are automatically added
|
||||
# to all users' assistants list.
|
||||
is_default_persona: Mapped[bool] = mapped_column(
|
||||
Boolean, default=False, nullable=False
|
||||
)
|
||||
# controls whether the persona is available to be selected by users
|
||||
is_visible: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
# controls the ordering of personas in the UI
|
||||
@@ -4926,9 +4943,7 @@ class ScimUserMapping(Base):
|
||||
__tablename__ = "scim_user_mapping"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
external_id: Mapped[str | None] = mapped_column(
|
||||
String, unique=True, index=True, nullable=True
|
||||
)
|
||||
external_id: Mapped[str] = mapped_column(String, unique=True, index=True)
|
||||
user_id: Mapped[UUID] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), unique=True, nullable=False
|
||||
)
|
||||
@@ -4985,25 +5000,3 @@ class CodeInterpreterServer(Base):
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
server_enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
|
||||
|
||||
class CacheStore(Base):
|
||||
"""Key-value cache table used by ``PostgresCacheBackend``.
|
||||
|
||||
Replaces Redis for simple KV caching, locks, and list operations
|
||||
when ``CACHE_BACKEND=postgres`` (NO_VECTOR_DB deployments).
|
||||
|
||||
Intentionally separate from ``KVStore``:
|
||||
- Stores raw bytes (LargeBinary) vs JSONB, matching Redis semantics.
|
||||
- Has ``expires_at`` for TTL; rows are periodically garbage-collected.
|
||||
- Holds ephemeral data (tokens, stop signals, lock state) not
|
||||
persistent application config, so cleanup can be aggressive.
|
||||
"""
|
||||
|
||||
__tablename__ = "cache_store"
|
||||
|
||||
key: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
value: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
|
||||
expires_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
@@ -18,8 +18,11 @@ from sqlalchemy.orm import Session
|
||||
from onyx.access.hierarchy_access import get_user_external_group_ids
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS
|
||||
from onyx.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
|
||||
from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX
|
||||
from onyx.db.document_access import get_accessible_documents_by_ids
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
@@ -251,15 +254,13 @@ def create_update_persona(
|
||||
# Permission to actually use these is checked later
|
||||
|
||||
try:
|
||||
# Featured persona validation
|
||||
if create_persona_request.featured:
|
||||
|
||||
# Curators can edit featured personas, but not make them
|
||||
# TODO this will be reworked soon with RBAC permissions feature
|
||||
# Default persona validation
|
||||
if create_persona_request.is_default_persona:
|
||||
# Curators can edit default personas, but not make them
|
||||
if user.role == UserRole.CURATOR or user.role == UserRole.GLOBAL_CURATOR:
|
||||
pass
|
||||
elif user.role != UserRole.ADMIN:
|
||||
raise ValueError("Only admins can make a featured persona")
|
||||
raise ValueError("Only admins can make a default persona")
|
||||
|
||||
# Convert incoming string UUIDs to UUID objects for DB operations
|
||||
converted_user_file_ids = None
|
||||
@@ -280,6 +281,7 @@ def create_update_persona(
|
||||
document_set_ids=create_persona_request.document_set_ids,
|
||||
tool_ids=create_persona_request.tool_ids,
|
||||
is_public=create_persona_request.is_public,
|
||||
recency_bias=create_persona_request.recency_bias,
|
||||
llm_model_provider_override=create_persona_request.llm_model_provider_override,
|
||||
llm_model_version_override=create_persona_request.llm_model_version_override,
|
||||
starter_messages=create_persona_request.starter_messages,
|
||||
@@ -293,7 +295,10 @@ def create_update_persona(
|
||||
remove_image=create_persona_request.remove_image,
|
||||
search_start_date=create_persona_request.search_start_date,
|
||||
label_ids=create_persona_request.label_ids,
|
||||
featured=create_persona_request.featured,
|
||||
num_chunks=create_persona_request.num_chunks,
|
||||
llm_relevance_filter=create_persona_request.llm_relevance_filter,
|
||||
llm_filter_extraction=create_persona_request.llm_filter_extraction,
|
||||
is_default_persona=create_persona_request.is_default_persona,
|
||||
user_file_ids=converted_user_file_ids,
|
||||
commit=False,
|
||||
hierarchy_node_ids=create_persona_request.hierarchy_node_ids,
|
||||
@@ -869,6 +874,10 @@ def upsert_persona(
|
||||
user: User | None,
|
||||
name: str,
|
||||
description: str,
|
||||
num_chunks: float,
|
||||
llm_relevance_filter: bool,
|
||||
llm_filter_extraction: bool,
|
||||
recency_bias: RecencyBiasSetting,
|
||||
llm_model_provider_override: str | None,
|
||||
llm_model_version_override: str | None,
|
||||
starter_messages: list[StarterMessage] | None,
|
||||
@@ -889,11 +898,13 @@ def upsert_persona(
|
||||
remove_image: bool | None = None,
|
||||
search_start_date: datetime | None = None,
|
||||
builtin_persona: bool = False,
|
||||
featured: bool | None = None,
|
||||
is_default_persona: bool | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
user_file_ids: list[UUID] | None = None,
|
||||
hierarchy_node_ids: list[int] | None = None,
|
||||
document_ids: list[str] | None = None,
|
||||
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
|
||||
chunks_below: int = CONTEXT_CHUNKS_BELOW,
|
||||
replace_base_system_prompt: bool = False,
|
||||
) -> Persona:
|
||||
"""
|
||||
@@ -1004,6 +1015,12 @@ def upsert_persona(
|
||||
# `default` and `built-in` properties can only be set when creating a persona.
|
||||
existing_persona.name = name
|
||||
existing_persona.description = description
|
||||
existing_persona.num_chunks = num_chunks
|
||||
existing_persona.chunks_above = chunks_above
|
||||
existing_persona.chunks_below = chunks_below
|
||||
existing_persona.llm_relevance_filter = llm_relevance_filter
|
||||
existing_persona.llm_filter_extraction = llm_filter_extraction
|
||||
existing_persona.recency_bias = recency_bias
|
||||
existing_persona.llm_model_provider_override = llm_model_provider_override
|
||||
existing_persona.llm_model_version_override = llm_model_version_override
|
||||
existing_persona.starter_messages = starter_messages
|
||||
@@ -1017,8 +1034,10 @@ def upsert_persona(
|
||||
if label_ids is not None:
|
||||
existing_persona.labels.clear()
|
||||
existing_persona.labels = labels or []
|
||||
existing_persona.featured = (
|
||||
featured if featured is not None else existing_persona.featured
|
||||
existing_persona.is_default_persona = (
|
||||
is_default_persona
|
||||
if is_default_persona is not None
|
||||
else existing_persona.is_default_persona
|
||||
)
|
||||
# Update embedded prompt fields if provided
|
||||
if system_prompt is not None:
|
||||
@@ -1071,6 +1090,12 @@ def upsert_persona(
|
||||
is_public=is_public,
|
||||
name=name,
|
||||
description=description,
|
||||
num_chunks=num_chunks,
|
||||
chunks_above=chunks_above,
|
||||
chunks_below=chunks_below,
|
||||
llm_relevance_filter=llm_relevance_filter,
|
||||
llm_filter_extraction=llm_filter_extraction,
|
||||
recency_bias=recency_bias,
|
||||
builtin_persona=builtin_persona,
|
||||
system_prompt=system_prompt or "",
|
||||
task_prompt=task_prompt or "",
|
||||
@@ -1086,7 +1111,9 @@ def upsert_persona(
|
||||
display_priority=display_priority,
|
||||
is_visible=is_visible,
|
||||
search_start_date=search_start_date,
|
||||
featured=(featured if featured is not None else False),
|
||||
is_default_persona=(
|
||||
is_default_persona if is_default_persona is not None else False
|
||||
),
|
||||
user_files=user_files or [],
|
||||
labels=labels or [],
|
||||
hierarchy_nodes=hierarchy_nodes or [],
|
||||
@@ -1131,9 +1158,9 @@ def delete_old_default_personas(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_persona_featured(
|
||||
def update_persona_is_default(
|
||||
persona_id: int,
|
||||
featured: bool,
|
||||
is_default: bool,
|
||||
db_session: Session,
|
||||
user: User,
|
||||
) -> None:
|
||||
@@ -1141,7 +1168,7 @@ def update_persona_featured(
|
||||
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
|
||||
)
|
||||
|
||||
persona.featured = featured
|
||||
persona.is_default_persona = is_default
|
||||
db_session.commit()
|
||||
|
||||
|
||||
|
||||
@@ -9,9 +9,8 @@ from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.background import BackgroundTasks
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
@@ -52,7 +51,7 @@ def create_user_files(
|
||||
) -> CategorizedFilesResult:
|
||||
|
||||
# Categorize the files
|
||||
categorized_files = categorize_uploaded_files(files, db_session)
|
||||
categorized_files = categorize_uploaded_files(files)
|
||||
# NOTE: At the moment, zip metadata is not used for user files.
|
||||
# Should revisit to decide whether this should be a feature.
|
||||
upload_response = upload_files(categorized_files.acceptable, FileOrigin.USER_FILE)
|
||||
@@ -106,8 +105,8 @@ def upload_files_to_user_files_with_indexing(
|
||||
user: User,
|
||||
temp_id_map: dict[str, str] | None,
|
||||
db_session: Session,
|
||||
background_tasks: BackgroundTasks | None = None,
|
||||
) -> CategorizedFilesResult:
|
||||
# Validate project ownership if a project_id is provided
|
||||
if project_id is not None and user is not None:
|
||||
if not check_project_ownership(project_id, user.id, db_session):
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
@@ -128,27 +127,16 @@ def upload_files_to_user_files_with_indexing(
|
||||
logger.warning(
|
||||
f"File {rejected_file.filename} rejected for {rejected_file.reason}"
|
||||
)
|
||||
|
||||
if DISABLE_VECTOR_DB and background_tasks is not None:
|
||||
from onyx.background.task_utils import drain_processing_loop
|
||||
|
||||
background_tasks.add_task(drain_processing_loop, tenant_id)
|
||||
for user_file in user_files:
|
||||
logger.info(f"Queued in-process processing for user_file_id={user_file.id}")
|
||||
else:
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
|
||||
for user_file in user_files:
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered indexing for user_file_id={user_file.id} "
|
||||
f"with task_id={task.id}"
|
||||
)
|
||||
for user_file in user_files:
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered indexing for user_file_id={user_file.id} with task_id={task.id}"
|
||||
)
|
||||
|
||||
return CategorizedFilesResult(
|
||||
user_files=user_files,
|
||||
|
||||
@@ -5,6 +5,8 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.constants import DEFAULT_PERSONA_SLACK_CHANNEL_NAME
|
||||
from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX
|
||||
from onyx.db.models import ChannelConfig
|
||||
@@ -43,6 +45,8 @@ def create_slack_channel_persona(
|
||||
channel_name: str | None,
|
||||
document_set_ids: list[int],
|
||||
existing_persona_id: int | None = None,
|
||||
num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
||||
enable_auto_filters: bool = False,
|
||||
) -> Persona:
|
||||
"""NOTE: does not commit changes"""
|
||||
|
||||
@@ -69,13 +73,17 @@ def create_slack_channel_persona(
|
||||
system_prompt="",
|
||||
task_prompt="",
|
||||
datetime_aware=True,
|
||||
num_chunks=num_chunks,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=enable_auto_filters,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
tool_ids=[search_tool.id],
|
||||
document_set_ids=document_set_ids,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
is_public=True,
|
||||
featured=False,
|
||||
is_default_persona=False,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ import httpx
|
||||
from opensearchpy import NotFoundError
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH
|
||||
from onyx.configs.app_configs import VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT
|
||||
from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
@@ -562,7 +563,12 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
)
|
||||
|
||||
if not self._client.index_exists():
|
||||
index_settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
if USING_AWS_MANAGED_OPENSEARCH:
|
||||
index_settings = (
|
||||
DocumentSchema.get_index_settings_for_aws_managed_opensearch()
|
||||
)
|
||||
else:
|
||||
index_settings = DocumentSchema.get_index_settings()
|
||||
self._client.create_index(
|
||||
mappings=expected_mappings,
|
||||
settings=index_settings,
|
||||
|
||||
@@ -12,7 +12,6 @@ from pydantic import model_validator
|
||||
from pydantic import SerializerFunctionWrapHandler
|
||||
|
||||
from onyx.configs.app_configs import OPENSEARCH_TEXT_ANALYZER
|
||||
from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
|
||||
from onyx.document_index.opensearch.constants import EF_CONSTRUCTION
|
||||
@@ -526,7 +525,7 @@ class DocumentSchema:
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_index_settings_for_aws_managed_opensearch_st_dev() -> dict[str, Any]:
|
||||
def get_index_settings_for_aws_managed_opensearch() -> dict[str, Any]:
|
||||
"""
|
||||
Settings for AWS-managed OpenSearch.
|
||||
|
||||
@@ -547,41 +546,3 @@ class DocumentSchema:
|
||||
"knn.algo_param.ef_search": EF_SEARCH,
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_index_settings_for_aws_managed_opensearch_mt_cloud() -> dict[str, Any]:
|
||||
"""
|
||||
Settings for AWS-managed OpenSearch in multi-tenant cloud.
|
||||
|
||||
324 shards very roughly targets a storage load of ~30Gb per shard, which
|
||||
according to AWS OpenSearch documentation is within a good target range.
|
||||
|
||||
As documented above we need 2 replicas for a total of 3 copies of the
|
||||
data because the cluster is configured with 3-AZ awareness.
|
||||
"""
|
||||
return {
|
||||
"index": {
|
||||
"number_of_shards": 324,
|
||||
"number_of_replicas": 2,
|
||||
# Required for vector search.
|
||||
"knn": True,
|
||||
"knn.algo_param.ef_search": EF_SEARCH,
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_index_settings_based_on_environment() -> dict[str, Any]:
|
||||
"""
|
||||
Returns the index settings based on the environment.
|
||||
"""
|
||||
if USING_AWS_MANAGED_OPENSEARCH:
|
||||
if MULTI_TENANT:
|
||||
return (
|
||||
DocumentSchema.get_index_settings_for_aws_managed_opensearch_mt_cloud()
|
||||
)
|
||||
else:
|
||||
return (
|
||||
DocumentSchema.get_index_settings_for_aws_managed_opensearch_st_dev()
|
||||
)
|
||||
else:
|
||||
return DocumentSchema.get_index_settings()
|
||||
|
||||
@@ -4,33 +4,39 @@ import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Redis key prefix for OAuth state
|
||||
OAUTH_STATE_PREFIX = "federated_oauth"
|
||||
OAUTH_STATE_TTL = 300 # 5 minutes
|
||||
# Default TTL for OAuth state (5 minutes)
|
||||
OAUTH_STATE_TTL = 300
|
||||
|
||||
|
||||
class OAuthSession:
|
||||
"""Represents an OAuth session stored in the cache backend."""
|
||||
"""Represents an OAuth session stored in Redis."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
federated_connector_id: int,
|
||||
user_id: str,
|
||||
redirect_uri: str | None = None,
|
||||
additional_data: dict[str, Any] | None = None,
|
||||
redirect_uri: Optional[str] = None,
|
||||
additional_data: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
self.federated_connector_id = federated_connector_id
|
||||
self.user_id = user_id
|
||||
self.redirect_uri = redirect_uri
|
||||
self.additional_data = additional_data or {}
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for Redis storage."""
|
||||
return {
|
||||
"federated_connector_id": self.federated_connector_id,
|
||||
"user_id": self.user_id,
|
||||
@@ -39,7 +45,8 @@ class OAuthSession:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "OAuthSession":
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "OAuthSession":
|
||||
"""Create from dictionary retrieved from Redis."""
|
||||
return cls(
|
||||
federated_connector_id=data["federated_connector_id"],
|
||||
user_id=data["user_id"],
|
||||
@@ -51,27 +58,31 @@ class OAuthSession:
|
||||
def generate_oauth_state(
|
||||
federated_connector_id: int,
|
||||
user_id: str,
|
||||
redirect_uri: str | None = None,
|
||||
additional_data: dict[str, Any] | None = None,
|
||||
redirect_uri: Optional[str] = None,
|
||||
additional_data: Optional[Dict[str, Any]] = None,
|
||||
ttl: int = OAUTH_STATE_TTL,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a secure state parameter and store session data in the cache backend.
|
||||
Generate a secure state parameter and store session data in Redis.
|
||||
|
||||
Args:
|
||||
federated_connector_id: ID of the federated connector
|
||||
user_id: ID of the user initiating OAuth
|
||||
redirect_uri: Optional redirect URI after OAuth completion
|
||||
additional_data: Any additional data to store with the session
|
||||
ttl: Time-to-live in seconds for the cache key
|
||||
ttl: Time-to-live in seconds for the Redis key
|
||||
|
||||
Returns:
|
||||
Base64-encoded state parameter
|
||||
"""
|
||||
# Generate a random UUID for the state
|
||||
state_uuid = uuid.uuid4()
|
||||
state_b64 = base64.urlsafe_b64encode(state_uuid.bytes).decode("utf-8").rstrip("=")
|
||||
|
||||
# Convert UUID to base64 for URL-safe state parameter
|
||||
state_bytes = state_uuid.bytes
|
||||
state_b64 = base64.urlsafe_b64encode(state_bytes).decode("utf-8").rstrip("=")
|
||||
|
||||
# Create session object
|
||||
session = OAuthSession(
|
||||
federated_connector_id=federated_connector_id,
|
||||
user_id=user_id,
|
||||
@@ -79,9 +90,15 @@ def generate_oauth_state(
|
||||
additional_data=additional_data,
|
||||
)
|
||||
|
||||
cache = get_cache_backend()
|
||||
cache_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
|
||||
cache.set(cache_key, json.dumps(session.to_dict()), ex=ttl)
|
||||
# Store in Redis with TTL
|
||||
redis_client = get_redis_client()
|
||||
redis_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
|
||||
|
||||
redis_client.set(
|
||||
redis_key,
|
||||
json.dumps(session.to_dict()),
|
||||
ex=ttl,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Generated OAuth state for federated_connector_id={federated_connector_id}, "
|
||||
@@ -108,15 +125,18 @@ def verify_oauth_state(state: str) -> OAuthSession:
|
||||
state_bytes = base64.urlsafe_b64decode(padded_state)
|
||||
state_uuid = uuid.UUID(bytes=state_bytes)
|
||||
|
||||
cache = get_cache_backend()
|
||||
cache_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
|
||||
# Look up in Redis
|
||||
redis_client = get_redis_client()
|
||||
redis_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
|
||||
|
||||
session_data = cache.get(cache_key)
|
||||
session_data = cast(bytes, redis_client.get(redis_key))
|
||||
if not session_data:
|
||||
raise ValueError(f"OAuth state not found: {state}")
|
||||
raise ValueError(f"OAuth state not found in Redis: {state}")
|
||||
|
||||
cache.delete(cache_key)
|
||||
# Delete the key after retrieval (one-time use)
|
||||
redis_client.delete(redis_key)
|
||||
|
||||
# Parse and return session
|
||||
session_dict = json.loads(session_data)
|
||||
return OAuthSession.from_dict(session_dict)
|
||||
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import json
|
||||
from typing import cast
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from redis.client import Redis
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import KVStore
|
||||
from onyx.key_value_store.interface import KeyValueStore
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
|
||||
@@ -18,27 +20,22 @@ KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day
|
||||
|
||||
|
||||
class PgRedisKVStore(KeyValueStore):
|
||||
def __init__(self, cache: CacheBackend | None = None) -> None:
|
||||
self._cache = cache
|
||||
|
||||
def _get_cache(self) -> CacheBackend:
|
||||
if self._cache is None:
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
|
||||
self._cache = get_cache_backend()
|
||||
return self._cache
|
||||
def __init__(self, redis_client: Redis | None = None) -> None:
|
||||
# If no redis_client is provided, fall back to the context var
|
||||
if redis_client is not None:
|
||||
self.redis_client = redis_client
|
||||
else:
|
||||
self.redis_client = get_redis_client()
|
||||
|
||||
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
|
||||
# Not encrypted in Cache backend (typically Redis), but encrypted in Postgres
|
||||
# Not encrypted in Redis, but encrypted in Postgres
|
||||
try:
|
||||
self._get_cache().set(
|
||||
self.redis_client.set(
|
||||
REDIS_KEY_PREFIX + key, json.dumps(val), ex=KV_REDIS_KEY_EXPIRATION
|
||||
)
|
||||
except Exception as e:
|
||||
# Fallback gracefully to Postgres if Cache backend fails
|
||||
logger.error(
|
||||
f"Failed to set value in Cache backend for key '{key}': {str(e)}"
|
||||
)
|
||||
# Fallback gracefully to Postgres if Redis fails
|
||||
logger.error(f"Failed to set value in Redis for key '{key}': {str(e)}")
|
||||
|
||||
encrypted_val = val if encrypt else None
|
||||
plain_val = val if not encrypt else None
|
||||
@@ -56,12 +53,16 @@ class PgRedisKVStore(KeyValueStore):
|
||||
def load(self, key: str, refresh_cache: bool = False) -> JSON_ro:
|
||||
if not refresh_cache:
|
||||
try:
|
||||
cached = self._get_cache().get(REDIS_KEY_PREFIX + key)
|
||||
if cached is not None:
|
||||
return json.loads(cached.decode("utf-8"))
|
||||
redis_value = self.redis_client.get(REDIS_KEY_PREFIX + key)
|
||||
if redis_value:
|
||||
if not isinstance(redis_value, bytes):
|
||||
raise ValueError(
|
||||
f"Redis value for key '{key}' is not a bytes object"
|
||||
)
|
||||
return json.loads(redis_value.decode("utf-8"))
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to get value from cache for key '{key}': {str(e)}"
|
||||
f"Failed to get value from Redis for key '{key}': {str(e)}"
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -78,21 +79,21 @@ class PgRedisKVStore(KeyValueStore):
|
||||
value = None
|
||||
|
||||
try:
|
||||
self._get_cache().set(
|
||||
self.redis_client.set(
|
||||
REDIS_KEY_PREFIX + key,
|
||||
json.dumps(value),
|
||||
ex=KV_REDIS_KEY_EXPIRATION,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to set value in cache for key '{key}': {str(e)}")
|
||||
logger.error(f"Failed to set value in Redis for key '{key}': {str(e)}")
|
||||
|
||||
return cast(JSON_ro, value)
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
try:
|
||||
self._get_cache().delete(REDIS_KEY_PREFIX + key)
|
||||
self.redis_client.delete(REDIS_KEY_PREFIX + key)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete value from cache for key '{key}': {str(e)}")
|
||||
logger.error(f"Failed to delete value from Redis for key '{key}': {str(e)}")
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
result = db_session.query(KVStore).filter_by(key=key).delete()
|
||||
|
||||
@@ -67,18 +67,6 @@ Status checked against LiteLLM v1.81.6-nightly (2026-02-02):
|
||||
STATUS: STILL NEEDED - litellm_core_utils/litellm_logging.py lines 3185-3199 set
|
||||
usage as a dict with chat completion format instead of keeping it as
|
||||
ResponseAPIUsage. Our patch creates a deep copy before modification.
|
||||
|
||||
7. Responses API metadata=None TypeError (_patch_responses_metadata_none):
|
||||
- LiteLLM's @client decorator wrapper in utils.py uses kwargs.get("metadata", {})
|
||||
to check for router calls, but when metadata is explicitly None (key exists with
|
||||
value None), the default {} is not used
|
||||
- This causes "argument of type 'NoneType' is not iterable" TypeError which swallows
|
||||
the real exception (e.g. AuthenticationError for wrong API key)
|
||||
- Surfaces as: APIConnectionError: OpenAIException - argument of type 'NoneType' is
|
||||
not iterable
|
||||
STATUS: STILL NEEDED - litellm/utils.py wrapper function (line 1721) does not guard
|
||||
against metadata being explicitly None. Triggered when Responses API bridge
|
||||
passes **litellm_params containing metadata=None.
|
||||
"""
|
||||
|
||||
import time
|
||||
@@ -737,44 +725,6 @@ def _patch_logging_assembled_streaming_response() -> None:
|
||||
LiteLLMLoggingObj._get_assembled_streaming_response = _patched_get_assembled_streaming_response # type: ignore[method-assign]
|
||||
|
||||
|
||||
def _patch_responses_metadata_none() -> None:
|
||||
"""
|
||||
Patches litellm.responses to normalize metadata=None to metadata={} in kwargs.
|
||||
|
||||
LiteLLM's @client decorator wrapper in utils.py (line 1721) does:
|
||||
_is_litellm_router_call = "model_group" in kwargs.get("metadata", {})
|
||||
When metadata is explicitly None in kwargs, kwargs.get("metadata", {}) returns
|
||||
None (the key exists, so the default is not used), causing:
|
||||
TypeError: argument of type 'NoneType' is not iterable
|
||||
|
||||
This swallows the real exception (e.g. AuthenticationError) and surfaces as:
|
||||
APIConnectionError: OpenAIException - argument of type 'NoneType' is not iterable
|
||||
|
||||
This happens when the Responses API bridge calls litellm.responses() with
|
||||
**litellm_params which may contain metadata=None.
|
||||
|
||||
STATUS: STILL NEEDED - litellm/utils.py wrapper function uses kwargs.get("metadata", {})
|
||||
which does not guard against metadata being explicitly None. Same pattern exists
|
||||
on line 1407 for async path.
|
||||
"""
|
||||
import litellm as _litellm
|
||||
from functools import wraps
|
||||
|
||||
original_responses = _litellm.responses
|
||||
|
||||
if getattr(original_responses, "_metadata_patched", False):
|
||||
return
|
||||
|
||||
@wraps(original_responses)
|
||||
def _patched_responses(*args: Any, **kwargs: Any) -> Any:
|
||||
if kwargs.get("metadata") is None:
|
||||
kwargs["metadata"] = {}
|
||||
return original_responses(*args, **kwargs)
|
||||
|
||||
_patched_responses._metadata_patched = True # type: ignore[attr-defined]
|
||||
_litellm.responses = _patched_responses
|
||||
|
||||
|
||||
def apply_monkey_patches() -> None:
|
||||
"""
|
||||
Apply all necessary monkey patches to LiteLLM for compatibility.
|
||||
@@ -786,7 +736,6 @@ def apply_monkey_patches() -> None:
|
||||
- Patching AzureOpenAIResponsesAPIConfig.should_fake_stream to enable native streaming
|
||||
- Patching ResponsesAPIResponse.model_construct to fix usage format in all code paths
|
||||
- Patching LiteLLMLoggingObj._get_assembled_streaming_response to avoid mutating original response
|
||||
- Patching litellm.responses to fix metadata=None causing TypeError in error handling
|
||||
"""
|
||||
_patch_ollama_chunk_parser()
|
||||
_patch_openai_responses_parallel_tool_calls()
|
||||
@@ -794,4 +743,3 @@ def apply_monkey_patches() -> None:
|
||||
_patch_azure_responses_should_fake_stream()
|
||||
_patch_responses_api_usage_format()
|
||||
_patch_logging_assembled_streaming_response()
|
||||
_patch_responses_metadata_none()
|
||||
|
||||
@@ -13,38 +13,44 @@ from datetime import datetime
|
||||
import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
from onyx.db.llm import fetch_auto_mode_providers
|
||||
from onyx.db.llm import sync_auto_mode_models
|
||||
from onyx.llm.well_known_providers.auto_update_models import LLMRecommendations
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_CACHE_KEY_LAST_UPDATED_AT = "auto_llm_update:last_updated_at"
|
||||
_CACHE_TTL_SECONDS = 60 * 60 * 24 # 24 hours
|
||||
# Redis key for caching the last updated timestamp (per-tenant)
|
||||
_REDIS_KEY_LAST_UPDATED_AT = "auto_llm_update:last_updated_at"
|
||||
|
||||
|
||||
def _get_cached_last_updated_at() -> datetime | None:
|
||||
"""Get the cached last_updated_at timestamp from Redis."""
|
||||
try:
|
||||
value = get_cache_backend().get(_CACHE_KEY_LAST_UPDATED_AT)
|
||||
if value is not None:
|
||||
redis_client = get_redis_client()
|
||||
value = redis_client.get(_REDIS_KEY_LAST_UPDATED_AT)
|
||||
if value and isinstance(value, bytes):
|
||||
# Value is bytes, decode to string then parse as ISO format
|
||||
return datetime.fromisoformat(value.decode("utf-8"))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get cached last_updated_at: {e}")
|
||||
logger.warning(f"Failed to get cached last_updated_at from Redis: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _set_cached_last_updated_at(updated_at: datetime) -> None:
|
||||
"""Set the cached last_updated_at timestamp in Redis."""
|
||||
try:
|
||||
get_cache_backend().set(
|
||||
_CACHE_KEY_LAST_UPDATED_AT,
|
||||
redis_client = get_redis_client()
|
||||
# Store as ISO format string, with 24 hour expiration
|
||||
redis_client.set(
|
||||
_REDIS_KEY_LAST_UPDATED_AT,
|
||||
updated_at.isoformat(),
|
||||
ex=_CACHE_TTL_SECONDS,
|
||||
ex=60 * 60 * 24, # 24 hours
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to set cached last_updated_at: {e}")
|
||||
logger.warning(f"Failed to set cached last_updated_at in Redis: {e}")
|
||||
|
||||
|
||||
def fetch_llm_recommendations_from_github(
|
||||
@@ -142,8 +148,9 @@ def sync_llm_models_from_github(
|
||||
|
||||
|
||||
def reset_cache() -> None:
|
||||
"""Reset the cache timestamp. Useful for testing."""
|
||||
"""Reset the cache timestamp in Redis. Useful for testing."""
|
||||
try:
|
||||
get_cache_backend().delete(_CACHE_KEY_LAST_UPDATED_AT)
|
||||
redis_client = get_redis_client()
|
||||
redis_client.delete(_REDIS_KEY_LAST_UPDATED_AT)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to reset cache: {e}")
|
||||
logger.warning(f"Failed to reset cache in Redis: {e}")
|
||||
|
||||
@@ -32,14 +32,11 @@ from onyx.auth.schemas import UserUpdate
|
||||
from onyx.auth.users import auth_backend
|
||||
from onyx.auth.users import create_onyx_oauth_router
|
||||
from onyx.auth.users import fastapi_users
|
||||
from onyx.cache.interface import CacheBackendType
|
||||
from onyx.configs.app_configs import APP_API_PREFIX
|
||||
from onyx.configs.app_configs import APP_HOST
|
||||
from onyx.configs.app_configs import APP_PORT
|
||||
from onyx.configs.app_configs import AUTH_RATE_LIMITING_ENABLED
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import CACHE_BACKEND
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import LOG_ENDPOINT_LATENCY
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_ID
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
@@ -257,53 +254,8 @@ def include_auth_router_with_prefix(
|
||||
)
|
||||
|
||||
|
||||
def validate_cache_backend_settings() -> None:
|
||||
"""Validate that CACHE_BACKEND=postgres is only used with DISABLE_VECTOR_DB.
|
||||
|
||||
The Postgres cache backend eliminates the Redis dependency, but only works
|
||||
when Celery is not running (which requires DISABLE_VECTOR_DB=true).
|
||||
"""
|
||||
if CACHE_BACKEND == CacheBackendType.POSTGRES and not DISABLE_VECTOR_DB:
|
||||
raise RuntimeError(
|
||||
"CACHE_BACKEND=postgres requires DISABLE_VECTOR_DB=true. "
|
||||
"The Postgres cache backend is only supported in no-vector-DB "
|
||||
"deployments where Celery is replaced by the in-process task runner."
|
||||
)
|
||||
|
||||
|
||||
def validate_no_vector_db_settings() -> None:
|
||||
"""Validate that DISABLE_VECTOR_DB is not combined with incompatible settings.
|
||||
|
||||
Raises RuntimeError if DISABLE_VECTOR_DB is set alongside MULTI_TENANT or ENABLE_CRAFT,
|
||||
since these modes require infrastructure that is removed in no-vector-DB deployments.
|
||||
"""
|
||||
if not DISABLE_VECTOR_DB:
|
||||
return
|
||||
|
||||
if MULTI_TENANT:
|
||||
raise RuntimeError(
|
||||
"DISABLE_VECTOR_DB cannot be used with MULTI_TENANT. "
|
||||
"Multi-tenant deployments require the vector database for "
|
||||
"per-tenant document indexing and search. Run in single-tenant "
|
||||
"mode when disabling the vector database."
|
||||
)
|
||||
|
||||
from onyx.server.features.build.configs import ENABLE_CRAFT
|
||||
|
||||
if ENABLE_CRAFT:
|
||||
raise RuntimeError(
|
||||
"DISABLE_VECTOR_DB cannot be used with ENABLE_CRAFT. "
|
||||
"Onyx Craft requires background workers for sandbox lifecycle "
|
||||
"management, which are removed in no-vector-DB deployments. "
|
||||
"Disable Craft (ENABLE_CRAFT=false) when disabling the vector database."
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
|
||||
validate_no_vector_db_settings()
|
||||
validate_cache_backend_settings()
|
||||
|
||||
# Set recursion limit
|
||||
if SYSTEM_RECURSION_LIMIT is not None:
|
||||
sys.setrecursionlimit(SYSTEM_RECURSION_LIMIT)
|
||||
@@ -372,20 +324,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
|
||||
if AUTH_RATE_LIMITING_ENABLED:
|
||||
await setup_auth_limiter()
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
from onyx.background.periodic_poller import recover_stuck_user_files
|
||||
from onyx.background.periodic_poller import start_periodic_poller
|
||||
|
||||
recover_stuck_user_files(POSTGRES_DEFAULT_SCHEMA)
|
||||
start_periodic_poller(POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
yield
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
from onyx.background.periodic_poller import stop_periodic_poller
|
||||
|
||||
stop_periodic_poller()
|
||||
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
if AUTH_RATE_LIMITING_ENABLED:
|
||||
|
||||
@@ -3,12 +3,10 @@ import datetime
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_FEEDBACK_REMINDER
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.user_preferences import activate_user
|
||||
from onyx.db.users import add_slack_user_if_not_exists
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.onyxbot.slack.blocks import get_feedback_reminder_blocks
|
||||
@@ -245,44 +243,6 @@ def handle_message(
|
||||
)
|
||||
return False
|
||||
|
||||
elif (
|
||||
not existing_user.is_active
|
||||
and existing_user.role == UserRole.SLACK_USER
|
||||
):
|
||||
check_seat_fn = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license",
|
||||
"check_seat_availability",
|
||||
None,
|
||||
)
|
||||
seat_result = check_seat_fn(db_session=db_session)
|
||||
if seat_result is not None and not seat_result.available:
|
||||
logger.info(
|
||||
f"Blocked inactive Slack user {message_info.email}: "
|
||||
f"{seat_result.error_message}"
|
||||
)
|
||||
respond_in_thread_or_channel(
|
||||
client=client,
|
||||
channel=channel,
|
||||
thread_ts=message_info.msg_to_respond,
|
||||
text=(
|
||||
"We weren't able to respond because your organization "
|
||||
"has reached its user seat limit. Your account is "
|
||||
"currently deactivated and cannot be reactivated "
|
||||
"until more seats are available. Please contact "
|
||||
"your Onyx administrator."
|
||||
),
|
||||
)
|
||||
return False
|
||||
|
||||
activate_user(existing_user, db_session)
|
||||
invalidate_license_cache_fn = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license",
|
||||
"invalidate_license_cache",
|
||||
None,
|
||||
)
|
||||
invalidate_license_cache_fn()
|
||||
logger.info(f"Reactivated inactive Slack user {message_info.email}")
|
||||
|
||||
add_slack_user_if_not_exists(db_session, message_info.email)
|
||||
|
||||
# first check if we need to respond with a standard answer
|
||||
|
||||
@@ -1133,8 +1133,7 @@ done
|
||||
# Already deleted
|
||||
service_deleted = True
|
||||
else:
|
||||
logger.error(f"Error deleting Service {service_name}: {e}")
|
||||
raise
|
||||
logger.warning(f"Error deleting Service {service_name}: {e}")
|
||||
|
||||
pod_deleted = False
|
||||
try:
|
||||
@@ -1149,8 +1148,7 @@ done
|
||||
# Already deleted
|
||||
pod_deleted = True
|
||||
else:
|
||||
logger.error(f"Error deleting Pod {pod_name}: {e}")
|
||||
raise
|
||||
logger.warning(f"Error deleting Pod {pod_name}: {e}")
|
||||
|
||||
# Wait for resources to be fully deleted to prevent 409 conflicts
|
||||
# on immediate re-provisioning
|
||||
|
||||
@@ -80,7 +80,7 @@ def cleanup_idle_sandboxes_task(self: Task, *, tenant_id: str) -> None: # noqa:
|
||||
|
||||
# Prevent overlapping runs of this task
|
||||
if not lock.acquire(blocking=False):
|
||||
task_logger.info("cleanup_idle_sandboxes_task - lock not acquired, skipping")
|
||||
task_logger.debug("cleanup_idle_sandboxes_task - lock not acquired, skipping")
|
||||
return
|
||||
|
||||
try:
|
||||
|
||||
@@ -32,7 +32,7 @@ from onyx.db.persona import get_persona_snapshots_for_user
|
||||
from onyx.db.persona import get_persona_snapshots_paginated
|
||||
from onyx.db.persona import mark_persona_as_deleted
|
||||
from onyx.db.persona import mark_persona_as_not_deleted
|
||||
from onyx.db.persona import update_persona_featured
|
||||
from onyx.db.persona import update_persona_is_default
|
||||
from onyx.db.persona import update_persona_label
|
||||
from onyx.db.persona import update_persona_public_status
|
||||
from onyx.db.persona import update_persona_shared
|
||||
@@ -130,8 +130,8 @@ class IsPublicRequest(BaseModel):
|
||||
is_public: bool
|
||||
|
||||
|
||||
class IsFeaturedRequest(BaseModel):
|
||||
featured: bool
|
||||
class IsDefaultRequest(BaseModel):
|
||||
is_default_persona: bool
|
||||
|
||||
|
||||
@admin_router.patch("/{persona_id}/visible")
|
||||
@@ -168,22 +168,22 @@ def patch_user_persona_public_status(
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.patch("/{persona_id}/featured")
|
||||
def patch_persona_featured_status(
|
||||
@admin_router.patch("/{persona_id}/default")
|
||||
def patch_persona_default_status(
|
||||
persona_id: int,
|
||||
is_featured_request: IsFeaturedRequest,
|
||||
is_default_request: IsDefaultRequest,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
update_persona_featured(
|
||||
update_persona_is_default(
|
||||
persona_id=persona_id,
|
||||
featured=is_featured_request.featured,
|
||||
is_default=is_default_request.is_default_persona,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to update persona featured status")
|
||||
logger.exception("Failed to update persona default status")
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import HierarchyNodeType
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import HierarchyNode
|
||||
@@ -107,7 +108,11 @@ class PersonaUpsertRequest(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
document_set_ids: list[int]
|
||||
num_chunks: float
|
||||
is_public: bool
|
||||
recency_bias: RecencyBiasSetting
|
||||
llm_filter_extraction: bool
|
||||
llm_relevance_filter: bool
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
starter_messages: list[StarterMessage] | None = None
|
||||
@@ -123,7 +128,7 @@ class PersonaUpsertRequest(BaseModel):
|
||||
)
|
||||
search_start_date: datetime | None = None
|
||||
label_ids: list[int] | None = None
|
||||
featured: bool = False
|
||||
is_default_persona: bool = False
|
||||
display_priority: int | None = None
|
||||
# Accept string UUIDs from frontend
|
||||
user_file_ids: list[str] | None = None
|
||||
@@ -150,6 +155,9 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
tools: list[ToolSnapshot]
|
||||
starter_messages: list[StarterMessage] | None
|
||||
|
||||
llm_relevance_filter: bool
|
||||
llm_filter_extraction: bool
|
||||
|
||||
# only show document sets in the UI that the assistant has access to
|
||||
document_sets: list[DocumentSetSummary]
|
||||
# Counts for knowledge sources (used to determine if search tool should be enabled)
|
||||
@@ -167,7 +175,7 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
is_public: bool
|
||||
is_visible: bool
|
||||
display_priority: int | None
|
||||
featured: bool
|
||||
is_default_persona: bool
|
||||
builtin_persona: bool
|
||||
|
||||
# Used for filtering
|
||||
@@ -206,6 +214,8 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
if should_expose_tool_to_fe(tool)
|
||||
],
|
||||
starter_messages=persona.starter_messages,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
document_sets=[
|
||||
DocumentSetSummary.from_model(document_set)
|
||||
for document_set in persona.document_sets
|
||||
@@ -220,7 +230,7 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
is_public=persona.is_public,
|
||||
is_visible=persona.is_visible,
|
||||
display_priority=persona.display_priority,
|
||||
featured=persona.featured,
|
||||
is_default_persona=persona.is_default_persona,
|
||||
builtin_persona=persona.builtin_persona,
|
||||
labels=[PersonaLabelSnapshot.from_model(label) for label in persona.labels],
|
||||
owner=(
|
||||
@@ -242,9 +252,11 @@ class PersonaSnapshot(BaseModel):
|
||||
# Return string UUIDs to frontend for consistency
|
||||
user_file_ids: list[str]
|
||||
display_priority: int | None
|
||||
featured: bool
|
||||
is_default_persona: bool
|
||||
builtin_persona: bool
|
||||
starter_messages: list[StarterMessage] | None
|
||||
llm_relevance_filter: bool
|
||||
llm_filter_extraction: bool
|
||||
tools: list[ToolSnapshot]
|
||||
labels: list["PersonaLabelSnapshot"]
|
||||
owner: MinimalUserSnapshot | None
|
||||
@@ -253,6 +265,7 @@ class PersonaSnapshot(BaseModel):
|
||||
document_sets: list[DocumentSetSummary]
|
||||
llm_model_provider_override: str | None
|
||||
llm_model_version_override: str | None
|
||||
num_chunks: float | None
|
||||
# Hierarchy nodes attached for scoped search
|
||||
hierarchy_nodes: list[HierarchyNodeSnapshot] = Field(default_factory=list)
|
||||
# Individual documents attached for scoped search
|
||||
@@ -276,9 +289,11 @@ class PersonaSnapshot(BaseModel):
|
||||
icon_name=persona.icon_name,
|
||||
user_file_ids=[str(file.id) for file in persona.user_files],
|
||||
display_priority=persona.display_priority,
|
||||
featured=persona.featured,
|
||||
is_default_persona=persona.is_default_persona,
|
||||
builtin_persona=persona.builtin_persona,
|
||||
starter_messages=persona.starter_messages,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
tools=[
|
||||
ToolSnapshot.from_model(tool)
|
||||
for tool in persona.tools
|
||||
@@ -309,6 +324,7 @@ class PersonaSnapshot(BaseModel):
|
||||
],
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
num_chunks=persona.num_chunks,
|
||||
system_prompt=persona.system_prompt,
|
||||
replace_base_system_prompt=persona.replace_base_system_prompt,
|
||||
task_prompt=persona.task_prompt,
|
||||
@@ -316,10 +332,12 @@ class PersonaSnapshot(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
# Model with full context on persona's internal settings
|
||||
# Model with full context on perona's internal settings
|
||||
# This is used for flows which need to know all settings
|
||||
class FullPersonaSnapshot(PersonaSnapshot):
|
||||
search_start_date: datetime | None = None
|
||||
llm_relevance_filter: bool = False
|
||||
llm_filter_extraction: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
@@ -342,7 +360,7 @@ class FullPersonaSnapshot(PersonaSnapshot):
|
||||
icon_name=persona.icon_name,
|
||||
user_file_ids=[str(file.id) for file in persona.user_files],
|
||||
display_priority=persona.display_priority,
|
||||
featured=persona.featured,
|
||||
is_default_persona=persona.is_default_persona,
|
||||
builtin_persona=persona.builtin_persona,
|
||||
starter_messages=persona.starter_messages,
|
||||
users=[
|
||||
@@ -373,7 +391,10 @@ class FullPersonaSnapshot(PersonaSnapshot):
|
||||
DocumentSetSummary.from_model(document_set_model)
|
||||
for document_set_model in persona.document_sets
|
||||
],
|
||||
num_chunks=persona.num_chunks,
|
||||
search_start_date=persona.search_start_date,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
system_prompt=persona.system_prompt,
|
||||
|
||||
@@ -2,7 +2,6 @@ import json
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import BackgroundTasks
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import Form
|
||||
@@ -13,7 +12,13 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
enqueue_user_file_project_sync_task,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
get_user_file_project_sync_queue_depth,
|
||||
)
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -29,6 +34,7 @@ from onyx.db.models import UserProject
|
||||
from onyx.db.persona import get_personas_by_ids
|
||||
from onyx.db.projects import get_project_token_count
|
||||
from onyx.db.projects import upload_files_to_user_files_with_indexing
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.features.projects.models import CategorizedFilesSnapshot
|
||||
from onyx.server.features.projects.models import ChatSessionRequest
|
||||
from onyx.server.features.projects.models import TokenCountResponse
|
||||
@@ -49,27 +55,7 @@ class UserFileDeleteResult(BaseModel):
|
||||
assistant_names: list[str] = []
|
||||
|
||||
|
||||
def _trigger_user_file_project_sync(
|
||||
user_file_id: UUID,
|
||||
tenant_id: str,
|
||||
background_tasks: BackgroundTasks | None = None,
|
||||
) -> None:
|
||||
if DISABLE_VECTOR_DB and background_tasks is not None:
|
||||
from onyx.background.task_utils import drain_project_sync_loop
|
||||
|
||||
background_tasks.add_task(drain_project_sync_loop, tenant_id)
|
||||
logger.info(f"Queued in-process project sync for user_file_id={user_file_id}")
|
||||
return
|
||||
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
enqueue_user_file_project_sync_task,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
get_user_file_project_sync_queue_depth,
|
||||
)
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
|
||||
def _trigger_user_file_project_sync(user_file_id: UUID, tenant_id: str) -> None:
|
||||
queue_depth = get_user_file_project_sync_queue_depth(client_app)
|
||||
if queue_depth > USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH:
|
||||
logger.warning(
|
||||
@@ -125,7 +111,6 @@ def create_project(
|
||||
|
||||
@router.post("/file/upload", tags=PUBLIC_API_TAGS)
|
||||
def upload_user_files(
|
||||
bg_tasks: BackgroundTasks,
|
||||
files: list[UploadFile] = File(...),
|
||||
project_id: int | None = Form(None),
|
||||
temp_id_map: str | None = Form(None), # JSON string mapping hashed key -> temp_id
|
||||
@@ -152,12 +137,12 @@ def upload_user_files(
|
||||
user=user,
|
||||
temp_id_map=parsed_temp_id_map,
|
||||
db_session=db_session,
|
||||
background_tasks=bg_tasks if DISABLE_VECTOR_DB else None,
|
||||
)
|
||||
|
||||
return CategorizedFilesSnapshot.from_result(categorized_files_result)
|
||||
|
||||
except Exception as e:
|
||||
# Log error with type, message, and stack for easier debugging
|
||||
logger.exception(f"Error uploading files - {type(e).__name__}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
@@ -207,7 +192,6 @@ def get_files_in_project(
|
||||
def unlink_user_file_from_project(
|
||||
project_id: int,
|
||||
file_id: UUID,
|
||||
bg_tasks: BackgroundTasks,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
@@ -224,6 +208,7 @@ def unlink_user_file_from_project(
|
||||
if project is None:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
user_id = user.id
|
||||
user_file = (
|
||||
db_session.query(UserFile)
|
||||
.filter(UserFile.id == file_id, UserFile.user_id == user_id)
|
||||
@@ -239,7 +224,7 @@ def unlink_user_file_from_project(
|
||||
db_session.commit()
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
_trigger_user_file_project_sync(user_file.id, tenant_id, bg_tasks)
|
||||
_trigger_user_file_project_sync(user_file.id, tenant_id)
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
@@ -252,7 +237,6 @@ def unlink_user_file_from_project(
|
||||
def link_user_file_to_project(
|
||||
project_id: int,
|
||||
file_id: UUID,
|
||||
bg_tasks: BackgroundTasks,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserFileSnapshot:
|
||||
@@ -284,7 +268,7 @@ def link_user_file_to_project(
|
||||
db_session.commit()
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
_trigger_user_file_project_sync(user_file.id, tenant_id, bg_tasks)
|
||||
_trigger_user_file_project_sync(user_file.id, tenant_id)
|
||||
|
||||
return UserFileSnapshot.from_model(user_file)
|
||||
|
||||
@@ -351,7 +335,7 @@ def upsert_project_instructions(
|
||||
class ProjectPayload(BaseModel):
|
||||
project: UserProjectSnapshot
|
||||
files: list[UserFileSnapshot] | None = None
|
||||
persona_id_to_featured: dict[int, bool] | None = None
|
||||
persona_id_to_is_default: dict[int, bool] | None = None
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -370,11 +354,13 @@ def get_project_details(
|
||||
if session.persona_id is not None
|
||||
]
|
||||
personas = get_personas_by_ids(persona_ids, db_session)
|
||||
persona_id_to_featured = {persona.id: persona.featured for persona in personas}
|
||||
persona_id_to_is_default = {
|
||||
persona.id: persona.is_default_persona for persona in personas
|
||||
}
|
||||
return ProjectPayload(
|
||||
project=project,
|
||||
files=files,
|
||||
persona_id_to_featured=persona_id_to_featured,
|
||||
persona_id_to_is_default=persona_id_to_is_default,
|
||||
)
|
||||
|
||||
|
||||
@@ -440,7 +426,6 @@ def delete_project(
|
||||
@router.delete("/file/{file_id}", tags=PUBLIC_API_TAGS)
|
||||
def delete_user_file(
|
||||
file_id: UUID,
|
||||
bg_tasks: BackgroundTasks,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserFileDeleteResult:
|
||||
@@ -473,25 +458,15 @@ def delete_user_file(
|
||||
db_session.commit()
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
if DISABLE_VECTOR_DB:
|
||||
from onyx.background.task_utils import drain_delete_loop
|
||||
|
||||
bg_tasks.add_task(drain_delete_loop, tenant_id)
|
||||
logger.info(f"Queued in-process delete for user_file_id={user_file.id}")
|
||||
else:
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file.id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_DELETE,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered delete for user_file_id={user_file.id} "
|
||||
f"with task_id={task.id}"
|
||||
)
|
||||
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file.id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_DELETE,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered delete for user_file_id={user_file.id} with task_id={task.id}"
|
||||
)
|
||||
return UserFileDeleteResult(
|
||||
has_associations=False, project_names=[], assistant_names=[]
|
||||
)
|
||||
|
||||
@@ -7,14 +7,13 @@ from PIL import UnidentifiedImageError
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.file_processing.password_validation import is_file_password_protected
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -117,9 +116,7 @@ def estimate_image_tokens_for_upload(
|
||||
pass
|
||||
|
||||
|
||||
def categorize_uploaded_files(
|
||||
files: list[UploadFile], db_session: Session
|
||||
) -> CategorizedFiles:
|
||||
def categorize_uploaded_files(files: list[UploadFile]) -> CategorizedFiles:
|
||||
"""
|
||||
Categorize uploaded files based on text extractability and tokenized length.
|
||||
|
||||
@@ -131,11 +128,11 @@ def categorize_uploaded_files(
|
||||
"""
|
||||
|
||||
results = CategorizedFiles()
|
||||
default_model = fetch_default_llm_model(db_session)
|
||||
llm = get_default_llm()
|
||||
|
||||
model_name = default_model.name if default_model else None
|
||||
provider_type = default_model.llm_provider.provider if default_model else None
|
||||
tokenizer = get_tokenizer(model_name=model_name, provider_type=provider_type)
|
||||
tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name, provider_type=llm.config.model_provider
|
||||
)
|
||||
|
||||
# Check if threshold checks should be skipped
|
||||
skip_threshold = False
|
||||
|
||||
@@ -8,10 +8,10 @@ import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx import __version__
|
||||
from onyx.cache.factory import get_shared_cache_backend
|
||||
from onyx.configs.app_configs import INSTANCE_TYPE
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.release_notes import create_release_notifications_for_versions
|
||||
from onyx.redis.redis_pool import get_shared_redis_client
|
||||
from onyx.server.features.release_notes.constants import AUTO_REFRESH_THRESHOLD_SECONDS
|
||||
from onyx.server.features.release_notes.constants import FETCH_TIMEOUT
|
||||
from onyx.server.features.release_notes.constants import GITHUB_CHANGELOG_RAW_URL
|
||||
@@ -113,46 +113,60 @@ def parse_mdx_to_release_note_entries(mdx_content: str) -> list[ReleaseNoteEntry
|
||||
|
||||
|
||||
def get_cached_etag() -> str | None:
|
||||
cache = get_shared_cache_backend()
|
||||
"""Get the cached GitHub ETag from Redis."""
|
||||
redis_client = get_shared_redis_client()
|
||||
try:
|
||||
etag = cache.get(REDIS_KEY_ETAG)
|
||||
etag = redis_client.get(REDIS_KEY_ETAG)
|
||||
if etag:
|
||||
return etag.decode("utf-8")
|
||||
return etag.decode("utf-8") if isinstance(etag, bytes) else str(etag)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get cached etag: {e}")
|
||||
logger.error(f"Failed to get cached etag from Redis: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_last_fetch_time() -> datetime | None:
|
||||
cache = get_shared_cache_backend()
|
||||
"""Get the last fetch timestamp from Redis."""
|
||||
redis_client = get_shared_redis_client()
|
||||
try:
|
||||
raw = cache.get(REDIS_KEY_FETCHED_AT)
|
||||
if not raw:
|
||||
fetched_at_str = redis_client.get(REDIS_KEY_FETCHED_AT)
|
||||
if not fetched_at_str:
|
||||
return None
|
||||
|
||||
last_fetch = datetime.fromisoformat(raw.decode("utf-8"))
|
||||
decoded = (
|
||||
fetched_at_str.decode("utf-8")
|
||||
if isinstance(fetched_at_str, bytes)
|
||||
else str(fetched_at_str)
|
||||
)
|
||||
|
||||
last_fetch = datetime.fromisoformat(decoded)
|
||||
|
||||
# Defensively ensure timezone awareness
|
||||
# fromisoformat() returns naive datetime if input lacks timezone
|
||||
if last_fetch.tzinfo is None:
|
||||
# Assume UTC for naive datetimes
|
||||
last_fetch = last_fetch.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
# Convert to UTC if timezone-aware
|
||||
last_fetch = last_fetch.astimezone(timezone.utc)
|
||||
|
||||
return last_fetch
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get last fetch time from cache: {e}")
|
||||
logger.error(f"Failed to get last fetch time from Redis: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def save_fetch_metadata(etag: str | None) -> None:
|
||||
cache = get_shared_cache_backend()
|
||||
"""Save ETag and fetch timestamp to Redis."""
|
||||
redis_client = get_shared_redis_client()
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
try:
|
||||
cache.set(REDIS_KEY_FETCHED_AT, now.isoformat(), ex=REDIS_CACHE_TTL)
|
||||
redis_client.set(REDIS_KEY_FETCHED_AT, now.isoformat(), ex=REDIS_CACHE_TTL)
|
||||
if etag:
|
||||
cache.set(REDIS_KEY_ETAG, etag, ex=REDIS_CACHE_TTL)
|
||||
redis_client.set(REDIS_KEY_ETAG, etag, ex=REDIS_CACHE_TTL)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save fetch metadata to cache: {e}")
|
||||
logger.error(f"Failed to save fetch metadata to Redis: {e}")
|
||||
|
||||
|
||||
def is_cache_stale() -> bool:
|
||||
@@ -182,10 +196,11 @@ def ensure_release_notes_fresh_and_notify(db_session: Session) -> None:
|
||||
if not is_cache_stale():
|
||||
return
|
||||
|
||||
cache = get_shared_cache_backend()
|
||||
lock = cache.lock(
|
||||
# Acquire lock to prevent concurrent fetches
|
||||
redis_client = get_shared_redis_client()
|
||||
lock = redis_client.lock(
|
||||
OnyxRedisLocks.RELEASE_NOTES_FETCH_LOCK,
|
||||
timeout=90,
|
||||
timeout=90, # 90 second timeout for the lock
|
||||
)
|
||||
|
||||
# Non-blocking acquire - if we can't get the lock, another request is handling it
|
||||
|
||||
@@ -5,6 +5,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
|
||||
from onyx.configs.kg_configs import KG_BETA_ASSISTANT_DESCRIPTION
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.entities import get_entity_stats_by_grounded_source_name
|
||||
from onyx.db.entity_type import get_configured_entity_types
|
||||
@@ -133,7 +134,11 @@ def enable_or_disable_kg(
|
||||
system_prompt=KG_BETA_ASSISTANT_SYSTEM_PROMPT,
|
||||
task_prompt=KG_BETA_ASSISTANT_TASK_PROMPT,
|
||||
datetime_aware=False,
|
||||
num_chunks=25,
|
||||
llm_relevance_filter=False,
|
||||
is_public=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
document_set_ids=[],
|
||||
tool_ids=[search_tool.id, kg_tool.id],
|
||||
llm_model_provider_override=None,
|
||||
@@ -142,7 +147,7 @@ def enable_or_disable_kg(
|
||||
users=[user.id],
|
||||
groups=[],
|
||||
label_ids=[],
|
||||
featured=False,
|
||||
is_default_persona=False,
|
||||
display_priority=0,
|
||||
user_file_ids=[],
|
||||
)
|
||||
|
||||
@@ -97,6 +97,7 @@ def _build_llm_provider_request(
|
||||
), # Only this from source
|
||||
api_base=api_base, # From request
|
||||
api_version=api_version, # From request
|
||||
default_model_name=model_name,
|
||||
deployment_name=deployment_name, # From request
|
||||
is_public=True,
|
||||
groups=[],
|
||||
@@ -135,6 +136,7 @@ def _build_llm_provider_request(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
default_model_name=model_name,
|
||||
deployment_name=deployment_name,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
@@ -166,6 +168,7 @@ def _create_image_gen_llm_provider__no_commit(
|
||||
api_key=provider_request.api_key,
|
||||
api_base=provider_request.api_base,
|
||||
api_version=provider_request.api_version,
|
||||
default_model_name=provider_request.default_model_name,
|
||||
deployment_name=provider_request.deployment_name,
|
||||
is_public=provider_request.is_public,
|
||||
custom_config=provider_request.custom_config,
|
||||
|
||||
@@ -22,10 +22,7 @@ from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.llm import can_user_access_llm_provider
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.db.llm import fetch_default_vision_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import fetch_existing_llm_provider_by_id
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_existing_models
|
||||
from onyx.db.llm import fetch_persona_with_groups
|
||||
@@ -55,12 +52,11 @@ from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
)
|
||||
from onyx.server.manage.llm.models import BedrockFinalModelResponse
|
||||
from onyx.server.manage.llm.models import BedrockModelsRequest
|
||||
from onyx.server.manage.llm.models import DefaultModel
|
||||
from onyx.server.manage.llm.models import LLMCost
|
||||
from onyx.server.manage.llm.models import LLMProviderDescriptor
|
||||
from onyx.server.manage.llm.models import LLMProviderResponse
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.manage.llm.models import OllamaFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OllamaModelDetails
|
||||
from onyx.server.manage.llm.models import OllamaModelsRequest
|
||||
@@ -237,9 +233,12 @@ def test_llm_configuration(
|
||||
|
||||
test_api_key = test_llm_request.api_key
|
||||
test_custom_config = test_llm_request.custom_config
|
||||
if test_llm_request.id:
|
||||
existing_provider = fetch_existing_llm_provider_by_id(
|
||||
id=test_llm_request.id, db_session=db_session
|
||||
if test_llm_request.name:
|
||||
# NOTE: we are querying by name. we probably should be querying by an invariant id, but
|
||||
# as it turns out the name is not editable in the UI and other code also keys off name,
|
||||
# so we won't rock the boat just yet.
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=test_llm_request.name, db_session=db_session
|
||||
)
|
||||
if existing_provider:
|
||||
test_custom_config = _restore_masked_custom_config_values(
|
||||
@@ -269,7 +268,7 @@ def test_llm_configuration(
|
||||
|
||||
llm = get_llm(
|
||||
provider=test_llm_request.provider,
|
||||
model=test_llm_request.model,
|
||||
model=test_llm_request.default_model_name,
|
||||
api_key=test_api_key,
|
||||
api_base=test_llm_request.api_base,
|
||||
api_version=test_llm_request.api_version,
|
||||
@@ -304,7 +303,7 @@ def list_llm_providers(
|
||||
include_image_gen: bool = Query(False),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LLMProviderResponse[LLMProviderView]:
|
||||
) -> list[LLMProviderView]:
|
||||
start_time = datetime.now(timezone.utc)
|
||||
logger.debug("Starting to fetch LLM providers")
|
||||
|
||||
@@ -329,15 +328,7 @@ def list_llm_providers(
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
logger.debug(f"Completed fetching LLM providers in {duration:.2f} seconds")
|
||||
|
||||
return LLMProviderResponse[LLMProviderView].from_models(
|
||||
providers=llm_provider_list,
|
||||
default_text=DefaultModel.from_model_config(
|
||||
fetch_default_llm_model(db_session)
|
||||
),
|
||||
default_vision=DefaultModel.from_model_config(
|
||||
fetch_default_vision_model(db_session)
|
||||
),
|
||||
)
|
||||
return llm_provider_list
|
||||
|
||||
|
||||
@admin_router.put("/provider")
|
||||
@@ -353,44 +344,18 @@ def put_llm_provider(
|
||||
# validate request (e.g. if we're intending to create but the name already exists we should throw an error)
|
||||
# NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache
|
||||
# the result
|
||||
existing_provider = None
|
||||
if llm_provider_upsert_request.id:
|
||||
existing_provider = fetch_existing_llm_provider_by_id(
|
||||
id=llm_provider_upsert_request.id, db_session=db_session
|
||||
)
|
||||
|
||||
# Check name constraints
|
||||
# TODO: Once port from name to id is complete, unique name will no longer be required
|
||||
if existing_provider and llm_provider_upsert_request.name != existing_provider.name:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Renaming providers is not currently supported",
|
||||
)
|
||||
|
||||
found_provider = fetch_existing_llm_provider(
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
)
|
||||
if found_provider is not None and found_provider is not existing_provider:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Provider with name={llm_provider_upsert_request.name} already exists",
|
||||
)
|
||||
|
||||
if existing_provider and is_creation:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"LLM Provider with name {llm_provider_upsert_request.name} and "
|
||||
f"id={llm_provider_upsert_request.id} already exists"
|
||||
),
|
||||
detail=f"LLM Provider with name {llm_provider_upsert_request.name} already exists",
|
||||
)
|
||||
elif not existing_provider and not is_creation:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"LLM Provider with name {llm_provider_upsert_request.name} and "
|
||||
f"id={llm_provider_upsert_request.id} does not exist"
|
||||
),
|
||||
detail=f"LLM Provider with name {llm_provider_upsert_request.name} does not exist",
|
||||
)
|
||||
|
||||
# SSRF Protection: Validate api_base and custom_config match stored values
|
||||
@@ -428,6 +393,22 @@ def put_llm_provider(
|
||||
deduplicated_personas.append(persona_id)
|
||||
llm_provider_upsert_request.personas = deduplicated_personas
|
||||
|
||||
default_model_found = False
|
||||
|
||||
for model_configuration in llm_provider_upsert_request.model_configurations:
|
||||
if model_configuration.name == llm_provider_upsert_request.default_model_name:
|
||||
model_configuration.is_visible = True
|
||||
default_model_found = True
|
||||
|
||||
# TODO: Remove this logic on api change
|
||||
# Believed to be a dead pathway but we want to be safe for now
|
||||
if not default_model_found:
|
||||
llm_provider_upsert_request.model_configurations.append(
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=llm_provider_upsert_request.default_model_name, is_visible=True
|
||||
)
|
||||
)
|
||||
|
||||
# the llm api key is sanitized when returned to clients, so the only time we
|
||||
# should get a real key is when it is explicitly changed
|
||||
if existing_provider and not llm_provider_upsert_request.api_key_changed:
|
||||
@@ -457,8 +438,8 @@ def put_llm_provider(
|
||||
config = fetch_llm_recommendations_from_github()
|
||||
if config and llm_provider_upsert_request.provider in config.providers:
|
||||
# Refetch the provider to get the updated model
|
||||
updated_provider = fetch_existing_llm_provider_by_id(
|
||||
id=result.id, db_session=db_session
|
||||
updated_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
)
|
||||
if updated_provider:
|
||||
sync_auto_mode_models(
|
||||
@@ -479,48 +460,37 @@ def put_llm_provider(
|
||||
@admin_router.delete("/provider/{provider_id}")
|
||||
def delete_llm_provider(
|
||||
provider_id: int,
|
||||
force: bool = Query(False),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
if not force:
|
||||
model = fetch_default_llm_model(db_session)
|
||||
|
||||
if model and model.llm_provider_id == provider_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot delete the default LLM provider",
|
||||
)
|
||||
|
||||
remove_llm_provider(db_session, provider_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.post("/default")
|
||||
@admin_router.post("/provider/{provider_id}/default")
|
||||
def set_provider_as_default(
|
||||
default_model_request: DefaultModel,
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
update_default_provider(
|
||||
provider_id=default_model_request.provider_id,
|
||||
model_name=default_model_request.model_name,
|
||||
db_session=db_session,
|
||||
)
|
||||
update_default_provider(provider_id=provider_id, db_session=db_session)
|
||||
|
||||
|
||||
@admin_router.post("/default-vision")
|
||||
@admin_router.post("/provider/{provider_id}/default-vision")
|
||||
def set_provider_as_default_vision(
|
||||
default_model: DefaultModel,
|
||||
provider_id: int,
|
||||
vision_model: str | None = Query(
|
||||
None, description="The default vision model to use"
|
||||
),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
if vision_model is None:
|
||||
raise HTTPException(status_code=404, detail="Vision model not provided")
|
||||
update_default_vision_provider(
|
||||
provider_id=default_model.provider_id,
|
||||
vision_model=default_model.model_name,
|
||||
db_session=db_session,
|
||||
provider_id=provider_id, vision_model=vision_model, db_session=db_session
|
||||
)
|
||||
|
||||
|
||||
@@ -546,7 +516,7 @@ def get_auto_config(
|
||||
def get_vision_capable_providers(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LLMProviderResponse[VisionProviderResponse]:
|
||||
) -> list[VisionProviderResponse]:
|
||||
"""Return a list of LLM providers and their models that support image input"""
|
||||
vision_models = fetch_existing_models(
|
||||
db_session=db_session, flow_types=[LLMModelFlowType.VISION]
|
||||
@@ -575,13 +545,7 @@ def get_vision_capable_providers(
|
||||
]
|
||||
|
||||
logger.debug(f"Found {len(vision_provider_response)} vision-capable providers")
|
||||
|
||||
return LLMProviderResponse[VisionProviderResponse].from_models(
|
||||
providers=vision_provider_response,
|
||||
default_vision=DefaultModel.from_model_config(
|
||||
fetch_default_vision_model(db_session)
|
||||
),
|
||||
)
|
||||
return vision_provider_response
|
||||
|
||||
|
||||
"""Endpoints for all"""
|
||||
@@ -591,7 +555,7 @@ def get_vision_capable_providers(
|
||||
def list_llm_provider_basics(
|
||||
user: User = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LLMProviderResponse[LLMProviderDescriptor]:
|
||||
) -> list[LLMProviderDescriptor]:
|
||||
"""Get LLM providers accessible to the current user.
|
||||
|
||||
Returns:
|
||||
@@ -613,9 +577,9 @@ def list_llm_provider_basics(
|
||||
for provider in all_providers:
|
||||
# Use centralized access control logic with persona=None since we're
|
||||
# listing providers without a specific persona context. This correctly:
|
||||
# - Includes public providers WITHOUT persona restrictions
|
||||
# - Includes all public providers
|
||||
# - Includes providers user can access via group membership
|
||||
# - Excludes providers with persona restrictions (requires specific persona)
|
||||
# - Excludes persona-only restricted providers (requires specific persona)
|
||||
# - Excludes non-public providers with no restrictions (admin-only)
|
||||
if can_user_access_llm_provider(
|
||||
provider, user_group_ids, persona=None, is_admin=is_admin
|
||||
@@ -628,15 +592,7 @@ def list_llm_provider_basics(
|
||||
f"Completed fetching {len(accessible_providers)} user-accessible providers in {duration:.2f} seconds"
|
||||
)
|
||||
|
||||
return LLMProviderResponse[LLMProviderDescriptor].from_models(
|
||||
providers=accessible_providers,
|
||||
default_text=DefaultModel.from_model_config(
|
||||
fetch_default_llm_model(db_session)
|
||||
),
|
||||
default_vision=DefaultModel.from_model_config(
|
||||
fetch_default_vision_model(db_session)
|
||||
),
|
||||
)
|
||||
return accessible_providers
|
||||
|
||||
|
||||
def get_valid_model_names_for_persona(
|
||||
@@ -648,7 +604,7 @@ def get_valid_model_names_for_persona(
|
||||
|
||||
Returns a list of model names (e.g., ["gpt-4o", "claude-3-5-sonnet"]) that are
|
||||
available to the user when using this persona, respecting all RBAC restrictions.
|
||||
Public providers are included unless they have persona restrictions that exclude this persona.
|
||||
Public providers are always included.
|
||||
"""
|
||||
persona = fetch_persona_with_groups(db_session, persona_id)
|
||||
if not persona:
|
||||
@@ -662,7 +618,7 @@ def get_valid_model_names_for_persona(
|
||||
|
||||
valid_models = []
|
||||
for llm_provider_model in all_providers:
|
||||
# Check access with persona context — respects all RBAC restrictions
|
||||
# Public providers always included, restricted checked via RBAC
|
||||
if can_user_access_llm_provider(
|
||||
llm_provider_model, user_group_ids, persona, is_admin=is_admin
|
||||
):
|
||||
@@ -679,11 +635,11 @@ def list_llm_providers_for_persona(
|
||||
persona_id: int,
|
||||
user: User = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LLMProviderResponse[LLMProviderDescriptor]:
|
||||
) -> list[LLMProviderDescriptor]:
|
||||
"""Get LLM providers for a specific persona.
|
||||
|
||||
Returns providers that the user can access when using this persona:
|
||||
- Public providers (respecting persona restrictions if set)
|
||||
- All public providers (is_public=True) - ALWAYS included
|
||||
- Restricted providers user can access via group/persona restrictions
|
||||
|
||||
This endpoint is used for background fetching of restricted providers
|
||||
@@ -712,7 +668,7 @@ def list_llm_providers_for_persona(
|
||||
llm_provider_list: list[LLMProviderDescriptor] = []
|
||||
|
||||
for llm_provider_model in all_providers:
|
||||
# Check access with persona context — respects persona restrictions
|
||||
# Use simplified access check - public providers always included
|
||||
if can_user_access_llm_provider(
|
||||
llm_provider_model, user_group_ids, persona, is_admin=is_admin
|
||||
):
|
||||
@@ -726,51 +682,7 @@ def list_llm_providers_for_persona(
|
||||
f"Completed fetching {len(llm_provider_list)} LLM providers for persona {persona_id} in {duration:.2f} seconds"
|
||||
)
|
||||
|
||||
# Get the default model and vision model for the persona
|
||||
# TODO: Port persona's over to use ID
|
||||
persona_default_provider = persona.llm_model_provider_override
|
||||
persona_default_model = persona.llm_model_version_override
|
||||
|
||||
default_text_model = fetch_default_llm_model(db_session)
|
||||
default_vision_model = fetch_default_vision_model(db_session)
|
||||
|
||||
# Build default_text and default_vision using persona overrides when available,
|
||||
# falling back to the global defaults.
|
||||
default_text = DefaultModel.from_model_config(default_text_model)
|
||||
default_vision = DefaultModel.from_model_config(default_vision_model)
|
||||
|
||||
if persona_default_provider:
|
||||
provider = fetch_existing_llm_provider(persona_default_provider, db_session)
|
||||
if provider and can_user_access_llm_provider(
|
||||
provider, user_group_ids, persona, is_admin=is_admin
|
||||
):
|
||||
if persona_default_model:
|
||||
# Persona specifies both provider and model — use them directly
|
||||
default_text = DefaultModel(
|
||||
provider_id=provider.id,
|
||||
model_name=persona_default_model,
|
||||
)
|
||||
else:
|
||||
# Persona specifies only the provider — pick a visible (public) model,
|
||||
# falling back to any model on this provider
|
||||
visible_model = next(
|
||||
(mc for mc in provider.model_configurations if mc.is_visible),
|
||||
None,
|
||||
)
|
||||
fallback_model = visible_model or next(
|
||||
iter(provider.model_configurations), None
|
||||
)
|
||||
if fallback_model:
|
||||
default_text = DefaultModel(
|
||||
provider_id=provider.id,
|
||||
model_name=fallback_model.name,
|
||||
)
|
||||
|
||||
return LLMProviderResponse[LLMProviderDescriptor].from_models(
|
||||
providers=llm_provider_list,
|
||||
default_text=default_text,
|
||||
default_vision=default_vision,
|
||||
)
|
||||
return llm_provider_list
|
||||
|
||||
|
||||
@admin_router.get("/provider-contextual-cost")
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
@@ -25,22 +21,50 @@ if TYPE_CHECKING:
|
||||
ModelConfiguration as ModelConfigurationModel,
|
||||
)
|
||||
|
||||
T = TypeVar("T", "LLMProviderDescriptor", "LLMProviderView", "VisionProviderResponse")
|
||||
|
||||
# TODO: Clear this up on api refactor
|
||||
# There is still logic that requires sending each providers default model name
|
||||
# There is no logic that requires sending the providers default vision model name
|
||||
# We only send for the one that is actually the default
|
||||
def get_default_llm_model_name(llm_provider_model: "LLMProviderModel") -> str:
|
||||
"""Find the default conversation model name for a provider.
|
||||
|
||||
Returns the model name if found, otherwise returns empty string.
|
||||
"""
|
||||
for model_config in llm_provider_model.model_configurations:
|
||||
for flow in model_config.llm_model_flows:
|
||||
if flow.is_default and flow.llm_model_flow_type == LLMModelFlowType.CHAT:
|
||||
return model_config.name
|
||||
return ""
|
||||
|
||||
|
||||
def get_default_vision_model_name(llm_provider_model: "LLMProviderModel") -> str | None:
|
||||
"""Find the default vision model name for a provider.
|
||||
|
||||
Returns the model name if found, otherwise returns None.
|
||||
"""
|
||||
for model_config in llm_provider_model.model_configurations:
|
||||
for flow in model_config.llm_model_flows:
|
||||
if flow.is_default and flow.llm_model_flow_type == LLMModelFlowType.VISION:
|
||||
return model_config.name
|
||||
return None
|
||||
|
||||
|
||||
class TestLLMRequest(BaseModel):
|
||||
# provider level
|
||||
id: int | None = None
|
||||
name: str | None = None
|
||||
provider: str
|
||||
model: str
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
custom_config: dict[str, str] | None = None
|
||||
|
||||
# model level
|
||||
default_model_name: str
|
||||
deployment_name: str | None = None
|
||||
|
||||
model_configurations: list["ModelConfigurationUpsertRequest"]
|
||||
|
||||
# if try and use the existing API/custom config key
|
||||
api_key_changed: bool
|
||||
custom_config_changed: bool
|
||||
@@ -56,10 +80,13 @@ class LLMProviderDescriptor(BaseModel):
|
||||
"""A descriptor for an LLM provider that can be safely viewed by
|
||||
non-admin users. Used when giving a list of available LLMs."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
provider: str
|
||||
provider_display_name: str # Human-friendly name like "Claude (Anthropic)"
|
||||
default_model_name: str
|
||||
is_default_provider: bool | None
|
||||
is_default_vision_provider: bool | None
|
||||
default_vision_model: str | None
|
||||
model_configurations: list["ModelConfigurationView"]
|
||||
|
||||
@classmethod
|
||||
@@ -72,12 +99,24 @@ class LLMProviderDescriptor(BaseModel):
|
||||
)
|
||||
|
||||
provider = llm_provider_model.provider
|
||||
default_model_name = get_default_llm_model_name(llm_provider_model)
|
||||
default_vision_model = get_default_vision_model_name(llm_provider_model)
|
||||
|
||||
is_default_provider = bool(default_model_name)
|
||||
is_default_vision_provider = default_vision_model is not None
|
||||
|
||||
default_model_name = (
|
||||
default_model_name or llm_provider_model.default_model_name or ""
|
||||
)
|
||||
|
||||
return cls(
|
||||
id=llm_provider_model.id,
|
||||
name=llm_provider_model.name,
|
||||
provider=provider,
|
||||
provider_display_name=get_provider_display_name(provider),
|
||||
default_model_name=default_model_name,
|
||||
is_default_provider=is_default_provider,
|
||||
is_default_vision_provider=is_default_vision_provider,
|
||||
default_vision_model=default_vision_model,
|
||||
model_configurations=filter_model_configurations(
|
||||
llm_provider_model.model_configurations, provider
|
||||
),
|
||||
@@ -91,17 +130,18 @@ class LLMProvider(BaseModel):
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
custom_config: dict[str, str] | None = None
|
||||
default_model_name: str
|
||||
is_public: bool = True
|
||||
is_auto_mode: bool = False
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
personas: list[int] = Field(default_factory=list)
|
||||
deployment_name: str | None = None
|
||||
default_vision_model: str | None = None
|
||||
|
||||
|
||||
class LLMProviderUpsertRequest(LLMProvider):
|
||||
# should only be used for a "custom" provider
|
||||
# for default providers, the built-in model names are used
|
||||
id: int | None = None
|
||||
api_key_changed: bool = False
|
||||
custom_config_changed: bool = False
|
||||
model_configurations: list["ModelConfigurationUpsertRequest"] = []
|
||||
@@ -117,6 +157,8 @@ class LLMProviderView(LLMProvider):
|
||||
"""Stripped down representation of LLMProvider for display / limited access info only"""
|
||||
|
||||
id: int
|
||||
is_default_provider: bool | None = None
|
||||
is_default_vision_provider: bool | None = None
|
||||
model_configurations: list["ModelConfigurationView"]
|
||||
|
||||
@classmethod
|
||||
@@ -138,6 +180,16 @@ class LLMProviderView(LLMProvider):
|
||||
|
||||
provider = llm_provider_model.provider
|
||||
|
||||
default_model_name = get_default_llm_model_name(llm_provider_model)
|
||||
default_vision_model = get_default_vision_model_name(llm_provider_model)
|
||||
|
||||
is_default_provider = bool(default_model_name)
|
||||
is_default_vision_provider = default_vision_model is not None
|
||||
|
||||
default_model_name = (
|
||||
default_model_name or llm_provider_model.default_model_name or ""
|
||||
)
|
||||
|
||||
return cls(
|
||||
id=llm_provider_model.id,
|
||||
name=llm_provider_model.name,
|
||||
@@ -150,6 +202,10 @@ class LLMProviderView(LLMProvider):
|
||||
api_base=llm_provider_model.api_base,
|
||||
api_version=llm_provider_model.api_version,
|
||||
custom_config=llm_provider_model.custom_config,
|
||||
default_model_name=default_model_name,
|
||||
is_default_provider=is_default_provider,
|
||||
is_default_vision_provider=is_default_vision_provider,
|
||||
default_vision_model=default_vision_model,
|
||||
is_public=llm_provider_model.is_public,
|
||||
is_auto_mode=llm_provider_model.is_auto_mode,
|
||||
groups=groups,
|
||||
@@ -369,38 +425,3 @@ class OpenRouterFinalModelResponse(BaseModel):
|
||||
int | None
|
||||
) # From OpenRouter API context_length (may be missing for some models)
|
||||
supports_image_input: bool
|
||||
|
||||
|
||||
class DefaultModel(BaseModel):
|
||||
provider_id: int
|
||||
model_name: str
|
||||
|
||||
@classmethod
|
||||
def from_model_config(
|
||||
cls, model_config: ModelConfigurationModel | None
|
||||
) -> DefaultModel | None:
|
||||
if not model_config:
|
||||
return None
|
||||
return cls(
|
||||
provider_id=model_config.llm_provider_id,
|
||||
model_name=model_config.name,
|
||||
)
|
||||
|
||||
|
||||
class LLMProviderResponse(BaseModel, Generic[T]):
|
||||
providers: list[T]
|
||||
default_text: DefaultModel | None = None
|
||||
default_vision: DefaultModel | None = None
|
||||
|
||||
@classmethod
|
||||
def from_models(
|
||||
cls,
|
||||
providers: list[T],
|
||||
default_text: DefaultModel | None = None,
|
||||
default_vision: DefaultModel | None = None,
|
||||
) -> LLMProviderResponse[T]:
|
||||
return cls(
|
||||
providers=providers,
|
||||
default_text=default_text,
|
||||
default_vision=default_vision,
|
||||
)
|
||||
|
||||
@@ -198,6 +198,7 @@ def patch_slack_channel_config(
|
||||
channel_name=channel_config["channel_name"],
|
||||
document_set_ids=slack_channel_config_creation_request.document_sets,
|
||||
existing_persona_id=existing_persona_id,
|
||||
enable_auto_filters=slack_channel_config_creation_request.enable_auto_filters,
|
||||
).id
|
||||
|
||||
slack_channel_config_model = update_slack_channel_config(
|
||||
|
||||
@@ -13,13 +13,13 @@ from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from redis.client import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.api_key import get_hashed_api_key_from_request
|
||||
from onyx.auth.pat import get_hashed_pat_from_request
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.chat.chat_processing_checker import is_chat_session_processing
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_utils import convert_chat_history_basic
|
||||
@@ -67,6 +67,7 @@ from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.llm.factory import get_llm_token_counter
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.secondary_llm_flows.chat_session_naming import generate_chat_session_name
|
||||
from onyx.server.api_key_usage import check_api_key_usage
|
||||
from onyx.server.query_and_chat.models import ChatFeedbackRequest
|
||||
@@ -151,20 +152,10 @@ def get_user_chat_sessions(
|
||||
project_id: int | None = None,
|
||||
only_non_project_chats: bool = True,
|
||||
include_failed_chats: bool = False,
|
||||
page_size: int = Query(default=50, ge=1, le=100),
|
||||
before: str | None = Query(default=None),
|
||||
) -> ChatSessionsResponse:
|
||||
user_id = user.id
|
||||
|
||||
try:
|
||||
before_dt = (
|
||||
datetime.datetime.fromisoformat(before) if before is not None else None
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=422, detail="Invalid 'before' timestamp format")
|
||||
|
||||
try:
|
||||
# Fetch one extra to determine if there are more results
|
||||
chat_sessions = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
@@ -172,16 +163,11 @@ def get_user_chat_sessions(
|
||||
project_id=project_id,
|
||||
only_non_project_chats=only_non_project_chats,
|
||||
include_failed_chats=include_failed_chats,
|
||||
limit=page_size + 1,
|
||||
before=before_dt,
|
||||
)
|
||||
|
||||
except ValueError:
|
||||
raise ValueError("Chat session does not exist or has been deleted")
|
||||
|
||||
has_more = len(chat_sessions) > page_size
|
||||
chat_sessions = chat_sessions[:page_size]
|
||||
|
||||
return ChatSessionsResponse(
|
||||
sessions=[
|
||||
ChatSessionDetails(
|
||||
@@ -195,8 +181,7 @@ def get_user_chat_sessions(
|
||||
current_temperature_override=chat.temperature_override,
|
||||
)
|
||||
for chat in chat_sessions
|
||||
],
|
||||
has_more=has_more,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -329,7 +314,7 @@ def get_chat_session(
|
||||
]
|
||||
|
||||
try:
|
||||
is_processing = is_chat_session_processing(session_id, get_cache_backend())
|
||||
is_processing = is_chat_session_processing(session_id, get_redis_client())
|
||||
# Edit the last message to indicate loading (Overriding default message value)
|
||||
if is_processing and chat_message_details:
|
||||
last_msg = chat_message_details[-1]
|
||||
@@ -926,10 +911,11 @@ async def search_chats(
|
||||
def stop_chat_session(
|
||||
chat_session_id: UUID,
|
||||
user: User = Depends(current_user), # noqa: ARG001
|
||||
redis_client: Redis = Depends(get_redis_client),
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Stop a chat session by setting a stop signal.
|
||||
Stop a chat session by setting a stop signal in Redis.
|
||||
This endpoint is called by the frontend when the user clicks the stop button.
|
||||
"""
|
||||
set_fence(chat_session_id, get_cache_backend(), True)
|
||||
set_fence(chat_session_id, redis_client, True)
|
||||
return {"message": "Chat session stopped"}
|
||||
|
||||
@@ -192,7 +192,6 @@ class ChatSessionDetails(BaseModel):
|
||||
|
||||
class ChatSessionsResponse(BaseModel):
|
||||
sessions: list[ChatSessionDetails]
|
||||
has_more: bool = False
|
||||
|
||||
|
||||
class ChatMessageDetail(BaseModel):
|
||||
|
||||
@@ -19,7 +19,6 @@ class ApplicationStatus(str, Enum):
|
||||
PAYMENT_REMINDER = "payment_reminder"
|
||||
GRACE_PERIOD = "grace_period"
|
||||
GATED_ACCESS = "gated_access"
|
||||
SEAT_LIMIT_EXCEEDED = "seat_limit_exceeded"
|
||||
|
||||
|
||||
class Notification(BaseModel):
|
||||
@@ -83,10 +82,6 @@ class Settings(BaseModel):
|
||||
# Default Assistant settings
|
||||
disable_default_assistant: bool | None = False
|
||||
|
||||
# Seat usage - populated by license enforcement when seat limit is exceeded
|
||||
seat_count: int | None = None
|
||||
used_seats: int | None = None
|
||||
|
||||
# OpenSearch migration
|
||||
opensearch_indexing_enabled: bool = False
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.app_configs import DISABLE_USER_KNOWLEDGE
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
@@ -7,8 +6,11 @@ from onyx.configs.constants import KV_SETTINGS_KEY
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -31,22 +33,30 @@ def load_settings() -> Settings:
|
||||
logger.error(f"Error loading settings from KV store: {str(e)}")
|
||||
settings = Settings()
|
||||
|
||||
cache = get_cache_backend()
|
||||
tenant_id = get_current_tenant_id() if MULTI_TENANT else None
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
try:
|
||||
value = cache.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
|
||||
value = redis_client.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
|
||||
if value is not None:
|
||||
assert isinstance(value, bytes)
|
||||
anonymous_user_enabled = int(value.decode("utf-8")) == 1
|
||||
else:
|
||||
# Default to False
|
||||
anonymous_user_enabled = False
|
||||
cache.set(OnyxRedisLocks.ANONYMOUS_USER_ENABLED, "0", ex=SETTINGS_TTL)
|
||||
# Optionally store the default back to Redis
|
||||
redis_client.set(
|
||||
OnyxRedisLocks.ANONYMOUS_USER_ENABLED, "0", ex=SETTINGS_TTL
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading anonymous user setting from cache: {str(e)}")
|
||||
# Log the error and reset to default
|
||||
logger.error(f"Error loading anonymous user setting from Redis: {str(e)}")
|
||||
anonymous_user_enabled = False
|
||||
|
||||
settings.anonymous_user_enabled = anonymous_user_enabled
|
||||
settings.query_history_type = ONYX_QUERY_HISTORY_TYPE
|
||||
|
||||
# Override user knowledge setting if disabled via environment variable
|
||||
if DISABLE_USER_KNOWLEDGE:
|
||||
settings.user_knowledge_enabled = False
|
||||
|
||||
@@ -56,10 +66,11 @@ def load_settings() -> Settings:
|
||||
|
||||
|
||||
def store_settings(settings: Settings) -> None:
|
||||
cache = get_cache_backend()
|
||||
tenant_id = get_current_tenant_id() if MULTI_TENANT else None
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
if settings.anonymous_user_enabled is not None:
|
||||
cache.set(
|
||||
redis_client.set(
|
||||
OnyxRedisLocks.ANONYMOUS_USER_ENABLED,
|
||||
"1" if settings.anonymous_user_enabled else "0",
|
||||
ex=SETTINGS_TTL,
|
||||
|
||||
@@ -25,7 +25,6 @@ from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from onyx.db.index_attempt import expire_index_attempts
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
@@ -255,18 +254,14 @@ def setup_postgres(db_session: Session) -> None:
|
||||
logger.notice("Setting up default OpenAI LLM for dev.")
|
||||
|
||||
llm_model = GEN_AI_MODEL_VERSION or "gpt-4o-mini"
|
||||
provider_name = "DevEnvPresetOpenAI"
|
||||
existing = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
model_req = LLMProviderUpsertRequest(
|
||||
id=existing.id if existing else None,
|
||||
name=provider_name,
|
||||
name="DevEnvPresetOpenAI",
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=GEN_AI_API_KEY,
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
custom_config=None,
|
||||
default_model_name=llm_model,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
model_configurations=[
|
||||
@@ -278,9 +273,7 @@ def setup_postgres(db_session: Session) -> None:
|
||||
new_llm_provider = upsert_llm_provider(
|
||||
llm_provider_upsert_request=model_req, db_session=db_session
|
||||
)
|
||||
update_default_provider(
|
||||
provider_id=new_llm_provider.id, model_name=llm_model, db_session=db_session
|
||||
)
|
||||
update_default_provider(provider_id=new_llm_provider.id, db_session=db_session)
|
||||
|
||||
|
||||
def update_default_multipass_indexing(db_session: Session) -> None:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Literal
|
||||
from typing import TypedDict
|
||||
@@ -13,9 +12,6 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_HEALTH_CACHE_TTL_SECONDS = 30
|
||||
_health_cache: dict[str, tuple[float, bool]] = {}
|
||||
|
||||
|
||||
class FileInput(TypedDict):
|
||||
"""Input file to be staged in execution workspace"""
|
||||
@@ -102,32 +98,16 @@ class CodeInterpreterClient:
|
||||
payload["files"] = files
|
||||
return payload
|
||||
|
||||
def health(self, use_cache: bool = False) -> bool:
|
||||
"""Check if the Code Interpreter service is healthy
|
||||
|
||||
Args:
|
||||
use_cache: When True, return a cached result if available and
|
||||
within the TTL window. The cache is always populated
|
||||
after a live request regardless of this flag.
|
||||
"""
|
||||
if use_cache:
|
||||
cached = _health_cache.get(self.base_url)
|
||||
if cached is not None:
|
||||
cached_at, cached_result = cached
|
||||
if time.monotonic() - cached_at < _HEALTH_CACHE_TTL_SECONDS:
|
||||
return cached_result
|
||||
|
||||
def health(self) -> bool:
|
||||
"""Check if the Code Interpreter service is healthy"""
|
||||
url = f"{self.base_url}/health"
|
||||
try:
|
||||
response = self.session.get(url, timeout=5)
|
||||
response.raise_for_status()
|
||||
result = response.json().get("status") == "ok"
|
||||
return response.json().get("status") == "ok"
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception caught when checking health, e={e}")
|
||||
result = False
|
||||
|
||||
_health_cache[self.base_url] = (time.monotonic(), result)
|
||||
return result
|
||||
return False
|
||||
|
||||
def execute(
|
||||
self,
|
||||
|
||||
@@ -107,11 +107,7 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
if not CODE_INTERPRETER_BASE_URL:
|
||||
return False
|
||||
server = fetch_code_interpreter_server(db_session)
|
||||
if not server.server_enabled:
|
||||
return False
|
||||
|
||||
client = CodeInterpreterClient()
|
||||
return client.health(use_cache=True)
|
||||
return server.server_enabled
|
||||
|
||||
def tool_definition(self) -> dict:
|
||||
return {
|
||||
|
||||
@@ -8,3 +8,37 @@ dependencies = [
|
||||
|
||||
[tool.uv.sources]
|
||||
onyx = { workspace = true }
|
||||
|
||||
[tool.mypy]
|
||||
plugins = "sqlalchemy.ext.mypy.plugin"
|
||||
mypy_path = "backend"
|
||||
explicit_package_bases = true
|
||||
disallow_untyped_defs = true
|
||||
warn_unused_ignores = true
|
||||
enable_error_code = ["possibly-undefined"]
|
||||
strict_equality = true
|
||||
# Patterns match paths whether mypy is run from backend/ (CI) or repo root (e.g. VS Code extension with target ./backend)
|
||||
exclude = [
|
||||
"(?:^|/)generated/",
|
||||
"(?:^|/)\\.venv/",
|
||||
"(?:^|/)onyx/server/features/build/sandbox/kubernetes/docker/skills/",
|
||||
"(?:^|/)onyx/server/features/build/sandbox/kubernetes/docker/templates/",
|
||||
]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "alembic.versions.*"
|
||||
disable_error_code = ["var-annotated"]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "alembic_tenants.versions.*"
|
||||
disable_error_code = ["var-annotated"]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "generated.*"
|
||||
follow_imports = "silent"
|
||||
ignore_errors = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "transformers.*"
|
||||
follow_imports = "skip"
|
||||
ignore_errors = true
|
||||
|
||||
@@ -528,7 +528,7 @@ lxml==5.3.0
|
||||
# unstructured
|
||||
# xmlsec
|
||||
# zeep
|
||||
lxml-html-clean==0.4.4
|
||||
lxml-html-clean==0.4.3
|
||||
# via lxml
|
||||
magika==0.6.3
|
||||
# via markitdown
|
||||
@@ -809,7 +809,7 @@ pypandoc-binary==1.16.2
|
||||
# via onyx
|
||||
pyparsing==3.2.5
|
||||
# via httplib2
|
||||
pypdf==6.7.5
|
||||
pypdf==6.7.3
|
||||
# via
|
||||
# onyx
|
||||
# unstructured-client
|
||||
|
||||
@@ -17,7 +17,7 @@ def test_bedrock_llm_configuration(client: TestClient) -> None:
|
||||
# Prepare the test request payload
|
||||
test_request: dict[str, Any] = {
|
||||
"provider": LlmProviderNames.BEDROCK,
|
||||
"model": _DEFAULT_BEDROCK_MODEL,
|
||||
"default_model_name": _DEFAULT_BEDROCK_MODEL,
|
||||
"api_key": None,
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
@@ -44,7 +44,7 @@ def test_bedrock_llm_configuration_invalid_key(client: TestClient) -> None:
|
||||
# Prepare the test request payload with invalid credentials
|
||||
test_request: dict[str, Any] = {
|
||||
"provider": LlmProviderNames.BEDROCK,
|
||||
"model": _DEFAULT_BEDROCK_MODEL,
|
||||
"default_model_name": _DEFAULT_BEDROCK_MODEL,
|
||||
"api_key": None,
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
|
||||
@@ -28,6 +28,7 @@ def ensure_default_llm_provider(db_session: Session) -> None:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=os.environ.get("OPENAI_API_KEY", "test"),
|
||||
is_public=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini",
|
||||
@@ -40,7 +41,7 @@ def ensure_default_llm_provider(db_session: Session) -> None:
|
||||
llm_provider_upsert_request=llm_provider_request,
|
||||
db_session=db_session,
|
||||
)
|
||||
update_default_provider(provider.id, "gpt-4o-mini", db_session)
|
||||
update_default_provider(provider.id, db_session)
|
||||
except Exception as exc: # pragma: no cover - only hits on duplicate setup issues
|
||||
# Rollback to clear the pending transaction state
|
||||
db_session.rollback()
|
||||
|
||||
@@ -47,6 +47,7 @@ def test_answer_with_only_anthropic_provider(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.ANTHROPIC,
|
||||
api_key=anthropic_api_key,
|
||||
default_model_name=anthropic_model,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
model_configurations=[
|
||||
@@ -58,7 +59,7 @@ def test_answer_with_only_anthropic_provider(
|
||||
)
|
||||
|
||||
try:
|
||||
update_default_provider(anthropic_provider.id, anthropic_model, db_session)
|
||||
update_default_provider(anthropic_provider.id, db_session)
|
||||
|
||||
test_user = create_test_user(db_session, email_prefix="anthropic_only")
|
||||
chat_session = create_chat_session(
|
||||
|
||||
@@ -9,6 +9,7 @@ from onyx.chat.models import AnswerStreamPart
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.process_message import handle_stream_message_objects
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.models import RecencyBiasSetting
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
@@ -73,6 +74,10 @@ def test_stream_chat_message_objects_without_web_search(
|
||||
user=None, # System persona
|
||||
name=f"Test Persona {uuid.uuid4()}",
|
||||
description="Test persona with no tools for web search test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=True,
|
||||
recency_bias=RecencyBiasSetting.BASE_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
|
||||
@@ -1,257 +0,0 @@
|
||||
"""External dependency unit tests for periodic task claiming.
|
||||
|
||||
Tests ``_try_claim_task`` and ``_try_run_periodic_task`` against real
|
||||
PostgreSQL, verifying happy-path behavior and concurrent-access safety.
|
||||
|
||||
The claim mechanism uses a transaction-scoped advisory lock + a KVStore
|
||||
timestamp for cross-instance dedup. The DB session is released before
|
||||
the task runs, so long-running tasks don't hold connections.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.background.periodic_poller import _PeriodicTaskDef
|
||||
from onyx.background.periodic_poller import _try_claim_task
|
||||
from onyx.background.periodic_poller import _try_run_periodic_task
|
||||
from onyx.background.periodic_poller import PERIODIC_TASK_KV_PREFIX
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.models import KVStore
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
_TEST_LOCK_BASE = 90_000
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def _init_engine() -> None:
|
||||
SqlEngine.init_engine(pool_size=10, max_overflow=5)
|
||||
|
||||
|
||||
def _make_task(
|
||||
*,
|
||||
name: str | None = None,
|
||||
interval: float = 3600,
|
||||
lock_id: int | None = None,
|
||||
run_fn: MagicMock | None = None,
|
||||
) -> _PeriodicTaskDef:
|
||||
return _PeriodicTaskDef(
|
||||
name=name if name is not None else f"test-{uuid4().hex[:8]}",
|
||||
interval_seconds=interval,
|
||||
lock_id=lock_id if lock_id is not None else _TEST_LOCK_BASE,
|
||||
run_fn=run_fn if run_fn is not None else MagicMock(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _cleanup_kv(
|
||||
tenant_context: None, # noqa: ARG001
|
||||
) -> Generator[None, None, None]:
|
||||
yield
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_session.query(KVStore).filter(
|
||||
KVStore.key.like(f"{PERIODIC_TASK_KV_PREFIX}test-%")
|
||||
).delete(synchronize_session=False)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Happy-path: _try_claim_task
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClaimHappyPath:
|
||||
def test_first_claim_succeeds(self) -> None:
|
||||
assert _try_claim_task(_make_task()) is True
|
||||
|
||||
def test_first_claim_creates_kv_row(self) -> None:
|
||||
task = _make_task()
|
||||
_try_claim_task(task)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
row = (
|
||||
db_session.query(KVStore)
|
||||
.filter_by(key=PERIODIC_TASK_KV_PREFIX + task.name)
|
||||
.first()
|
||||
)
|
||||
assert row is not None
|
||||
assert row.value is not None
|
||||
|
||||
def test_second_claim_within_interval_fails(self) -> None:
|
||||
task = _make_task(interval=3600)
|
||||
assert _try_claim_task(task) is True
|
||||
assert _try_claim_task(task) is False
|
||||
|
||||
def test_claim_after_interval_succeeds(self) -> None:
|
||||
task = _make_task(interval=1)
|
||||
assert _try_claim_task(task) is True
|
||||
|
||||
kv_key = PERIODIC_TASK_KV_PREFIX + task.name
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
row = db_session.query(KVStore).filter_by(key=kv_key).first()
|
||||
assert row is not None
|
||||
row.value = (datetime.now(timezone.utc) - timedelta(seconds=10)).isoformat()
|
||||
db_session.commit()
|
||||
|
||||
assert _try_claim_task(task) is True
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Happy-path: _try_run_periodic_task
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunHappyPath:
|
||||
def test_runs_task_and_updates_last_run_at(self) -> None:
|
||||
mock_fn = MagicMock()
|
||||
task = _make_task(run_fn=mock_fn)
|
||||
|
||||
_try_run_periodic_task(task)
|
||||
|
||||
mock_fn.assert_called_once()
|
||||
assert task.last_run_at > 0
|
||||
|
||||
def test_skips_when_in_memory_interval_not_elapsed(self) -> None:
|
||||
mock_fn = MagicMock()
|
||||
task = _make_task(run_fn=mock_fn, interval=3600)
|
||||
task.last_run_at = time.monotonic()
|
||||
|
||||
_try_run_periodic_task(task)
|
||||
|
||||
mock_fn.assert_not_called()
|
||||
|
||||
def test_skips_when_db_claim_blocked(self) -> None:
|
||||
name = f"test-{uuid4().hex[:8]}"
|
||||
lock_id = _TEST_LOCK_BASE + 10
|
||||
|
||||
_try_claim_task(_make_task(name=name, lock_id=lock_id, interval=3600))
|
||||
|
||||
mock_fn = MagicMock()
|
||||
task = _make_task(name=name, lock_id=lock_id, interval=3600, run_fn=mock_fn)
|
||||
_try_run_periodic_task(task)
|
||||
|
||||
mock_fn.assert_not_called()
|
||||
|
||||
def test_task_exception_does_not_propagate(self) -> None:
|
||||
task = _make_task(run_fn=MagicMock(side_effect=RuntimeError("boom")))
|
||||
_try_run_periodic_task(task)
|
||||
|
||||
def test_claim_committed_before_task_runs(self) -> None:
|
||||
"""The KV claim must be visible in the DB when run_fn executes."""
|
||||
task_name = f"test-order-{uuid4().hex[:8]}"
|
||||
kv_key = PERIODIC_TASK_KV_PREFIX + task_name
|
||||
claim_visible: list[bool] = []
|
||||
|
||||
def check_claim() -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
row = db_session.query(KVStore).filter_by(key=kv_key).first()
|
||||
claim_visible.append(row is not None and row.value is not None)
|
||||
|
||||
task = _PeriodicTaskDef(
|
||||
name=task_name,
|
||||
interval_seconds=3600,
|
||||
lock_id=_TEST_LOCK_BASE + 11,
|
||||
run_fn=check_claim,
|
||||
)
|
||||
|
||||
_try_run_periodic_task(task)
|
||||
|
||||
assert claim_visible == [True]
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Concurrency: only one claimer should win
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClaimConcurrency:
|
||||
def test_concurrent_claims_single_winner(self) -> None:
|
||||
"""Many threads claim the same task — exactly one should succeed."""
|
||||
num_threads = 20
|
||||
task_name = f"test-race-{uuid4().hex[:8]}"
|
||||
lock_id = _TEST_LOCK_BASE + 20
|
||||
|
||||
def claim() -> bool:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
|
||||
return _try_claim_task(
|
||||
_PeriodicTaskDef(
|
||||
name=task_name,
|
||||
interval_seconds=3600,
|
||||
lock_id=lock_id,
|
||||
run_fn=lambda: None,
|
||||
)
|
||||
)
|
||||
|
||||
results: list[bool] = []
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = [executor.submit(claim) for _ in range(num_threads)]
|
||||
for future in as_completed(futures):
|
||||
results.append(future.result())
|
||||
|
||||
winners = sum(1 for r in results if r)
|
||||
assert winners == 1, f"Expected 1 winner, got {winners}"
|
||||
|
||||
def test_concurrent_run_single_execution(self) -> None:
|
||||
"""Many threads run the same task — run_fn fires exactly once."""
|
||||
num_threads = 20
|
||||
task_name = f"test-run-race-{uuid4().hex[:8]}"
|
||||
lock_id = _TEST_LOCK_BASE + 21
|
||||
counter = MagicMock()
|
||||
|
||||
def run() -> None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
|
||||
_try_run_periodic_task(
|
||||
_PeriodicTaskDef(
|
||||
name=task_name,
|
||||
interval_seconds=3600,
|
||||
lock_id=lock_id,
|
||||
run_fn=counter,
|
||||
)
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = [executor.submit(run) for _ in range(num_threads)]
|
||||
for future in as_completed(futures):
|
||||
future.result()
|
||||
|
||||
assert (
|
||||
counter.call_count == 1
|
||||
), f"Expected run_fn called once, got {counter.call_count}"
|
||||
|
||||
def test_no_errors_under_contention(self) -> None:
|
||||
"""All threads complete without exceptions under high contention."""
|
||||
num_threads = 30
|
||||
task_name = f"test-err-{uuid4().hex[:8]}"
|
||||
lock_id = _TEST_LOCK_BASE + 22
|
||||
errors: list[Exception] = []
|
||||
|
||||
def claim() -> bool:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
|
||||
return _try_claim_task(
|
||||
_PeriodicTaskDef(
|
||||
name=task_name,
|
||||
interval_seconds=3600,
|
||||
lock_id=lock_id,
|
||||
run_fn=lambda: None,
|
||||
)
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = [executor.submit(claim) for _ in range(num_threads)]
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
future.result()
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
assert errors == [], f"Got {len(errors)} errors: {errors}"
|
||||
@@ -1,352 +0,0 @@
|
||||
"""External dependency unit tests for startup recovery (Step 10g).
|
||||
|
||||
Seeds ``UserFile`` records in stuck states (PROCESSING, DELETING,
|
||||
needs_project_sync) then calls ``recover_stuck_user_files`` and verifies
|
||||
the drain loops pick them up via ``FOR UPDATE SKIP LOCKED``.
|
||||
|
||||
Uses real PostgreSQL (via ``db_session`` / ``tenant_context`` fixtures).
|
||||
The per-file ``*_impl`` functions are mocked so no real file store or
|
||||
connector is needed — we only verify that recovery finds and dispatches
|
||||
the correct files.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.periodic_poller import recover_stuck_user_files
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import UserFile
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_IMPL_MODULE = "onyx.background.celery.tasks.user_file_processing.tasks"
|
||||
|
||||
|
||||
def _create_user_file(
|
||||
db_session: Session,
|
||||
user_id: object,
|
||||
*,
|
||||
status: UserFileStatus = UserFileStatus.PROCESSING,
|
||||
needs_project_sync: bool = False,
|
||||
needs_persona_sync: bool = False,
|
||||
) -> UserFile:
|
||||
uf = UserFile(
|
||||
id=uuid4(),
|
||||
user_id=user_id,
|
||||
file_id=f"test_file_{uuid4().hex[:8]}",
|
||||
name=f"test_{uuid4().hex[:8]}.txt",
|
||||
file_type="text/plain",
|
||||
status=status,
|
||||
needs_project_sync=needs_project_sync,
|
||||
needs_persona_sync=needs_persona_sync,
|
||||
)
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
db_session.refresh(uf)
|
||||
return uf
|
||||
|
||||
|
||||
def _fake_delete_impl(
|
||||
user_file_id: str, tenant_id: str, redis_locking: bool # noqa: ARG001
|
||||
) -> None:
|
||||
"""Mock side-effect: delete the row so the drain loop terminates."""
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as session:
|
||||
session.execute(sa.delete(UserFile).where(UserFile.id == UUID(user_file_id)))
|
||||
session.commit()
|
||||
|
||||
|
||||
def _fake_sync_impl(
|
||||
user_file_id: str, tenant_id: str, redis_locking: bool # noqa: ARG001
|
||||
) -> None:
|
||||
"""Mock side-effect: clear sync flags so the drain loop terminates."""
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as session:
|
||||
session.execute(
|
||||
sa.update(UserFile)
|
||||
.where(UserFile.id == UUID(user_file_id))
|
||||
.values(needs_project_sync=False, needs_persona_sync=False)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _cleanup_user_files(db_session: Session) -> Generator[list[UserFile], None, None]:
|
||||
"""Track created UserFile rows and delete them after each test."""
|
||||
created: list[UserFile] = []
|
||||
yield created
|
||||
for uf in created:
|
||||
existing = db_session.get(UserFile, uf.id)
|
||||
if existing:
|
||||
db_session.delete(existing)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRecoverProcessingFiles:
|
||||
"""Files in PROCESSING status are re-processed via the processing drain loop."""
|
||||
|
||||
def test_processing_files_recovered(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_proc")
|
||||
uf = _create_user_file(db_session, user.id, status=UserFileStatus.PROCESSING)
|
||||
_cleanup_user_files.append(uf)
|
||||
|
||||
mock_impl = MagicMock()
|
||||
with patch(f"{_IMPL_MODULE}.process_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert (
|
||||
str(uf.id) in called_ids
|
||||
), f"Expected file {uf.id} to be recovered but got: {called_ids}"
|
||||
|
||||
def test_completed_files_not_recovered(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_comp")
|
||||
uf = _create_user_file(db_session, user.id, status=UserFileStatus.COMPLETED)
|
||||
_cleanup_user_files.append(uf)
|
||||
|
||||
mock_impl = MagicMock()
|
||||
with patch(f"{_IMPL_MODULE}.process_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert (
|
||||
str(uf.id) not in called_ids
|
||||
), f"COMPLETED file {uf.id} should not have been recovered"
|
||||
|
||||
|
||||
class TestRecoverDeletingFiles:
|
||||
"""Files in DELETING status are recovered via the delete drain loop."""
|
||||
|
||||
def test_deleting_files_recovered(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_del")
|
||||
uf = _create_user_file(db_session, user.id, status=UserFileStatus.DELETING)
|
||||
# Row is deleted by _fake_delete_impl, so no cleanup needed.
|
||||
|
||||
mock_impl = MagicMock(side_effect=_fake_delete_impl)
|
||||
with patch(f"{_IMPL_MODULE}.delete_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert (
|
||||
str(uf.id) in called_ids
|
||||
), f"Expected file {uf.id} to be recovered for deletion but got: {called_ids}"
|
||||
|
||||
|
||||
class TestRecoverSyncFiles:
|
||||
"""Files needing project/persona sync are recovered via the sync drain loop."""
|
||||
|
||||
def test_needs_project_sync_recovered(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_sync")
|
||||
uf = _create_user_file(
|
||||
db_session,
|
||||
user.id,
|
||||
status=UserFileStatus.COMPLETED,
|
||||
needs_project_sync=True,
|
||||
)
|
||||
_cleanup_user_files.append(uf)
|
||||
|
||||
mock_impl = MagicMock(side_effect=_fake_sync_impl)
|
||||
with patch(f"{_IMPL_MODULE}.project_sync_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert (
|
||||
str(uf.id) in called_ids
|
||||
), f"Expected file {uf.id} to be recovered for sync but got: {called_ids}"
|
||||
|
||||
def test_needs_persona_sync_recovered(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_psync")
|
||||
uf = _create_user_file(
|
||||
db_session,
|
||||
user.id,
|
||||
status=UserFileStatus.COMPLETED,
|
||||
needs_persona_sync=True,
|
||||
)
|
||||
_cleanup_user_files.append(uf)
|
||||
|
||||
mock_impl = MagicMock(side_effect=_fake_sync_impl)
|
||||
with patch(f"{_IMPL_MODULE}.project_sync_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert (
|
||||
str(uf.id) in called_ids
|
||||
), f"Expected file {uf.id} to be recovered for persona sync but got: {called_ids}"
|
||||
|
||||
|
||||
class TestRecoveryMultipleFiles:
|
||||
"""Recovery processes all stuck files in one pass, not just the first."""
|
||||
|
||||
def test_multiple_processing_files(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_multi")
|
||||
files = []
|
||||
for _ in range(3):
|
||||
uf = _create_user_file(
|
||||
db_session, user.id, status=UserFileStatus.PROCESSING
|
||||
)
|
||||
_cleanup_user_files.append(uf)
|
||||
files.append(uf)
|
||||
|
||||
mock_impl = MagicMock()
|
||||
with patch(f"{_IMPL_MODULE}.process_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = {call.kwargs["user_file_id"] for call in mock_impl.call_args_list}
|
||||
expected_ids = {str(uf.id) for uf in files}
|
||||
assert expected_ids.issubset(called_ids), (
|
||||
f"Expected all {len(files)} files to be recovered. "
|
||||
f"Missing: {expected_ids - called_ids}"
|
||||
)
|
||||
|
||||
|
||||
class TestTransientFailures:
|
||||
"""Drain loops skip failed files, process the rest, and terminate."""
|
||||
|
||||
def test_processing_failure_skips_and_continues(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "fail_proc")
|
||||
uf_fail = _create_user_file(
|
||||
db_session, user.id, status=UserFileStatus.PROCESSING
|
||||
)
|
||||
uf_ok = _create_user_file(db_session, user.id, status=UserFileStatus.PROCESSING)
|
||||
_cleanup_user_files.extend([uf_fail, uf_ok])
|
||||
|
||||
fail_id = str(uf_fail.id)
|
||||
|
||||
def side_effect(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool # noqa: ARG001
|
||||
) -> None:
|
||||
if user_file_id == fail_id:
|
||||
raise RuntimeError("transient failure")
|
||||
|
||||
mock_impl = MagicMock(side_effect=side_effect)
|
||||
with patch(f"{_IMPL_MODULE}.process_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert fail_id in called_ids, "Failed file should have been attempted"
|
||||
assert str(uf_ok.id) in called_ids, "Healthy file should have been processed"
|
||||
assert called_ids.count(fail_id) == 1, "Failed file retried — infinite loop"
|
||||
assert called_ids.count(str(uf_ok.id)) == 1
|
||||
|
||||
def test_delete_failure_skips_and_continues(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "fail_del")
|
||||
uf_fail = _create_user_file(db_session, user.id, status=UserFileStatus.DELETING)
|
||||
uf_ok = _create_user_file(db_session, user.id, status=UserFileStatus.DELETING)
|
||||
_cleanup_user_files.append(uf_fail)
|
||||
|
||||
fail_id = str(uf_fail.id)
|
||||
|
||||
def side_effect(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool
|
||||
) -> None:
|
||||
if user_file_id == fail_id:
|
||||
raise RuntimeError("transient failure")
|
||||
_fake_delete_impl(user_file_id, tenant_id, redis_locking)
|
||||
|
||||
mock_impl = MagicMock(side_effect=side_effect)
|
||||
with patch(f"{_IMPL_MODULE}.delete_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert fail_id in called_ids, "Failed file should have been attempted"
|
||||
assert str(uf_ok.id) in called_ids, "Healthy file should have been deleted"
|
||||
assert called_ids.count(fail_id) == 1, "Failed file retried — infinite loop"
|
||||
assert called_ids.count(str(uf_ok.id)) == 1
|
||||
|
||||
def test_sync_failure_skips_and_continues(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "fail_sync")
|
||||
uf_fail = _create_user_file(
|
||||
db_session,
|
||||
user.id,
|
||||
status=UserFileStatus.COMPLETED,
|
||||
needs_project_sync=True,
|
||||
)
|
||||
uf_ok = _create_user_file(
|
||||
db_session,
|
||||
user.id,
|
||||
status=UserFileStatus.COMPLETED,
|
||||
needs_persona_sync=True,
|
||||
)
|
||||
_cleanup_user_files.extend([uf_fail, uf_ok])
|
||||
|
||||
fail_id = str(uf_fail.id)
|
||||
|
||||
def side_effect(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool
|
||||
) -> None:
|
||||
if user_file_id == fail_id:
|
||||
raise RuntimeError("transient failure")
|
||||
_fake_sync_impl(user_file_id, tenant_id, redis_locking)
|
||||
|
||||
mock_impl = MagicMock(side_effect=side_effect)
|
||||
with patch(f"{_IMPL_MODULE}.project_sync_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert fail_id in called_ids, "Failed file should have been attempted"
|
||||
assert str(uf_ok.id) in called_ids, "Healthy file should have been synced"
|
||||
assert called_ids.count(fail_id) == 1, "Failed file retried — infinite loop"
|
||||
assert called_ids.count(str(uf_ok.id)) == 1
|
||||
@@ -1,57 +0,0 @@
|
||||
"""Fixtures for cache backend tests.
|
||||
|
||||
Requires a running PostgreSQL instance (and Redis for parity tests).
|
||||
Run with::
|
||||
|
||||
python -m dotenv -f .vscode/.env run -- pytest tests/external_dependency_unit/cache/
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.postgres_backend import PostgresCacheBackend
|
||||
from onyx.cache.redis_backend import RedisCacheBackend
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def _init_db() -> Generator[None, None, None]:
|
||||
"""Initialize DB engine. Assumes Postgres has migrations applied (e.g. via docker compose)."""
|
||||
SqlEngine.init_engine(pool_size=5, max_overflow=2)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _tenant_context() -> Generator[None, None, None]:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pg_cache() -> PostgresCacheBackend:
|
||||
return PostgresCacheBackend(TEST_TENANT_ID)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def redis_cache() -> RedisCacheBackend:
|
||||
from onyx.redis.redis_pool import redis_pool
|
||||
|
||||
return RedisCacheBackend(redis_pool.get_client(TEST_TENANT_ID))
|
||||
|
||||
|
||||
@pytest.fixture(params=["postgres", "redis"], ids=["postgres", "redis"])
|
||||
def cache(
|
||||
request: pytest.FixtureRequest,
|
||||
pg_cache: PostgresCacheBackend,
|
||||
redis_cache: RedisCacheBackend,
|
||||
) -> CacheBackend:
|
||||
if request.param == "postgres":
|
||||
return pg_cache
|
||||
return redis_cache
|
||||
@@ -1,100 +0,0 @@
|
||||
"""Parameterized tests that run the same CacheBackend operations against
|
||||
both Redis and PostgreSQL, asserting identical return values.
|
||||
|
||||
Each test runs twice (once per backend) via the ``cache`` fixture defined
|
||||
in conftest.py.
|
||||
"""
|
||||
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.interface import TTL_KEY_NOT_FOUND
|
||||
from onyx.cache.interface import TTL_NO_EXPIRY
|
||||
|
||||
|
||||
def _key() -> str:
|
||||
return f"parity_{uuid4().hex[:12]}"
|
||||
|
||||
|
||||
class TestKVParity:
|
||||
def test_get_missing(self, cache: CacheBackend) -> None:
|
||||
assert cache.get(_key()) is None
|
||||
|
||||
def test_get_set(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"value")
|
||||
assert cache.get(k) == b"value"
|
||||
|
||||
def test_overwrite(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"a")
|
||||
cache.set(k, b"b")
|
||||
assert cache.get(k) == b"b"
|
||||
|
||||
def test_set_string(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, "hello")
|
||||
assert cache.get(k) == b"hello"
|
||||
|
||||
def test_set_int(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, 42)
|
||||
assert cache.get(k) == b"42"
|
||||
|
||||
def test_delete(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"x")
|
||||
cache.delete(k)
|
||||
assert cache.get(k) is None
|
||||
|
||||
def test_exists(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
assert not cache.exists(k)
|
||||
cache.set(k, b"x")
|
||||
assert cache.exists(k)
|
||||
|
||||
|
||||
class TestTTLParity:
|
||||
def test_ttl_missing(self, cache: CacheBackend) -> None:
|
||||
assert cache.ttl(_key()) == TTL_KEY_NOT_FOUND
|
||||
|
||||
def test_ttl_no_expiry(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"x")
|
||||
assert cache.ttl(k) == TTL_NO_EXPIRY
|
||||
|
||||
def test_ttl_remaining(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"x", ex=10)
|
||||
remaining = cache.ttl(k)
|
||||
assert 8 <= remaining <= 10
|
||||
|
||||
def test_set_with_ttl_expires(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"x", ex=1)
|
||||
assert cache.get(k) == b"x"
|
||||
time.sleep(1.5)
|
||||
assert cache.get(k) is None
|
||||
|
||||
|
||||
class TestLockParity:
|
||||
def test_acquire_release(self, cache: CacheBackend) -> None:
|
||||
lock = cache.lock(f"parity_lock_{uuid4().hex[:8]}")
|
||||
assert lock.acquire(blocking=False)
|
||||
assert lock.owned()
|
||||
lock.release()
|
||||
assert not lock.owned()
|
||||
|
||||
|
||||
class TestListParity:
|
||||
def test_rpush_blpop(self, cache: CacheBackend) -> None:
|
||||
k = f"parity_list_{uuid4().hex[:8]}"
|
||||
cache.rpush(k, b"item")
|
||||
result = cache.blpop([k], timeout=1)
|
||||
assert result is not None
|
||||
assert result[1] == b"item"
|
||||
|
||||
def test_blpop_timeout(self, cache: CacheBackend) -> None:
|
||||
result = cache.blpop([f"parity_empty_{uuid4().hex[:8]}"], timeout=1)
|
||||
assert result is None
|
||||
@@ -1,129 +0,0 @@
|
||||
"""Tests for PgRedisKVStore's cache layer integration with CacheBackend.
|
||||
|
||||
Verifies that the KV store correctly uses the CacheBackend for caching
|
||||
in front of PostgreSQL: cache hits, cache misses falling through to PG,
|
||||
cache population after PG reads, cache invalidation on delete, and
|
||||
graceful degradation when the cache backend raises.
|
||||
|
||||
Requires running PostgreSQL.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.postgres_backend import PostgresCacheBackend
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.models import CacheStore
|
||||
from onyx.db.models import KVStore
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.key_value_store.store import PgRedisKVStore
|
||||
from onyx.key_value_store.store import REDIS_KEY_PREFIX
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clean_kv() -> Generator[None, None, None]:
|
||||
yield
|
||||
with get_session_with_tenant(tenant_id=TEST_TENANT_ID) as session:
|
||||
session.execute(delete(KVStore))
|
||||
session.execute(delete(CacheStore))
|
||||
session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def kv_store(pg_cache: PostgresCacheBackend) -> PgRedisKVStore:
|
||||
return PgRedisKVStore(cache=pg_cache)
|
||||
|
||||
|
||||
class TestStoreAndLoad:
|
||||
def test_store_populates_cache_and_pg(
|
||||
self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend
|
||||
) -> None:
|
||||
kv_store.store("k1", {"hello": "world"})
|
||||
|
||||
cached = pg_cache.get(REDIS_KEY_PREFIX + "k1")
|
||||
assert cached is not None
|
||||
assert json.loads(cached) == {"hello": "world"}
|
||||
|
||||
loaded = kv_store.load("k1")
|
||||
assert loaded == {"hello": "world"}
|
||||
|
||||
def test_load_returns_cached_value_without_pg_hit(
|
||||
self, pg_cache: PostgresCacheBackend
|
||||
) -> None:
|
||||
"""If the cache already has the value, PG should not be queried."""
|
||||
pg_cache.set(REDIS_KEY_PREFIX + "cached_only", json.dumps({"from": "cache"}))
|
||||
kv = PgRedisKVStore(cache=pg_cache)
|
||||
assert kv.load("cached_only") == {"from": "cache"}
|
||||
|
||||
def test_load_falls_through_to_pg_on_cache_miss(
|
||||
self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend
|
||||
) -> None:
|
||||
kv_store.store("k2", [1, 2, 3])
|
||||
|
||||
pg_cache.delete(REDIS_KEY_PREFIX + "k2")
|
||||
assert pg_cache.get(REDIS_KEY_PREFIX + "k2") is None
|
||||
|
||||
loaded = kv_store.load("k2")
|
||||
assert loaded == [1, 2, 3]
|
||||
|
||||
repopulated = pg_cache.get(REDIS_KEY_PREFIX + "k2")
|
||||
assert repopulated is not None
|
||||
assert json.loads(repopulated) == [1, 2, 3]
|
||||
|
||||
def test_load_with_refresh_cache_skips_cache(
|
||||
self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend
|
||||
) -> None:
|
||||
kv_store.store("k3", "original")
|
||||
|
||||
pg_cache.set(REDIS_KEY_PREFIX + "k3", json.dumps("stale"))
|
||||
|
||||
loaded = kv_store.load("k3", refresh_cache=True)
|
||||
assert loaded == "original"
|
||||
|
||||
|
||||
class TestDelete:
|
||||
def test_delete_removes_from_cache_and_pg(
|
||||
self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend
|
||||
) -> None:
|
||||
kv_store.store("del_me", "bye")
|
||||
kv_store.delete("del_me")
|
||||
|
||||
assert pg_cache.get(REDIS_KEY_PREFIX + "del_me") is None
|
||||
|
||||
with pytest.raises(KvKeyNotFoundError):
|
||||
kv_store.load("del_me")
|
||||
|
||||
def test_delete_missing_key_raises(self, kv_store: PgRedisKVStore) -> None:
|
||||
with pytest.raises(KvKeyNotFoundError):
|
||||
kv_store.delete("nonexistent")
|
||||
|
||||
|
||||
class TestCacheFailureGracefulDegradation:
|
||||
def test_store_succeeds_when_cache_set_raises(self) -> None:
|
||||
failing_cache = MagicMock(spec=CacheBackend)
|
||||
failing_cache.set.side_effect = ConnectionError("cache down")
|
||||
|
||||
kv = PgRedisKVStore(cache=failing_cache)
|
||||
kv.store("resilient", {"data": True})
|
||||
|
||||
working_cache = MagicMock(spec=CacheBackend)
|
||||
working_cache.get.return_value = None
|
||||
kv_reader = PgRedisKVStore(cache=working_cache)
|
||||
loaded = kv_reader.load("resilient")
|
||||
assert loaded == {"data": True}
|
||||
|
||||
def test_load_falls_through_when_cache_get_raises(self) -> None:
|
||||
failing_cache = MagicMock(spec=CacheBackend)
|
||||
failing_cache.get.side_effect = ConnectionError("cache down")
|
||||
failing_cache.set.side_effect = ConnectionError("cache down")
|
||||
|
||||
kv = PgRedisKVStore(cache=failing_cache)
|
||||
kv.store("survive", 42)
|
||||
loaded = kv.load("survive")
|
||||
assert loaded == 42
|
||||
@@ -1,229 +0,0 @@
|
||||
"""Tests for PostgresCacheBackend against real PostgreSQL.
|
||||
|
||||
Covers every method on the backend: KV CRUD, TTL behaviour, advisory
|
||||
locks (acquire / release / contention), list operations (rpush / blpop),
|
||||
and the periodic cleanup function.
|
||||
"""
|
||||
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from onyx.cache.interface import TTL_KEY_NOT_FOUND
|
||||
from onyx.cache.interface import TTL_NO_EXPIRY
|
||||
from onyx.cache.postgres_backend import cleanup_expired_cache_entries
|
||||
from onyx.cache.postgres_backend import PostgresCacheBackend
|
||||
from onyx.db.models import CacheStore
|
||||
|
||||
|
||||
def _key() -> str:
|
||||
return f"test_{uuid4().hex[:12]}"
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Basic KV
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestKV:
|
||||
def test_get_set(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"hello")
|
||||
assert pg_cache.get(k) == b"hello"
|
||||
|
||||
def test_get_missing(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
assert pg_cache.get(_key()) is None
|
||||
|
||||
def test_set_overwrite(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"first")
|
||||
pg_cache.set(k, b"second")
|
||||
assert pg_cache.get(k) == b"second"
|
||||
|
||||
def test_set_string_value(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, "string_val")
|
||||
assert pg_cache.get(k) == b"string_val"
|
||||
|
||||
def test_set_int_value(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, 42)
|
||||
assert pg_cache.get(k) == b"42"
|
||||
|
||||
def test_delete(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"to_delete")
|
||||
pg_cache.delete(k)
|
||||
assert pg_cache.get(k) is None
|
||||
|
||||
def test_delete_missing_is_noop(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
pg_cache.delete(_key())
|
||||
|
||||
def test_exists(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
assert not pg_cache.exists(k)
|
||||
pg_cache.set(k, b"x")
|
||||
assert pg_cache.exists(k)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# TTL
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTTL:
|
||||
def test_set_with_ttl_expires(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"ephemeral", ex=1)
|
||||
assert pg_cache.get(k) == b"ephemeral"
|
||||
time.sleep(1.5)
|
||||
assert pg_cache.get(k) is None
|
||||
|
||||
def test_ttl_no_expiry(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"forever")
|
||||
assert pg_cache.ttl(k) == TTL_NO_EXPIRY
|
||||
|
||||
def test_ttl_missing_key(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
assert pg_cache.ttl(_key()) == TTL_KEY_NOT_FOUND
|
||||
|
||||
def test_ttl_remaining(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"x", ex=10)
|
||||
remaining = pg_cache.ttl(k)
|
||||
assert 8 <= remaining <= 10
|
||||
|
||||
def test_ttl_expired_key(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"x", ex=1)
|
||||
time.sleep(1.5)
|
||||
assert pg_cache.ttl(k) == TTL_KEY_NOT_FOUND
|
||||
|
||||
def test_expire_adds_ttl(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"x")
|
||||
assert pg_cache.ttl(k) == TTL_NO_EXPIRY
|
||||
pg_cache.expire(k, 10)
|
||||
assert 8 <= pg_cache.ttl(k) <= 10
|
||||
|
||||
def test_exists_respects_ttl(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"x", ex=1)
|
||||
assert pg_cache.exists(k)
|
||||
time.sleep(1.5)
|
||||
assert not pg_cache.exists(k)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Locks
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLock:
|
||||
def test_acquire_release(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
lock = pg_cache.lock(f"lock_{uuid4().hex[:8]}")
|
||||
assert lock.acquire(blocking=False)
|
||||
assert lock.owned()
|
||||
lock.release()
|
||||
assert not lock.owned()
|
||||
|
||||
def test_contention(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
name = f"contention_{uuid4().hex[:8]}"
|
||||
lock1 = pg_cache.lock(name)
|
||||
lock2 = pg_cache.lock(name)
|
||||
|
||||
assert lock1.acquire(blocking=False)
|
||||
assert not lock2.acquire(blocking=False)
|
||||
|
||||
lock1.release()
|
||||
assert lock2.acquire(blocking=False)
|
||||
lock2.release()
|
||||
|
||||
def test_context_manager(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
with pg_cache.lock(f"ctx_{uuid4().hex[:8]}") as lock:
|
||||
assert lock.owned()
|
||||
assert not lock.owned()
|
||||
|
||||
def test_blocking_timeout(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
name = f"timeout_{uuid4().hex[:8]}"
|
||||
holder = pg_cache.lock(name)
|
||||
holder.acquire(blocking=False)
|
||||
|
||||
waiter = pg_cache.lock(name, timeout=0.3)
|
||||
start = time.monotonic()
|
||||
assert not waiter.acquire(blocking=True, blocking_timeout=0.3)
|
||||
elapsed = time.monotonic() - start
|
||||
assert elapsed >= 0.25
|
||||
|
||||
holder.release()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# List (rpush / blpop)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestList:
|
||||
def test_rpush_blpop(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = f"list_{uuid4().hex[:8]}"
|
||||
pg_cache.rpush(k, b"item1")
|
||||
result = pg_cache.blpop([k], timeout=1)
|
||||
assert result is not None
|
||||
assert result == (k.encode(), b"item1")
|
||||
|
||||
def test_blpop_timeout(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
result = pg_cache.blpop([f"empty_{uuid4().hex[:8]}"], timeout=1)
|
||||
assert result is None
|
||||
|
||||
def test_fifo_order(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = f"fifo_{uuid4().hex[:8]}"
|
||||
pg_cache.rpush(k, b"first")
|
||||
time.sleep(0.01)
|
||||
pg_cache.rpush(k, b"second")
|
||||
|
||||
r1 = pg_cache.blpop([k], timeout=1)
|
||||
r2 = pg_cache.blpop([k], timeout=1)
|
||||
assert r1 is not None and r1[1] == b"first"
|
||||
assert r2 is not None and r2[1] == b"second"
|
||||
|
||||
def test_multiple_keys(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k1 = f"mk1_{uuid4().hex[:8]}"
|
||||
k2 = f"mk2_{uuid4().hex[:8]}"
|
||||
pg_cache.rpush(k2, b"from_k2")
|
||||
|
||||
result = pg_cache.blpop([k1, k2], timeout=1)
|
||||
assert result is not None
|
||||
assert result == (k2.encode(), b"from_k2")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Cleanup
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCleanup:
|
||||
def test_removes_expired_rows(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
k = _key()
|
||||
pg_cache.set(k, b"stale", ex=1)
|
||||
time.sleep(1.5)
|
||||
cleanup_expired_cache_entries()
|
||||
|
||||
stmt = select(CacheStore.key).where(CacheStore.key == k)
|
||||
with get_session_with_current_tenant() as session:
|
||||
row = session.execute(stmt).first()
|
||||
assert row is None, "expired row should be physically deleted"
|
||||
|
||||
def test_preserves_unexpired_rows(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"fresh", ex=300)
|
||||
cleanup_expired_cache_entries()
|
||||
assert pg_cache.get(k) == b"fresh"
|
||||
|
||||
def test_preserves_no_ttl_rows(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"permanent")
|
||||
cleanup_expired_cache_entries()
|
||||
assert pg_cache.get(k) == b"permanent"
|
||||
@@ -36,6 +36,7 @@ from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
user_file_project_sync_lock_key,
|
||||
)
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserFile
|
||||
@@ -85,6 +86,12 @@ def _create_test_persona(
|
||||
persona = Persona(
|
||||
name=f"Test Persona {uuid4().hex[:8]}",
|
||||
description="Test persona",
|
||||
num_chunks=10.0,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
system_prompt="You are a test assistant",
|
||||
task_prompt="Answer the question",
|
||||
tools=[],
|
||||
@@ -403,6 +410,10 @@ class TestUpsertPersonaMarksSyncFlag:
|
||||
user=user,
|
||||
name=f"persona-{uuid4().hex[:8]}",
|
||||
description="test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
@@ -431,6 +442,10 @@ class TestUpsertPersonaMarksSyncFlag:
|
||||
user=user,
|
||||
name=f"persona-{uuid4().hex[:8]}",
|
||||
description="test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
@@ -446,11 +461,16 @@ class TestUpsertPersonaMarksSyncFlag:
|
||||
uf_old.needs_persona_sync = False
|
||||
db_session.commit()
|
||||
|
||||
assert persona.num_chunks is not None
|
||||
# Now update the persona to swap files
|
||||
upsert_persona(
|
||||
user=user,
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=persona.num_chunks,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=persona.recency_bias,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
@@ -481,6 +501,10 @@ class TestUpsertPersonaMarksSyncFlag:
|
||||
user=user,
|
||||
name=f"persona-{uuid4().hex[:8]}",
|
||||
description="test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
@@ -495,10 +519,15 @@ class TestUpsertPersonaMarksSyncFlag:
|
||||
uf.needs_persona_sync = False
|
||||
db_session.commit()
|
||||
|
||||
assert persona.num_chunks is not None
|
||||
upsert_persona(
|
||||
user=user,
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=persona.num_chunks,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=persona.recency_bias,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
|
||||
@@ -18,6 +18,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserFile
|
||||
@@ -57,6 +58,12 @@ def _create_persona(db_session: Session, user: User) -> Persona:
|
||||
persona = Persona(
|
||||
name=f"Test Persona {uuid4().hex[:8]}",
|
||||
description="Test persona",
|
||||
num_chunks=10.0,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
system_prompt="test",
|
||||
task_prompt="test",
|
||||
tools=[],
|
||||
|
||||
@@ -29,7 +29,6 @@ from onyx.server.manage.llm.api import (
|
||||
test_llm_configuration as run_test_llm_configuration,
|
||||
)
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.manage.llm.models import TestLLMRequest as LLMTestRequest
|
||||
|
||||
@@ -45,14 +44,15 @@ def _create_test_provider(
|
||||
db_session: Session,
|
||||
name: str,
|
||||
api_key: str = "sk-test-key-00000000000000000000000000000000000",
|
||||
) -> LLMProviderView:
|
||||
) -> None:
|
||||
"""Helper to create a test LLM provider in the database."""
|
||||
return upsert_llm_provider(
|
||||
upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=api_key,
|
||||
api_key_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o-mini", is_visible=True)
|
||||
],
|
||||
@@ -102,11 +102,17 @@ class TestLLMConfigurationEndpoint:
|
||||
# This should complete without exception
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
name=None, # New provider (not in DB)
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-new-test-key-0000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=False,
|
||||
model="gpt-4o-mini",
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -146,11 +152,17 @@ class TestLLMConfigurationEndpoint:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
name=None,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-invalid-key-00000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=False,
|
||||
model="gpt-4o-mini",
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -182,9 +194,7 @@ class TestLLMConfigurationEndpoint:
|
||||
|
||||
try:
|
||||
# First, create the provider in the database
|
||||
provider = _create_test_provider(
|
||||
db_session, provider_name, api_key=original_api_key
|
||||
)
|
||||
_create_test_provider(db_session, provider_name, api_key=original_api_key)
|
||||
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
|
||||
@@ -192,12 +202,17 @@ class TestLLMConfigurationEndpoint:
|
||||
# Test with api_key_changed=False - should use stored key
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
id=provider.id,
|
||||
name=provider_name, # Existing provider
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None, # Not providing a new key
|
||||
api_key_changed=False, # Using existing key
|
||||
custom_config_changed=False,
|
||||
model="gpt-4o-mini",
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -231,9 +246,7 @@ class TestLLMConfigurationEndpoint:
|
||||
|
||||
try:
|
||||
# First, create the provider in the database
|
||||
provider = _create_test_provider(
|
||||
db_session, provider_name, api_key=original_api_key
|
||||
)
|
||||
_create_test_provider(db_session, provider_name, api_key=original_api_key)
|
||||
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
|
||||
@@ -241,12 +254,17 @@ class TestLLMConfigurationEndpoint:
|
||||
# Test with api_key_changed=True - should use new key
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
id=provider.id,
|
||||
name=provider_name, # Existing provider
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=new_api_key, # Providing a new key
|
||||
api_key_changed=True, # Key is being changed
|
||||
custom_config_changed=False,
|
||||
model="gpt-4o-mini",
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -279,7 +297,7 @@ class TestLLMConfigurationEndpoint:
|
||||
|
||||
try:
|
||||
# First, create the provider in the database with custom_config
|
||||
provider = upsert_llm_provider(
|
||||
upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
@@ -287,6 +305,7 @@ class TestLLMConfigurationEndpoint:
|
||||
api_key_changed=True,
|
||||
custom_config=original_custom_config,
|
||||
custom_config_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
@@ -302,13 +321,18 @@ class TestLLMConfigurationEndpoint:
|
||||
# Test with custom_config_changed=False - should use stored config
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None,
|
||||
api_key_changed=False,
|
||||
custom_config=None, # Not providing new config
|
||||
custom_config_changed=False, # Using existing config
|
||||
model="gpt-4o-mini",
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -344,11 +368,17 @@ class TestLLMConfigurationEndpoint:
|
||||
for model_name in test_models:
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
name=None,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=False,
|
||||
model=model_name,
|
||||
default_model_name=model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -412,6 +442,7 @@ class TestDefaultProviderEndpoint:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=provider_1_api_key,
|
||||
api_key_changed=True,
|
||||
default_model_name=provider_1_initial_model,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True),
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
|
||||
@@ -421,7 +452,7 @@ class TestDefaultProviderEndpoint:
|
||||
)
|
||||
|
||||
# Set provider 1 as the default provider explicitly
|
||||
update_default_provider(provider_1.id, provider_1_initial_model, db_session)
|
||||
update_default_provider(provider_1.id, db_session)
|
||||
|
||||
# Step 2: Call run_test_default_provider - should use provider 1's default model
|
||||
with patch(
|
||||
@@ -441,6 +472,7 @@ class TestDefaultProviderEndpoint:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=provider_2_api_key,
|
||||
api_key_changed=True,
|
||||
default_model_name=provider_2_default_model,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
@@ -467,11 +499,11 @@ class TestDefaultProviderEndpoint:
|
||||
# Step 5: Update provider 1's default model
|
||||
upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
id=provider_1.id,
|
||||
name=provider_1_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=provider_1_api_key,
|
||||
api_key_changed=True,
|
||||
default_model_name=provider_1_updated_model, # Changed
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True),
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
|
||||
@@ -480,9 +512,6 @@ class TestDefaultProviderEndpoint:
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Set provider 1's default model to the updated model
|
||||
update_default_provider(provider_1.id, provider_1_updated_model, db_session)
|
||||
|
||||
# Step 6: Call run_test_default_provider - should use new model on provider 1
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
|
||||
@@ -495,7 +524,7 @@ class TestDefaultProviderEndpoint:
|
||||
captured_llms.clear()
|
||||
|
||||
# Step 7: Change the default provider to provider 2
|
||||
update_default_provider(provider_2.id, provider_2_default_model, db_session)
|
||||
update_default_provider(provider_2.id, db_session)
|
||||
|
||||
# Step 8: Call run_test_default_provider - should use provider 2
|
||||
with patch(
|
||||
@@ -567,6 +596,7 @@ class TestDefaultProviderEndpoint:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
@@ -575,7 +605,7 @@ class TestDefaultProviderEndpoint:
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
update_default_provider(provider.id, "gpt-4o-mini", db_session)
|
||||
update_default_provider(provider.id, db_session)
|
||||
|
||||
# Test should fail
|
||||
with patch(
|
||||
|
||||
@@ -49,6 +49,7 @@ def _create_test_provider(
|
||||
api_key_changed=True,
|
||||
api_base=api_base,
|
||||
custom_config=custom_config,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o-mini", is_visible=True)
|
||||
],
|
||||
@@ -90,14 +91,14 @@ class TestLLMProviderChanges:
|
||||
the API key should be blocked.
|
||||
"""
|
||||
try:
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
_create_test_provider(db_session, provider_name)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base="https://attacker.example.com",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -124,16 +125,16 @@ class TestLLMProviderChanges:
|
||||
Changing api_base IS allowed when the API key is also being changed.
|
||||
"""
|
||||
try:
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
_create_test_provider(db_session, provider_name)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-new-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
api_base="https://custom-endpoint.example.com/v1",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -158,16 +159,14 @@ class TestLLMProviderChanges:
|
||||
original_api_base = "https://original.example.com/v1"
|
||||
|
||||
try:
|
||||
provider = _create_test_provider(
|
||||
db_session, provider_name, api_base=original_api_base
|
||||
)
|
||||
_create_test_provider(db_session, provider_name, api_base=original_api_base)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base=original_api_base,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -191,14 +190,14 @@ class TestLLMProviderChanges:
|
||||
changes. This allows model-only updates when provider has no custom base URL.
|
||||
"""
|
||||
try:
|
||||
view = _create_test_provider(db_session, provider_name, api_base=None)
|
||||
_create_test_provider(db_session, provider_name, api_base=None)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=view.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base="",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -224,16 +223,14 @@ class TestLLMProviderChanges:
|
||||
original_api_base = "https://original.example.com/v1"
|
||||
|
||||
try:
|
||||
provider = _create_test_provider(
|
||||
db_session, provider_name, api_base=original_api_base
|
||||
)
|
||||
_create_test_provider(db_session, provider_name, api_base=original_api_base)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base=None,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -262,14 +259,14 @@ class TestLLMProviderChanges:
|
||||
users have full control over their deployment.
|
||||
"""
|
||||
try:
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
_create_test_provider(db_session, provider_name)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base="https://custom.example.com/v1",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -300,6 +297,7 @@ class TestLLMProviderChanges:
|
||||
api_key="sk-new-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
api_base="https://custom.example.com/v1",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -324,7 +322,7 @@ class TestLLMProviderChanges:
|
||||
redirect LLM API requests).
|
||||
"""
|
||||
try:
|
||||
provider = _create_test_provider(
|
||||
_create_test_provider(
|
||||
db_session,
|
||||
provider_name,
|
||||
custom_config={"SOME_CONFIG": "original_value"},
|
||||
@@ -332,11 +330,11 @@ class TestLLMProviderChanges:
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
custom_config={"OPENAI_API_BASE": "https://attacker.example.com"},
|
||||
custom_config_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -364,15 +362,15 @@ class TestLLMProviderChanges:
|
||||
without changing the API key.
|
||||
"""
|
||||
try:
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
_create_test_provider(db_session, provider_name)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
custom_config={"OPENAI_API_BASE": "https://attacker.example.com"},
|
||||
custom_config_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -401,7 +399,7 @@ class TestLLMProviderChanges:
|
||||
new_config = {"AWS_REGION_NAME": "us-west-2"}
|
||||
|
||||
try:
|
||||
provider = _create_test_provider(
|
||||
_create_test_provider(
|
||||
db_session,
|
||||
provider_name,
|
||||
custom_config={"AWS_REGION_NAME": "us-east-1"},
|
||||
@@ -409,13 +407,13 @@ class TestLLMProviderChanges:
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-new-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=True,
|
||||
custom_config=new_config,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -440,17 +438,17 @@ class TestLLMProviderChanges:
|
||||
original_config = {"AWS_REGION_NAME": "us-east-1"}
|
||||
|
||||
try:
|
||||
provider = _create_test_provider(
|
||||
_create_test_provider(
|
||||
db_session, provider_name, custom_config=original_config
|
||||
)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
custom_config=original_config,
|
||||
custom_config_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -476,7 +474,7 @@ class TestLLMProviderChanges:
|
||||
new_config = {"AWS_REGION_NAME": "eu-west-1"}
|
||||
|
||||
try:
|
||||
provider = _create_test_provider(
|
||||
_create_test_provider(
|
||||
db_session,
|
||||
provider_name,
|
||||
custom_config={"AWS_REGION_NAME": "us-east-1"},
|
||||
@@ -484,10 +482,10 @@ class TestLLMProviderChanges:
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
custom_config=new_config,
|
||||
default_model_name="gpt-4o-mini",
|
||||
custom_config_changed=True,
|
||||
)
|
||||
|
||||
@@ -532,8 +530,14 @@ def test_upload_with_custom_config_then_change(
|
||||
with patch("onyx.server.manage.llm.api.test_llm", side_effect=capture_test_llm):
|
||||
run_llm_config_test(
|
||||
LLMTestRequest(
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
model=default_model_name,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
api_key_changed=False,
|
||||
custom_config_changed=True,
|
||||
custom_config=custom_config,
|
||||
@@ -542,10 +546,11 @@ def test_upload_with_custom_config_then_change(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
provider = put_llm_provider(
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
default_model_name=default_model_name,
|
||||
custom_config=custom_config,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -564,9 +569,14 @@ def test_upload_with_custom_config_then_change(
|
||||
# Turn auto mode off
|
||||
run_llm_config_test(
|
||||
LLMTestRequest(
|
||||
id=provider.id,
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
model=default_model_name,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
api_key_changed=False,
|
||||
custom_config_changed=False,
|
||||
),
|
||||
@@ -576,9 +586,9 @@ def test_upload_with_custom_config_then_change(
|
||||
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
@@ -606,13 +616,13 @@ def test_upload_with_custom_config_then_change(
|
||||
)
|
||||
|
||||
# Check inside the database and check that custom_config is the same as the original
|
||||
db_provider = fetch_existing_llm_provider(name=name, db_session=db_session)
|
||||
if not db_provider:
|
||||
provider = fetch_existing_llm_provider(name=name, db_session=db_session)
|
||||
if not provider:
|
||||
assert False, "Provider not found in the database"
|
||||
|
||||
assert db_provider.custom_config == custom_config, (
|
||||
assert provider.custom_config == custom_config, (
|
||||
f"Expected custom_config {custom_config}, "
|
||||
f"but got {db_provider.custom_config}"
|
||||
f"but got {provider.custom_config}"
|
||||
)
|
||||
finally:
|
||||
db_session.rollback()
|
||||
@@ -632,10 +642,11 @@ def test_preserves_masked_sensitive_custom_config_on_provider_update(
|
||||
}
|
||||
|
||||
try:
|
||||
view = put_llm_provider(
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=provider,
|
||||
default_model_name=default_model_name,
|
||||
custom_config=original_custom_config,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -654,9 +665,9 @@ def test_preserves_masked_sensitive_custom_config_on_provider_update(
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
id=view.id,
|
||||
name=name,
|
||||
provider=provider,
|
||||
default_model_name=default_model_name,
|
||||
custom_config={
|
||||
"vertex_credentials": _mask_string(
|
||||
original_custom_config["vertex_credentials"]
|
||||
@@ -695,7 +706,7 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
|
||||
) -> None:
|
||||
"""LLM test should restore masked sensitive custom config values before invocation."""
|
||||
name = f"test-provider-vertex-test-{uuid4().hex[:8]}"
|
||||
provider_name = LlmProviderNames.VERTEX_AI.value
|
||||
provider = LlmProviderNames.VERTEX_AI.value
|
||||
default_model_name = "gemini-2.5-pro"
|
||||
original_custom_config = {
|
||||
"vertex_credentials": '{"type":"service_account","private_key":"REAL_PRIVATE_KEY"}',
|
||||
@@ -708,10 +719,11 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
|
||||
return ""
|
||||
|
||||
try:
|
||||
provider = put_llm_provider(
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
provider=provider,
|
||||
default_model_name=default_model_name,
|
||||
custom_config=original_custom_config,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -730,9 +742,14 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
|
||||
with patch("onyx.server.manage.llm.api.test_llm", side_effect=capture_test_llm):
|
||||
run_llm_config_test(
|
||||
LLMTestRequest(
|
||||
id=provider.id,
|
||||
provider=provider_name,
|
||||
model=default_model_name,
|
||||
name=name,
|
||||
provider=provider,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
api_key_changed=False,
|
||||
custom_config_changed=True,
|
||||
custom_config={
|
||||
|
||||
@@ -15,11 +15,9 @@ import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.llm import fetch_auto_mode_providers
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_llm_provider_view
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import sync_auto_mode_models
|
||||
from onyx.db.llm import update_default_provider
|
||||
@@ -137,6 +135,7 @@ class TestAutoModeSyncFeature:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name=expected_default_model,
|
||||
model_configurations=[], # No model configs provided
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -164,8 +163,13 @@ class TestAutoModeSyncFeature:
|
||||
if mc.name in all_expected_models:
|
||||
assert mc.is_visible is True, f"Model '{mc.name}' should be visible"
|
||||
|
||||
# Verify the default model was set correctly
|
||||
assert (
|
||||
provider.default_model_name == expected_default_model
|
||||
), f"Default model should be '{expected_default_model}'"
|
||||
|
||||
# Step 4: Set the provider as default
|
||||
update_default_provider(provider.id, expected_default_model, db_session)
|
||||
update_default_provider(provider.id, db_session)
|
||||
|
||||
# Step 5: Fetch the default provider and verify
|
||||
default_model = fetch_default_llm_model(db_session)
|
||||
@@ -234,6 +238,7 @@ class TestAutoModeSyncFeature:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name="gpt-4o",
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -312,6 +317,7 @@ class TestAutoModeSyncFeature:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=False, # Not in auto mode initially
|
||||
default_model_name="gpt-4",
|
||||
model_configurations=initial_models,
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -320,13 +326,13 @@ class TestAutoModeSyncFeature:
|
||||
)
|
||||
|
||||
# Verify initial state: all models are visible
|
||||
initial_provider = fetch_existing_llm_provider(
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert initial_provider is not None
|
||||
assert initial_provider.is_auto_mode is False
|
||||
assert provider is not None
|
||||
assert provider.is_auto_mode is False
|
||||
|
||||
for mc in initial_provider.model_configurations:
|
||||
for mc in provider.model_configurations:
|
||||
assert (
|
||||
mc.is_visible is True
|
||||
), f"Initial model '{mc.name}' should be visible"
|
||||
@@ -338,12 +344,12 @@ class TestAutoModeSyncFeature:
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
id=initial_provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None, # Not changing API key
|
||||
api_key_changed=False,
|
||||
is_auto_mode=True, # Now enabling auto mode
|
||||
default_model_name=auto_mode_default,
|
||||
model_configurations=[], # Auto mode will sync from config
|
||||
),
|
||||
is_creation=False, # This is an update
|
||||
@@ -354,15 +360,15 @@ class TestAutoModeSyncFeature:
|
||||
# Step 3: Verify model visibility after auto mode transition
|
||||
# Expire session cache to force fresh fetch after sync_auto_mode_models committed
|
||||
db_session.expire_all()
|
||||
provider_view = fetch_llm_provider_view(
|
||||
provider_name=provider_name, db_session=db_session
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider_view is not None
|
||||
assert provider_view.is_auto_mode is True
|
||||
assert provider is not None
|
||||
assert provider.is_auto_mode is True
|
||||
|
||||
# Build a map of model name -> visibility
|
||||
model_visibility = {
|
||||
mc.name: mc.is_visible for mc in provider_view.model_configurations
|
||||
mc.name: mc.is_visible for mc in provider.model_configurations
|
||||
}
|
||||
|
||||
# Models in auto mode config should be visible
|
||||
@@ -382,6 +388,9 @@ class TestAutoModeSyncFeature:
|
||||
model_visibility[model_name] is False
|
||||
), f"Model '{model_name}' not in auto config should NOT be visible"
|
||||
|
||||
# Verify the default model was updated
|
||||
assert provider.default_model_name == auto_mode_default
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
@@ -423,12 +432,8 @@ class TestAutoModeSyncFeature:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o",
|
||||
is_visible=True,
|
||||
)
|
||||
],
|
||||
default_model_name="gpt-4o",
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
@@ -530,6 +535,7 @@ class TestAutoModeSyncFeature:
|
||||
api_key=provider_1_api_key,
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name=provider_1_default_model,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -543,7 +549,7 @@ class TestAutoModeSyncFeature:
|
||||
name=provider_1_name, db_session=db_session
|
||||
)
|
||||
assert provider_1 is not None
|
||||
update_default_provider(provider_1.id, provider_1_default_model, db_session)
|
||||
update_default_provider(provider_1.id, db_session)
|
||||
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
@@ -557,6 +563,7 @@ class TestAutoModeSyncFeature:
|
||||
api_key=provider_2_api_key,
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name=provider_2_default_model,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -577,7 +584,7 @@ class TestAutoModeSyncFeature:
|
||||
name=provider_2_name, db_session=db_session
|
||||
)
|
||||
assert provider_2 is not None
|
||||
update_default_provider(provider_2.id, provider_2_default_model, db_session)
|
||||
update_default_provider(provider_2.id, db_session)
|
||||
|
||||
# Step 5: Verify provider 2 is now the default
|
||||
db_session.expire_all()
|
||||
@@ -637,6 +644,7 @@ class TestAutoModeMissingFlows:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name="gpt-4o",
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -693,364 +701,3 @@ class TestAutoModeMissingFlows:
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
|
||||
class TestAutoModeTransitionsAndResync:
|
||||
"""Tests for auto/manual transitions, config evolution, and sync idempotency."""
|
||||
|
||||
def test_auto_to_manual_mode_preserves_models_and_stops_syncing(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""Disabling auto mode should preserve the current model list and
|
||||
prevent future syncs from altering visibility.
|
||||
|
||||
Steps:
|
||||
1. Create provider in auto mode — models synced from config.
|
||||
2. Update provider to manual mode (is_auto_mode=False).
|
||||
3. Verify all models remain with unchanged visibility.
|
||||
4. Call sync_auto_mode_models with a *different* config.
|
||||
5. Verify fetch_auto_mode_providers excludes this provider, so the
|
||||
periodic task would never call sync on it.
|
||||
"""
|
||||
initial_config = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o",
|
||||
additional_models=["gpt-4o-mini"],
|
||||
)
|
||||
|
||||
try:
|
||||
# Step 1: Create in auto mode
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
return_value=initial_config,
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
visibility_before = {
|
||||
mc.name: mc.is_visible for mc in provider.model_configurations
|
||||
}
|
||||
assert visibility_before == {"gpt-4o": True, "gpt-4o-mini": True}
|
||||
|
||||
# Step 2: Switch to manual mode
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None,
|
||||
api_key_changed=False,
|
||||
is_auto_mode=False,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
),
|
||||
],
|
||||
),
|
||||
is_creation=False,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Step 3: Models unchanged
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
assert provider.is_auto_mode is False
|
||||
visibility_after = {
|
||||
mc.name: mc.is_visible for mc in provider.model_configurations
|
||||
}
|
||||
assert visibility_after == visibility_before
|
||||
|
||||
# Step 4-5: Provider excluded from auto mode queries
|
||||
auto_providers = fetch_auto_mode_providers(db_session)
|
||||
auto_provider_ids = {p.id for p in auto_providers}
|
||||
assert provider.id not in auto_provider_ids
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
def test_resync_adds_new_and_hides_removed_models(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""When the GitHub config changes between syncs, a subsequent sync
|
||||
should add newly listed models and hide models that were removed.
|
||||
|
||||
Steps:
|
||||
1. Create provider in auto mode with config v1: [gpt-4o, gpt-4o-mini].
|
||||
2. Sync with config v2: [gpt-4o, gpt-4-turbo] (gpt-4o-mini removed,
|
||||
gpt-4-turbo added).
|
||||
3. Verify gpt-4o still visible, gpt-4o-mini hidden, gpt-4-turbo added
|
||||
and visible.
|
||||
"""
|
||||
config_v1 = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o",
|
||||
additional_models=["gpt-4o-mini"],
|
||||
)
|
||||
config_v2 = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o",
|
||||
additional_models=["gpt-4-turbo"],
|
||||
)
|
||||
|
||||
try:
|
||||
# Step 1: Create with config v1
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
return_value=config_v1,
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Step 2: Re-sync with config v2
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
|
||||
changes = sync_auto_mode_models(
|
||||
db_session=db_session,
|
||||
provider=provider,
|
||||
llm_recommendations=config_v2,
|
||||
)
|
||||
assert changes > 0
|
||||
|
||||
# Step 3: Verify
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
|
||||
visibility = {
|
||||
mc.name: mc.is_visible for mc in provider.model_configurations
|
||||
}
|
||||
|
||||
# gpt-4o: still in config -> visible
|
||||
assert visibility["gpt-4o"] is True
|
||||
# gpt-4o-mini: removed from config -> hidden (not deleted)
|
||||
assert "gpt-4o-mini" in visibility, "Removed model should still exist in DB"
|
||||
assert visibility["gpt-4o-mini"] is False
|
||||
# gpt-4-turbo: newly added -> visible
|
||||
assert visibility["gpt-4-turbo"] is True
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
def test_sync_is_idempotent(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""Running sync twice with the same config should produce zero
|
||||
changes on the second call."""
|
||||
config = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o",
|
||||
additional_models=["gpt-4o-mini", "gpt-4-turbo"],
|
||||
)
|
||||
|
||||
try:
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
return_value=config,
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
|
||||
# First explicit sync (may report changes if creation already synced)
|
||||
sync_auto_mode_models(
|
||||
db_session=db_session,
|
||||
provider=provider,
|
||||
llm_recommendations=config,
|
||||
)
|
||||
|
||||
# Snapshot state after first sync
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
snapshot = {
|
||||
mc.name: (mc.is_visible, mc.display_name)
|
||||
for mc in provider.model_configurations
|
||||
}
|
||||
|
||||
# Second sync — should be a no-op
|
||||
changes = sync_auto_mode_models(
|
||||
db_session=db_session,
|
||||
provider=provider,
|
||||
llm_recommendations=config,
|
||||
)
|
||||
assert (
|
||||
changes == 0
|
||||
), f"Expected 0 changes on idempotent re-sync, got {changes}"
|
||||
|
||||
# State should be identical
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
current = {
|
||||
mc.name: (mc.is_visible, mc.display_name)
|
||||
for mc in provider.model_configurations
|
||||
}
|
||||
assert current == snapshot
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
def test_default_model_hidden_when_removed_from_config(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""When the current default model is removed from the config, sync
|
||||
should hide it. The default model flow row should still exist (it
|
||||
points at the ModelConfiguration), but the model is no longer visible.
|
||||
|
||||
Steps:
|
||||
1. Create provider with config: default=gpt-4o, additional=[gpt-4o-mini].
|
||||
2. Set gpt-4o as the global default.
|
||||
3. Re-sync with config: default=gpt-4o-mini (gpt-4o removed entirely).
|
||||
4. Verify gpt-4o is hidden, gpt-4o-mini is visible, and
|
||||
fetch_default_llm_model still returns a result (the flow row persists).
|
||||
"""
|
||||
config_v1 = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o",
|
||||
additional_models=["gpt-4o-mini"],
|
||||
)
|
||||
config_v2 = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o-mini",
|
||||
additional_models=[],
|
||||
)
|
||||
|
||||
try:
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
return_value=config_v1,
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Step 2: Set gpt-4o as global default
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
default_before = fetch_default_llm_model(db_session)
|
||||
assert default_before is not None
|
||||
assert default_before.name == "gpt-4o"
|
||||
|
||||
# Step 3: Re-sync with config v2 (gpt-4o removed)
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
|
||||
changes = sync_auto_mode_models(
|
||||
db_session=db_session,
|
||||
provider=provider,
|
||||
llm_recommendations=config_v2,
|
||||
)
|
||||
assert changes > 0
|
||||
|
||||
# Step 4: Verify visibility
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
|
||||
visibility = {
|
||||
mc.name: mc.is_visible for mc in provider.model_configurations
|
||||
}
|
||||
assert visibility["gpt-4o"] is False, "Removed default should be hidden"
|
||||
assert visibility["gpt-4o-mini"] is True, "New default should be visible"
|
||||
|
||||
# The LLMModelFlow row for gpt-4o still exists (is_default=True),
|
||||
# but the model is hidden. fetch_default_llm_model filters on
|
||||
# is_visible=True, so it should NOT return gpt-4o.
|
||||
db_session.expire_all()
|
||||
default_after = fetch_default_llm_model(db_session)
|
||||
assert (
|
||||
default_after is None or default_after.name != "gpt-4o"
|
||||
), "Hidden model should not be returned as the default"
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
@@ -64,6 +64,7 @@ def _create_provider(
|
||||
name=name,
|
||||
provider=provider,
|
||||
api_key="sk-ant-api03-...",
|
||||
default_model_name="claude-3-5-sonnet-20240620",
|
||||
is_public=is_public,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -153,9 +154,7 @@ def test_user_sends_message_to_private_provider(
|
||||
)
|
||||
_create_provider(db_session, LlmProviderNames.GOOGLE, "private-provider", False)
|
||||
|
||||
update_default_provider(
|
||||
public_provider_id, "claude-3-5-sonnet-20240620", db_session
|
||||
)
|
||||
update_default_provider(public_provider_id, db_session)
|
||||
|
||||
try:
|
||||
# Create chat session
|
||||
|
||||
@@ -42,6 +42,7 @@ def _create_llm_provider_and_model(
|
||||
name=provider_name,
|
||||
provider="openai",
|
||||
api_key="test-api-key",
|
||||
default_model_name=model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=model_name,
|
||||
|
||||
@@ -22,6 +22,7 @@ from sqlalchemy.orm import Session
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
from onyx.configs.constants import FederatedConnectorSource
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.context.search.federated.slack_search import fetch_and_cache_channel_metadata
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import FederatedConnector
|
||||
@@ -54,6 +55,11 @@ def _create_test_persona_with_slack_config(db_session: Session) -> Persona | Non
|
||||
persona = Persona(
|
||||
name=f"test_slack_persona_{unique_id}",
|
||||
description="Test persona for Slack federated search",
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=True,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
system_prompt="You are a helpful assistant.",
|
||||
task_prompt="Answer the user's question based on the provided context.",
|
||||
)
|
||||
@@ -428,6 +434,7 @@ class TestSlackBotFederatedSearch:
|
||||
name=f"test-llm-provider-{uuid4().hex[:8]}",
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=api_key,
|
||||
default_model_name="gpt-4o",
|
||||
is_public=True,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -441,7 +448,7 @@ class TestSlackBotFederatedSearch:
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
update_default_provider(provider_view.id, "gpt-4o", db_session)
|
||||
update_default_provider(provider_view.id, db_session)
|
||||
|
||||
def _teardown_common_mocks(self, patches: list) -> None:
|
||||
"""Stop all patches"""
|
||||
@@ -812,6 +819,11 @@ def test_slack_channel_config_eager_loads_persona(db_session: Session) -> None:
|
||||
persona = Persona(
|
||||
name=f"test_eager_load_persona_{unique_id}",
|
||||
description="Test persona for eager loading test",
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=True,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
system_prompt="You are a helpful assistant.",
|
||||
task_prompt="Answer the user's question.",
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import MCPAuthenticationPerformer
|
||||
from onyx.db.enums import MCPAuthenticationType
|
||||
from onyx.db.enums import MCPTransport
|
||||
@@ -46,6 +47,12 @@ def _create_test_persona_with_mcp_tool(
|
||||
persona = Persona(
|
||||
name=f"Test MCP Persona {uuid4().hex[:8]}",
|
||||
description="Test persona with MCP tools",
|
||||
num_chunks=10.0,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
system_prompt="You are a helpful assistant",
|
||||
task_prompt="Answer the user's question",
|
||||
tools=tools,
|
||||
|
||||
@@ -17,6 +17,7 @@ import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import OAuthConfig
|
||||
from onyx.db.models import Persona
|
||||
@@ -56,6 +57,12 @@ def _create_test_persona(db_session: Session, user: User, tools: list[Tool]) ->
|
||||
persona = Persona(
|
||||
name=f"Test Persona {uuid4().hex[:8]}",
|
||||
description="Test persona",
|
||||
num_chunks=10.0,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
system_prompt="You are a helpful assistant",
|
||||
task_prompt="Answer the user's question",
|
||||
tools=tools,
|
||||
|
||||
@@ -933,7 +933,6 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi import UploadFile
|
||||
from fastapi.background import BackgroundTasks
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.datastructures import Headers
|
||||
|
||||
@@ -1027,13 +1026,6 @@ class _MockCIHandler(BaseHTTPRequestHandler):
|
||||
else:
|
||||
self._respond_json(404, {"error": "not found"})
|
||||
|
||||
def do_GET(self) -> None:
|
||||
self._capture("GET", b"")
|
||||
if self.path == "/health":
|
||||
self._respond_json(200, {"status": "ok"})
|
||||
else:
|
||||
self._respond_json(404, {"error": "not found"})
|
||||
|
||||
def do_DELETE(self) -> None:
|
||||
self._capture("DELETE", b"")
|
||||
self.send_response(200)
|
||||
@@ -1114,14 +1106,6 @@ def mock_ci_server() -> Generator[MockCodeInterpreterServer, None, None]:
|
||||
server.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_health_cache() -> None:
|
||||
"""Reset the health check cache before every test."""
|
||||
import onyx.tools.tool_implementations.python.code_interpreter_client as mod
|
||||
|
||||
mod._health_cache = {}
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _attach_python_tool_to_default_persona(db_session: Session) -> None:
|
||||
"""Ensure the default persona (id=0) has the PythonTool attached."""
|
||||
@@ -1155,7 +1139,6 @@ def test_code_interpreter_receives_chat_files(
|
||||
# Upload a test CSV
|
||||
csv_content = b"name,age,city\nAlice,30,NYC\nBob,25,SF\n"
|
||||
result = upload_user_files(
|
||||
bg_tasks=BackgroundTasks(),
|
||||
files=[
|
||||
UploadFile(
|
||||
file=io.BytesIO(csv_content),
|
||||
|
||||
@@ -4,12 +4,10 @@ from uuid import uuid4
|
||||
import requests
|
||||
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.server.manage.llm.models import DefaultModel
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -34,6 +32,7 @@ class LLMProviderManager:
|
||||
llm_provider = LLMProviderUpsertRequest(
|
||||
name=name or f"test-provider-{uuid4()}",
|
||||
provider=provider or LlmProviderNames.OPENAI,
|
||||
default_model_name=default_model_name or "gpt-4o-mini",
|
||||
api_key=api_key or os.environ["OPENAI_API_KEY"],
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
@@ -66,7 +65,7 @@ class LLMProviderManager:
|
||||
name=response_data["name"],
|
||||
provider=response_data["provider"],
|
||||
api_key=response_data["api_key"],
|
||||
default_model_name=default_model_name or "gpt-4o-mini",
|
||||
default_model_name=response_data["default_model_name"],
|
||||
is_public=response_data["is_public"],
|
||||
is_auto_mode=response_data.get("is_auto_mode", False),
|
||||
groups=response_data["groups"],
|
||||
@@ -76,19 +75,9 @@ class LLMProviderManager:
|
||||
)
|
||||
|
||||
if set_as_default:
|
||||
if default_model_name is None:
|
||||
default_model_name = "gpt-4o-mini"
|
||||
set_default_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/llm/default",
|
||||
json={
|
||||
"provider_id": response_data["id"],
|
||||
"model_name": default_model_name,
|
||||
},
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{llm_response.json()['id']}/default",
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
set_default_response.raise_for_status()
|
||||
|
||||
@@ -115,7 +104,7 @@ class LLMProviderManager:
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [LLMProviderView(**p) for p in response.json()["providers"]]
|
||||
return [LLMProviderView(**ug) for ug in response.json()]
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
@@ -124,11 +113,7 @@ class LLMProviderManager:
|
||||
verify_deleted: bool = False,
|
||||
) -> None:
|
||||
all_llm_providers = LLMProviderManager.get_all(user_performing_action)
|
||||
default_model = LLMProviderManager.get_default_model(user_performing_action)
|
||||
for fetched_llm_provider in all_llm_providers:
|
||||
model_names = [
|
||||
model.name for model in fetched_llm_provider.model_configurations
|
||||
]
|
||||
if llm_provider.id == fetched_llm_provider.id:
|
||||
if verify_deleted:
|
||||
raise ValueError(
|
||||
@@ -141,30 +126,11 @@ class LLMProviderManager:
|
||||
if (
|
||||
fetched_llm_groups == llm_provider_groups
|
||||
and llm_provider.provider == fetched_llm_provider.provider
|
||||
and (
|
||||
default_model is None or default_model.model_name in model_names
|
||||
)
|
||||
and llm_provider.default_model_name
|
||||
== fetched_llm_provider.default_model_name
|
||||
and llm_provider.is_public == fetched_llm_provider.is_public
|
||||
and set(fetched_llm_provider.personas) == set(llm_provider.personas)
|
||||
):
|
||||
return
|
||||
if not verify_deleted:
|
||||
raise ValueError(f"LLM Provider {llm_provider.id} not found")
|
||||
|
||||
@staticmethod
|
||||
def get_default_model(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DefaultModel | None:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
default_text = response.json().get("default_text")
|
||||
if default_text is None:
|
||||
return None
|
||||
return DefaultModel(**default_text)
|
||||
|
||||
@@ -3,6 +3,7 @@ from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.server.features.persona.models import FullPersonaSnapshot
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
@@ -19,7 +20,11 @@ class PersonaManager:
|
||||
description: str | None = None,
|
||||
system_prompt: str | None = None,
|
||||
task_prompt: str | None = None,
|
||||
num_chunks: float = 5,
|
||||
llm_relevance_filter: bool = True,
|
||||
is_public: bool = True,
|
||||
llm_filter_extraction: bool = True,
|
||||
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO,
|
||||
datetime_aware: bool = False,
|
||||
document_set_ids: list[int] | None = None,
|
||||
tool_ids: list[int] | None = None,
|
||||
@@ -30,7 +35,6 @@ class PersonaManager:
|
||||
label_ids: list[int] | None = None,
|
||||
user_file_ids: list[str] | None = None,
|
||||
display_priority: int | None = None,
|
||||
featured: bool = False,
|
||||
) -> DATestPersona:
|
||||
name = name or f"test-persona-{uuid4()}"
|
||||
description = description or f"Description for {name}"
|
||||
@@ -43,7 +47,11 @@ class PersonaManager:
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
datetime_aware=datetime_aware,
|
||||
num_chunks=num_chunks,
|
||||
llm_relevance_filter=llm_relevance_filter,
|
||||
is_public=is_public,
|
||||
llm_filter_extraction=llm_filter_extraction,
|
||||
recency_bias=recency_bias,
|
||||
document_set_ids=document_set_ids or [],
|
||||
tool_ids=tool_ids or [],
|
||||
llm_model_provider_override=llm_model_provider_override,
|
||||
@@ -53,7 +61,6 @@ class PersonaManager:
|
||||
label_ids=label_ids or [],
|
||||
user_file_ids=user_file_ids or [],
|
||||
display_priority=display_priority,
|
||||
featured=featured,
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
@@ -68,7 +75,11 @@ class PersonaManager:
|
||||
id=persona_data["id"],
|
||||
name=name,
|
||||
description=description,
|
||||
num_chunks=num_chunks,
|
||||
llm_relevance_filter=llm_relevance_filter,
|
||||
is_public=is_public,
|
||||
llm_filter_extraction=llm_filter_extraction,
|
||||
recency_bias=recency_bias,
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
datetime_aware=datetime_aware,
|
||||
@@ -79,7 +90,6 @@ class PersonaManager:
|
||||
users=users or [],
|
||||
groups=groups or [],
|
||||
label_ids=label_ids or [],
|
||||
featured=featured,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -90,7 +100,11 @@ class PersonaManager:
|
||||
description: str | None = None,
|
||||
system_prompt: str | None = None,
|
||||
task_prompt: str | None = None,
|
||||
num_chunks: float | None = None,
|
||||
llm_relevance_filter: bool | None = None,
|
||||
is_public: bool | None = None,
|
||||
llm_filter_extraction: bool | None = None,
|
||||
recency_bias: RecencyBiasSetting | None = None,
|
||||
datetime_aware: bool = False,
|
||||
document_set_ids: list[int] | None = None,
|
||||
tool_ids: list[int] | None = None,
|
||||
@@ -99,7 +113,6 @@ class PersonaManager:
|
||||
users: list[str] | None = None,
|
||||
groups: list[int] | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
featured: bool | None = None,
|
||||
) -> DATestPersona:
|
||||
system_prompt = system_prompt or f"System prompt for {persona.name}"
|
||||
task_prompt = task_prompt or f"Task prompt for {persona.name}"
|
||||
@@ -110,7 +123,13 @@ class PersonaManager:
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
datetime_aware=datetime_aware,
|
||||
num_chunks=num_chunks or persona.num_chunks,
|
||||
llm_relevance_filter=llm_relevance_filter or persona.llm_relevance_filter,
|
||||
is_public=persona.is_public if is_public is None else is_public,
|
||||
llm_filter_extraction=(
|
||||
llm_filter_extraction or persona.llm_filter_extraction
|
||||
),
|
||||
recency_bias=recency_bias or persona.recency_bias,
|
||||
document_set_ids=document_set_ids or persona.document_set_ids,
|
||||
tool_ids=tool_ids or persona.tool_ids,
|
||||
llm_model_provider_override=(
|
||||
@@ -122,7 +141,6 @@ class PersonaManager:
|
||||
users=[UUID(user) for user in (users or persona.users)],
|
||||
groups=groups or persona.groups,
|
||||
label_ids=label_ids or persona.label_ids,
|
||||
featured=featured if featured is not None else persona.featured,
|
||||
)
|
||||
|
||||
response = requests.patch(
|
||||
@@ -137,12 +155,16 @@ class PersonaManager:
|
||||
id=updated_persona_data["id"],
|
||||
name=updated_persona_data["name"],
|
||||
description=updated_persona_data["description"],
|
||||
num_chunks=updated_persona_data["num_chunks"],
|
||||
llm_relevance_filter=updated_persona_data["llm_relevance_filter"],
|
||||
is_public=updated_persona_data["is_public"],
|
||||
llm_filter_extraction=updated_persona_data["llm_filter_extraction"],
|
||||
recency_bias=recency_bias or persona.recency_bias,
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
datetime_aware=datetime_aware,
|
||||
document_set_ids=[ds["id"] for ds in updated_persona_data["document_sets"]],
|
||||
tool_ids=[t["id"] for t in updated_persona_data["tools"]],
|
||||
document_set_ids=updated_persona_data["document_sets"],
|
||||
tool_ids=updated_persona_data["tools"],
|
||||
llm_model_provider_override=updated_persona_data[
|
||||
"llm_model_provider_override"
|
||||
],
|
||||
@@ -151,8 +173,7 @@ class PersonaManager:
|
||||
],
|
||||
users=[user["email"] for user in updated_persona_data["users"]],
|
||||
groups=updated_persona_data["groups"],
|
||||
label_ids=[label["id"] for label in updated_persona_data["labels"]],
|
||||
featured=updated_persona_data["featured"],
|
||||
label_ids=updated_persona_data["labels"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -201,13 +222,32 @@ class PersonaManager:
|
||||
fetched_persona.description,
|
||||
)
|
||||
)
|
||||
if fetched_persona.num_chunks != persona.num_chunks:
|
||||
mismatches.append(
|
||||
("num_chunks", persona.num_chunks, fetched_persona.num_chunks)
|
||||
)
|
||||
if fetched_persona.llm_relevance_filter != persona.llm_relevance_filter:
|
||||
mismatches.append(
|
||||
(
|
||||
"llm_relevance_filter",
|
||||
persona.llm_relevance_filter,
|
||||
fetched_persona.llm_relevance_filter,
|
||||
)
|
||||
)
|
||||
if fetched_persona.is_public != persona.is_public:
|
||||
mismatches.append(
|
||||
("is_public", persona.is_public, fetched_persona.is_public)
|
||||
)
|
||||
if fetched_persona.featured != persona.featured:
|
||||
if (
|
||||
fetched_persona.llm_filter_extraction
|
||||
!= persona.llm_filter_extraction
|
||||
):
|
||||
mismatches.append(
|
||||
("featured", persona.featured, fetched_persona.featured)
|
||||
(
|
||||
"llm_filter_extraction",
|
||||
persona.llm_filter_extraction,
|
||||
fetched_persona.llm_filter_extraction,
|
||||
)
|
||||
)
|
||||
if (
|
||||
fetched_persona.llm_model_provider_override
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
|
||||
|
||||
class ScimClient:
|
||||
"""HTTP client for making authenticated SCIM v2 requests."""
|
||||
|
||||
@staticmethod
|
||||
def _headers(raw_token: str) -> dict[str, str]:
|
||||
return {
|
||||
**GENERAL_HEADERS,
|
||||
"Authorization": f"Bearer {raw_token}",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get(path: str, raw_token: str) -> requests.Response:
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=ScimClient._headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def post(path: str, raw_token: str, json: dict) -> requests.Response:
|
||||
return requests.post(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
json=json,
|
||||
headers=ScimClient._headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def put(path: str, raw_token: str, json: dict) -> requests.Response:
|
||||
return requests.put(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
json=json,
|
||||
headers=ScimClient._headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def patch(path: str, raw_token: str, json: dict) -> requests.Response:
|
||||
return requests.patch(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
json=json,
|
||||
headers=ScimClient._headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def delete(path: str, raw_token: str) -> requests.Response:
|
||||
return requests.delete(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=ScimClient._headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_no_auth(path: str) -> requests.Response:
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=GENERAL_HEADERS,
|
||||
timeout=60,
|
||||
)
|
||||
@@ -1,6 +1,7 @@
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestScimToken
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -50,3 +51,29 @@ class ScimTokenManager:
|
||||
created_at=data["created_at"],
|
||||
last_used_at=data.get("last_used_at"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_scim_headers(raw_token: str) -> dict[str, str]:
|
||||
return {
|
||||
**GENERAL_HEADERS,
|
||||
"Authorization": f"Bearer {raw_token}",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def scim_get(
|
||||
path: str,
|
||||
raw_token: str,
|
||||
) -> requests.Response:
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=ScimTokenManager.get_scim_headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def scim_get_no_auth(path: str) -> requests.Response:
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=GENERAL_HEADERS,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ from pydantic import Field
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.db.enums import AccessType
|
||||
@@ -127,7 +128,7 @@ class DATestLLMProvider(BaseModel):
|
||||
name: str
|
||||
provider: str
|
||||
api_key: str
|
||||
default_model_name: str | None = None
|
||||
default_model_name: str
|
||||
is_public: bool
|
||||
is_auto_mode: bool = False
|
||||
groups: list[int]
|
||||
@@ -161,7 +162,11 @@ class DATestPersona(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str
|
||||
num_chunks: float
|
||||
llm_relevance_filter: bool
|
||||
is_public: bool
|
||||
llm_filter_extraction: bool
|
||||
recency_bias: RecencyBiasSetting
|
||||
document_set_ids: list[int]
|
||||
tool_ids: list[int]
|
||||
llm_model_provider_override: str | None
|
||||
@@ -169,7 +174,6 @@ class DATestPersona(BaseModel):
|
||||
users: list[str]
|
||||
groups: list[int]
|
||||
label_ids: list[int]
|
||||
featured: bool = False
|
||||
|
||||
# Embedded prompt fields (no longer separate prompt_ids)
|
||||
system_prompt: str | None = None
|
||||
|
||||
@@ -8,6 +8,7 @@ from collections.abc import Generator
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.discord_bot import bulk_create_channel_configs
|
||||
from onyx.db.discord_bot import create_discord_bot_config
|
||||
from onyx.db.discord_bot import create_guild_config
|
||||
@@ -35,8 +36,14 @@ def _create_test_persona(db_session: Session, persona_id: int, name: str) -> Per
|
||||
id=persona_id,
|
||||
name=name,
|
||||
description="Test persona for Discord bot tests",
|
||||
num_chunks=5.0,
|
||||
chunks_above=1,
|
||||
chunks_below=1,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.FAVOR_RECENT,
|
||||
is_visible=True,
|
||||
featured=False,
|
||||
is_default_persona=False,
|
||||
deleted=False,
|
||||
builtin_persona=False,
|
||||
)
|
||||
|
||||
@@ -414,24 +414,6 @@ def test_mock_connector_checkpoint_recovery(
|
||||
)
|
||||
assert finished_index_attempt.status == IndexingStatus.FAILED
|
||||
|
||||
# Pause the connector immediately to prevent check_for_indexing from
|
||||
# creating automatic retry attempts while we reset the mock server.
|
||||
# Without this, the INITIAL_INDEXING status causes immediate retries
|
||||
# that would consume (or fail against) the mock server before we can
|
||||
# set up the recovery behavior.
|
||||
CCPairManager.pause_cc_pair(cc_pair, user_performing_action=admin_user)
|
||||
|
||||
# Collect all index attempt IDs created so far (the initial one plus
|
||||
# any automatic retries that may have started before the pause took effect).
|
||||
all_prior_attempt_ids: list[int] = []
|
||||
index_attempts_page = IndexAttemptManager.get_index_attempt_page(
|
||||
cc_pair_id=cc_pair.id,
|
||||
page=0,
|
||||
page_size=100,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
all_prior_attempt_ids = [ia.id for ia in index_attempts_page.items]
|
||||
|
||||
# Verify initial state: both docs should be indexed
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
documents = DocumentManager.fetch_documents_for_cc_pair(
|
||||
@@ -483,14 +465,17 @@ def test_mock_connector_checkpoint_recovery(
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Set the manual indexing trigger, then unpause to allow the recovery run.
|
||||
# After the failure, the connector is in repeated error state and paused.
|
||||
# Set the manual indexing trigger first (while paused), then unpause.
|
||||
# This ensures the trigger is set before CHECK_FOR_INDEXING runs, which will
|
||||
# prevent the connector from being re-paused when repeated error state is detected.
|
||||
CCPairManager.run_once(
|
||||
cc_pair, from_beginning=False, user_performing_action=admin_user
|
||||
)
|
||||
CCPairManager.unpause_cc_pair(cc_pair, user_performing_action=admin_user)
|
||||
recovery_index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempts_to_ignore=all_prior_attempt_ids,
|
||||
index_attempts_to_ignore=[initial_index_attempt.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
IndexAttemptManager.wait_for_index_attempt_completion(
|
||||
|
||||
@@ -130,8 +130,8 @@ def test_repeated_error_state_detection_and_recovery(
|
||||
# )
|
||||
break
|
||||
|
||||
if time.monotonic() - start_time > 90:
|
||||
assert False, "CC pair did not enter repeated error state within 90 seconds"
|
||||
if time.monotonic() - start_time > 30:
|
||||
assert False, "CC pair did not enter repeated error state within 30 seconds"
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user