mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-03 14:45:46 +00:00
Compare commits
53 Commits
worktree-o
...
def_ci_url
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
15c828d0e2 | ||
|
|
6aa56821d6 | ||
|
|
eda436de01 | ||
|
|
07915a6c01 | ||
|
|
2c3e9aecd1 | ||
|
|
fa29cc3849 | ||
|
|
24ac8b37d3 | ||
|
|
be8b108ae4 | ||
|
|
f380a75df3 | ||
|
|
21ec93663b | ||
|
|
d789c74024 | ||
|
|
fe014776f7 | ||
|
|
700ca0e0fc | ||
|
|
a84f8238ec | ||
|
|
4fc802e19d | ||
|
|
6cfd49439a | ||
|
|
71a1faa47e | ||
|
|
1a65217baf | ||
|
|
30fa43b5fc | ||
|
|
28332fa24b | ||
|
|
1f5050f9f6 | ||
|
|
3c1d29d3cf | ||
|
|
709e3f4ca7 | ||
|
|
dfa27c08ef | ||
|
|
13d60dcb0e | ||
|
|
30704f427f | ||
|
|
4f3c54f282 | ||
|
|
580d41dc23 | ||
|
|
897e181d67 | ||
|
|
fd322a8a10 | ||
|
|
11c54bafb5 | ||
|
|
c93617df5d | ||
|
|
0cdd438f46 | ||
|
|
31aef36f78 | ||
|
|
0c35dfc0e4 | ||
|
|
a9769757fe | ||
|
|
15d8946f40 | ||
|
|
ba79539d6d | ||
|
|
59d3725fc6 | ||
|
|
9c05bd215d | ||
|
|
4d2aa09654 | ||
|
|
16c07c8756 | ||
|
|
3fb4f5d6e6 | ||
|
|
14fab7fcdf | ||
|
|
22a335fffa | ||
|
|
b0f7466eba | ||
|
|
b1d42726b1 | ||
|
|
7d922bffc1 | ||
|
|
de7fc36fc5 | ||
|
|
7f9e37450d | ||
|
|
c7ef85b733 | ||
|
|
bd9319e592 | ||
|
|
db5955d6f2 |
@@ -54,6 +54,7 @@ runs:
|
||||
shell: bash
|
||||
env:
|
||||
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
|
||||
INTEGRATION_REPOSITORY: ${{ inputs.runs-on-ecr-cache }}
|
||||
TAG: nightly-llm-it-${{ inputs.run-id }}
|
||||
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
|
||||
HEAD_SHA: ${{ inputs.github-sha }}
|
||||
|
||||
26
.github/workflows/deployment.yml
vendored
26
.github/workflows/deployment.yml
vendored
@@ -426,8 +426,9 @@ jobs:
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64,mode=max
|
||||
@@ -499,8 +500,9 @@ jobs:
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64,mode=max
|
||||
@@ -646,8 +648,8 @@ jobs:
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64,mode=max
|
||||
@@ -728,8 +730,8 @@ jobs:
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64,mode=max
|
||||
@@ -862,8 +864,9 @@ jobs:
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64,mode=max
|
||||
@@ -934,8 +937,9 @@ jobs:
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64,mode=max
|
||||
@@ -1072,8 +1076,8 @@ jobs:
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
ENABLE_CRAFT=true
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64,mode=max
|
||||
@@ -1145,8 +1149,8 @@ jobs:
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
ENABLE_CRAFT=true
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64,mode=max
|
||||
@@ -1287,8 +1291,9 @@ jobs:
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64,mode=max
|
||||
@@ -1366,8 +1371,9 @@ jobs:
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64,mode=max
|
||||
|
||||
13
.github/workflows/nightly-llm-provider-chat.yml
vendored
13
.github/workflows/nightly-llm-provider-chat.yml
vendored
@@ -15,6 +15,9 @@ permissions:
|
||||
jobs:
|
||||
provider-chat-test:
|
||||
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
with:
|
||||
openai_models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
|
||||
anthropic_models: ${{ vars.NIGHTLY_LLM_ANTHROPIC_MODELS }}
|
||||
@@ -25,16 +28,6 @@ jobs:
|
||||
ollama_models: ${{ vars.NIGHTLY_LLM_OLLAMA_MODELS }}
|
||||
openrouter_models: ${{ vars.NIGHTLY_LLM_OPENROUTER_MODELS }}
|
||||
strict: true
|
||||
secrets:
|
||||
openai_api_key: ${{ secrets.OPENAI_API_KEY }}
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
bedrock_api_key: ${{ secrets.BEDROCK_API_KEY }}
|
||||
vertex_ai_custom_config_json: ${{ secrets.NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON }}
|
||||
azure_api_key: ${{ secrets.AZURE_API_KEY }}
|
||||
ollama_api_key: ${{ secrets.OLLAMA_API_KEY }}
|
||||
openrouter_api_key: ${{ secrets.OPENROUTER_API_KEY }}
|
||||
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs: [provider-chat-test]
|
||||
|
||||
@@ -114,8 +114,10 @@ jobs:
|
||||
|
||||
- name: Mark workflow as failed if cherry-pick failed
|
||||
if: steps.gate.outputs.should_cherrypick == 'true' && steps.run_cherry_pick.outputs.status == 'failure'
|
||||
env:
|
||||
CHERRY_PICK_REASON: ${{ steps.run_cherry_pick.outputs.reason }}
|
||||
run: |
|
||||
echo "::error::Automated cherry-pick failed (${{ steps.run_cherry_pick.outputs.reason }})."
|
||||
echo "::error::Automated cherry-pick failed (${CHERRY_PICK_REASON})."
|
||||
exit 1
|
||||
|
||||
notify-slack-on-cherry-pick-failure:
|
||||
|
||||
@@ -160,7 +160,7 @@ jobs:
|
||||
cd deployment/docker_compose
|
||||
|
||||
# Get list of running containers
|
||||
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml -f docker-compose.opensearch.yml ps -q)
|
||||
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml ps -q)
|
||||
|
||||
# Collect logs from each container
|
||||
for container in $containers; do
|
||||
|
||||
21
.github/workflows/pr-python-checks.yml
vendored
21
.github/workflows/pr-python-checks.yml
vendored
@@ -8,7 +8,7 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
- "release/**"
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
@@ -21,7 +21,13 @@ jobs:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
# Note: Mypy seems quite optimized for x64 compared to arm64.
|
||||
# Similarly, mypy is single-threaded and incremental, so 2cpu is sufficient.
|
||||
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}-mypy-check", "extras=s3-cache"]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=2cpu-linux-x64,
|
||||
"run-id=${{ github.run_id }}-mypy-check",
|
||||
"extras=s3-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
|
||||
steps:
|
||||
@@ -52,21 +58,14 @@ jobs:
|
||||
if: ${{ vars.DISABLE_MYPY_CACHE != 'true' }}
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: backend/.mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
path: .mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'pyproject.toml') }}
|
||||
restore-keys: |
|
||||
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
|
||||
mypy-${{ runner.os }}-
|
||||
|
||||
- name: Run MyPy
|
||||
working-directory: ./backend
|
||||
env:
|
||||
MYPY_FORCE_COLOR: 1
|
||||
TERM: xterm-256color
|
||||
run: mypy .
|
||||
|
||||
- name: Run MyPy (tools/)
|
||||
env:
|
||||
MYPY_FORCE_COLOR: 1
|
||||
TERM: xterm-256color
|
||||
run: mypy tools/
|
||||
|
||||
@@ -48,28 +48,10 @@ on:
|
||||
required: false
|
||||
default: true
|
||||
type: boolean
|
||||
secrets:
|
||||
openai_api_key:
|
||||
required: false
|
||||
anthropic_api_key:
|
||||
required: false
|
||||
bedrock_api_key:
|
||||
required: false
|
||||
vertex_ai_custom_config_json:
|
||||
required: false
|
||||
azure_api_key:
|
||||
required: false
|
||||
ollama_api_key:
|
||||
required: false
|
||||
openrouter_api_key:
|
||||
required: false
|
||||
DOCKER_USERNAME:
|
||||
required: true
|
||||
DOCKER_TOKEN:
|
||||
required: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
|
||||
jobs:
|
||||
build-backend-image:
|
||||
@@ -81,6 +63,7 @@ jobs:
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
environment: ci-protected
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
@@ -89,6 +72,19 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, test/docker-username
|
||||
DOCKER_TOKEN, test/docker-token
|
||||
|
||||
- name: Build backend image
|
||||
uses: ./.github/actions/build-backend-image
|
||||
with:
|
||||
@@ -97,8 +93,8 @@ jobs:
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
docker-username: ${{ env.DOCKER_USERNAME }}
|
||||
docker-token: ${{ env.DOCKER_TOKEN }}
|
||||
docker-no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' && 'true' || 'false' }}
|
||||
|
||||
build-model-server-image:
|
||||
@@ -110,6 +106,7 @@ jobs:
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
environment: ci-protected
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
@@ -118,6 +115,19 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, test/docker-username
|
||||
DOCKER_TOKEN, test/docker-token
|
||||
|
||||
- name: Build model server image
|
||||
uses: ./.github/actions/build-model-server-image
|
||||
with:
|
||||
@@ -126,8 +136,8 @@ jobs:
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
docker-username: ${{ env.DOCKER_USERNAME }}
|
||||
docker-token: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
build-integration-image:
|
||||
runs-on:
|
||||
@@ -138,6 +148,7 @@ jobs:
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
environment: ci-protected
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
@@ -146,6 +157,19 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, test/docker-username
|
||||
DOCKER_TOKEN, test/docker-token
|
||||
|
||||
- name: Build integration image
|
||||
uses: ./.github/actions/build-integration-image
|
||||
with:
|
||||
@@ -154,8 +178,8 @@ jobs:
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
docker-username: ${{ env.DOCKER_USERNAME }}
|
||||
docker-token: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
provider-chat-test:
|
||||
needs:
|
||||
@@ -170,56 +194,56 @@ jobs:
|
||||
include:
|
||||
- provider: openai
|
||||
models: ${{ inputs.openai_models }}
|
||||
api_key_secret: openai_api_key
|
||||
custom_config_secret: ""
|
||||
api_key_env: OPENAI_API_KEY
|
||||
custom_config_env: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: true
|
||||
- provider: anthropic
|
||||
models: ${{ inputs.anthropic_models }}
|
||||
api_key_secret: anthropic_api_key
|
||||
custom_config_secret: ""
|
||||
api_key_env: ANTHROPIC_API_KEY
|
||||
custom_config_env: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: true
|
||||
- provider: bedrock
|
||||
models: ${{ inputs.bedrock_models }}
|
||||
api_key_secret: bedrock_api_key
|
||||
custom_config_secret: ""
|
||||
api_key_env: BEDROCK_API_KEY
|
||||
custom_config_env: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: vertex_ai
|
||||
models: ${{ inputs.vertex_ai_models }}
|
||||
api_key_secret: ""
|
||||
custom_config_secret: vertex_ai_custom_config_json
|
||||
api_key_env: ""
|
||||
custom_config_env: NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: azure
|
||||
models: ${{ inputs.azure_models }}
|
||||
api_key_secret: azure_api_key
|
||||
custom_config_secret: ""
|
||||
api_key_env: AZURE_API_KEY
|
||||
custom_config_env: ""
|
||||
api_base: ${{ inputs.azure_api_base }}
|
||||
api_version: "2025-04-01-preview"
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: ollama_chat
|
||||
models: ${{ inputs.ollama_models }}
|
||||
api_key_secret: ollama_api_key
|
||||
custom_config_secret: ""
|
||||
api_key_env: OLLAMA_API_KEY
|
||||
custom_config_env: ""
|
||||
api_base: "https://ollama.com"
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: openrouter
|
||||
models: ${{ inputs.openrouter_models }}
|
||||
api_key_secret: openrouter_api_key
|
||||
custom_config_secret: ""
|
||||
api_key_env: OPENROUTER_API_KEY
|
||||
custom_config_env: ""
|
||||
api_base: "https://openrouter.ai/api/v1"
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
@@ -230,6 +254,7 @@ jobs:
|
||||
- "run-id=${{ github.run_id }}-nightly-${{ matrix.provider }}-provider-chat-test"
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 45
|
||||
environment: ci-protected
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
@@ -238,21 +263,43 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
# Keep JSON values unparsed so vertex custom config is passed as raw JSON.
|
||||
parse-json-secrets: false
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, test/docker-username
|
||||
DOCKER_TOKEN, test/docker-token
|
||||
OPENAI_API_KEY, test/openai-api-key
|
||||
ANTHROPIC_API_KEY, test/anthropic-api-key
|
||||
BEDROCK_API_KEY, test/bedrock-api-key
|
||||
NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON, test/nightly-llm-vertex-ai-custom-config-json
|
||||
AZURE_API_KEY, test/azure-api-key
|
||||
OLLAMA_API_KEY, test/ollama-api-key
|
||||
OPENROUTER_API_KEY, test/openrouter-api-key
|
||||
|
||||
- name: Run nightly provider chat test
|
||||
uses: ./.github/actions/run-nightly-provider-chat-test
|
||||
with:
|
||||
provider: ${{ matrix.provider }}
|
||||
models: ${{ matrix.models }}
|
||||
provider-api-key: ${{ matrix.api_key_secret && secrets[matrix.api_key_secret] || '' }}
|
||||
provider-api-key: ${{ matrix.api_key_env && env[matrix.api_key_env] || '' }}
|
||||
strict: ${{ inputs.strict && 'true' || 'false' }}
|
||||
api-base: ${{ matrix.api_base }}
|
||||
api-version: ${{ matrix.api_version }}
|
||||
deployment-name: ${{ matrix.deployment_name }}
|
||||
custom-config-json: ${{ matrix.custom_config_secret && secrets[matrix.custom_config_secret] || '' }}
|
||||
custom-config-json: ${{ matrix.custom_config_env && env[matrix.custom_config_env] || '' }}
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
docker-username: ${{ env.DOCKER_USERNAME }}
|
||||
docker-token: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Dump API server logs
|
||||
if: always()
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
"""add cache_store table
|
||||
|
||||
Revision ID: 2664261bfaab
|
||||
Revises: 4a1e4b1c89d2
|
||||
Create Date: 2026-02-27 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2664261bfaab"
|
||||
down_revision = "4a1e4b1c89d2"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"cache_store",
|
||||
sa.Column("key", sa.String(), nullable=False),
|
||||
sa.Column("value", sa.LargeBinary(), nullable=True),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint("key"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_cache_store_expires",
|
||||
"cache_store",
|
||||
["expires_at"],
|
||||
postgresql_where=sa.text("expires_at IS NOT NULL"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_cache_store_expires", table_name="cache_store")
|
||||
op.drop_table("cache_store")
|
||||
@@ -0,0 +1,51 @@
|
||||
"""Add INDEXING to UserFileStatus
|
||||
|
||||
Revision ID: 4a1e4b1c89d2
|
||||
Revises: 6b3b4083c5aa
|
||||
Create Date: 2026-02-28 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = "4a1e4b1c89d2"
|
||||
down_revision = "6b3b4083c5aa"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
TABLE = "user_file"
|
||||
COLUMN = "status"
|
||||
CONSTRAINT_NAME = "ck_user_file_status"
|
||||
|
||||
OLD_VALUES = ("PROCESSING", "COMPLETED", "FAILED", "CANCELED", "DELETING")
|
||||
NEW_VALUES = ("PROCESSING", "INDEXING", "COMPLETED", "FAILED", "CANCELED", "DELETING")
|
||||
|
||||
|
||||
def _drop_status_check_constraint() -> None:
|
||||
"""Drop the existing CHECK constraint on user_file.status.
|
||||
|
||||
The constraint name is auto-generated by SQLAlchemy and unknown,
|
||||
so we look it up via the inspector.
|
||||
"""
|
||||
inspector = sa.inspect(op.get_bind())
|
||||
for constraint in inspector.get_check_constraints(TABLE):
|
||||
if COLUMN in constraint.get("sqltext", ""):
|
||||
constraint_name = constraint["name"]
|
||||
if constraint_name is not None:
|
||||
op.drop_constraint(constraint_name, TABLE, type_="check")
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
_drop_status_check_constraint()
|
||||
in_clause = ", ".join(f"'{v}'" for v in NEW_VALUES)
|
||||
op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
f"UPDATE {TABLE} SET {COLUMN} = 'PROCESSING' WHERE {COLUMN} = 'INDEXING'"
|
||||
)
|
||||
op.drop_constraint(CONSTRAINT_NAME, TABLE, type_="check")
|
||||
in_clause = ", ".join(f"'{v}'" for v in OLD_VALUES)
|
||||
op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})")
|
||||
@@ -0,0 +1,112 @@
|
||||
"""persona cleanup and featured
|
||||
|
||||
Revision ID: 6b3b4083c5aa
|
||||
Revises: 57122d037335
|
||||
Create Date: 2026-02-26 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "6b3b4083c5aa"
|
||||
down_revision = "57122d037335"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add featured column with nullable=True first
|
||||
op.add_column("persona", sa.Column("featured", sa.Boolean(), nullable=True))
|
||||
|
||||
# Migrate data from is_default_persona to featured
|
||||
op.execute("UPDATE persona SET featured = is_default_persona")
|
||||
|
||||
# Make featured non-nullable with default=False
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"featured",
|
||||
existing_type=sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
)
|
||||
|
||||
# Drop is_default_persona column
|
||||
op.drop_column("persona", "is_default_persona")
|
||||
|
||||
# Drop unused columns
|
||||
op.drop_column("persona", "num_chunks")
|
||||
op.drop_column("persona", "chunks_above")
|
||||
op.drop_column("persona", "chunks_below")
|
||||
op.drop_column("persona", "llm_relevance_filter")
|
||||
op.drop_column("persona", "llm_filter_extraction")
|
||||
op.drop_column("persona", "recency_bias")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Add back recency_bias column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"recency_bias",
|
||||
sa.VARCHAR(),
|
||||
nullable=False,
|
||||
server_default="base_decay",
|
||||
),
|
||||
)
|
||||
|
||||
# Add back llm_filter_extraction column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"llm_filter_extraction",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
# Add back llm_relevance_filter column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"llm_relevance_filter",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
# Add back chunks_below column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("chunks_below", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
|
||||
# Add back chunks_above column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("chunks_above", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
|
||||
# Add back num_chunks column
|
||||
op.add_column("persona", sa.Column("num_chunks", sa.Float(), nullable=True))
|
||||
|
||||
# Add back is_default_persona column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"is_default_persona",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
# Migrate data from featured to is_default_persona
|
||||
op.execute("UPDATE persona SET is_default_persona = featured")
|
||||
|
||||
# Drop featured column
|
||||
op.drop_column("persona", "featured")
|
||||
@@ -0,0 +1,34 @@
|
||||
"""make scim_user_mapping.external_id nullable
|
||||
|
||||
Revision ID: a3b8d9e2f1c4
|
||||
Revises: 2664261bfaab
|
||||
Create Date: 2026-03-02
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a3b8d9e2f1c4"
|
||||
down_revision = "2664261bfaab"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column(
|
||||
"scim_user_mapping",
|
||||
"external_id",
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Delete any rows where external_id is NULL before re-applying NOT NULL
|
||||
op.execute("DELETE FROM scim_user_mapping WHERE external_id IS NULL")
|
||||
op.alter_column(
|
||||
"scim_user_mapping",
|
||||
"external_id",
|
||||
nullable=False,
|
||||
)
|
||||
@@ -126,12 +126,16 @@ class ScimDAL(DAL):
|
||||
|
||||
def create_user_mapping(
|
||||
self,
|
||||
external_id: str,
|
||||
external_id: str | None,
|
||||
user_id: UUID,
|
||||
scim_username: str | None = None,
|
||||
fields: ScimMappingFields | None = None,
|
||||
) -> ScimUserMapping:
|
||||
"""Create a mapping between a SCIM externalId and an Onyx user."""
|
||||
"""Create a SCIM mapping for a user.
|
||||
|
||||
``external_id`` may be ``None`` when the IdP omits it (RFC 7643
|
||||
allows this). The mapping still marks the user as SCIM-managed.
|
||||
"""
|
||||
f = fields or ScimMappingFields()
|
||||
mapping = ScimUserMapping(
|
||||
external_id=external_id,
|
||||
@@ -270,8 +274,13 @@ class ScimDAL(DAL):
|
||||
Raises:
|
||||
ValueError: If the filter uses an unsupported attribute.
|
||||
"""
|
||||
query = select(User).where(
|
||||
User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER])
|
||||
# Inner-join with ScimUserMapping so only SCIM-managed users appear.
|
||||
# Pre-existing system accounts (anonymous, admin, etc.) are excluded
|
||||
# unless they were explicitly linked via SCIM provisioning.
|
||||
query = (
|
||||
select(User)
|
||||
.join(ScimUserMapping, ScimUserMapping.user_id == User.id)
|
||||
.where(User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER]))
|
||||
)
|
||||
|
||||
if scim_filter:
|
||||
@@ -321,34 +330,37 @@ class ScimDAL(DAL):
|
||||
scim_username: str | None = None,
|
||||
fields: ScimMappingFields | None = None,
|
||||
) -> None:
|
||||
"""Create, update, or delete the external ID mapping for a user.
|
||||
"""Sync the SCIM mapping for a user.
|
||||
|
||||
If a mapping already exists, its fields are updated (including
|
||||
setting ``external_id`` to ``None`` when the IdP omits it).
|
||||
If no mapping exists and ``new_external_id`` is provided, a new
|
||||
mapping is created. A mapping is never deleted here — SCIM-managed
|
||||
users must retain their mapping to remain visible in ``GET /Users``.
|
||||
|
||||
When *fields* is provided, all mapping fields are written
|
||||
unconditionally — including ``None`` values — so that a caller can
|
||||
clear a previously-set field (e.g. removing a department).
|
||||
"""
|
||||
mapping = self.get_user_mapping_by_user_id(user_id)
|
||||
if new_external_id:
|
||||
if mapping:
|
||||
if mapping.external_id != new_external_id:
|
||||
mapping.external_id = new_external_id
|
||||
if scim_username is not None:
|
||||
mapping.scim_username = scim_username
|
||||
if fields is not None:
|
||||
mapping.department = fields.department
|
||||
mapping.manager = fields.manager
|
||||
mapping.given_name = fields.given_name
|
||||
mapping.family_name = fields.family_name
|
||||
mapping.scim_emails_json = fields.scim_emails_json
|
||||
else:
|
||||
self.create_user_mapping(
|
||||
external_id=new_external_id,
|
||||
user_id=user_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
elif mapping:
|
||||
self.delete_user_mapping(mapping.id)
|
||||
if mapping:
|
||||
if mapping.external_id != new_external_id:
|
||||
mapping.external_id = new_external_id
|
||||
if scim_username is not None:
|
||||
mapping.scim_username = scim_username
|
||||
if fields is not None:
|
||||
mapping.department = fields.department
|
||||
mapping.manager = fields.manager
|
||||
mapping.given_name = fields.given_name
|
||||
mapping.family_name = fields.family_name
|
||||
mapping.scim_emails_json = fields.scim_emails_json
|
||||
elif new_external_id:
|
||||
self.create_user_mapping(
|
||||
external_id=new_external_id,
|
||||
user_id=user_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
def _get_user_mappings_batch(
|
||||
self, user_ids: list[UUID]
|
||||
|
||||
@@ -31,6 +31,7 @@ from ee.onyx.server.query_and_chat.query_backend import (
|
||||
from ee.onyx.server.query_and_chat.search_backend import router as search_router
|
||||
from ee.onyx.server.query_history.api import router as query_history_router
|
||||
from ee.onyx.server.reporting.usage_export_api import router as usage_export_router
|
||||
from ee.onyx.server.scim.api import register_scim_exception_handlers
|
||||
from ee.onyx.server.scim.api import scim_router
|
||||
from ee.onyx.server.seeding import seed_db
|
||||
from ee.onyx.server.tenants.api import router as tenants_router
|
||||
@@ -167,6 +168,7 @@ def get_application() -> FastAPI:
|
||||
# they use their own SCIM bearer token auth).
|
||||
# Not behind APP_API_PREFIX because IdPs expect /scim/v2/... directly.
|
||||
application.include_router(scim_router)
|
||||
register_scim_exception_handlers(application)
|
||||
|
||||
# Ensure all routes have auth enabled or are explicitly marked as public
|
||||
check_ee_router_auth(application)
|
||||
|
||||
@@ -15,7 +15,9 @@ from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import FastAPI
|
||||
from fastapi import Query
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi_users.password import PasswordHelper
|
||||
@@ -24,6 +26,7 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.scim import ScimDAL
|
||||
from ee.onyx.server.scim.auth import ScimAuthError
|
||||
from ee.onyx.server.scim.auth import verify_scim_token
|
||||
from ee.onyx.server.scim.filtering import parse_scim_filter
|
||||
from ee.onyx.server.scim.models import SCIM_LIST_RESPONSE_SCHEMA
|
||||
@@ -77,6 +80,22 @@ scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
|
||||
_pw_helper = PasswordHelper()
|
||||
|
||||
|
||||
def register_scim_exception_handlers(app: FastAPI) -> None:
|
||||
"""Register SCIM-specific exception handlers on the FastAPI app.
|
||||
|
||||
Call this after ``app.include_router(scim_router)`` so that auth
|
||||
failures from ``verify_scim_token`` return RFC 7644 §3.12 error
|
||||
envelopes (with ``schemas`` and ``status`` fields) instead of
|
||||
FastAPI's default ``{"detail": "..."}`` format.
|
||||
"""
|
||||
|
||||
@app.exception_handler(ScimAuthError)
|
||||
async def _handle_scim_auth_error(
|
||||
_request: Request, exc: ScimAuthError
|
||||
) -> ScimJSONResponse:
|
||||
return _scim_error_response(exc.status_code, exc.detail)
|
||||
|
||||
|
||||
def _get_provider(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
) -> ScimProvider:
|
||||
@@ -404,21 +423,63 @@ def create_user(
|
||||
|
||||
email = user_resource.userName.strip()
|
||||
|
||||
# externalId is how the IdP correlates this user on subsequent requests.
|
||||
# Without it, the IdP can't find the user and will try to re-create,
|
||||
# hitting a 409 conflict — so we require it up front.
|
||||
if not user_resource.externalId:
|
||||
return _scim_error_response(400, "externalId is required")
|
||||
# Check for existing user — if they exist but aren't SCIM-managed yet,
|
||||
# link them to the IdP rather than rejecting with 409.
|
||||
external_id: str | None = user_resource.externalId
|
||||
scim_username: str = user_resource.userName.strip()
|
||||
fields: ScimMappingFields = _fields_from_resource(user_resource)
|
||||
|
||||
# Enforce seat limit
|
||||
existing_user = dal.get_user_by_email(email)
|
||||
if existing_user:
|
||||
existing_mapping = dal.get_user_mapping_by_user_id(existing_user.id)
|
||||
if existing_mapping:
|
||||
return _scim_error_response(409, f"User with email {email} already exists")
|
||||
|
||||
# Adopt pre-existing user into SCIM management.
|
||||
# Reactivating a deactivated user consumes a seat, so enforce the
|
||||
# seat limit the same way replace_user does.
|
||||
if user_resource.active and not existing_user.is_active:
|
||||
seat_error = _check_seat_availability(dal)
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
|
||||
personal_name = _scim_name_to_str(user_resource.name)
|
||||
dal.update_user(
|
||||
existing_user,
|
||||
is_active=user_resource.active,
|
||||
**({"personal_name": personal_name} if personal_name else {}),
|
||||
)
|
||||
|
||||
try:
|
||||
dal.create_user_mapping(
|
||||
external_id=external_id,
|
||||
user_id=existing_user.id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
dal.commit()
|
||||
except IntegrityError:
|
||||
dal.rollback()
|
||||
return _scim_error_response(
|
||||
409, f"User with email {email} already has a SCIM mapping"
|
||||
)
|
||||
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
existing_user,
|
||||
external_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
),
|
||||
status_code=201,
|
||||
)
|
||||
|
||||
# Only enforce seat limit for net-new users — adopting a pre-existing
|
||||
# user doesn't consume a new seat.
|
||||
seat_error = _check_seat_availability(dal)
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
|
||||
# Check for existing user
|
||||
if dal.get_user_by_email(email):
|
||||
return _scim_error_response(409, f"User with email {email} already exists")
|
||||
|
||||
# Create user with a random password (SCIM users authenticate via IdP)
|
||||
personal_name = _scim_name_to_str(user_resource.name)
|
||||
user = User(
|
||||
@@ -436,18 +497,21 @@ def create_user(
|
||||
dal.rollback()
|
||||
return _scim_error_response(409, f"User with email {email} already exists")
|
||||
|
||||
# Create SCIM mapping (externalId is validated above, always present)
|
||||
external_id = user_resource.externalId
|
||||
scim_username = user_resource.userName.strip()
|
||||
fields = _fields_from_resource(user_resource)
|
||||
dal.create_user_mapping(
|
||||
external_id=external_id,
|
||||
user_id=user.id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
# Always create a SCIM mapping so that the user is marked as
|
||||
# SCIM-managed. externalId may be None (RFC 7643 says it's optional).
|
||||
try:
|
||||
dal.create_user_mapping(
|
||||
external_id=external_id,
|
||||
user_id=user.id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
dal.commit()
|
||||
except IntegrityError:
|
||||
dal.rollback()
|
||||
return _scim_error_response(
|
||||
409, f"User with email {email} already has a SCIM mapping"
|
||||
)
|
||||
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
|
||||
@@ -19,7 +19,6 @@ import hashlib
|
||||
import secrets
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -28,6 +27,21 @@ from onyx.auth.utils import get_hashed_bearer_token_from_request
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import ScimToken
|
||||
|
||||
|
||||
class ScimAuthError(Exception):
|
||||
"""Raised when SCIM bearer token authentication fails.
|
||||
|
||||
Unlike HTTPException, this carries the status and detail so the SCIM
|
||||
exception handler can wrap them in an RFC 7644 §3.12 error envelope
|
||||
with ``schemas`` and ``status`` fields.
|
||||
"""
|
||||
|
||||
def __init__(self, status_code: int, detail: str) -> None:
|
||||
self.status_code = status_code
|
||||
self.detail = detail
|
||||
super().__init__(detail)
|
||||
|
||||
|
||||
SCIM_TOKEN_PREFIX = "onyx_scim_"
|
||||
SCIM_TOKEN_LENGTH = 48
|
||||
|
||||
@@ -82,23 +96,14 @@ def verify_scim_token(
|
||||
"""
|
||||
hashed = _get_hashed_scim_token_from_request(request)
|
||||
if not hashed:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing or invalid SCIM bearer token",
|
||||
)
|
||||
raise ScimAuthError(401, "Missing or invalid SCIM bearer token")
|
||||
|
||||
token = dal.get_token_by_hash(hashed)
|
||||
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid SCIM bearer token",
|
||||
)
|
||||
raise ScimAuthError(401, "Invalid SCIM bearer token")
|
||||
|
||||
if not token.is_active:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="SCIM token has been revoked",
|
||||
)
|
||||
raise ScimAuthError(401, "SCIM token has been revoked")
|
||||
|
||||
return token
|
||||
|
||||
@@ -153,26 +153,31 @@ class ScimProvider(ABC):
|
||||
self,
|
||||
user: User,
|
||||
fields: ScimMappingFields,
|
||||
) -> ScimName | None:
|
||||
) -> ScimName:
|
||||
"""Build SCIM name components for the response.
|
||||
|
||||
Round-trips stored ``given_name``/``family_name`` when available (so
|
||||
the IdP gets back what it sent). Falls back to splitting
|
||||
``personal_name`` for users provisioned before we stored components.
|
||||
Always returns a ScimName — Okta's spec tests expect ``name``
|
||||
(with ``givenName``/``familyName``) on every user resource.
|
||||
Providers may override for custom behavior.
|
||||
"""
|
||||
if fields.given_name is not None or fields.family_name is not None:
|
||||
return ScimName(
|
||||
givenName=fields.given_name,
|
||||
familyName=fields.family_name,
|
||||
formatted=user.personal_name,
|
||||
givenName=fields.given_name or "",
|
||||
familyName=fields.family_name or "",
|
||||
formatted=user.personal_name or "",
|
||||
)
|
||||
if not user.personal_name:
|
||||
return None
|
||||
# Derive a reasonable name from the email so that SCIM spec tests
|
||||
# see non-empty givenName / familyName for every user resource.
|
||||
local = user.email.split("@")[0] if user.email else ""
|
||||
return ScimName(givenName=local, familyName="", formatted=local)
|
||||
parts = user.personal_name.split(" ", 1)
|
||||
return ScimName(
|
||||
givenName=parts[0],
|
||||
familyName=parts[1] if len(parts) > 1 else None,
|
||||
familyName=parts[1] if len(parts) > 1 else "",
|
||||
formatted=user.personal_name,
|
||||
)
|
||||
|
||||
|
||||
@@ -18,8 +18,8 @@ from ee.onyx.server.enterprise_settings.store import (
|
||||
store_settings as store_ee_settings,
|
||||
)
|
||||
from ee.onyx.server.enterprise_settings.store import upload_logo
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import Tool
|
||||
@@ -117,15 +117,38 @@ def _seed_custom_tools(db_session: Session, tools: List[CustomToolSeed]) -> None
|
||||
def _seed_llms(
|
||||
db_session: Session, llm_upsert_requests: list[LLMProviderUpsertRequest]
|
||||
) -> None:
|
||||
if llm_upsert_requests:
|
||||
logger.notice("Seeding LLMs")
|
||||
seeded_providers = [
|
||||
upsert_llm_provider(llm_upsert_request, db_session)
|
||||
for llm_upsert_request in llm_upsert_requests
|
||||
]
|
||||
update_default_provider(
|
||||
provider_id=seeded_providers[0].id, db_session=db_session
|
||||
)
|
||||
if not llm_upsert_requests:
|
||||
return
|
||||
|
||||
logger.notice("Seeding LLMs")
|
||||
for request in llm_upsert_requests:
|
||||
existing = fetch_existing_llm_provider(name=request.name, db_session=db_session)
|
||||
if existing:
|
||||
request.id = existing.id
|
||||
seeded_providers = [
|
||||
upsert_llm_provider(llm_upsert_request, db_session)
|
||||
for llm_upsert_request in llm_upsert_requests
|
||||
]
|
||||
|
||||
default_provider = next(
|
||||
(p for p in seeded_providers if p.model_configurations), None
|
||||
)
|
||||
if not default_provider:
|
||||
return
|
||||
|
||||
visible_configs = [
|
||||
mc for mc in default_provider.model_configurations if mc.is_visible
|
||||
]
|
||||
default_config = (
|
||||
visible_configs[0]
|
||||
if visible_configs
|
||||
else default_provider.model_configurations[0]
|
||||
)
|
||||
update_default_provider(
|
||||
provider_id=default_provider.id,
|
||||
model_name=default_config.name,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:
|
||||
@@ -137,12 +160,6 @@ def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) ->
|
||||
user=None, # Seeding is done as admin
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=(
|
||||
persona.num_chunks if persona.num_chunks is not None else 0.0
|
||||
),
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
document_set_ids=persona.document_set_ids,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
@@ -154,6 +171,7 @@ def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) ->
|
||||
system_prompt=persona.system_prompt,
|
||||
task_prompt=persona.task_prompt,
|
||||
datetime_aware=persona.datetime_aware,
|
||||
featured=persona.featured,
|
||||
commit=False,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
@@ -109,6 +109,12 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
if metadata.status == _BLOCKING_STATUS:
|
||||
settings.application_status = metadata.status
|
||||
settings.ee_features_enabled = False
|
||||
elif metadata.used_seats > metadata.seats:
|
||||
# License is valid but seat limit exceeded
|
||||
settings.application_status = ApplicationStatus.SEAT_LIMIT_EXCEEDED
|
||||
settings.seat_count = metadata.seats
|
||||
settings.used_seats = metadata.used_seats
|
||||
settings.ee_features_enabled = True
|
||||
else:
|
||||
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
|
||||
settings.ee_features_enabled = True
|
||||
|
||||
@@ -33,6 +33,7 @@ from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.engine.sql_engine import get_session_with_shared_schema
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.image_generation import create_default_image_gen_config_from_api_key
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_cloud_embedding_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
@@ -302,12 +303,17 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
|
||||
has_set_default_provider = False
|
||||
|
||||
def _upsert(request: LLMProviderUpsertRequest) -> None:
|
||||
def _upsert(request: LLMProviderUpsertRequest, default_model: str) -> None:
|
||||
nonlocal has_set_default_provider
|
||||
try:
|
||||
existing = fetch_existing_llm_provider(
|
||||
name=request.name, db_session=db_session
|
||||
)
|
||||
if existing:
|
||||
request.id = existing.id
|
||||
provider = upsert_llm_provider(request, db_session)
|
||||
if not has_set_default_provider:
|
||||
update_default_provider(provider.id, db_session)
|
||||
update_default_provider(provider.id, default_model, db_session)
|
||||
has_set_default_provider = True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure {request.provider} provider: {e}")
|
||||
@@ -325,14 +331,13 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
OPENAI_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(openai_provider)
|
||||
_upsert(openai_provider, default_model_name)
|
||||
|
||||
# Create default image generation config using the OpenAI API key
|
||||
try:
|
||||
@@ -361,14 +366,13 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="Anthropic",
|
||||
provider=ANTHROPIC_PROVIDER_NAME,
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
ANTHROPIC_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(anthropic_provider)
|
||||
_upsert(anthropic_provider, default_model_name)
|
||||
else:
|
||||
logger.info(
|
||||
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
|
||||
@@ -393,14 +397,13 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="Google Vertex AI",
|
||||
provider=VERTEXAI_PROVIDER_NAME,
|
||||
custom_config=custom_config,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
VERTEXAI_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(vertexai_provider)
|
||||
_upsert(vertexai_provider, default_model_name)
|
||||
else:
|
||||
logger.info(
|
||||
"VERTEXAI_DEFAULT_CREDENTIALS not set, skipping Vertex AI provider configuration"
|
||||
@@ -432,12 +435,11 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="OpenRouter",
|
||||
provider=OPENROUTER_PROVIDER_NAME,
|
||||
api_key=OPENROUTER_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=model_configurations,
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(openrouter_provider)
|
||||
_upsert(openrouter_provider, default_model_name)
|
||||
else:
|
||||
logger.info(
|
||||
"OPENROUTER_DEFAULT_API_KEY not set, skipping OpenRouter provider configuration"
|
||||
|
||||
@@ -543,7 +543,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
result = await db_session.execute(
|
||||
select(Persona.id)
|
||||
.where(
|
||||
Persona.is_default_persona.is_(True),
|
||||
Persona.featured.is_(True),
|
||||
Persona.is_public.is_(True),
|
||||
Persona.is_visible.is_(True),
|
||||
Persona.deleted.is_(False),
|
||||
@@ -725,11 +725,19 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
if user_by_session:
|
||||
user = user_by_session
|
||||
|
||||
# If the user is inactive, check seat availability before
|
||||
# upgrading role — otherwise they'd become an inactive BASIC
|
||||
# user who still can't log in.
|
||||
if not user.is_active:
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
enforce_seat_limit(sync_db)
|
||||
|
||||
await self.user_db.update(
|
||||
user,
|
||||
{
|
||||
"is_verified": is_verified_by_default,
|
||||
"role": UserRole.BASIC,
|
||||
**({"is_active": True} if not user.is_active else {}),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -241,8 +241,7 @@ _VECTOR_DB_BEAT_TASK_NAMES: set[str] = {
|
||||
"check-for-index-attempt-cleanup",
|
||||
"check-for-doc-permissions-sync",
|
||||
"check-for-external-group-sync",
|
||||
"check-for-documents-for-opensearch-migration",
|
||||
"migrate-documents-from-vespa-to-opensearch",
|
||||
"migrate-chunks-from-vespa-to-opensearch",
|
||||
}
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
|
||||
@@ -414,34 +414,31 @@ def _process_user_file_with_indexing(
|
||||
raise RuntimeError(f"Indexing pipeline failed for user file {user_file_id}")
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
def process_user_file_impl(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool
|
||||
) -> None:
|
||||
task_logger.info(f"process_single_user_file - Starting id={user_file_id}")
|
||||
"""Core implementation for processing a single user file.
|
||||
|
||||
When redis_locking=True, acquires a per-file Redis lock and clears the
|
||||
queued-key guard (Celery path). When redis_locking=False, skips all Redis
|
||||
operations (BackgroundTask path).
|
||||
"""
|
||||
task_logger.info(f"process_user_file_impl - Starting id={user_file_id}")
|
||||
start = time.monotonic()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Clear the "queued" guard set by the beat generator so that the next beat
|
||||
# cycle can re-enqueue this file if it is still in PROCESSING state after
|
||||
# this task completes or fails.
|
||||
redis_client.delete(_user_file_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
if not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"process_single_user_file - Lock held, skipping user_file_id={user_file_id}"
|
||||
file_lock: RedisLock | None = None
|
||||
if redis_locking:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client.delete(_user_file_queued_key(user_file_id))
|
||||
file_lock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
|
||||
)
|
||||
return None
|
||||
if file_lock is not None and not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"process_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
return
|
||||
|
||||
documents: list[Document] = []
|
||||
try:
|
||||
@@ -449,15 +446,18 @@ def process_single_user_file(
|
||||
uf = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if not uf:
|
||||
task_logger.warning(
|
||||
f"process_single_user_file - UserFile not found id={user_file_id}"
|
||||
f"process_user_file_impl - UserFile not found id={user_file_id}"
|
||||
)
|
||||
return None
|
||||
return
|
||||
|
||||
if uf.status != UserFileStatus.PROCESSING:
|
||||
if uf.status not in (
|
||||
UserFileStatus.PROCESSING,
|
||||
UserFileStatus.INDEXING,
|
||||
):
|
||||
task_logger.info(
|
||||
f"process_single_user_file - Skipping id={user_file_id} status={uf.status}"
|
||||
f"process_user_file_impl - Skipping id={user_file_id} status={uf.status}"
|
||||
)
|
||||
return None
|
||||
return
|
||||
|
||||
connector = LocalFileConnector(
|
||||
file_locations=[uf.file_id],
|
||||
@@ -471,7 +471,6 @@ def process_single_user_file(
|
||||
[doc for doc in batch if not isinstance(doc, HierarchyNode)]
|
||||
)
|
||||
|
||||
# update the document id to userfile id in the documents
|
||||
for document in documents:
|
||||
document.id = str(user_file_id)
|
||||
document.source = DocumentSource.USER_FILE
|
||||
@@ -493,9 +492,8 @@ def process_single_user_file(
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"process_single_user_file - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
f"process_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
# don't update the status if the user file is being deleted
|
||||
current_user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if (
|
||||
current_user_file
|
||||
@@ -504,33 +502,43 @@ def process_single_user_file(
|
||||
uf.status = UserFileStatus.FAILED
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
return None
|
||||
return
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"process_single_user_file - Finished id={user_file_id} docs={len(documents)} elapsed={elapsed:.2f}s"
|
||||
f"process_user_file_impl - Finished id={user_file_id} docs={len(documents)} elapsed={elapsed:.2f}s"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
# Attempt to mark the file as failed
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
uf = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if uf:
|
||||
# don't update the status if the user file is being deleted
|
||||
if uf.status != UserFileStatus.DELETING:
|
||||
uf.status = UserFileStatus.FAILED
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
|
||||
task_logger.exception(
|
||||
f"process_single_user_file - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
f"process_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
return None
|
||||
raise
|
||||
finally:
|
||||
if file_lock.owned():
|
||||
if file_lock is not None and file_lock.owned():
|
||||
file_lock.release()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
process_user_file_impl(
|
||||
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
|
||||
)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_USER_FILE_DELETE,
|
||||
soft_time_limit=300,
|
||||
@@ -581,36 +589,38 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file_delete(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
def delete_user_file_impl(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool
|
||||
) -> None:
|
||||
"""Process a single user file delete."""
|
||||
task_logger.info(f"process_single_user_file_delete - Starting id={user_file_id}")
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_delete_lock_key(user_file_id),
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
if not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"process_single_user_file_delete - Lock held, skipping user_file_id={user_file_id}"
|
||||
"""Core implementation for deleting a single user file.
|
||||
|
||||
When redis_locking=True, acquires a per-file Redis lock (Celery path).
|
||||
When redis_locking=False, skips Redis operations (BackgroundTask path).
|
||||
"""
|
||||
task_logger.info(f"delete_user_file_impl - Starting id={user_file_id}")
|
||||
|
||||
file_lock: RedisLock | None = None
|
||||
if redis_locking:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
file_lock = redis_client.lock(
|
||||
_user_file_delete_lock_key(user_file_id),
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
return None
|
||||
if file_lock is not None and not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"delete_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if not user_file:
|
||||
task_logger.info(
|
||||
f"process_single_user_file_delete - User file not found id={user_file_id}"
|
||||
f"delete_user_file_impl - User file not found id={user_file_id}"
|
||||
)
|
||||
return None
|
||||
return
|
||||
|
||||
# 1) Delete vector DB chunks (skip when disabled)
|
||||
if not DISABLE_VECTOR_DB:
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
@@ -648,7 +658,6 @@ def process_single_user_file_delete(
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
|
||||
# 2) Delete the user-uploaded file content from filestore (blob + metadata)
|
||||
file_store = get_default_file_store()
|
||||
try:
|
||||
file_store.delete_file(user_file.file_id)
|
||||
@@ -656,26 +665,34 @@ def process_single_user_file_delete(
|
||||
user_file_id_to_plaintext_file_name(user_file.id)
|
||||
)
|
||||
except Exception as e:
|
||||
# This block executed only if the file is not found in the filestore
|
||||
task_logger.exception(
|
||||
f"process_single_user_file_delete - Error deleting file id={user_file.id} - {e.__class__.__name__}"
|
||||
f"delete_user_file_impl - Error deleting file id={user_file.id} - {e.__class__.__name__}"
|
||||
)
|
||||
|
||||
# 3) Finally, delete the UserFile row
|
||||
db_session.delete(user_file)
|
||||
db_session.commit()
|
||||
task_logger.info(
|
||||
f"process_single_user_file_delete - Completed id={user_file_id}"
|
||||
)
|
||||
task_logger.info(f"delete_user_file_impl - Completed id={user_file_id}")
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"process_single_user_file_delete - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
f"delete_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
return None
|
||||
raise
|
||||
finally:
|
||||
if file_lock.owned():
|
||||
if file_lock is not None and file_lock.owned():
|
||||
file_lock.release()
|
||||
return None
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file_delete(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
delete_user_file_impl(
|
||||
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
|
||||
)
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -747,32 +764,30 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file_project_sync(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
def project_sync_user_file_impl(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool
|
||||
) -> None:
|
||||
"""Process a single user file project sync."""
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - Starting id={user_file_id}"
|
||||
)
|
||||
"""Core implementation for syncing a user file's project/persona metadata.
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
|
||||
When redis_locking=True, acquires a per-file Redis lock and clears the
|
||||
queued-key guard (Celery path). When redis_locking=False, skips Redis
|
||||
operations (BackgroundTask path).
|
||||
"""
|
||||
task_logger.info(f"project_sync_user_file_impl - Starting id={user_file_id}")
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
user_file_project_sync_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
if not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - Lock held, skipping user_file_id={user_file_id}"
|
||||
file_lock: RedisLock | None = None
|
||||
if redis_locking:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
|
||||
file_lock = redis_client.lock(
|
||||
user_file_project_sync_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
return None
|
||||
if file_lock is not None and not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"project_sync_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -783,11 +798,10 @@ def process_single_user_file_project_sync(
|
||||
).scalar_one_or_none()
|
||||
if not user_file:
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - User file not found id={user_file_id}"
|
||||
f"project_sync_user_file_impl - User file not found id={user_file_id}"
|
||||
)
|
||||
return None
|
||||
return
|
||||
|
||||
# Sync project metadata to vector DB (skip when disabled)
|
||||
if not DISABLE_VECTOR_DB:
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
@@ -822,7 +836,7 @@ def process_single_user_file_project_sync(
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - User file id={user_file_id}"
|
||||
f"project_sync_user_file_impl - User file id={user_file_id}"
|
||||
)
|
||||
|
||||
user_file.needs_project_sync = False
|
||||
@@ -835,11 +849,22 @@ def process_single_user_file_project_sync(
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"process_single_user_file_project_sync - Error syncing project for file id={user_file_id} - {e.__class__.__name__}"
|
||||
f"project_sync_user_file_impl - Error syncing project for file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
return None
|
||||
raise
|
||||
finally:
|
||||
if file_lock.owned():
|
||||
if file_lock is not None and file_lock.owned():
|
||||
file_lock.release()
|
||||
|
||||
return None
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file_project_sync(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
|
||||
)
|
||||
|
||||
307
backend/onyx/background/periodic_poller.py
Normal file
307
backend/onyx/background/periodic_poller.py
Normal file
@@ -0,0 +1,307 @@
|
||||
"""Periodic poller for NO_VECTOR_DB deployments.
|
||||
|
||||
Replaces Celery Beat and background workers with a lightweight daemon thread
|
||||
that runs from the API server process. Two responsibilities:
|
||||
|
||||
1. Recovery polling (every 30 s): re-processes user files stuck in
|
||||
PROCESSING / DELETING / needs_sync states via the drain loops defined
|
||||
in ``task_utils.py``.
|
||||
|
||||
2. Periodic task execution (configurable intervals): runs LLM model updates
|
||||
and scheduled evals at their configured cadences, with Postgres advisory
|
||||
lock deduplication across multiple API server instances.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
RECOVERY_INTERVAL_SECONDS = 30
|
||||
PERIODIC_TASK_LOCK_BASE = 20_000
|
||||
PERIODIC_TASK_KV_PREFIX = "periodic_poller:last_claimed:"
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Periodic task definitions
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
_NEVER_RAN: float = -1e18
|
||||
|
||||
|
||||
@dataclass
|
||||
class _PeriodicTaskDef:
|
||||
name: str
|
||||
interval_seconds: float
|
||||
lock_id: int
|
||||
run_fn: Callable[[], None]
|
||||
last_run_at: float = field(default=_NEVER_RAN)
|
||||
|
||||
|
||||
def _run_auto_llm_update() -> None:
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
|
||||
if not AUTO_LLM_CONFIG_URL:
|
||||
return
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.llm.well_known_providers.auto_update_service import (
|
||||
sync_llm_models_from_github,
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
sync_llm_models_from_github(db_session)
|
||||
|
||||
|
||||
def _run_cache_cleanup() -> None:
|
||||
from onyx.cache.postgres_backend import cleanup_expired_cache_entries
|
||||
|
||||
cleanup_expired_cache_entries()
|
||||
|
||||
|
||||
def _run_scheduled_eval() -> None:
|
||||
from onyx.configs.app_configs import BRAINTRUST_API_KEY
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_PERMISSIONS_EMAIL
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_PROJECT
|
||||
|
||||
if not all(
|
||||
[
|
||||
BRAINTRUST_API_KEY,
|
||||
SCHEDULED_EVAL_PROJECT,
|
||||
SCHEDULED_EVAL_DATASET_NAMES,
|
||||
SCHEDULED_EVAL_PERMISSIONS_EMAIL,
|
||||
]
|
||||
):
|
||||
return
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from onyx.evals.eval import run_eval
|
||||
from onyx.evals.models import EvalConfigurationOptions
|
||||
|
||||
run_timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
for dataset_name in SCHEDULED_EVAL_DATASET_NAMES:
|
||||
try:
|
||||
run_eval(
|
||||
configuration=EvalConfigurationOptions(
|
||||
search_permissions_email=SCHEDULED_EVAL_PERMISSIONS_EMAIL,
|
||||
dataset_name=dataset_name,
|
||||
no_send_logs=False,
|
||||
braintrust_project=SCHEDULED_EVAL_PROJECT,
|
||||
experiment_name=f"{dataset_name} - {run_timestamp}",
|
||||
),
|
||||
remote_dataset_name=dataset_name,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Periodic poller - Failed scheduled eval for dataset {dataset_name}"
|
||||
)
|
||||
|
||||
|
||||
_CACHE_CLEANUP_INTERVAL_SECONDS = 300
|
||||
|
||||
|
||||
def _build_periodic_tasks() -> list[_PeriodicTaskDef]:
|
||||
from onyx.cache.interface import CacheBackendType
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
|
||||
from onyx.configs.app_configs import CACHE_BACKEND
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
|
||||
tasks: list[_PeriodicTaskDef] = []
|
||||
if CACHE_BACKEND == CacheBackendType.POSTGRES:
|
||||
tasks.append(
|
||||
_PeriodicTaskDef(
|
||||
name="cache-cleanup",
|
||||
interval_seconds=_CACHE_CLEANUP_INTERVAL_SECONDS,
|
||||
lock_id=PERIODIC_TASK_LOCK_BASE + 2,
|
||||
run_fn=_run_cache_cleanup,
|
||||
)
|
||||
)
|
||||
if AUTO_LLM_CONFIG_URL:
|
||||
tasks.append(
|
||||
_PeriodicTaskDef(
|
||||
name="auto-llm-update",
|
||||
interval_seconds=AUTO_LLM_UPDATE_INTERVAL_SECONDS,
|
||||
lock_id=PERIODIC_TASK_LOCK_BASE,
|
||||
run_fn=_run_auto_llm_update,
|
||||
)
|
||||
)
|
||||
if SCHEDULED_EVAL_DATASET_NAMES:
|
||||
tasks.append(
|
||||
_PeriodicTaskDef(
|
||||
name="scheduled-eval",
|
||||
interval_seconds=7 * 24 * 3600,
|
||||
lock_id=PERIODIC_TASK_LOCK_BASE + 1,
|
||||
run_fn=_run_scheduled_eval,
|
||||
)
|
||||
)
|
||||
return tasks
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Periodic task runner with advisory-lock-guarded claim
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def _try_claim_task(task_def: _PeriodicTaskDef) -> bool:
|
||||
"""Atomically check whether *task_def* should run and record a claim.
|
||||
|
||||
Uses a transaction-scoped advisory lock for atomicity combined with a
|
||||
``KVStore`` timestamp for cross-instance dedup. The DB session is held
|
||||
only for this brief claim transaction, not during task execution.
|
||||
"""
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import KVStore
|
||||
|
||||
kv_key = PERIODIC_TASK_KV_PREFIX + task_def.name
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
acquired = db_session.execute(
|
||||
text("SELECT pg_try_advisory_xact_lock(:id)"),
|
||||
{"id": task_def.lock_id},
|
||||
).scalar()
|
||||
if not acquired:
|
||||
return False
|
||||
|
||||
row = db_session.query(KVStore).filter_by(key=kv_key).first()
|
||||
if row and row.value is not None:
|
||||
last_claimed = datetime.fromisoformat(str(row.value))
|
||||
elapsed = (datetime.now(timezone.utc) - last_claimed).total_seconds()
|
||||
if elapsed < task_def.interval_seconds:
|
||||
return False
|
||||
|
||||
now_ts = datetime.now(timezone.utc).isoformat()
|
||||
if row:
|
||||
row.value = now_ts
|
||||
else:
|
||||
db_session.add(KVStore(key=kv_key, value=now_ts))
|
||||
db_session.commit()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _try_run_periodic_task(task_def: _PeriodicTaskDef) -> None:
|
||||
"""Run *task_def* if its interval has elapsed and no peer holds the lock."""
|
||||
now = time.monotonic()
|
||||
if now - task_def.last_run_at < task_def.interval_seconds:
|
||||
return
|
||||
|
||||
if not _try_claim_task(task_def):
|
||||
return
|
||||
|
||||
try:
|
||||
task_def.run_fn()
|
||||
task_def.last_run_at = now
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Periodic poller - Error running periodic task {task_def.name}"
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Recovery / drain loop runner
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def _run_drain_loops(tenant_id: str) -> None:
|
||||
from onyx.background.task_utils import drain_delete_loop
|
||||
from onyx.background.task_utils import drain_processing_loop
|
||||
from onyx.background.task_utils import drain_project_sync_loop
|
||||
|
||||
drain_processing_loop(tenant_id)
|
||||
drain_delete_loop(tenant_id)
|
||||
drain_project_sync_loop(tenant_id)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Startup recovery (10g)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def recover_stuck_user_files(tenant_id: str) -> None:
|
||||
"""Run all drain loops once to re-process files left in intermediate states.
|
||||
|
||||
Called from ``lifespan()`` on startup when ``DISABLE_VECTOR_DB`` is set.
|
||||
"""
|
||||
logger.info("recover_stuck_user_files - Checking for stuck user files")
|
||||
try:
|
||||
_run_drain_loops(tenant_id)
|
||||
except Exception:
|
||||
logger.exception("recover_stuck_user_files - Error during recovery")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Daemon thread (10f)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
_shutdown_event = threading.Event()
|
||||
_poller_thread: threading.Thread | None = None
|
||||
|
||||
|
||||
def _poller_loop(tenant_id: str) -> None:
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
periodic_tasks = _build_periodic_tasks()
|
||||
logger.info(
|
||||
f"Periodic poller started with {len(periodic_tasks)} periodic task(s): "
|
||||
f"{[t.name for t in periodic_tasks]}"
|
||||
)
|
||||
|
||||
while not _shutdown_event.is_set():
|
||||
try:
|
||||
_run_drain_loops(tenant_id)
|
||||
except Exception:
|
||||
logger.exception("Periodic poller - Error in recovery polling")
|
||||
|
||||
for task_def in periodic_tasks:
|
||||
try:
|
||||
_try_run_periodic_task(task_def)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Periodic poller - Unhandled error checking task {task_def.name}"
|
||||
)
|
||||
|
||||
_shutdown_event.wait(RECOVERY_INTERVAL_SECONDS)
|
||||
|
||||
|
||||
def start_periodic_poller(tenant_id: str) -> None:
|
||||
"""Start the periodic poller daemon thread."""
|
||||
global _poller_thread # noqa: PLW0603
|
||||
_shutdown_event.clear()
|
||||
_poller_thread = threading.Thread(
|
||||
target=_poller_loop,
|
||||
args=(tenant_id,),
|
||||
daemon=True,
|
||||
name="no-vectordb-periodic-poller",
|
||||
)
|
||||
_poller_thread.start()
|
||||
logger.info("Periodic poller thread started")
|
||||
|
||||
|
||||
def stop_periodic_poller() -> None:
|
||||
"""Signal the periodic poller to stop and wait for it to exit."""
|
||||
global _poller_thread # noqa: PLW0603
|
||||
if _poller_thread is None:
|
||||
return
|
||||
_shutdown_event.set()
|
||||
_poller_thread.join(timeout=10)
|
||||
if _poller_thread.is_alive():
|
||||
logger.warning("Periodic poller thread did not stop within timeout")
|
||||
_poller_thread = None
|
||||
logger.info("Periodic poller thread stopped")
|
||||
@@ -1,3 +1,33 @@
|
||||
"""Background task utilities.
|
||||
|
||||
Contains query-history report helpers (used by all deployment modes) and
|
||||
in-process background task execution helpers for NO_VECTOR_DB mode:
|
||||
|
||||
- Atomic claim-and-mark helpers that prevent duplicate processing
|
||||
- Drain loops that process all pending user file work
|
||||
|
||||
Each claim function runs a short-lived transaction: SELECT ... FOR UPDATE
|
||||
SKIP LOCKED, UPDATE the row to remove it from future queries, COMMIT.
|
||||
After the commit the row lock is released, but the row is no longer
|
||||
eligible for re-claiming. No long-lived sessions or advisory locks.
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Query-history report helpers (pre-existing, used by all modes)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
QUERY_REPORT_NAME_PREFIX = "query-history"
|
||||
|
||||
|
||||
@@ -9,3 +39,168 @@ def construct_query_history_report_name(
|
||||
|
||||
def extract_task_id_from_query_history_report_name(name: str) -> str:
|
||||
return name.removeprefix(f"{QUERY_REPORT_NAME_PREFIX}-").removesuffix(".csv")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Atomic claim-and-mark helpers
|
||||
# ------------------------------------------------------------------
|
||||
# Each function runs inside a single short-lived session/transaction:
|
||||
# 1. SELECT ... FOR UPDATE SKIP LOCKED (locks one eligible row)
|
||||
# 2. UPDATE the row so it is no longer eligible
|
||||
# 3. COMMIT (releases the row lock)
|
||||
# After the commit, no other drain loop can claim the same row.
|
||||
|
||||
|
||||
def _claim_next_processing_file(db_session: Session) -> UUID | None:
|
||||
"""Claim the next PROCESSING file by transitioning it to INDEXING.
|
||||
|
||||
Returns the file id, or None when no eligible files remain.
|
||||
"""
|
||||
file_id = db_session.execute(
|
||||
select(UserFile.id)
|
||||
.where(UserFile.status == UserFileStatus.PROCESSING)
|
||||
.order_by(UserFile.created_at)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
).scalar_one_or_none()
|
||||
if file_id is None:
|
||||
return None
|
||||
|
||||
db_session.execute(
|
||||
sa.update(UserFile)
|
||||
.where(UserFile.id == file_id)
|
||||
.values(status=UserFileStatus.INDEXING)
|
||||
)
|
||||
db_session.commit()
|
||||
return file_id
|
||||
|
||||
|
||||
def _claim_next_deleting_file(
|
||||
db_session: Session,
|
||||
exclude_ids: set[UUID] | None = None,
|
||||
) -> UUID | None:
|
||||
"""Claim the next DELETING file.
|
||||
|
||||
No status transition needed — the impl deletes the row on success.
|
||||
The short-lived FOR UPDATE lock prevents concurrent claims.
|
||||
*exclude_ids* prevents re-processing the same file if the impl fails.
|
||||
"""
|
||||
stmt = (
|
||||
select(UserFile.id)
|
||||
.where(UserFile.status == UserFileStatus.DELETING)
|
||||
.order_by(UserFile.created_at)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
)
|
||||
if exclude_ids:
|
||||
stmt = stmt.where(UserFile.id.notin_(exclude_ids))
|
||||
file_id = db_session.execute(stmt).scalar_one_or_none()
|
||||
db_session.commit()
|
||||
return file_id
|
||||
|
||||
|
||||
def _claim_next_sync_file(
|
||||
db_session: Session,
|
||||
exclude_ids: set[UUID] | None = None,
|
||||
) -> UUID | None:
|
||||
"""Claim the next file needing project/persona sync.
|
||||
|
||||
No status transition needed — the impl clears the sync flags on
|
||||
success. The short-lived FOR UPDATE lock prevents concurrent claims.
|
||||
*exclude_ids* prevents re-processing the same file if the impl fails.
|
||||
"""
|
||||
stmt = (
|
||||
select(UserFile.id)
|
||||
.where(
|
||||
sa.and_(
|
||||
sa.or_(
|
||||
UserFile.needs_project_sync.is_(True),
|
||||
UserFile.needs_persona_sync.is_(True),
|
||||
),
|
||||
UserFile.status == UserFileStatus.COMPLETED,
|
||||
)
|
||||
)
|
||||
.order_by(UserFile.created_at)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
)
|
||||
if exclude_ids:
|
||||
stmt = stmt.where(UserFile.id.notin_(exclude_ids))
|
||||
file_id = db_session.execute(stmt).scalar_one_or_none()
|
||||
db_session.commit()
|
||||
return file_id
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Drain loops — process *all* pending work of each type
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def drain_processing_loop(tenant_id: str) -> None:
|
||||
"""Process all pending PROCESSING user files."""
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
process_user_file_impl,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
while True:
|
||||
with get_session_with_current_tenant() as session:
|
||||
file_id = _claim_next_processing_file(session)
|
||||
if file_id is None:
|
||||
break
|
||||
try:
|
||||
process_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to process user file {file_id}")
|
||||
|
||||
|
||||
def drain_delete_loop(tenant_id: str) -> None:
|
||||
"""Delete all pending DELETING user files."""
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
delete_user_file_impl,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
failed: set[UUID] = set()
|
||||
while True:
|
||||
with get_session_with_current_tenant() as session:
|
||||
file_id = _claim_next_deleting_file(session, exclude_ids=failed)
|
||||
if file_id is None:
|
||||
break
|
||||
try:
|
||||
delete_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to delete user file {file_id}")
|
||||
failed.add(file_id)
|
||||
|
||||
|
||||
def drain_project_sync_loop(tenant_id: str) -> None:
|
||||
"""Sync all pending project/persona metadata for user files."""
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
project_sync_user_file_impl,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
failed: set[UUID] = set()
|
||||
while True:
|
||||
with get_session_with_current_tenant() as session:
|
||||
file_id = _claim_next_sync_file(session, exclude_ids=failed)
|
||||
if file_id is None:
|
||||
break
|
||||
try:
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to sync user file {file_id}")
|
||||
failed.add(file_id)
|
||||
|
||||
51
backend/onyx/cache/factory.py
vendored
Normal file
51
backend/onyx/cache/factory.py
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.interface import CacheBackendType
|
||||
from onyx.configs.app_configs import CACHE_BACKEND
|
||||
|
||||
|
||||
def _build_redis_backend(tenant_id: str) -> CacheBackend:
|
||||
from onyx.cache.redis_backend import RedisCacheBackend
|
||||
from onyx.redis.redis_pool import redis_pool
|
||||
|
||||
return RedisCacheBackend(redis_pool.get_client(tenant_id))
|
||||
|
||||
|
||||
def _build_postgres_backend(tenant_id: str) -> CacheBackend:
|
||||
from onyx.cache.postgres_backend import PostgresCacheBackend
|
||||
|
||||
return PostgresCacheBackend(tenant_id)
|
||||
|
||||
|
||||
_BACKEND_BUILDERS: dict[CacheBackendType, Callable[[str], CacheBackend]] = {
|
||||
CacheBackendType.REDIS: _build_redis_backend,
|
||||
CacheBackendType.POSTGRES: _build_postgres_backend,
|
||||
}
|
||||
|
||||
|
||||
def get_cache_backend(*, tenant_id: str | None = None) -> CacheBackend:
|
||||
"""Return a tenant-aware ``CacheBackend``.
|
||||
|
||||
If *tenant_id* is ``None``, the current tenant is read from the
|
||||
thread-local context variable (same behaviour as ``get_redis_client``).
|
||||
"""
|
||||
if tenant_id is None:
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
builder = _BACKEND_BUILDERS.get(CACHE_BACKEND)
|
||||
if builder is None:
|
||||
raise ValueError(
|
||||
f"Unsupported CACHE_BACKEND={CACHE_BACKEND!r}. "
|
||||
f"Supported values: {[t.value for t in CacheBackendType]}"
|
||||
)
|
||||
return builder(tenant_id)
|
||||
|
||||
|
||||
def get_shared_cache_backend() -> CacheBackend:
|
||||
"""Return a ``CacheBackend`` in the shared (cross-tenant) namespace."""
|
||||
from shared_configs.configs import DEFAULT_REDIS_PREFIX
|
||||
|
||||
return get_cache_backend(tenant_id=DEFAULT_REDIS_PREFIX)
|
||||
104
backend/onyx/cache/interface.py
vendored
Normal file
104
backend/onyx/cache/interface.py
vendored
Normal file
@@ -0,0 +1,104 @@
|
||||
import abc
|
||||
from enum import Enum
|
||||
|
||||
TTL_KEY_NOT_FOUND = -2
|
||||
TTL_NO_EXPIRY = -1
|
||||
|
||||
|
||||
class CacheBackendType(str, Enum):
|
||||
REDIS = "redis"
|
||||
POSTGRES = "postgres"
|
||||
|
||||
|
||||
class CacheLock(abc.ABC):
|
||||
"""Abstract distributed lock returned by CacheBackend.lock()."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def acquire(
|
||||
self,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: float | None = None,
|
||||
) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def release(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def owned(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def __enter__(self) -> "CacheLock":
|
||||
if not self.acquire():
|
||||
raise RuntimeError("Failed to acquire lock")
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: object) -> None:
|
||||
self.release()
|
||||
|
||||
|
||||
class CacheBackend(abc.ABC):
|
||||
"""Thin abstraction over a key-value cache with TTL, locks, and blocking lists.
|
||||
|
||||
Covers the subset of Redis operations used outside of Celery. When
|
||||
CACHE_BACKEND=postgres, a PostgreSQL-backed implementation is used instead.
|
||||
"""
|
||||
|
||||
# -- basic key/value ---------------------------------------------------
|
||||
|
||||
@abc.abstractmethod
|
||||
def get(self, key: str) -> bytes | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str | bytes | int | float,
|
||||
ex: int | None = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete(self, key: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def exists(self, key: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
# -- TTL ---------------------------------------------------------------
|
||||
|
||||
@abc.abstractmethod
|
||||
def expire(self, key: str, seconds: int) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def ttl(self, key: str) -> int:
|
||||
"""Return remaining TTL in seconds.
|
||||
|
||||
Returns ``TTL_NO_EXPIRY`` (-1) if key exists without expiry,
|
||||
``TTL_KEY_NOT_FOUND`` (-2) if key is missing or expired.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# -- distributed lock --------------------------------------------------
|
||||
|
||||
@abc.abstractmethod
|
||||
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
|
||||
raise NotImplementedError
|
||||
|
||||
# -- blocking list (used by MCP OAuth BLPOP pattern) -------------------
|
||||
|
||||
@abc.abstractmethod
|
||||
def rpush(self, key: str, value: str | bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
|
||||
"""Block until a value is available on one of *keys*, or *timeout* expires.
|
||||
|
||||
Returns ``(key, value)`` or ``None`` on timeout.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
323
backend/onyx/cache/postgres_backend.py
vendored
Normal file
323
backend/onyx/cache/postgres_backend.py
vendored
Normal file
@@ -0,0 +1,323 @@
|
||||
"""PostgreSQL-backed ``CacheBackend`` for NO_VECTOR_DB deployments.
|
||||
|
||||
Uses the ``cache_store`` table for key-value storage, PostgreSQL advisory locks
|
||||
for distributed locking, and a polling loop for the BLPOP pattern.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import struct
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import AbstractContextManager
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.interface import CacheLock
|
||||
from onyx.cache.interface import TTL_KEY_NOT_FOUND
|
||||
from onyx.cache.interface import TTL_NO_EXPIRY
|
||||
from onyx.db.models import CacheStore
|
||||
|
||||
_LIST_KEY_PREFIX = "_q:"
|
||||
# ASCII: ':' (0x3A) < ';' (0x3B). Upper bound for range queries so [prefix+, prefix;)
|
||||
# captures all list-item keys (e.g. _q:mylist:123:uuid) without including other
|
||||
# lists whose names share a prefix (e.g. _q:mylist2:...).
|
||||
_LIST_KEY_RANGE_TERMINATOR = ";"
|
||||
_LIST_ITEM_TTL_SECONDS = 3600
|
||||
_LOCK_POLL_INTERVAL = 0.1
|
||||
_BLPOP_POLL_INTERVAL = 0.25
|
||||
|
||||
|
||||
def _list_item_key(key: str) -> str:
|
||||
"""Unique key for a list item. Timestamp for FIFO ordering; UUID prevents
|
||||
collision when concurrent rpush calls occur within the same nanosecond.
|
||||
"""
|
||||
return f"{_LIST_KEY_PREFIX}{key}:{time.time_ns()}:{uuid.uuid4().hex}"
|
||||
|
||||
|
||||
def _to_bytes(value: str | bytes | int | float) -> bytes:
|
||||
if isinstance(value, bytes):
|
||||
return value
|
||||
return str(value).encode()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lock
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class PostgresCacheLock(CacheLock):
|
||||
"""Advisory-lock-based distributed lock.
|
||||
|
||||
Uses ``get_session_with_tenant`` for connection lifecycle. The lock is tied
|
||||
to the session's connection; releasing or closing the session frees it.
|
||||
|
||||
NOTE: Unlike Redis locks, advisory locks do not auto-expire after
|
||||
``timeout`` seconds. They are released when ``release()`` is
|
||||
called or when the session is closed.
|
||||
"""
|
||||
|
||||
def __init__(self, lock_id: int, timeout: float | None, tenant_id: str) -> None:
|
||||
self._lock_id = lock_id
|
||||
self._timeout = timeout
|
||||
self._tenant_id = tenant_id
|
||||
self._session_cm: AbstractContextManager[Session] | None = None
|
||||
self._session: Session | None = None
|
||||
self._acquired = False
|
||||
|
||||
def acquire(
|
||||
self,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: float | None = None,
|
||||
) -> bool:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
self._session_cm = get_session_with_tenant(tenant_id=self._tenant_id)
|
||||
self._session = self._session_cm.__enter__()
|
||||
try:
|
||||
if not blocking:
|
||||
return self._try_lock()
|
||||
|
||||
effective_timeout = blocking_timeout or self._timeout
|
||||
deadline = (
|
||||
(time.monotonic() + effective_timeout) if effective_timeout else None
|
||||
)
|
||||
while True:
|
||||
if self._try_lock():
|
||||
return True
|
||||
if deadline is not None and time.monotonic() >= deadline:
|
||||
return False
|
||||
time.sleep(_LOCK_POLL_INTERVAL)
|
||||
finally:
|
||||
if not self._acquired:
|
||||
self._close_session()
|
||||
|
||||
def release(self) -> None:
|
||||
if not self._acquired or self._session is None:
|
||||
return
|
||||
try:
|
||||
self._session.execute(select(func.pg_advisory_unlock(self._lock_id)))
|
||||
finally:
|
||||
self._acquired = False
|
||||
self._close_session()
|
||||
|
||||
def owned(self) -> bool:
|
||||
return self._acquired
|
||||
|
||||
def _close_session(self) -> None:
|
||||
if self._session_cm is not None:
|
||||
try:
|
||||
self._session_cm.__exit__(None, None, None)
|
||||
finally:
|
||||
self._session_cm = None
|
||||
self._session = None
|
||||
|
||||
def _try_lock(self) -> bool:
|
||||
assert self._session is not None
|
||||
result = self._session.execute(
|
||||
select(func.pg_try_advisory_lock(self._lock_id))
|
||||
).scalar()
|
||||
if result:
|
||||
self._acquired = True
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Backend
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class PostgresCacheBackend(CacheBackend):
|
||||
"""``CacheBackend`` backed by the ``cache_store`` table in PostgreSQL.
|
||||
|
||||
Each operation opens and closes its own database session so the backend
|
||||
is safe to share across threads. Tenant isolation is handled by
|
||||
SQLAlchemy's ``schema_translate_map`` (set by ``get_session_with_tenant``).
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_id: str) -> None:
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
# -- basic key/value ---------------------------------------------------
|
||||
|
||||
def get(self, key: str) -> bytes | None:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
stmt = select(CacheStore.value).where(
|
||||
CacheStore.key == key,
|
||||
or_(CacheStore.expires_at.is_(None), CacheStore.expires_at > func.now()),
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
value = session.execute(stmt).scalar_one_or_none()
|
||||
if value is None:
|
||||
return None
|
||||
return bytes(value)
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str | bytes | int | float,
|
||||
ex: int | None = None,
|
||||
) -> None:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
value_bytes = _to_bytes(value)
|
||||
expires_at = (
|
||||
datetime.now(timezone.utc) + timedelta(seconds=ex)
|
||||
if ex is not None
|
||||
else None
|
||||
)
|
||||
stmt = (
|
||||
pg_insert(CacheStore)
|
||||
.values(key=key, value=value_bytes, expires_at=expires_at)
|
||||
.on_conflict_do_update(
|
||||
index_elements=[CacheStore.key],
|
||||
set_={"value": value_bytes, "expires_at": expires_at},
|
||||
)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
session.execute(delete(CacheStore).where(CacheStore.key == key))
|
||||
session.commit()
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
stmt = (
|
||||
select(CacheStore.key)
|
||||
.where(
|
||||
CacheStore.key == key,
|
||||
or_(
|
||||
CacheStore.expires_at.is_(None),
|
||||
CacheStore.expires_at > func.now(),
|
||||
),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
return session.execute(stmt).first() is not None
|
||||
|
||||
# -- TTL ---------------------------------------------------------------
|
||||
|
||||
def expire(self, key: str, seconds: int) -> None:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
new_exp = datetime.now(timezone.utc) + timedelta(seconds=seconds)
|
||||
stmt = (
|
||||
update(CacheStore).where(CacheStore.key == key).values(expires_at=new_exp)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
def ttl(self, key: str) -> int:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
stmt = select(CacheStore.expires_at).where(CacheStore.key == key)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
result = session.execute(stmt).first()
|
||||
if result is None:
|
||||
return TTL_KEY_NOT_FOUND
|
||||
expires_at: datetime | None = result[0]
|
||||
if expires_at is None:
|
||||
return TTL_NO_EXPIRY
|
||||
remaining = (expires_at - datetime.now(timezone.utc)).total_seconds()
|
||||
if remaining <= 0:
|
||||
return TTL_KEY_NOT_FOUND
|
||||
return int(remaining)
|
||||
|
||||
# -- distributed lock --------------------------------------------------
|
||||
|
||||
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
|
||||
return PostgresCacheLock(
|
||||
self._lock_id_for(name), timeout, tenant_id=self._tenant_id
|
||||
)
|
||||
|
||||
# -- blocking list (MCP OAuth BLPOP pattern) ---------------------------
|
||||
|
||||
def rpush(self, key: str, value: str | bytes) -> None:
|
||||
self.set(_list_item_key(key), value, ex=_LIST_ITEM_TTL_SECONDS)
|
||||
|
||||
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
|
||||
if timeout <= 0:
|
||||
raise ValueError(
|
||||
"PostgresCacheBackend.blpop requires timeout > 0. "
|
||||
"timeout=0 would block the calling thread indefinitely "
|
||||
"with no way to interrupt short of process termination."
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
deadline = time.monotonic() + timeout
|
||||
while True:
|
||||
for key in keys:
|
||||
lower = f"{_LIST_KEY_PREFIX}{key}:"
|
||||
upper = f"{_LIST_KEY_PREFIX}{key}{_LIST_KEY_RANGE_TERMINATOR}"
|
||||
stmt = (
|
||||
select(CacheStore)
|
||||
.where(
|
||||
CacheStore.key >= lower,
|
||||
CacheStore.key < upper,
|
||||
or_(
|
||||
CacheStore.expires_at.is_(None),
|
||||
CacheStore.expires_at > func.now(),
|
||||
),
|
||||
)
|
||||
.order_by(CacheStore.key)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
row = session.execute(stmt).scalars().first()
|
||||
if row is not None:
|
||||
value = bytes(row.value) if row.value else b""
|
||||
session.delete(row)
|
||||
session.commit()
|
||||
return (key.encode(), value)
|
||||
if time.monotonic() >= deadline:
|
||||
return None
|
||||
time.sleep(_BLPOP_POLL_INTERVAL)
|
||||
|
||||
# -- helpers -----------------------------------------------------------
|
||||
|
||||
def _lock_id_for(self, name: str) -> int:
|
||||
"""Map *name* to a 64-bit signed int for ``pg_advisory_lock``."""
|
||||
h = hashlib.md5(f"{self._tenant_id}:{name}".encode()).digest()
|
||||
return struct.unpack("q", h[:8])[0]
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Periodic cleanup
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def cleanup_expired_cache_entries() -> None:
|
||||
"""Delete rows whose ``expires_at`` is in the past.
|
||||
|
||||
Called by the periodic poller every 5 minutes.
|
||||
"""
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as session:
|
||||
session.execute(
|
||||
delete(CacheStore).where(
|
||||
CacheStore.expires_at.is_not(None),
|
||||
CacheStore.expires_at < func.now(),
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
92
backend/onyx/cache/redis_backend.py
vendored
Normal file
92
backend/onyx/cache/redis_backend.py
vendored
Normal file
@@ -0,0 +1,92 @@
|
||||
from typing import cast
|
||||
|
||||
from redis.client import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.interface import CacheLock
|
||||
|
||||
|
||||
class RedisCacheLock(CacheLock):
|
||||
"""Wraps ``redis.lock.Lock`` behind the ``CacheLock`` interface."""
|
||||
|
||||
def __init__(self, lock: RedisLock) -> None:
|
||||
self._lock = lock
|
||||
|
||||
def acquire(
|
||||
self,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: float | None = None,
|
||||
) -> bool:
|
||||
return bool(
|
||||
self._lock.acquire(
|
||||
blocking=blocking,
|
||||
blocking_timeout=blocking_timeout,
|
||||
)
|
||||
)
|
||||
|
||||
def release(self) -> None:
|
||||
self._lock.release()
|
||||
|
||||
def owned(self) -> bool:
|
||||
return bool(self._lock.owned())
|
||||
|
||||
|
||||
class RedisCacheBackend(CacheBackend):
|
||||
"""``CacheBackend`` implementation that delegates to a ``redis.Redis`` client.
|
||||
|
||||
This is a thin pass-through — every method maps 1-to-1 to the underlying
|
||||
Redis command. ``TenantRedis`` key-prefixing is handled by the client
|
||||
itself (provided by ``get_redis_client``).
|
||||
"""
|
||||
|
||||
def __init__(self, redis_client: Redis) -> None:
|
||||
self._r = redis_client
|
||||
|
||||
# -- basic key/value ---------------------------------------------------
|
||||
|
||||
def get(self, key: str) -> bytes | None:
|
||||
val = self._r.get(key)
|
||||
if val is None:
|
||||
return None
|
||||
if isinstance(val, bytes):
|
||||
return val
|
||||
return str(val).encode()
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str | bytes | int | float,
|
||||
ex: int | None = None,
|
||||
) -> None:
|
||||
self._r.set(key, value, ex=ex)
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
self._r.delete(key)
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
return bool(self._r.exists(key))
|
||||
|
||||
# -- TTL ---------------------------------------------------------------
|
||||
|
||||
def expire(self, key: str, seconds: int) -> None:
|
||||
self._r.expire(key, seconds)
|
||||
|
||||
def ttl(self, key: str) -> int:
|
||||
return cast(int, self._r.ttl(key))
|
||||
|
||||
# -- distributed lock --------------------------------------------------
|
||||
|
||||
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
|
||||
return RedisCacheLock(self._r.lock(name, timeout=timeout))
|
||||
|
||||
# -- blocking list (MCP OAuth BLPOP pattern) ---------------------------
|
||||
|
||||
def rpush(self, key: str, value: str | bytes) -> None:
|
||||
self._r.rpush(key, value)
|
||||
|
||||
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
|
||||
result = cast(list[bytes] | None, self._r.blpop(keys, timeout=timeout))
|
||||
if result is None:
|
||||
return None
|
||||
return (result[0], result[1])
|
||||
@@ -1,57 +1,52 @@
|
||||
from uuid import UUID
|
||||
|
||||
from redis.client import Redis
|
||||
from onyx.cache.interface import CacheBackend
|
||||
|
||||
# Redis key prefixes for chat message processing
|
||||
PREFIX = "chatprocessing"
|
||||
FENCE_PREFIX = f"{PREFIX}_fence"
|
||||
FENCE_TTL = 30 * 60 # 30 minutes
|
||||
|
||||
|
||||
def _get_fence_key(chat_session_id: UUID) -> str:
|
||||
"""
|
||||
Generate the Redis key for a chat session processing a message.
|
||||
"""Generate the cache key for a chat session processing fence.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
|
||||
Returns:
|
||||
The fence key string (tenant_id is automatically added by the Redis client)
|
||||
The fence key string. Tenant isolation is handled automatically
|
||||
by the cache backend (Redis key-prefixing or Postgres schema routing).
|
||||
"""
|
||||
return f"{FENCE_PREFIX}_{chat_session_id}"
|
||||
|
||||
|
||||
def set_processing_status(
|
||||
chat_session_id: UUID, redis_client: Redis, value: bool
|
||||
chat_session_id: UUID, cache: CacheBackend, value: bool
|
||||
) -> None:
|
||||
"""
|
||||
Set or clear the fence for a chat session processing a message.
|
||||
"""Set or clear the fence for a chat session processing a message.
|
||||
|
||||
If the key exists, we are processing a message. If the key does not exist, we are not processing a message.
|
||||
If the key exists, a message is being processed.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
redis_client: The Redis client to use
|
||||
cache: Tenant-aware cache backend
|
||||
value: True to set the fence, False to clear it
|
||||
"""
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
|
||||
if value:
|
||||
redis_client.set(fence_key, 0, ex=FENCE_TTL)
|
||||
cache.set(fence_key, 0, ex=FENCE_TTL)
|
||||
else:
|
||||
redis_client.delete(fence_key)
|
||||
cache.delete(fence_key)
|
||||
|
||||
|
||||
def is_chat_session_processing(chat_session_id: UUID, redis_client: Redis) -> bool:
|
||||
"""
|
||||
Check if the chat session is processing a message.
|
||||
def is_chat_session_processing(chat_session_id: UUID, cache: CacheBackend) -> bool:
|
||||
"""Check if the chat session is processing a message.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
redis_client: The Redis client to use
|
||||
cache: Tenant-aware cache backend
|
||||
|
||||
Returns:
|
||||
True if the chat session is processing a message, False otherwise
|
||||
"""
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
return bool(redis_client.exists(fence_key))
|
||||
return cache.exists(_get_fence_key(chat_session_id))
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -530,11 +531,13 @@ def _create_file_tool_metadata_message(
|
||||
"""
|
||||
lines = [
|
||||
"You have access to the following files. Use the read_file tool to "
|
||||
"read sections of any file:"
|
||||
"read sections of any file. You MUST pass the file_id UUID (not the "
|
||||
"filename) to read_file:"
|
||||
]
|
||||
for meta in file_metadata:
|
||||
lines.append(
|
||||
f'- {meta.file_id}: "{meta.filename}" (~{meta.approx_char_count:,} chars)'
|
||||
f'- file_id="{meta.file_id}" filename="{meta.filename}" '
|
||||
f"(~{meta.approx_char_count:,} chars)"
|
||||
)
|
||||
|
||||
message_content = "\n".join(lines)
|
||||
@@ -558,12 +561,16 @@ def _create_context_files_message(
|
||||
# Format as documents JSON as described in README
|
||||
documents_list = []
|
||||
for idx, file_text in enumerate(context_files.file_texts, start=1):
|
||||
documents_list.append(
|
||||
{
|
||||
"document": idx,
|
||||
"contents": file_text,
|
||||
}
|
||||
title = (
|
||||
context_files.file_metadata[idx - 1].filename
|
||||
if idx - 1 < len(context_files.file_metadata)
|
||||
else None
|
||||
)
|
||||
entry: dict[str, Any] = {"document": idx}
|
||||
if title:
|
||||
entry["title"] = title
|
||||
entry["contents"] = file_text
|
||||
documents_list.append(entry)
|
||||
|
||||
documents_json = json.dumps({"documents": documents_list}, indent=2)
|
||||
message_content = f"Here are some documents provided for context, they may not all be relevant:\n{documents_json}"
|
||||
|
||||
@@ -11,9 +11,10 @@ from contextvars import Token
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from redis.client import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.chat.chat_processing_checker import set_processing_status
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_state import run_chat_loop_with_state_containers
|
||||
@@ -79,7 +80,6 @@ from onyx.llm.request_context import reset_llm_mock_response
|
||||
from onyx.llm.request_context import set_llm_mock_response
|
||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
@@ -448,7 +448,7 @@ def handle_stream_message_objects(
|
||||
|
||||
llm: LLM | None = None
|
||||
chat_session: ChatSession | None = None
|
||||
redis_client: Redis | None = None
|
||||
cache: CacheBackend | None = None
|
||||
|
||||
user_id = user.id
|
||||
if user.is_anonymous:
|
||||
@@ -809,19 +809,19 @@ def handle_stream_message_objects(
|
||||
)
|
||||
simple_chat_history.insert(0, summary_simple)
|
||||
|
||||
redis_client = get_redis_client()
|
||||
cache = get_cache_backend()
|
||||
|
||||
reset_cancel_status(
|
||||
chat_session.id,
|
||||
redis_client,
|
||||
cache,
|
||||
)
|
||||
|
||||
def check_is_connected() -> bool:
|
||||
return check_stop_signal(chat_session.id, redis_client)
|
||||
return check_stop_signal(chat_session.id, cache)
|
||||
|
||||
set_processing_status(
|
||||
chat_session_id=chat_session.id,
|
||||
redis_client=redis_client,
|
||||
cache=cache,
|
||||
value=True,
|
||||
)
|
||||
|
||||
@@ -968,10 +968,10 @@ def handle_stream_message_objects(
|
||||
reset_llm_mock_response(mock_response_token)
|
||||
|
||||
try:
|
||||
if redis_client is not None and chat_session is not None:
|
||||
if cache is not None and chat_session is not None:
|
||||
set_processing_status(
|
||||
chat_session_id=chat_session.id,
|
||||
redis_client=redis_client,
|
||||
cache=cache,
|
||||
value=False,
|
||||
)
|
||||
except Exception:
|
||||
|
||||
@@ -1,65 +1,58 @@
|
||||
from uuid import UUID
|
||||
|
||||
from redis.client import Redis
|
||||
from onyx.cache.interface import CacheBackend
|
||||
|
||||
# Redis key prefixes for chat session stop signals
|
||||
PREFIX = "chatsessionstop"
|
||||
FENCE_PREFIX = f"{PREFIX}_fence"
|
||||
FENCE_TTL = 10 * 60 # 10 minutes - defensive TTL to prevent memory leaks
|
||||
FENCE_TTL = 10 * 60 # 10 minutes
|
||||
|
||||
|
||||
def _get_fence_key(chat_session_id: UUID) -> str:
|
||||
"""
|
||||
Generate the Redis key for a chat session stop signal fence.
|
||||
"""Generate the cache key for a chat session stop signal fence.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
|
||||
Returns:
|
||||
The fence key string (tenant_id is automatically added by the Redis client)
|
||||
The fence key string. Tenant isolation is handled automatically
|
||||
by the cache backend (Redis key-prefixing or Postgres schema routing).
|
||||
"""
|
||||
return f"{FENCE_PREFIX}_{chat_session_id}"
|
||||
|
||||
|
||||
def set_fence(chat_session_id: UUID, redis_client: Redis, value: bool) -> None:
|
||||
"""
|
||||
Set or clear the stop signal fence for a chat session.
|
||||
def set_fence(chat_session_id: UUID, cache: CacheBackend, value: bool) -> None:
|
||||
"""Set or clear the stop signal fence for a chat session.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
redis_client: Redis client to use (tenant-aware client that auto-prefixes keys)
|
||||
cache: Tenant-aware cache backend
|
||||
value: True to set the fence (stop signal), False to clear it
|
||||
"""
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
if not value:
|
||||
redis_client.delete(fence_key)
|
||||
cache.delete(fence_key)
|
||||
return
|
||||
|
||||
redis_client.set(fence_key, 0, ex=FENCE_TTL)
|
||||
cache.set(fence_key, 0, ex=FENCE_TTL)
|
||||
|
||||
|
||||
def is_connected(chat_session_id: UUID, redis_client: Redis) -> bool:
|
||||
"""
|
||||
Check if the chat session should continue (not stopped).
|
||||
def is_connected(chat_session_id: UUID, cache: CacheBackend) -> bool:
|
||||
"""Check if the chat session should continue (not stopped).
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session to check
|
||||
redis_client: Redis client to use for checking the stop signal (tenant-aware client that auto-prefixes keys)
|
||||
cache: Tenant-aware cache backend
|
||||
|
||||
Returns:
|
||||
True if the session should continue, False if it should stop
|
||||
"""
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
return not bool(redis_client.exists(fence_key))
|
||||
return not cache.exists(_get_fence_key(chat_session_id))
|
||||
|
||||
|
||||
def reset_cancel_status(chat_session_id: UUID, redis_client: Redis) -> None:
|
||||
"""
|
||||
Clear the stop signal for a chat session.
|
||||
def reset_cancel_status(chat_session_id: UUID, cache: CacheBackend) -> None:
|
||||
"""Clear the stop signal for a chat session.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
redis_client: Redis client to use (tenant-aware client that auto-prefixes keys)
|
||||
cache: Tenant-aware cache backend
|
||||
"""
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
redis_client.delete(fence_key)
|
||||
cache.delete(_get_fence_key(chat_session_id))
|
||||
|
||||
@@ -6,6 +6,7 @@ from datetime import timezone
|
||||
from typing import cast
|
||||
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
from onyx.cache.interface import CacheBackendType
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
|
||||
@@ -54,6 +55,12 @@ DISABLE_USER_KNOWLEDGE = os.environ.get("DISABLE_USER_KNOWLEDGE", "").lower() ==
|
||||
# are disabled but core chat, tools, user file uploads, and Projects still work.
|
||||
DISABLE_VECTOR_DB = os.environ.get("DISABLE_VECTOR_DB", "").lower() == "true"
|
||||
|
||||
# Which backend to use for caching, locks, and ephemeral state.
|
||||
# "redis" (default) or "postgres" (only valid when DISABLE_VECTOR_DB=true).
|
||||
CACHE_BACKEND = CacheBackendType(
|
||||
os.environ.get("CACHE_BACKEND", CacheBackendType.REDIS)
|
||||
)
|
||||
|
||||
# Maximum token count for a single uploaded file. Files exceeding this are rejected.
|
||||
# Defaults to 100k tokens (or 10M when vector DB is disabled).
|
||||
_DEFAULT_FILE_TOKEN_LIMIT = 10_000_000 if DISABLE_VECTOR_DB else 100_000
|
||||
@@ -812,7 +819,9 @@ RERANK_COUNT = int(os.environ.get("RERANK_COUNT") or 1000)
|
||||
# Tool Configs
|
||||
#####
|
||||
# Code Interpreter Service Configuration
|
||||
CODE_INTERPRETER_BASE_URL = os.environ.get("CODE_INTERPRETER_BASE_URL")
|
||||
CODE_INTERPRETER_BASE_URL = os.environ.get(
|
||||
"CODE_INTERPRETER_BASE_URL", "http://localhost:8000"
|
||||
)
|
||||
|
||||
CODE_INTERPRETER_DEFAULT_TIMEOUT_MS = int(
|
||||
os.environ.get("CODE_INTERPRETER_DEFAULT_TIMEOUT_MS") or 60_000
|
||||
|
||||
@@ -98,6 +98,7 @@ def get_chat_sessions_by_user(
|
||||
db_session: Session,
|
||||
include_onyxbot_flows: bool = False,
|
||||
limit: int = 50,
|
||||
before: datetime | None = None,
|
||||
project_id: int | None = None,
|
||||
only_non_project_chats: bool = False,
|
||||
include_failed_chats: bool = False,
|
||||
@@ -112,6 +113,9 @@ def get_chat_sessions_by_user(
|
||||
if deleted is not None:
|
||||
stmt = stmt.where(ChatSession.deleted == deleted)
|
||||
|
||||
if before is not None:
|
||||
stmt = stmt.where(ChatSession.time_updated < before)
|
||||
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
|
||||
@@ -186,6 +186,7 @@ class EmbeddingPrecision(str, PyEnum):
|
||||
|
||||
class UserFileStatus(str, PyEnum):
|
||||
PROCESSING = "PROCESSING"
|
||||
INDEXING = "INDEXING"
|
||||
COMPLETED = "COMPLETED"
|
||||
FAILED = "FAILED"
|
||||
CANCELED = "CANCELED"
|
||||
|
||||
@@ -202,7 +202,6 @@ def create_default_image_gen_config_from_api_key(
|
||||
api_key=api_key,
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
default_model_name=model_name,
|
||||
deployment_name=None,
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
@@ -109,45 +109,38 @@ def can_user_access_llm_provider(
|
||||
is_admin: If True, bypass user group restrictions but still respect persona restrictions
|
||||
|
||||
Access logic:
|
||||
1. If is_public=True → everyone has access (public override)
|
||||
2. If is_public=False:
|
||||
- Both groups AND personas set → must satisfy BOTH (AND logic, admins bypass group check)
|
||||
- Only groups set → must be in one of the groups (OR across groups, admins bypass)
|
||||
- Only personas set → must use one of the personas (OR across personas, applies to admins)
|
||||
- Neither set → NOBODY has access unless admin (locked, admin-only)
|
||||
- is_public controls USER access (group bypass): when True, all users can access
|
||||
regardless of group membership. When False, user must be in a whitelisted group
|
||||
(or be admin).
|
||||
- Persona restrictions are ALWAYS enforced when set, regardless of is_public.
|
||||
This allows admins to make a provider available to all users while still
|
||||
restricting which personas (assistants) can use it.
|
||||
|
||||
Decision matrix:
|
||||
1. is_public=True, no personas set → everyone has access
|
||||
2. is_public=True, personas set → all users, but only whitelisted personas
|
||||
3. is_public=False, groups+personas set → must satisfy BOTH (admins bypass groups)
|
||||
4. is_public=False, only groups set → must be in group (admins bypass)
|
||||
5. is_public=False, only personas set → must use whitelisted persona
|
||||
6. is_public=False, neither set → admin-only (locked)
|
||||
"""
|
||||
# Public override - everyone has access
|
||||
if provider.is_public:
|
||||
return True
|
||||
|
||||
# Extract IDs once to avoid multiple iterations
|
||||
provider_group_ids = (
|
||||
{group.id for group in provider.groups} if provider.groups else set()
|
||||
)
|
||||
provider_persona_ids = (
|
||||
{p.id for p in provider.personas} if provider.personas else set()
|
||||
)
|
||||
|
||||
provider_group_ids = {g.id for g in (provider.groups or [])}
|
||||
provider_persona_ids = {p.id for p in (provider.personas or [])}
|
||||
has_groups = bool(provider_group_ids)
|
||||
has_personas = bool(provider_persona_ids)
|
||||
|
||||
# Both groups AND personas set → AND logic (must satisfy both)
|
||||
if has_groups and has_personas:
|
||||
# Admins bypass group check but still must satisfy persona restrictions
|
||||
user_in_group = is_admin or bool(user_group_ids & provider_group_ids)
|
||||
persona_allowed = persona.id in provider_persona_ids if persona else False
|
||||
return user_in_group and persona_allowed
|
||||
# Persona restrictions are always enforced when set, regardless of is_public
|
||||
if has_personas and not (persona and persona.id in provider_persona_ids):
|
||||
return False
|
||||
|
||||
if provider.is_public:
|
||||
return True
|
||||
|
||||
# Only groups set → user must be in one of the groups (admins bypass)
|
||||
if has_groups:
|
||||
return is_admin or bool(user_group_ids & provider_group_ids)
|
||||
|
||||
# Only personas set → persona must be in allowed list (applies to admins too)
|
||||
if has_personas:
|
||||
return persona.id in provider_persona_ids if persona else False
|
||||
|
||||
# Neither groups nor personas set, and not public → admins can access
|
||||
return is_admin
|
||||
# No groups: either persona-whitelisted (already passed) or admin-only if locked
|
||||
return has_personas or is_admin
|
||||
|
||||
|
||||
def validate_persona_ids_exist(
|
||||
@@ -213,11 +206,29 @@ def upsert_llm_provider(
|
||||
llm_provider_upsert_request: LLMProviderUpsertRequest,
|
||||
db_session: Session,
|
||||
) -> LLMProviderView:
|
||||
existing_llm_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
)
|
||||
existing_llm_provider: LLMProviderModel | None = None
|
||||
if llm_provider_upsert_request.id:
|
||||
existing_llm_provider = fetch_existing_llm_provider_by_id(
|
||||
id=llm_provider_upsert_request.id, db_session=db_session
|
||||
)
|
||||
if not existing_llm_provider:
|
||||
raise ValueError(
|
||||
f"LLM provider with id {llm_provider_upsert_request.id} not found"
|
||||
)
|
||||
|
||||
if not existing_llm_provider:
|
||||
if existing_llm_provider.name != llm_provider_upsert_request.name:
|
||||
raise ValueError(
|
||||
f"LLM provider with id {llm_provider_upsert_request.id} name change not allowed"
|
||||
)
|
||||
else:
|
||||
existing_llm_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
)
|
||||
if existing_llm_provider:
|
||||
raise ValueError(
|
||||
f"LLM provider with name '{llm_provider_upsert_request.name}'"
|
||||
" already exists"
|
||||
)
|
||||
existing_llm_provider = LLMProviderModel(name=llm_provider_upsert_request.name)
|
||||
db_session.add(existing_llm_provider)
|
||||
|
||||
@@ -238,11 +249,7 @@ def upsert_llm_provider(
|
||||
existing_llm_provider.api_base = api_base
|
||||
existing_llm_provider.api_version = llm_provider_upsert_request.api_version
|
||||
existing_llm_provider.custom_config = custom_config
|
||||
# TODO: Remove default model name on api change
|
||||
# Needed due to /provider/{id}/default endpoint not disclosing the default model name
|
||||
existing_llm_provider.default_model_name = (
|
||||
llm_provider_upsert_request.default_model_name
|
||||
)
|
||||
|
||||
existing_llm_provider.is_public = llm_provider_upsert_request.is_public
|
||||
existing_llm_provider.is_auto_mode = llm_provider_upsert_request.is_auto_mode
|
||||
existing_llm_provider.deployment_name = llm_provider_upsert_request.deployment_name
|
||||
@@ -306,15 +313,6 @@ def upsert_llm_provider(
|
||||
display_name=model_config.display_name,
|
||||
)
|
||||
|
||||
default_model = fetch_default_model(db_session, LLMModelFlowType.CHAT)
|
||||
if default_model and default_model.llm_provider_id == existing_llm_provider.id:
|
||||
_update_default_model(
|
||||
db_session=db_session,
|
||||
provider_id=existing_llm_provider.id,
|
||||
model=existing_llm_provider.default_model_name,
|
||||
flow_type=LLMModelFlowType.CHAT,
|
||||
)
|
||||
|
||||
# Make sure the relationship table stays up to date
|
||||
update_group_llm_provider_relationships__no_commit(
|
||||
llm_provider_id=existing_llm_provider.id,
|
||||
@@ -488,6 +486,22 @@ def fetch_existing_llm_provider(
|
||||
return provider_model
|
||||
|
||||
|
||||
def fetch_existing_llm_provider_by_id(
|
||||
id: int, db_session: Session
|
||||
) -> LLMProviderModel | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel)
|
||||
.where(LLMProviderModel.id == id)
|
||||
.options(
|
||||
selectinload(LLMProviderModel.model_configurations),
|
||||
selectinload(LLMProviderModel.groups),
|
||||
selectinload(LLMProviderModel.personas),
|
||||
)
|
||||
)
|
||||
|
||||
return provider_model
|
||||
|
||||
|
||||
def fetch_embedding_provider(
|
||||
db_session: Session, provider_type: EmbeddingProvider
|
||||
) -> CloudEmbeddingProviderModel | None:
|
||||
@@ -604,22 +618,13 @@ def remove_llm_provider__no_commit(db_session: Session, provider_id: int) -> Non
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def update_default_provider(provider_id: int, db_session: Session) -> None:
|
||||
# Attempt to get the default_model_name from the provider first
|
||||
# TODO: Remove default_model_name check
|
||||
provider = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.id == provider_id,
|
||||
)
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f"LLM Provider with id={provider_id} does not exist")
|
||||
|
||||
def update_default_provider(
|
||||
provider_id: int, model_name: str, db_session: Session
|
||||
) -> None:
|
||||
_update_default_model(
|
||||
db_session,
|
||||
provider_id,
|
||||
provider.default_model_name, # type: ignore[arg-type]
|
||||
model_name,
|
||||
LLMModelFlowType.CHAT,
|
||||
)
|
||||
|
||||
@@ -805,12 +810,6 @@ def sync_auto_mode_models(
|
||||
)
|
||||
changes += 1
|
||||
|
||||
# In Auto mode, default model is always set from GitHub config
|
||||
default_model = llm_recommendations.get_default_model(provider.provider)
|
||||
if default_model and provider.default_model_name != default_model.name:
|
||||
provider.default_model_name = default_model.name
|
||||
changes += 1
|
||||
|
||||
db_session.commit()
|
||||
return changes
|
||||
|
||||
|
||||
@@ -103,7 +103,6 @@ from onyx.utils.encryption import encrypt_string_to_bytes
|
||||
from onyx.utils.sensitive import SensitiveValue
|
||||
from onyx.utils.headers import HeaderItemDict
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
|
||||
# TODO: After anonymous user migration has been deployed, make user_id columns NOT NULL
|
||||
# and update Mapped[User | None] relationships to Mapped[User] where needed.
|
||||
@@ -3265,19 +3264,6 @@ class Persona(Base):
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String)
|
||||
description: Mapped[str] = mapped_column(String)
|
||||
# Number of chunks to pass to the LLM for generation.
|
||||
num_chunks: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
chunks_above: Mapped[int] = mapped_column(Integer)
|
||||
chunks_below: Mapped[int] = mapped_column(Integer)
|
||||
# Pass every chunk through LLM for evaluation, fairly expensive
|
||||
# Can be turned off globally by admin, in which case, this setting is ignored
|
||||
llm_relevance_filter: Mapped[bool] = mapped_column(Boolean)
|
||||
# Enables using LLM to extract time and source type filters
|
||||
# Can also be admin disabled globally
|
||||
llm_filter_extraction: Mapped[bool] = mapped_column(Boolean)
|
||||
recency_bias: Mapped[RecencyBiasSetting] = mapped_column(
|
||||
Enum(RecencyBiasSetting, native_enum=False)
|
||||
)
|
||||
|
||||
# Allows the persona to specify a specific default LLM model
|
||||
# NOTE: only is applied on the actual response generation - is not used for things like
|
||||
@@ -3304,11 +3290,8 @@ class Persona(Base):
|
||||
# Treated specially (cannot be user edited etc.)
|
||||
builtin_persona: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
# Default personas are personas created by admins and are automatically added
|
||||
# to all users' assistants list.
|
||||
is_default_persona: Mapped[bool] = mapped_column(
|
||||
Boolean, default=False, nullable=False
|
||||
)
|
||||
# Featured personas are highlighted in the UI
|
||||
featured: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
# controls whether the persona is available to be selected by users
|
||||
is_visible: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
# controls the ordering of personas in the UI
|
||||
@@ -4943,7 +4926,9 @@ class ScimUserMapping(Base):
|
||||
__tablename__ = "scim_user_mapping"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
external_id: Mapped[str] = mapped_column(String, unique=True, index=True)
|
||||
external_id: Mapped[str | None] = mapped_column(
|
||||
String, unique=True, index=True, nullable=True
|
||||
)
|
||||
user_id: Mapped[UUID] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), unique=True, nullable=False
|
||||
)
|
||||
@@ -5000,3 +4985,25 @@ class CodeInterpreterServer(Base):
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
server_enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
|
||||
|
||||
class CacheStore(Base):
|
||||
"""Key-value cache table used by ``PostgresCacheBackend``.
|
||||
|
||||
Replaces Redis for simple KV caching, locks, and list operations
|
||||
when ``CACHE_BACKEND=postgres`` (NO_VECTOR_DB deployments).
|
||||
|
||||
Intentionally separate from ``KVStore``:
|
||||
- Stores raw bytes (LargeBinary) vs JSONB, matching Redis semantics.
|
||||
- Has ``expires_at`` for TTL; rows are periodically garbage-collected.
|
||||
- Holds ephemeral data (tokens, stop signals, lock state) not
|
||||
persistent application config, so cleanup can be aggressive.
|
||||
"""
|
||||
|
||||
__tablename__ = "cache_store"
|
||||
|
||||
key: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
value: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
|
||||
expires_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
@@ -18,11 +18,8 @@ from sqlalchemy.orm import Session
|
||||
from onyx.access.hierarchy_access import get_user_external_group_ids
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS
|
||||
from onyx.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
|
||||
from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX
|
||||
from onyx.db.document_access import get_accessible_documents_by_ids
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
@@ -254,13 +251,15 @@ def create_update_persona(
|
||||
# Permission to actually use these is checked later
|
||||
|
||||
try:
|
||||
# Default persona validation
|
||||
if create_persona_request.is_default_persona:
|
||||
# Curators can edit default personas, but not make them
|
||||
# Featured persona validation
|
||||
if create_persona_request.featured:
|
||||
|
||||
# Curators can edit featured personas, but not make them
|
||||
# TODO this will be reworked soon with RBAC permissions feature
|
||||
if user.role == UserRole.CURATOR or user.role == UserRole.GLOBAL_CURATOR:
|
||||
pass
|
||||
elif user.role != UserRole.ADMIN:
|
||||
raise ValueError("Only admins can make a default persona")
|
||||
raise ValueError("Only admins can make a featured persona")
|
||||
|
||||
# Convert incoming string UUIDs to UUID objects for DB operations
|
||||
converted_user_file_ids = None
|
||||
@@ -281,7 +280,6 @@ def create_update_persona(
|
||||
document_set_ids=create_persona_request.document_set_ids,
|
||||
tool_ids=create_persona_request.tool_ids,
|
||||
is_public=create_persona_request.is_public,
|
||||
recency_bias=create_persona_request.recency_bias,
|
||||
llm_model_provider_override=create_persona_request.llm_model_provider_override,
|
||||
llm_model_version_override=create_persona_request.llm_model_version_override,
|
||||
starter_messages=create_persona_request.starter_messages,
|
||||
@@ -295,10 +293,7 @@ def create_update_persona(
|
||||
remove_image=create_persona_request.remove_image,
|
||||
search_start_date=create_persona_request.search_start_date,
|
||||
label_ids=create_persona_request.label_ids,
|
||||
num_chunks=create_persona_request.num_chunks,
|
||||
llm_relevance_filter=create_persona_request.llm_relevance_filter,
|
||||
llm_filter_extraction=create_persona_request.llm_filter_extraction,
|
||||
is_default_persona=create_persona_request.is_default_persona,
|
||||
featured=create_persona_request.featured,
|
||||
user_file_ids=converted_user_file_ids,
|
||||
commit=False,
|
||||
hierarchy_node_ids=create_persona_request.hierarchy_node_ids,
|
||||
@@ -874,10 +869,6 @@ def upsert_persona(
|
||||
user: User | None,
|
||||
name: str,
|
||||
description: str,
|
||||
num_chunks: float,
|
||||
llm_relevance_filter: bool,
|
||||
llm_filter_extraction: bool,
|
||||
recency_bias: RecencyBiasSetting,
|
||||
llm_model_provider_override: str | None,
|
||||
llm_model_version_override: str | None,
|
||||
starter_messages: list[StarterMessage] | None,
|
||||
@@ -898,13 +889,11 @@ def upsert_persona(
|
||||
remove_image: bool | None = None,
|
||||
search_start_date: datetime | None = None,
|
||||
builtin_persona: bool = False,
|
||||
is_default_persona: bool | None = None,
|
||||
featured: bool | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
user_file_ids: list[UUID] | None = None,
|
||||
hierarchy_node_ids: list[int] | None = None,
|
||||
document_ids: list[str] | None = None,
|
||||
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
|
||||
chunks_below: int = CONTEXT_CHUNKS_BELOW,
|
||||
replace_base_system_prompt: bool = False,
|
||||
) -> Persona:
|
||||
"""
|
||||
@@ -1015,12 +1004,6 @@ def upsert_persona(
|
||||
# `default` and `built-in` properties can only be set when creating a persona.
|
||||
existing_persona.name = name
|
||||
existing_persona.description = description
|
||||
existing_persona.num_chunks = num_chunks
|
||||
existing_persona.chunks_above = chunks_above
|
||||
existing_persona.chunks_below = chunks_below
|
||||
existing_persona.llm_relevance_filter = llm_relevance_filter
|
||||
existing_persona.llm_filter_extraction = llm_filter_extraction
|
||||
existing_persona.recency_bias = recency_bias
|
||||
existing_persona.llm_model_provider_override = llm_model_provider_override
|
||||
existing_persona.llm_model_version_override = llm_model_version_override
|
||||
existing_persona.starter_messages = starter_messages
|
||||
@@ -1034,10 +1017,8 @@ def upsert_persona(
|
||||
if label_ids is not None:
|
||||
existing_persona.labels.clear()
|
||||
existing_persona.labels = labels or []
|
||||
existing_persona.is_default_persona = (
|
||||
is_default_persona
|
||||
if is_default_persona is not None
|
||||
else existing_persona.is_default_persona
|
||||
existing_persona.featured = (
|
||||
featured if featured is not None else existing_persona.featured
|
||||
)
|
||||
# Update embedded prompt fields if provided
|
||||
if system_prompt is not None:
|
||||
@@ -1090,12 +1071,6 @@ def upsert_persona(
|
||||
is_public=is_public,
|
||||
name=name,
|
||||
description=description,
|
||||
num_chunks=num_chunks,
|
||||
chunks_above=chunks_above,
|
||||
chunks_below=chunks_below,
|
||||
llm_relevance_filter=llm_relevance_filter,
|
||||
llm_filter_extraction=llm_filter_extraction,
|
||||
recency_bias=recency_bias,
|
||||
builtin_persona=builtin_persona,
|
||||
system_prompt=system_prompt or "",
|
||||
task_prompt=task_prompt or "",
|
||||
@@ -1111,9 +1086,7 @@ def upsert_persona(
|
||||
display_priority=display_priority,
|
||||
is_visible=is_visible,
|
||||
search_start_date=search_start_date,
|
||||
is_default_persona=(
|
||||
is_default_persona if is_default_persona is not None else False
|
||||
),
|
||||
featured=(featured if featured is not None else False),
|
||||
user_files=user_files or [],
|
||||
labels=labels or [],
|
||||
hierarchy_nodes=hierarchy_nodes or [],
|
||||
@@ -1158,9 +1131,9 @@ def delete_old_default_personas(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_persona_is_default(
|
||||
def update_persona_featured(
|
||||
persona_id: int,
|
||||
is_default: bool,
|
||||
featured: bool,
|
||||
db_session: Session,
|
||||
user: User,
|
||||
) -> None:
|
||||
@@ -1168,7 +1141,7 @@ def update_persona_is_default(
|
||||
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
|
||||
)
|
||||
|
||||
persona.is_default_persona = is_default
|
||||
persona.featured = featured
|
||||
db_session.commit()
|
||||
|
||||
|
||||
|
||||
@@ -9,8 +9,9 @@ from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.background import BackgroundTasks
|
||||
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
@@ -105,8 +106,8 @@ def upload_files_to_user_files_with_indexing(
|
||||
user: User,
|
||||
temp_id_map: dict[str, str] | None,
|
||||
db_session: Session,
|
||||
background_tasks: BackgroundTasks | None = None,
|
||||
) -> CategorizedFilesResult:
|
||||
# Validate project ownership if a project_id is provided
|
||||
if project_id is not None and user is not None:
|
||||
if not check_project_ownership(project_id, user.id, db_session):
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
@@ -127,16 +128,27 @@ def upload_files_to_user_files_with_indexing(
|
||||
logger.warning(
|
||||
f"File {rejected_file.filename} rejected for {rejected_file.reason}"
|
||||
)
|
||||
for user_file in user_files:
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered indexing for user_file_id={user_file.id} with task_id={task.id}"
|
||||
)
|
||||
|
||||
if DISABLE_VECTOR_DB and background_tasks is not None:
|
||||
from onyx.background.task_utils import drain_processing_loop
|
||||
|
||||
background_tasks.add_task(drain_processing_loop, tenant_id)
|
||||
for user_file in user_files:
|
||||
logger.info(f"Queued in-process processing for user_file_id={user_file.id}")
|
||||
else:
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
|
||||
for user_file in user_files:
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered indexing for user_file_id={user_file.id} "
|
||||
f"with task_id={task.id}"
|
||||
)
|
||||
|
||||
return CategorizedFilesResult(
|
||||
user_files=user_files,
|
||||
|
||||
@@ -5,8 +5,6 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.constants import DEFAULT_PERSONA_SLACK_CHANNEL_NAME
|
||||
from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX
|
||||
from onyx.db.models import ChannelConfig
|
||||
@@ -45,8 +43,6 @@ def create_slack_channel_persona(
|
||||
channel_name: str | None,
|
||||
document_set_ids: list[int],
|
||||
existing_persona_id: int | None = None,
|
||||
num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
||||
enable_auto_filters: bool = False,
|
||||
) -> Persona:
|
||||
"""NOTE: does not commit changes"""
|
||||
|
||||
@@ -73,17 +69,13 @@ def create_slack_channel_persona(
|
||||
system_prompt="",
|
||||
task_prompt="",
|
||||
datetime_aware=True,
|
||||
num_chunks=num_chunks,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=enable_auto_filters,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
tool_ids=[search_tool.id],
|
||||
document_set_ids=document_set_ids,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
is_public=True,
|
||||
is_default_persona=False,
|
||||
featured=False,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
|
||||
@@ -6,7 +6,6 @@ import httpx
|
||||
from opensearchpy import NotFoundError
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH
|
||||
from onyx.configs.app_configs import VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT
|
||||
from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
@@ -563,12 +562,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
)
|
||||
|
||||
if not self._client.index_exists():
|
||||
if USING_AWS_MANAGED_OPENSEARCH:
|
||||
index_settings = (
|
||||
DocumentSchema.get_index_settings_for_aws_managed_opensearch()
|
||||
)
|
||||
else:
|
||||
index_settings = DocumentSchema.get_index_settings()
|
||||
index_settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
self._client.create_index(
|
||||
mappings=expected_mappings,
|
||||
settings=index_settings,
|
||||
|
||||
@@ -12,6 +12,7 @@ from pydantic import model_validator
|
||||
from pydantic import SerializerFunctionWrapHandler
|
||||
|
||||
from onyx.configs.app_configs import OPENSEARCH_TEXT_ANALYZER
|
||||
from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
|
||||
from onyx.document_index.opensearch.constants import EF_CONSTRUCTION
|
||||
@@ -525,7 +526,7 @@ class DocumentSchema:
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_index_settings_for_aws_managed_opensearch() -> dict[str, Any]:
|
||||
def get_index_settings_for_aws_managed_opensearch_st_dev() -> dict[str, Any]:
|
||||
"""
|
||||
Settings for AWS-managed OpenSearch.
|
||||
|
||||
@@ -546,3 +547,41 @@ class DocumentSchema:
|
||||
"knn.algo_param.ef_search": EF_SEARCH,
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_index_settings_for_aws_managed_opensearch_mt_cloud() -> dict[str, Any]:
|
||||
"""
|
||||
Settings for AWS-managed OpenSearch in multi-tenant cloud.
|
||||
|
||||
324 shards very roughly targets a storage load of ~30Gb per shard, which
|
||||
according to AWS OpenSearch documentation is within a good target range.
|
||||
|
||||
As documented above we need 2 replicas for a total of 3 copies of the
|
||||
data because the cluster is configured with 3-AZ awareness.
|
||||
"""
|
||||
return {
|
||||
"index": {
|
||||
"number_of_shards": 324,
|
||||
"number_of_replicas": 2,
|
||||
# Required for vector search.
|
||||
"knn": True,
|
||||
"knn.algo_param.ef_search": EF_SEARCH,
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_index_settings_based_on_environment() -> dict[str, Any]:
|
||||
"""
|
||||
Returns the index settings based on the environment.
|
||||
"""
|
||||
if USING_AWS_MANAGED_OPENSEARCH:
|
||||
if MULTI_TENANT:
|
||||
return (
|
||||
DocumentSchema.get_index_settings_for_aws_managed_opensearch_mt_cloud()
|
||||
)
|
||||
else:
|
||||
return (
|
||||
DocumentSchema.get_index_settings_for_aws_managed_opensearch_st_dev()
|
||||
)
|
||||
else:
|
||||
return DocumentSchema.get_index_settings()
|
||||
|
||||
@@ -4,39 +4,33 @@ import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Redis key prefix for OAuth state
|
||||
OAUTH_STATE_PREFIX = "federated_oauth"
|
||||
# Default TTL for OAuth state (5 minutes)
|
||||
OAUTH_STATE_TTL = 300
|
||||
OAUTH_STATE_TTL = 300 # 5 minutes
|
||||
|
||||
|
||||
class OAuthSession:
|
||||
"""Represents an OAuth session stored in Redis."""
|
||||
"""Represents an OAuth session stored in the cache backend."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
federated_connector_id: int,
|
||||
user_id: str,
|
||||
redirect_uri: Optional[str] = None,
|
||||
additional_data: Optional[Dict[str, Any]] = None,
|
||||
redirect_uri: str | None = None,
|
||||
additional_data: dict[str, Any] | None = None,
|
||||
):
|
||||
self.federated_connector_id = federated_connector_id
|
||||
self.user_id = user_id
|
||||
self.redirect_uri = redirect_uri
|
||||
self.additional_data = additional_data or {}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for Redis storage."""
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"federated_connector_id": self.federated_connector_id,
|
||||
"user_id": self.user_id,
|
||||
@@ -45,8 +39,7 @@ class OAuthSession:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "OAuthSession":
|
||||
"""Create from dictionary retrieved from Redis."""
|
||||
def from_dict(cls, data: dict[str, Any]) -> "OAuthSession":
|
||||
return cls(
|
||||
federated_connector_id=data["federated_connector_id"],
|
||||
user_id=data["user_id"],
|
||||
@@ -58,31 +51,27 @@ class OAuthSession:
|
||||
def generate_oauth_state(
|
||||
federated_connector_id: int,
|
||||
user_id: str,
|
||||
redirect_uri: Optional[str] = None,
|
||||
additional_data: Optional[Dict[str, Any]] = None,
|
||||
redirect_uri: str | None = None,
|
||||
additional_data: dict[str, Any] | None = None,
|
||||
ttl: int = OAUTH_STATE_TTL,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a secure state parameter and store session data in Redis.
|
||||
Generate a secure state parameter and store session data in the cache backend.
|
||||
|
||||
Args:
|
||||
federated_connector_id: ID of the federated connector
|
||||
user_id: ID of the user initiating OAuth
|
||||
redirect_uri: Optional redirect URI after OAuth completion
|
||||
additional_data: Any additional data to store with the session
|
||||
ttl: Time-to-live in seconds for the Redis key
|
||||
ttl: Time-to-live in seconds for the cache key
|
||||
|
||||
Returns:
|
||||
Base64-encoded state parameter
|
||||
"""
|
||||
# Generate a random UUID for the state
|
||||
state_uuid = uuid.uuid4()
|
||||
state_b64 = base64.urlsafe_b64encode(state_uuid.bytes).decode("utf-8").rstrip("=")
|
||||
|
||||
# Convert UUID to base64 for URL-safe state parameter
|
||||
state_bytes = state_uuid.bytes
|
||||
state_b64 = base64.urlsafe_b64encode(state_bytes).decode("utf-8").rstrip("=")
|
||||
|
||||
# Create session object
|
||||
session = OAuthSession(
|
||||
federated_connector_id=federated_connector_id,
|
||||
user_id=user_id,
|
||||
@@ -90,15 +79,9 @@ def generate_oauth_state(
|
||||
additional_data=additional_data,
|
||||
)
|
||||
|
||||
# Store in Redis with TTL
|
||||
redis_client = get_redis_client()
|
||||
redis_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
|
||||
|
||||
redis_client.set(
|
||||
redis_key,
|
||||
json.dumps(session.to_dict()),
|
||||
ex=ttl,
|
||||
)
|
||||
cache = get_cache_backend()
|
||||
cache_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
|
||||
cache.set(cache_key, json.dumps(session.to_dict()), ex=ttl)
|
||||
|
||||
logger.info(
|
||||
f"Generated OAuth state for federated_connector_id={federated_connector_id}, "
|
||||
@@ -125,18 +108,15 @@ def verify_oauth_state(state: str) -> OAuthSession:
|
||||
state_bytes = base64.urlsafe_b64decode(padded_state)
|
||||
state_uuid = uuid.UUID(bytes=state_bytes)
|
||||
|
||||
# Look up in Redis
|
||||
redis_client = get_redis_client()
|
||||
redis_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
|
||||
cache = get_cache_backend()
|
||||
cache_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
|
||||
|
||||
session_data = cast(bytes, redis_client.get(redis_key))
|
||||
session_data = cache.get(cache_key)
|
||||
if not session_data:
|
||||
raise ValueError(f"OAuth state not found in Redis: {state}")
|
||||
raise ValueError(f"OAuth state not found: {state}")
|
||||
|
||||
# Delete the key after retrieval (one-time use)
|
||||
redis_client.delete(redis_key)
|
||||
cache.delete(cache_key)
|
||||
|
||||
# Parse and return session
|
||||
session_dict = json.loads(session_data)
|
||||
return OAuthSession.from_dict(session_dict)
|
||||
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import json
|
||||
from typing import cast
|
||||
|
||||
from redis.client import Redis
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import KVStore
|
||||
from onyx.key_value_store.interface import KeyValueStore
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
|
||||
@@ -20,22 +18,27 @@ KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day
|
||||
|
||||
|
||||
class PgRedisKVStore(KeyValueStore):
|
||||
def __init__(self, redis_client: Redis | None = None) -> None:
|
||||
# If no redis_client is provided, fall back to the context var
|
||||
if redis_client is not None:
|
||||
self.redis_client = redis_client
|
||||
else:
|
||||
self.redis_client = get_redis_client()
|
||||
def __init__(self, cache: CacheBackend | None = None) -> None:
|
||||
self._cache = cache
|
||||
|
||||
def _get_cache(self) -> CacheBackend:
|
||||
if self._cache is None:
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
|
||||
self._cache = get_cache_backend()
|
||||
return self._cache
|
||||
|
||||
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
|
||||
# Not encrypted in Redis, but encrypted in Postgres
|
||||
# Not encrypted in Cache backend (typically Redis), but encrypted in Postgres
|
||||
try:
|
||||
self.redis_client.set(
|
||||
self._get_cache().set(
|
||||
REDIS_KEY_PREFIX + key, json.dumps(val), ex=KV_REDIS_KEY_EXPIRATION
|
||||
)
|
||||
except Exception as e:
|
||||
# Fallback gracefully to Postgres if Redis fails
|
||||
logger.error(f"Failed to set value in Redis for key '{key}': {str(e)}")
|
||||
# Fallback gracefully to Postgres if Cache backend fails
|
||||
logger.error(
|
||||
f"Failed to set value in Cache backend for key '{key}': {str(e)}"
|
||||
)
|
||||
|
||||
encrypted_val = val if encrypt else None
|
||||
plain_val = val if not encrypt else None
|
||||
@@ -53,16 +56,12 @@ class PgRedisKVStore(KeyValueStore):
|
||||
def load(self, key: str, refresh_cache: bool = False) -> JSON_ro:
|
||||
if not refresh_cache:
|
||||
try:
|
||||
redis_value = self.redis_client.get(REDIS_KEY_PREFIX + key)
|
||||
if redis_value:
|
||||
if not isinstance(redis_value, bytes):
|
||||
raise ValueError(
|
||||
f"Redis value for key '{key}' is not a bytes object"
|
||||
)
|
||||
return json.loads(redis_value.decode("utf-8"))
|
||||
cached = self._get_cache().get(REDIS_KEY_PREFIX + key)
|
||||
if cached is not None:
|
||||
return json.loads(cached.decode("utf-8"))
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to get value from Redis for key '{key}': {str(e)}"
|
||||
f"Failed to get value from cache for key '{key}': {str(e)}"
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -79,21 +78,21 @@ class PgRedisKVStore(KeyValueStore):
|
||||
value = None
|
||||
|
||||
try:
|
||||
self.redis_client.set(
|
||||
self._get_cache().set(
|
||||
REDIS_KEY_PREFIX + key,
|
||||
json.dumps(value),
|
||||
ex=KV_REDIS_KEY_EXPIRATION,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to set value in Redis for key '{key}': {str(e)}")
|
||||
logger.error(f"Failed to set value in cache for key '{key}': {str(e)}")
|
||||
|
||||
return cast(JSON_ro, value)
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
try:
|
||||
self.redis_client.delete(REDIS_KEY_PREFIX + key)
|
||||
self._get_cache().delete(REDIS_KEY_PREFIX + key)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete value from Redis for key '{key}': {str(e)}")
|
||||
logger.error(f"Failed to delete value from cache for key '{key}': {str(e)}")
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
result = db_session.query(KVStore).filter_by(key=key).delete()
|
||||
|
||||
@@ -67,6 +67,18 @@ Status checked against LiteLLM v1.81.6-nightly (2026-02-02):
|
||||
STATUS: STILL NEEDED - litellm_core_utils/litellm_logging.py lines 3185-3199 set
|
||||
usage as a dict with chat completion format instead of keeping it as
|
||||
ResponseAPIUsage. Our patch creates a deep copy before modification.
|
||||
|
||||
7. Responses API metadata=None TypeError (_patch_responses_metadata_none):
|
||||
- LiteLLM's @client decorator wrapper in utils.py uses kwargs.get("metadata", {})
|
||||
to check for router calls, but when metadata is explicitly None (key exists with
|
||||
value None), the default {} is not used
|
||||
- This causes "argument of type 'NoneType' is not iterable" TypeError which swallows
|
||||
the real exception (e.g. AuthenticationError for wrong API key)
|
||||
- Surfaces as: APIConnectionError: OpenAIException - argument of type 'NoneType' is
|
||||
not iterable
|
||||
STATUS: STILL NEEDED - litellm/utils.py wrapper function (line 1721) does not guard
|
||||
against metadata being explicitly None. Triggered when Responses API bridge
|
||||
passes **litellm_params containing metadata=None.
|
||||
"""
|
||||
|
||||
import time
|
||||
@@ -725,6 +737,44 @@ def _patch_logging_assembled_streaming_response() -> None:
|
||||
LiteLLMLoggingObj._get_assembled_streaming_response = _patched_get_assembled_streaming_response # type: ignore[method-assign]
|
||||
|
||||
|
||||
def _patch_responses_metadata_none() -> None:
|
||||
"""
|
||||
Patches litellm.responses to normalize metadata=None to metadata={} in kwargs.
|
||||
|
||||
LiteLLM's @client decorator wrapper in utils.py (line 1721) does:
|
||||
_is_litellm_router_call = "model_group" in kwargs.get("metadata", {})
|
||||
When metadata is explicitly None in kwargs, kwargs.get("metadata", {}) returns
|
||||
None (the key exists, so the default is not used), causing:
|
||||
TypeError: argument of type 'NoneType' is not iterable
|
||||
|
||||
This swallows the real exception (e.g. AuthenticationError) and surfaces as:
|
||||
APIConnectionError: OpenAIException - argument of type 'NoneType' is not iterable
|
||||
|
||||
This happens when the Responses API bridge calls litellm.responses() with
|
||||
**litellm_params which may contain metadata=None.
|
||||
|
||||
STATUS: STILL NEEDED - litellm/utils.py wrapper function uses kwargs.get("metadata", {})
|
||||
which does not guard against metadata being explicitly None. Same pattern exists
|
||||
on line 1407 for async path.
|
||||
"""
|
||||
import litellm as _litellm
|
||||
from functools import wraps
|
||||
|
||||
original_responses = _litellm.responses
|
||||
|
||||
if getattr(original_responses, "_metadata_patched", False):
|
||||
return
|
||||
|
||||
@wraps(original_responses)
|
||||
def _patched_responses(*args: Any, **kwargs: Any) -> Any:
|
||||
if kwargs.get("metadata") is None:
|
||||
kwargs["metadata"] = {}
|
||||
return original_responses(*args, **kwargs)
|
||||
|
||||
_patched_responses._metadata_patched = True # type: ignore[attr-defined]
|
||||
_litellm.responses = _patched_responses
|
||||
|
||||
|
||||
def apply_monkey_patches() -> None:
|
||||
"""
|
||||
Apply all necessary monkey patches to LiteLLM for compatibility.
|
||||
@@ -736,6 +786,7 @@ def apply_monkey_patches() -> None:
|
||||
- Patching AzureOpenAIResponsesAPIConfig.should_fake_stream to enable native streaming
|
||||
- Patching ResponsesAPIResponse.model_construct to fix usage format in all code paths
|
||||
- Patching LiteLLMLoggingObj._get_assembled_streaming_response to avoid mutating original response
|
||||
- Patching litellm.responses to fix metadata=None causing TypeError in error handling
|
||||
"""
|
||||
_patch_ollama_chunk_parser()
|
||||
_patch_openai_responses_parallel_tool_calls()
|
||||
@@ -743,3 +794,4 @@ def apply_monkey_patches() -> None:
|
||||
_patch_azure_responses_should_fake_stream()
|
||||
_patch_responses_api_usage_format()
|
||||
_patch_logging_assembled_streaming_response()
|
||||
_patch_responses_metadata_none()
|
||||
|
||||
@@ -13,44 +13,38 @@ from datetime import datetime
|
||||
import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
from onyx.db.llm import fetch_auto_mode_providers
|
||||
from onyx.db.llm import sync_auto_mode_models
|
||||
from onyx.llm.well_known_providers.auto_update_models import LLMRecommendations
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Redis key for caching the last updated timestamp (per-tenant)
|
||||
_REDIS_KEY_LAST_UPDATED_AT = "auto_llm_update:last_updated_at"
|
||||
_CACHE_KEY_LAST_UPDATED_AT = "auto_llm_update:last_updated_at"
|
||||
_CACHE_TTL_SECONDS = 60 * 60 * 24 # 24 hours
|
||||
|
||||
|
||||
def _get_cached_last_updated_at() -> datetime | None:
|
||||
"""Get the cached last_updated_at timestamp from Redis."""
|
||||
try:
|
||||
redis_client = get_redis_client()
|
||||
value = redis_client.get(_REDIS_KEY_LAST_UPDATED_AT)
|
||||
if value and isinstance(value, bytes):
|
||||
# Value is bytes, decode to string then parse as ISO format
|
||||
value = get_cache_backend().get(_CACHE_KEY_LAST_UPDATED_AT)
|
||||
if value is not None:
|
||||
return datetime.fromisoformat(value.decode("utf-8"))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get cached last_updated_at from Redis: {e}")
|
||||
logger.warning(f"Failed to get cached last_updated_at: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _set_cached_last_updated_at(updated_at: datetime) -> None:
|
||||
"""Set the cached last_updated_at timestamp in Redis."""
|
||||
try:
|
||||
redis_client = get_redis_client()
|
||||
# Store as ISO format string, with 24 hour expiration
|
||||
redis_client.set(
|
||||
_REDIS_KEY_LAST_UPDATED_AT,
|
||||
get_cache_backend().set(
|
||||
_CACHE_KEY_LAST_UPDATED_AT,
|
||||
updated_at.isoformat(),
|
||||
ex=60 * 60 * 24, # 24 hours
|
||||
ex=_CACHE_TTL_SECONDS,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to set cached last_updated_at in Redis: {e}")
|
||||
logger.warning(f"Failed to set cached last_updated_at: {e}")
|
||||
|
||||
|
||||
def fetch_llm_recommendations_from_github(
|
||||
@@ -148,9 +142,8 @@ def sync_llm_models_from_github(
|
||||
|
||||
|
||||
def reset_cache() -> None:
|
||||
"""Reset the cache timestamp in Redis. Useful for testing."""
|
||||
"""Reset the cache timestamp. Useful for testing."""
|
||||
try:
|
||||
redis_client = get_redis_client()
|
||||
redis_client.delete(_REDIS_KEY_LAST_UPDATED_AT)
|
||||
get_cache_backend().delete(_CACHE_KEY_LAST_UPDATED_AT)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to reset cache in Redis: {e}")
|
||||
logger.warning(f"Failed to reset cache: {e}")
|
||||
|
||||
@@ -32,11 +32,14 @@ from onyx.auth.schemas import UserUpdate
|
||||
from onyx.auth.users import auth_backend
|
||||
from onyx.auth.users import create_onyx_oauth_router
|
||||
from onyx.auth.users import fastapi_users
|
||||
from onyx.cache.interface import CacheBackendType
|
||||
from onyx.configs.app_configs import APP_API_PREFIX
|
||||
from onyx.configs.app_configs import APP_HOST
|
||||
from onyx.configs.app_configs import APP_PORT
|
||||
from onyx.configs.app_configs import AUTH_RATE_LIMITING_ENABLED
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import CACHE_BACKEND
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import LOG_ENDPOINT_LATENCY
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_ID
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
@@ -254,8 +257,53 @@ def include_auth_router_with_prefix(
|
||||
)
|
||||
|
||||
|
||||
def validate_cache_backend_settings() -> None:
|
||||
"""Validate that CACHE_BACKEND=postgres is only used with DISABLE_VECTOR_DB.
|
||||
|
||||
The Postgres cache backend eliminates the Redis dependency, but only works
|
||||
when Celery is not running (which requires DISABLE_VECTOR_DB=true).
|
||||
"""
|
||||
if CACHE_BACKEND == CacheBackendType.POSTGRES and not DISABLE_VECTOR_DB:
|
||||
raise RuntimeError(
|
||||
"CACHE_BACKEND=postgres requires DISABLE_VECTOR_DB=true. "
|
||||
"The Postgres cache backend is only supported in no-vector-DB "
|
||||
"deployments where Celery is replaced by the in-process task runner."
|
||||
)
|
||||
|
||||
|
||||
def validate_no_vector_db_settings() -> None:
|
||||
"""Validate that DISABLE_VECTOR_DB is not combined with incompatible settings.
|
||||
|
||||
Raises RuntimeError if DISABLE_VECTOR_DB is set alongside MULTI_TENANT or ENABLE_CRAFT,
|
||||
since these modes require infrastructure that is removed in no-vector-DB deployments.
|
||||
"""
|
||||
if not DISABLE_VECTOR_DB:
|
||||
return
|
||||
|
||||
if MULTI_TENANT:
|
||||
raise RuntimeError(
|
||||
"DISABLE_VECTOR_DB cannot be used with MULTI_TENANT. "
|
||||
"Multi-tenant deployments require the vector database for "
|
||||
"per-tenant document indexing and search. Run in single-tenant "
|
||||
"mode when disabling the vector database."
|
||||
)
|
||||
|
||||
from onyx.server.features.build.configs import ENABLE_CRAFT
|
||||
|
||||
if ENABLE_CRAFT:
|
||||
raise RuntimeError(
|
||||
"DISABLE_VECTOR_DB cannot be used with ENABLE_CRAFT. "
|
||||
"Onyx Craft requires background workers for sandbox lifecycle "
|
||||
"management, which are removed in no-vector-DB deployments. "
|
||||
"Disable Craft (ENABLE_CRAFT=false) when disabling the vector database."
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
|
||||
validate_no_vector_db_settings()
|
||||
validate_cache_backend_settings()
|
||||
|
||||
# Set recursion limit
|
||||
if SYSTEM_RECURSION_LIMIT is not None:
|
||||
sys.setrecursionlimit(SYSTEM_RECURSION_LIMIT)
|
||||
@@ -324,8 +372,20 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
|
||||
if AUTH_RATE_LIMITING_ENABLED:
|
||||
await setup_auth_limiter()
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
from onyx.background.periodic_poller import recover_stuck_user_files
|
||||
from onyx.background.periodic_poller import start_periodic_poller
|
||||
|
||||
recover_stuck_user_files(POSTGRES_DEFAULT_SCHEMA)
|
||||
start_periodic_poller(POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
yield
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
from onyx.background.periodic_poller import stop_periodic_poller
|
||||
|
||||
stop_periodic_poller()
|
||||
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
if AUTH_RATE_LIMITING_ENABLED:
|
||||
|
||||
@@ -3,10 +3,12 @@ import datetime
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_FEEDBACK_REMINDER
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.user_preferences import activate_user
|
||||
from onyx.db.users import add_slack_user_if_not_exists
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.onyxbot.slack.blocks import get_feedback_reminder_blocks
|
||||
@@ -243,6 +245,44 @@ def handle_message(
|
||||
)
|
||||
return False
|
||||
|
||||
elif (
|
||||
not existing_user.is_active
|
||||
and existing_user.role == UserRole.SLACK_USER
|
||||
):
|
||||
check_seat_fn = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license",
|
||||
"check_seat_availability",
|
||||
None,
|
||||
)
|
||||
seat_result = check_seat_fn(db_session=db_session)
|
||||
if seat_result is not None and not seat_result.available:
|
||||
logger.info(
|
||||
f"Blocked inactive Slack user {message_info.email}: "
|
||||
f"{seat_result.error_message}"
|
||||
)
|
||||
respond_in_thread_or_channel(
|
||||
client=client,
|
||||
channel=channel,
|
||||
thread_ts=message_info.msg_to_respond,
|
||||
text=(
|
||||
"We weren't able to respond because your organization "
|
||||
"has reached its user seat limit. Your account is "
|
||||
"currently deactivated and cannot be reactivated "
|
||||
"until more seats are available. Please contact "
|
||||
"your Onyx administrator."
|
||||
),
|
||||
)
|
||||
return False
|
||||
|
||||
activate_user(existing_user, db_session)
|
||||
invalidate_license_cache_fn = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license",
|
||||
"invalidate_license_cache",
|
||||
None,
|
||||
)
|
||||
invalidate_license_cache_fn()
|
||||
logger.info(f"Reactivated inactive Slack user {message_info.email}")
|
||||
|
||||
add_slack_user_if_not_exists(db_session, message_info.email)
|
||||
|
||||
# first check if we need to respond with a standard answer
|
||||
|
||||
@@ -32,7 +32,7 @@ from onyx.db.persona import get_persona_snapshots_for_user
|
||||
from onyx.db.persona import get_persona_snapshots_paginated
|
||||
from onyx.db.persona import mark_persona_as_deleted
|
||||
from onyx.db.persona import mark_persona_as_not_deleted
|
||||
from onyx.db.persona import update_persona_is_default
|
||||
from onyx.db.persona import update_persona_featured
|
||||
from onyx.db.persona import update_persona_label
|
||||
from onyx.db.persona import update_persona_public_status
|
||||
from onyx.db.persona import update_persona_shared
|
||||
@@ -130,8 +130,8 @@ class IsPublicRequest(BaseModel):
|
||||
is_public: bool
|
||||
|
||||
|
||||
class IsDefaultRequest(BaseModel):
|
||||
is_default_persona: bool
|
||||
class IsFeaturedRequest(BaseModel):
|
||||
featured: bool
|
||||
|
||||
|
||||
@admin_router.patch("/{persona_id}/visible")
|
||||
@@ -168,22 +168,22 @@ def patch_user_persona_public_status(
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.patch("/{persona_id}/default")
|
||||
def patch_persona_default_status(
|
||||
@admin_router.patch("/{persona_id}/featured")
|
||||
def patch_persona_featured_status(
|
||||
persona_id: int,
|
||||
is_default_request: IsDefaultRequest,
|
||||
is_featured_request: IsFeaturedRequest,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
update_persona_is_default(
|
||||
update_persona_featured(
|
||||
persona_id=persona_id,
|
||||
is_default=is_default_request.is_default_persona,
|
||||
featured=is_featured_request.featured,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to update persona default status")
|
||||
logger.exception("Failed to update persona featured status")
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import HierarchyNodeType
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import HierarchyNode
|
||||
@@ -108,11 +107,7 @@ class PersonaUpsertRequest(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
document_set_ids: list[int]
|
||||
num_chunks: float
|
||||
is_public: bool
|
||||
recency_bias: RecencyBiasSetting
|
||||
llm_filter_extraction: bool
|
||||
llm_relevance_filter: bool
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
starter_messages: list[StarterMessage] | None = None
|
||||
@@ -128,7 +123,7 @@ class PersonaUpsertRequest(BaseModel):
|
||||
)
|
||||
search_start_date: datetime | None = None
|
||||
label_ids: list[int] | None = None
|
||||
is_default_persona: bool = False
|
||||
featured: bool = False
|
||||
display_priority: int | None = None
|
||||
# Accept string UUIDs from frontend
|
||||
user_file_ids: list[str] | None = None
|
||||
@@ -155,9 +150,6 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
tools: list[ToolSnapshot]
|
||||
starter_messages: list[StarterMessage] | None
|
||||
|
||||
llm_relevance_filter: bool
|
||||
llm_filter_extraction: bool
|
||||
|
||||
# only show document sets in the UI that the assistant has access to
|
||||
document_sets: list[DocumentSetSummary]
|
||||
# Counts for knowledge sources (used to determine if search tool should be enabled)
|
||||
@@ -175,7 +167,7 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
is_public: bool
|
||||
is_visible: bool
|
||||
display_priority: int | None
|
||||
is_default_persona: bool
|
||||
featured: bool
|
||||
builtin_persona: bool
|
||||
|
||||
# Used for filtering
|
||||
@@ -214,8 +206,6 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
if should_expose_tool_to_fe(tool)
|
||||
],
|
||||
starter_messages=persona.starter_messages,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
document_sets=[
|
||||
DocumentSetSummary.from_model(document_set)
|
||||
for document_set in persona.document_sets
|
||||
@@ -230,7 +220,7 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
is_public=persona.is_public,
|
||||
is_visible=persona.is_visible,
|
||||
display_priority=persona.display_priority,
|
||||
is_default_persona=persona.is_default_persona,
|
||||
featured=persona.featured,
|
||||
builtin_persona=persona.builtin_persona,
|
||||
labels=[PersonaLabelSnapshot.from_model(label) for label in persona.labels],
|
||||
owner=(
|
||||
@@ -252,11 +242,9 @@ class PersonaSnapshot(BaseModel):
|
||||
# Return string UUIDs to frontend for consistency
|
||||
user_file_ids: list[str]
|
||||
display_priority: int | None
|
||||
is_default_persona: bool
|
||||
featured: bool
|
||||
builtin_persona: bool
|
||||
starter_messages: list[StarterMessage] | None
|
||||
llm_relevance_filter: bool
|
||||
llm_filter_extraction: bool
|
||||
tools: list[ToolSnapshot]
|
||||
labels: list["PersonaLabelSnapshot"]
|
||||
owner: MinimalUserSnapshot | None
|
||||
@@ -265,7 +253,6 @@ class PersonaSnapshot(BaseModel):
|
||||
document_sets: list[DocumentSetSummary]
|
||||
llm_model_provider_override: str | None
|
||||
llm_model_version_override: str | None
|
||||
num_chunks: float | None
|
||||
# Hierarchy nodes attached for scoped search
|
||||
hierarchy_nodes: list[HierarchyNodeSnapshot] = Field(default_factory=list)
|
||||
# Individual documents attached for scoped search
|
||||
@@ -289,11 +276,9 @@ class PersonaSnapshot(BaseModel):
|
||||
icon_name=persona.icon_name,
|
||||
user_file_ids=[str(file.id) for file in persona.user_files],
|
||||
display_priority=persona.display_priority,
|
||||
is_default_persona=persona.is_default_persona,
|
||||
featured=persona.featured,
|
||||
builtin_persona=persona.builtin_persona,
|
||||
starter_messages=persona.starter_messages,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
tools=[
|
||||
ToolSnapshot.from_model(tool)
|
||||
for tool in persona.tools
|
||||
@@ -324,7 +309,6 @@ class PersonaSnapshot(BaseModel):
|
||||
],
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
num_chunks=persona.num_chunks,
|
||||
system_prompt=persona.system_prompt,
|
||||
replace_base_system_prompt=persona.replace_base_system_prompt,
|
||||
task_prompt=persona.task_prompt,
|
||||
@@ -332,12 +316,10 @@ class PersonaSnapshot(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
# Model with full context on perona's internal settings
|
||||
# Model with full context on persona's internal settings
|
||||
# This is used for flows which need to know all settings
|
||||
class FullPersonaSnapshot(PersonaSnapshot):
|
||||
search_start_date: datetime | None = None
|
||||
llm_relevance_filter: bool = False
|
||||
llm_filter_extraction: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
@@ -360,7 +342,7 @@ class FullPersonaSnapshot(PersonaSnapshot):
|
||||
icon_name=persona.icon_name,
|
||||
user_file_ids=[str(file.id) for file in persona.user_files],
|
||||
display_priority=persona.display_priority,
|
||||
is_default_persona=persona.is_default_persona,
|
||||
featured=persona.featured,
|
||||
builtin_persona=persona.builtin_persona,
|
||||
starter_messages=persona.starter_messages,
|
||||
users=[
|
||||
@@ -391,10 +373,7 @@ class FullPersonaSnapshot(PersonaSnapshot):
|
||||
DocumentSetSummary.from_model(document_set_model)
|
||||
for document_set_model in persona.document_sets
|
||||
],
|
||||
num_chunks=persona.num_chunks,
|
||||
search_start_date=persona.search_start_date,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
system_prompt=persona.system_prompt,
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import BackgroundTasks
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import Form
|
||||
@@ -12,13 +13,7 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
enqueue_user_file_project_sync_task,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
get_user_file_project_sync_queue_depth,
|
||||
)
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -34,7 +29,6 @@ from onyx.db.models import UserProject
|
||||
from onyx.db.persona import get_personas_by_ids
|
||||
from onyx.db.projects import get_project_token_count
|
||||
from onyx.db.projects import upload_files_to_user_files_with_indexing
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.features.projects.models import CategorizedFilesSnapshot
|
||||
from onyx.server.features.projects.models import ChatSessionRequest
|
||||
from onyx.server.features.projects.models import TokenCountResponse
|
||||
@@ -55,7 +49,27 @@ class UserFileDeleteResult(BaseModel):
|
||||
assistant_names: list[str] = []
|
||||
|
||||
|
||||
def _trigger_user_file_project_sync(user_file_id: UUID, tenant_id: str) -> None:
|
||||
def _trigger_user_file_project_sync(
|
||||
user_file_id: UUID,
|
||||
tenant_id: str,
|
||||
background_tasks: BackgroundTasks | None = None,
|
||||
) -> None:
|
||||
if DISABLE_VECTOR_DB and background_tasks is not None:
|
||||
from onyx.background.task_utils import drain_project_sync_loop
|
||||
|
||||
background_tasks.add_task(drain_project_sync_loop, tenant_id)
|
||||
logger.info(f"Queued in-process project sync for user_file_id={user_file_id}")
|
||||
return
|
||||
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
enqueue_user_file_project_sync_task,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
get_user_file_project_sync_queue_depth,
|
||||
)
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
|
||||
queue_depth = get_user_file_project_sync_queue_depth(client_app)
|
||||
if queue_depth > USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH:
|
||||
logger.warning(
|
||||
@@ -111,6 +125,7 @@ def create_project(
|
||||
|
||||
@router.post("/file/upload", tags=PUBLIC_API_TAGS)
|
||||
def upload_user_files(
|
||||
bg_tasks: BackgroundTasks,
|
||||
files: list[UploadFile] = File(...),
|
||||
project_id: int | None = Form(None),
|
||||
temp_id_map: str | None = Form(None), # JSON string mapping hashed key -> temp_id
|
||||
@@ -137,12 +152,12 @@ def upload_user_files(
|
||||
user=user,
|
||||
temp_id_map=parsed_temp_id_map,
|
||||
db_session=db_session,
|
||||
background_tasks=bg_tasks if DISABLE_VECTOR_DB else None,
|
||||
)
|
||||
|
||||
return CategorizedFilesSnapshot.from_result(categorized_files_result)
|
||||
|
||||
except Exception as e:
|
||||
# Log error with type, message, and stack for easier debugging
|
||||
logger.exception(f"Error uploading files - {type(e).__name__}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
@@ -192,6 +207,7 @@ def get_files_in_project(
|
||||
def unlink_user_file_from_project(
|
||||
project_id: int,
|
||||
file_id: UUID,
|
||||
bg_tasks: BackgroundTasks,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
@@ -208,7 +224,6 @@ def unlink_user_file_from_project(
|
||||
if project is None:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
user_id = user.id
|
||||
user_file = (
|
||||
db_session.query(UserFile)
|
||||
.filter(UserFile.id == file_id, UserFile.user_id == user_id)
|
||||
@@ -224,7 +239,7 @@ def unlink_user_file_from_project(
|
||||
db_session.commit()
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
_trigger_user_file_project_sync(user_file.id, tenant_id)
|
||||
_trigger_user_file_project_sync(user_file.id, tenant_id, bg_tasks)
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
@@ -237,6 +252,7 @@ def unlink_user_file_from_project(
|
||||
def link_user_file_to_project(
|
||||
project_id: int,
|
||||
file_id: UUID,
|
||||
bg_tasks: BackgroundTasks,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserFileSnapshot:
|
||||
@@ -268,7 +284,7 @@ def link_user_file_to_project(
|
||||
db_session.commit()
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
_trigger_user_file_project_sync(user_file.id, tenant_id)
|
||||
_trigger_user_file_project_sync(user_file.id, tenant_id, bg_tasks)
|
||||
|
||||
return UserFileSnapshot.from_model(user_file)
|
||||
|
||||
@@ -335,7 +351,7 @@ def upsert_project_instructions(
|
||||
class ProjectPayload(BaseModel):
|
||||
project: UserProjectSnapshot
|
||||
files: list[UserFileSnapshot] | None = None
|
||||
persona_id_to_is_default: dict[int, bool] | None = None
|
||||
persona_id_to_featured: dict[int, bool] | None = None
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -354,13 +370,11 @@ def get_project_details(
|
||||
if session.persona_id is not None
|
||||
]
|
||||
personas = get_personas_by_ids(persona_ids, db_session)
|
||||
persona_id_to_is_default = {
|
||||
persona.id: persona.is_default_persona for persona in personas
|
||||
}
|
||||
persona_id_to_featured = {persona.id: persona.featured for persona in personas}
|
||||
return ProjectPayload(
|
||||
project=project,
|
||||
files=files,
|
||||
persona_id_to_is_default=persona_id_to_is_default,
|
||||
persona_id_to_featured=persona_id_to_featured,
|
||||
)
|
||||
|
||||
|
||||
@@ -426,6 +440,7 @@ def delete_project(
|
||||
@router.delete("/file/{file_id}", tags=PUBLIC_API_TAGS)
|
||||
def delete_user_file(
|
||||
file_id: UUID,
|
||||
bg_tasks: BackgroundTasks,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserFileDeleteResult:
|
||||
@@ -458,15 +473,25 @@ def delete_user_file(
|
||||
db_session.commit()
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file.id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_DELETE,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered delete for user_file_id={user_file.id} with task_id={task.id}"
|
||||
)
|
||||
if DISABLE_VECTOR_DB:
|
||||
from onyx.background.task_utils import drain_delete_loop
|
||||
|
||||
bg_tasks.add_task(drain_delete_loop, tenant_id)
|
||||
logger.info(f"Queued in-process delete for user_file_id={user_file.id}")
|
||||
else:
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file.id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_DELETE,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered delete for user_file_id={user_file.id} "
|
||||
f"with task_id={task.id}"
|
||||
)
|
||||
|
||||
return UserFileDeleteResult(
|
||||
has_associations=False, project_names=[], assistant_names=[]
|
||||
)
|
||||
|
||||
@@ -8,10 +8,10 @@ import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx import __version__
|
||||
from onyx.cache.factory import get_shared_cache_backend
|
||||
from onyx.configs.app_configs import INSTANCE_TYPE
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.release_notes import create_release_notifications_for_versions
|
||||
from onyx.redis.redis_pool import get_shared_redis_client
|
||||
from onyx.server.features.release_notes.constants import AUTO_REFRESH_THRESHOLD_SECONDS
|
||||
from onyx.server.features.release_notes.constants import FETCH_TIMEOUT
|
||||
from onyx.server.features.release_notes.constants import GITHUB_CHANGELOG_RAW_URL
|
||||
@@ -113,60 +113,46 @@ def parse_mdx_to_release_note_entries(mdx_content: str) -> list[ReleaseNoteEntry
|
||||
|
||||
|
||||
def get_cached_etag() -> str | None:
|
||||
"""Get the cached GitHub ETag from Redis."""
|
||||
redis_client = get_shared_redis_client()
|
||||
cache = get_shared_cache_backend()
|
||||
try:
|
||||
etag = redis_client.get(REDIS_KEY_ETAG)
|
||||
etag = cache.get(REDIS_KEY_ETAG)
|
||||
if etag:
|
||||
return etag.decode("utf-8") if isinstance(etag, bytes) else str(etag)
|
||||
return etag.decode("utf-8")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get cached etag from Redis: {e}")
|
||||
logger.error(f"Failed to get cached etag: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_last_fetch_time() -> datetime | None:
|
||||
"""Get the last fetch timestamp from Redis."""
|
||||
redis_client = get_shared_redis_client()
|
||||
cache = get_shared_cache_backend()
|
||||
try:
|
||||
fetched_at_str = redis_client.get(REDIS_KEY_FETCHED_AT)
|
||||
if not fetched_at_str:
|
||||
raw = cache.get(REDIS_KEY_FETCHED_AT)
|
||||
if not raw:
|
||||
return None
|
||||
|
||||
decoded = (
|
||||
fetched_at_str.decode("utf-8")
|
||||
if isinstance(fetched_at_str, bytes)
|
||||
else str(fetched_at_str)
|
||||
)
|
||||
|
||||
last_fetch = datetime.fromisoformat(decoded)
|
||||
|
||||
# Defensively ensure timezone awareness
|
||||
# fromisoformat() returns naive datetime if input lacks timezone
|
||||
last_fetch = datetime.fromisoformat(raw.decode("utf-8"))
|
||||
if last_fetch.tzinfo is None:
|
||||
# Assume UTC for naive datetimes
|
||||
last_fetch = last_fetch.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
# Convert to UTC if timezone-aware
|
||||
last_fetch = last_fetch.astimezone(timezone.utc)
|
||||
|
||||
return last_fetch
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get last fetch time from Redis: {e}")
|
||||
logger.error(f"Failed to get last fetch time from cache: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def save_fetch_metadata(etag: str | None) -> None:
|
||||
"""Save ETag and fetch timestamp to Redis."""
|
||||
redis_client = get_shared_redis_client()
|
||||
cache = get_shared_cache_backend()
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
try:
|
||||
redis_client.set(REDIS_KEY_FETCHED_AT, now.isoformat(), ex=REDIS_CACHE_TTL)
|
||||
cache.set(REDIS_KEY_FETCHED_AT, now.isoformat(), ex=REDIS_CACHE_TTL)
|
||||
if etag:
|
||||
redis_client.set(REDIS_KEY_ETAG, etag, ex=REDIS_CACHE_TTL)
|
||||
cache.set(REDIS_KEY_ETAG, etag, ex=REDIS_CACHE_TTL)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save fetch metadata to Redis: {e}")
|
||||
logger.error(f"Failed to save fetch metadata to cache: {e}")
|
||||
|
||||
|
||||
def is_cache_stale() -> bool:
|
||||
@@ -196,11 +182,10 @@ def ensure_release_notes_fresh_and_notify(db_session: Session) -> None:
|
||||
if not is_cache_stale():
|
||||
return
|
||||
|
||||
# Acquire lock to prevent concurrent fetches
|
||||
redis_client = get_shared_redis_client()
|
||||
lock = redis_client.lock(
|
||||
cache = get_shared_cache_backend()
|
||||
lock = cache.lock(
|
||||
OnyxRedisLocks.RELEASE_NOTES_FETCH_LOCK,
|
||||
timeout=90, # 90 second timeout for the lock
|
||||
timeout=90,
|
||||
)
|
||||
|
||||
# Non-blocking acquire - if we can't get the lock, another request is handling it
|
||||
|
||||
@@ -5,7 +5,6 @@ from sqlalchemy.orm import Session
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
|
||||
from onyx.configs.kg_configs import KG_BETA_ASSISTANT_DESCRIPTION
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.entities import get_entity_stats_by_grounded_source_name
|
||||
from onyx.db.entity_type import get_configured_entity_types
|
||||
@@ -134,11 +133,7 @@ def enable_or_disable_kg(
|
||||
system_prompt=KG_BETA_ASSISTANT_SYSTEM_PROMPT,
|
||||
task_prompt=KG_BETA_ASSISTANT_TASK_PROMPT,
|
||||
datetime_aware=False,
|
||||
num_chunks=25,
|
||||
llm_relevance_filter=False,
|
||||
is_public=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
document_set_ids=[],
|
||||
tool_ids=[search_tool.id, kg_tool.id],
|
||||
llm_model_provider_override=None,
|
||||
@@ -147,7 +142,7 @@ def enable_or_disable_kg(
|
||||
users=[user.id],
|
||||
groups=[],
|
||||
label_ids=[],
|
||||
is_default_persona=False,
|
||||
featured=False,
|
||||
display_priority=0,
|
||||
user_file_ids=[],
|
||||
)
|
||||
|
||||
@@ -97,7 +97,6 @@ def _build_llm_provider_request(
|
||||
), # Only this from source
|
||||
api_base=api_base, # From request
|
||||
api_version=api_version, # From request
|
||||
default_model_name=model_name,
|
||||
deployment_name=deployment_name, # From request
|
||||
is_public=True,
|
||||
groups=[],
|
||||
@@ -136,7 +135,6 @@ def _build_llm_provider_request(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
default_model_name=model_name,
|
||||
deployment_name=deployment_name,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
@@ -168,7 +166,6 @@ def _create_image_gen_llm_provider__no_commit(
|
||||
api_key=provider_request.api_key,
|
||||
api_base=provider_request.api_base,
|
||||
api_version=provider_request.api_version,
|
||||
default_model_name=provider_request.default_model_name,
|
||||
deployment_name=provider_request.deployment_name,
|
||||
is_public=provider_request.is_public,
|
||||
custom_config=provider_request.custom_config,
|
||||
|
||||
@@ -22,7 +22,10 @@ from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.llm import can_user_access_llm_provider
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.db.llm import fetch_default_vision_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import fetch_existing_llm_provider_by_id
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_existing_models
|
||||
from onyx.db.llm import fetch_persona_with_groups
|
||||
@@ -52,11 +55,12 @@ from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
)
|
||||
from onyx.server.manage.llm.models import BedrockFinalModelResponse
|
||||
from onyx.server.manage.llm.models import BedrockModelsRequest
|
||||
from onyx.server.manage.llm.models import DefaultModel
|
||||
from onyx.server.manage.llm.models import LLMCost
|
||||
from onyx.server.manage.llm.models import LLMProviderDescriptor
|
||||
from onyx.server.manage.llm.models import LLMProviderResponse
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.manage.llm.models import OllamaFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OllamaModelDetails
|
||||
from onyx.server.manage.llm.models import OllamaModelsRequest
|
||||
@@ -233,12 +237,9 @@ def test_llm_configuration(
|
||||
|
||||
test_api_key = test_llm_request.api_key
|
||||
test_custom_config = test_llm_request.custom_config
|
||||
if test_llm_request.name:
|
||||
# NOTE: we are querying by name. we probably should be querying by an invariant id, but
|
||||
# as it turns out the name is not editable in the UI and other code also keys off name,
|
||||
# so we won't rock the boat just yet.
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=test_llm_request.name, db_session=db_session
|
||||
if test_llm_request.id:
|
||||
existing_provider = fetch_existing_llm_provider_by_id(
|
||||
id=test_llm_request.id, db_session=db_session
|
||||
)
|
||||
if existing_provider:
|
||||
test_custom_config = _restore_masked_custom_config_values(
|
||||
@@ -268,7 +269,7 @@ def test_llm_configuration(
|
||||
|
||||
llm = get_llm(
|
||||
provider=test_llm_request.provider,
|
||||
model=test_llm_request.default_model_name,
|
||||
model=test_llm_request.model,
|
||||
api_key=test_api_key,
|
||||
api_base=test_llm_request.api_base,
|
||||
api_version=test_llm_request.api_version,
|
||||
@@ -303,7 +304,7 @@ def list_llm_providers(
|
||||
include_image_gen: bool = Query(False),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LLMProviderView]:
|
||||
) -> LLMProviderResponse[LLMProviderView]:
|
||||
start_time = datetime.now(timezone.utc)
|
||||
logger.debug("Starting to fetch LLM providers")
|
||||
|
||||
@@ -328,7 +329,15 @@ def list_llm_providers(
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
logger.debug(f"Completed fetching LLM providers in {duration:.2f} seconds")
|
||||
|
||||
return llm_provider_list
|
||||
return LLMProviderResponse[LLMProviderView].from_models(
|
||||
providers=llm_provider_list,
|
||||
default_text=DefaultModel.from_model_config(
|
||||
fetch_default_llm_model(db_session)
|
||||
),
|
||||
default_vision=DefaultModel.from_model_config(
|
||||
fetch_default_vision_model(db_session)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@admin_router.put("/provider")
|
||||
@@ -344,18 +353,44 @@ def put_llm_provider(
|
||||
# validate request (e.g. if we're intending to create but the name already exists we should throw an error)
|
||||
# NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache
|
||||
# the result
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
existing_provider = None
|
||||
if llm_provider_upsert_request.id:
|
||||
existing_provider = fetch_existing_llm_provider_by_id(
|
||||
id=llm_provider_upsert_request.id, db_session=db_session
|
||||
)
|
||||
|
||||
# Check name constraints
|
||||
# TODO: Once port from name to id is complete, unique name will no longer be required
|
||||
if existing_provider and llm_provider_upsert_request.name != existing_provider.name:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Renaming providers is not currently supported",
|
||||
)
|
||||
|
||||
found_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
)
|
||||
if found_provider is not None and found_provider is not existing_provider:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Provider with name={llm_provider_upsert_request.name} already exists",
|
||||
)
|
||||
|
||||
if existing_provider and is_creation:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"LLM Provider with name {llm_provider_upsert_request.name} already exists",
|
||||
detail=(
|
||||
f"LLM Provider with name {llm_provider_upsert_request.name} and "
|
||||
f"id={llm_provider_upsert_request.id} already exists"
|
||||
),
|
||||
)
|
||||
elif not existing_provider and not is_creation:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"LLM Provider with name {llm_provider_upsert_request.name} does not exist",
|
||||
detail=(
|
||||
f"LLM Provider with name {llm_provider_upsert_request.name} and "
|
||||
f"id={llm_provider_upsert_request.id} does not exist"
|
||||
),
|
||||
)
|
||||
|
||||
# SSRF Protection: Validate api_base and custom_config match stored values
|
||||
@@ -393,22 +428,6 @@ def put_llm_provider(
|
||||
deduplicated_personas.append(persona_id)
|
||||
llm_provider_upsert_request.personas = deduplicated_personas
|
||||
|
||||
default_model_found = False
|
||||
|
||||
for model_configuration in llm_provider_upsert_request.model_configurations:
|
||||
if model_configuration.name == llm_provider_upsert_request.default_model_name:
|
||||
model_configuration.is_visible = True
|
||||
default_model_found = True
|
||||
|
||||
# TODO: Remove this logic on api change
|
||||
# Believed to be a dead pathway but we want to be safe for now
|
||||
if not default_model_found:
|
||||
llm_provider_upsert_request.model_configurations.append(
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=llm_provider_upsert_request.default_model_name, is_visible=True
|
||||
)
|
||||
)
|
||||
|
||||
# the llm api key is sanitized when returned to clients, so the only time we
|
||||
# should get a real key is when it is explicitly changed
|
||||
if existing_provider and not llm_provider_upsert_request.api_key_changed:
|
||||
@@ -438,8 +457,8 @@ def put_llm_provider(
|
||||
config = fetch_llm_recommendations_from_github()
|
||||
if config and llm_provider_upsert_request.provider in config.providers:
|
||||
# Refetch the provider to get the updated model
|
||||
updated_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
updated_provider = fetch_existing_llm_provider_by_id(
|
||||
id=result.id, db_session=db_session
|
||||
)
|
||||
if updated_provider:
|
||||
sync_auto_mode_models(
|
||||
@@ -460,37 +479,48 @@ def put_llm_provider(
|
||||
@admin_router.delete("/provider/{provider_id}")
|
||||
def delete_llm_provider(
|
||||
provider_id: int,
|
||||
force: bool = Query(False),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
if not force:
|
||||
model = fetch_default_llm_model(db_session)
|
||||
|
||||
if model and model.llm_provider_id == provider_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot delete the default LLM provider",
|
||||
)
|
||||
|
||||
remove_llm_provider(db_session, provider_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.post("/provider/{provider_id}/default")
|
||||
@admin_router.post("/default")
|
||||
def set_provider_as_default(
|
||||
provider_id: int,
|
||||
default_model_request: DefaultModel,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
update_default_provider(provider_id=provider_id, db_session=db_session)
|
||||
update_default_provider(
|
||||
provider_id=default_model_request.provider_id,
|
||||
model_name=default_model_request.model_name,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
@admin_router.post("/provider/{provider_id}/default-vision")
|
||||
@admin_router.post("/default-vision")
|
||||
def set_provider_as_default_vision(
|
||||
provider_id: int,
|
||||
vision_model: str | None = Query(
|
||||
None, description="The default vision model to use"
|
||||
),
|
||||
default_model: DefaultModel,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
if vision_model is None:
|
||||
raise HTTPException(status_code=404, detail="Vision model not provided")
|
||||
update_default_vision_provider(
|
||||
provider_id=provider_id, vision_model=vision_model, db_session=db_session
|
||||
provider_id=default_model.provider_id,
|
||||
vision_model=default_model.model_name,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
@@ -516,7 +546,7 @@ def get_auto_config(
|
||||
def get_vision_capable_providers(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[VisionProviderResponse]:
|
||||
) -> LLMProviderResponse[VisionProviderResponse]:
|
||||
"""Return a list of LLM providers and their models that support image input"""
|
||||
vision_models = fetch_existing_models(
|
||||
db_session=db_session, flow_types=[LLMModelFlowType.VISION]
|
||||
@@ -545,7 +575,13 @@ def get_vision_capable_providers(
|
||||
]
|
||||
|
||||
logger.debug(f"Found {len(vision_provider_response)} vision-capable providers")
|
||||
return vision_provider_response
|
||||
|
||||
return LLMProviderResponse[VisionProviderResponse].from_models(
|
||||
providers=vision_provider_response,
|
||||
default_vision=DefaultModel.from_model_config(
|
||||
fetch_default_vision_model(db_session)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
"""Endpoints for all"""
|
||||
@@ -555,7 +591,7 @@ def get_vision_capable_providers(
|
||||
def list_llm_provider_basics(
|
||||
user: User = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LLMProviderDescriptor]:
|
||||
) -> LLMProviderResponse[LLMProviderDescriptor]:
|
||||
"""Get LLM providers accessible to the current user.
|
||||
|
||||
Returns:
|
||||
@@ -577,9 +613,9 @@ def list_llm_provider_basics(
|
||||
for provider in all_providers:
|
||||
# Use centralized access control logic with persona=None since we're
|
||||
# listing providers without a specific persona context. This correctly:
|
||||
# - Includes all public providers
|
||||
# - Includes public providers WITHOUT persona restrictions
|
||||
# - Includes providers user can access via group membership
|
||||
# - Excludes persona-only restricted providers (requires specific persona)
|
||||
# - Excludes providers with persona restrictions (requires specific persona)
|
||||
# - Excludes non-public providers with no restrictions (admin-only)
|
||||
if can_user_access_llm_provider(
|
||||
provider, user_group_ids, persona=None, is_admin=is_admin
|
||||
@@ -592,7 +628,15 @@ def list_llm_provider_basics(
|
||||
f"Completed fetching {len(accessible_providers)} user-accessible providers in {duration:.2f} seconds"
|
||||
)
|
||||
|
||||
return accessible_providers
|
||||
return LLMProviderResponse[LLMProviderDescriptor].from_models(
|
||||
providers=accessible_providers,
|
||||
default_text=DefaultModel.from_model_config(
|
||||
fetch_default_llm_model(db_session)
|
||||
),
|
||||
default_vision=DefaultModel.from_model_config(
|
||||
fetch_default_vision_model(db_session)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_valid_model_names_for_persona(
|
||||
@@ -604,7 +648,7 @@ def get_valid_model_names_for_persona(
|
||||
|
||||
Returns a list of model names (e.g., ["gpt-4o", "claude-3-5-sonnet"]) that are
|
||||
available to the user when using this persona, respecting all RBAC restrictions.
|
||||
Public providers are always included.
|
||||
Public providers are included unless they have persona restrictions that exclude this persona.
|
||||
"""
|
||||
persona = fetch_persona_with_groups(db_session, persona_id)
|
||||
if not persona:
|
||||
@@ -618,7 +662,7 @@ def get_valid_model_names_for_persona(
|
||||
|
||||
valid_models = []
|
||||
for llm_provider_model in all_providers:
|
||||
# Public providers always included, restricted checked via RBAC
|
||||
# Check access with persona context — respects all RBAC restrictions
|
||||
if can_user_access_llm_provider(
|
||||
llm_provider_model, user_group_ids, persona, is_admin=is_admin
|
||||
):
|
||||
@@ -635,11 +679,11 @@ def list_llm_providers_for_persona(
|
||||
persona_id: int,
|
||||
user: User = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LLMProviderDescriptor]:
|
||||
) -> LLMProviderResponse[LLMProviderDescriptor]:
|
||||
"""Get LLM providers for a specific persona.
|
||||
|
||||
Returns providers that the user can access when using this persona:
|
||||
- All public providers (is_public=True) - ALWAYS included
|
||||
- Public providers (respecting persona restrictions if set)
|
||||
- Restricted providers user can access via group/persona restrictions
|
||||
|
||||
This endpoint is used for background fetching of restricted providers
|
||||
@@ -668,7 +712,7 @@ def list_llm_providers_for_persona(
|
||||
llm_provider_list: list[LLMProviderDescriptor] = []
|
||||
|
||||
for llm_provider_model in all_providers:
|
||||
# Use simplified access check - public providers always included
|
||||
# Check access with persona context — respects persona restrictions
|
||||
if can_user_access_llm_provider(
|
||||
llm_provider_model, user_group_ids, persona, is_admin=is_admin
|
||||
):
|
||||
@@ -682,7 +726,51 @@ def list_llm_providers_for_persona(
|
||||
f"Completed fetching {len(llm_provider_list)} LLM providers for persona {persona_id} in {duration:.2f} seconds"
|
||||
)
|
||||
|
||||
return llm_provider_list
|
||||
# Get the default model and vision model for the persona
|
||||
# TODO: Port persona's over to use ID
|
||||
persona_default_provider = persona.llm_model_provider_override
|
||||
persona_default_model = persona.llm_model_version_override
|
||||
|
||||
default_text_model = fetch_default_llm_model(db_session)
|
||||
default_vision_model = fetch_default_vision_model(db_session)
|
||||
|
||||
# Build default_text and default_vision using persona overrides when available,
|
||||
# falling back to the global defaults.
|
||||
default_text = DefaultModel.from_model_config(default_text_model)
|
||||
default_vision = DefaultModel.from_model_config(default_vision_model)
|
||||
|
||||
if persona_default_provider:
|
||||
provider = fetch_existing_llm_provider(persona_default_provider, db_session)
|
||||
if provider and can_user_access_llm_provider(
|
||||
provider, user_group_ids, persona, is_admin=is_admin
|
||||
):
|
||||
if persona_default_model:
|
||||
# Persona specifies both provider and model — use them directly
|
||||
default_text = DefaultModel(
|
||||
provider_id=provider.id,
|
||||
model_name=persona_default_model,
|
||||
)
|
||||
else:
|
||||
# Persona specifies only the provider — pick a visible (public) model,
|
||||
# falling back to any model on this provider
|
||||
visible_model = next(
|
||||
(mc for mc in provider.model_configurations if mc.is_visible),
|
||||
None,
|
||||
)
|
||||
fallback_model = visible_model or next(
|
||||
iter(provider.model_configurations), None
|
||||
)
|
||||
if fallback_model:
|
||||
default_text = DefaultModel(
|
||||
provider_id=provider.id,
|
||||
model_name=fallback_model.name,
|
||||
)
|
||||
|
||||
return LLMProviderResponse[LLMProviderDescriptor].from_models(
|
||||
providers=llm_provider_list,
|
||||
default_text=default_text,
|
||||
default_vision=default_vision,
|
||||
)
|
||||
|
||||
|
||||
@admin_router.get("/provider-contextual-cost")
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
@@ -21,50 +25,22 @@ if TYPE_CHECKING:
|
||||
ModelConfiguration as ModelConfigurationModel,
|
||||
)
|
||||
|
||||
|
||||
# TODO: Clear this up on api refactor
|
||||
# There is still logic that requires sending each providers default model name
|
||||
# There is no logic that requires sending the providers default vision model name
|
||||
# We only send for the one that is actually the default
|
||||
def get_default_llm_model_name(llm_provider_model: "LLMProviderModel") -> str:
|
||||
"""Find the default conversation model name for a provider.
|
||||
|
||||
Returns the model name if found, otherwise returns empty string.
|
||||
"""
|
||||
for model_config in llm_provider_model.model_configurations:
|
||||
for flow in model_config.llm_model_flows:
|
||||
if flow.is_default and flow.llm_model_flow_type == LLMModelFlowType.CHAT:
|
||||
return model_config.name
|
||||
return ""
|
||||
|
||||
|
||||
def get_default_vision_model_name(llm_provider_model: "LLMProviderModel") -> str | None:
|
||||
"""Find the default vision model name for a provider.
|
||||
|
||||
Returns the model name if found, otherwise returns None.
|
||||
"""
|
||||
for model_config in llm_provider_model.model_configurations:
|
||||
for flow in model_config.llm_model_flows:
|
||||
if flow.is_default and flow.llm_model_flow_type == LLMModelFlowType.VISION:
|
||||
return model_config.name
|
||||
return None
|
||||
T = TypeVar("T", "LLMProviderDescriptor", "LLMProviderView", "VisionProviderResponse")
|
||||
|
||||
|
||||
class TestLLMRequest(BaseModel):
|
||||
# provider level
|
||||
name: str | None = None
|
||||
id: int | None = None
|
||||
provider: str
|
||||
model: str
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
custom_config: dict[str, str] | None = None
|
||||
|
||||
# model level
|
||||
default_model_name: str
|
||||
deployment_name: str | None = None
|
||||
|
||||
model_configurations: list["ModelConfigurationUpsertRequest"]
|
||||
|
||||
# if try and use the existing API/custom config key
|
||||
api_key_changed: bool
|
||||
custom_config_changed: bool
|
||||
@@ -80,13 +56,10 @@ class LLMProviderDescriptor(BaseModel):
|
||||
"""A descriptor for an LLM provider that can be safely viewed by
|
||||
non-admin users. Used when giving a list of available LLMs."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
provider: str
|
||||
provider_display_name: str # Human-friendly name like "Claude (Anthropic)"
|
||||
default_model_name: str
|
||||
is_default_provider: bool | None
|
||||
is_default_vision_provider: bool | None
|
||||
default_vision_model: str | None
|
||||
model_configurations: list["ModelConfigurationView"]
|
||||
|
||||
@classmethod
|
||||
@@ -99,24 +72,12 @@ class LLMProviderDescriptor(BaseModel):
|
||||
)
|
||||
|
||||
provider = llm_provider_model.provider
|
||||
default_model_name = get_default_llm_model_name(llm_provider_model)
|
||||
default_vision_model = get_default_vision_model_name(llm_provider_model)
|
||||
|
||||
is_default_provider = bool(default_model_name)
|
||||
is_default_vision_provider = default_vision_model is not None
|
||||
|
||||
default_model_name = (
|
||||
default_model_name or llm_provider_model.default_model_name or ""
|
||||
)
|
||||
|
||||
return cls(
|
||||
id=llm_provider_model.id,
|
||||
name=llm_provider_model.name,
|
||||
provider=provider,
|
||||
provider_display_name=get_provider_display_name(provider),
|
||||
default_model_name=default_model_name,
|
||||
is_default_provider=is_default_provider,
|
||||
is_default_vision_provider=is_default_vision_provider,
|
||||
default_vision_model=default_vision_model,
|
||||
model_configurations=filter_model_configurations(
|
||||
llm_provider_model.model_configurations, provider
|
||||
),
|
||||
@@ -130,18 +91,17 @@ class LLMProvider(BaseModel):
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
custom_config: dict[str, str] | None = None
|
||||
default_model_name: str
|
||||
is_public: bool = True
|
||||
is_auto_mode: bool = False
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
personas: list[int] = Field(default_factory=list)
|
||||
deployment_name: str | None = None
|
||||
default_vision_model: str | None = None
|
||||
|
||||
|
||||
class LLMProviderUpsertRequest(LLMProvider):
|
||||
# should only be used for a "custom" provider
|
||||
# for default providers, the built-in model names are used
|
||||
id: int | None = None
|
||||
api_key_changed: bool = False
|
||||
custom_config_changed: bool = False
|
||||
model_configurations: list["ModelConfigurationUpsertRequest"] = []
|
||||
@@ -157,8 +117,6 @@ class LLMProviderView(LLMProvider):
|
||||
"""Stripped down representation of LLMProvider for display / limited access info only"""
|
||||
|
||||
id: int
|
||||
is_default_provider: bool | None = None
|
||||
is_default_vision_provider: bool | None = None
|
||||
model_configurations: list["ModelConfigurationView"]
|
||||
|
||||
@classmethod
|
||||
@@ -180,16 +138,6 @@ class LLMProviderView(LLMProvider):
|
||||
|
||||
provider = llm_provider_model.provider
|
||||
|
||||
default_model_name = get_default_llm_model_name(llm_provider_model)
|
||||
default_vision_model = get_default_vision_model_name(llm_provider_model)
|
||||
|
||||
is_default_provider = bool(default_model_name)
|
||||
is_default_vision_provider = default_vision_model is not None
|
||||
|
||||
default_model_name = (
|
||||
default_model_name or llm_provider_model.default_model_name or ""
|
||||
)
|
||||
|
||||
return cls(
|
||||
id=llm_provider_model.id,
|
||||
name=llm_provider_model.name,
|
||||
@@ -202,10 +150,6 @@ class LLMProviderView(LLMProvider):
|
||||
api_base=llm_provider_model.api_base,
|
||||
api_version=llm_provider_model.api_version,
|
||||
custom_config=llm_provider_model.custom_config,
|
||||
default_model_name=default_model_name,
|
||||
is_default_provider=is_default_provider,
|
||||
is_default_vision_provider=is_default_vision_provider,
|
||||
default_vision_model=default_vision_model,
|
||||
is_public=llm_provider_model.is_public,
|
||||
is_auto_mode=llm_provider_model.is_auto_mode,
|
||||
groups=groups,
|
||||
@@ -425,3 +369,38 @@ class OpenRouterFinalModelResponse(BaseModel):
|
||||
int | None
|
||||
) # From OpenRouter API context_length (may be missing for some models)
|
||||
supports_image_input: bool
|
||||
|
||||
|
||||
class DefaultModel(BaseModel):
|
||||
provider_id: int
|
||||
model_name: str
|
||||
|
||||
@classmethod
|
||||
def from_model_config(
|
||||
cls, model_config: ModelConfigurationModel | None
|
||||
) -> DefaultModel | None:
|
||||
if not model_config:
|
||||
return None
|
||||
return cls(
|
||||
provider_id=model_config.llm_provider_id,
|
||||
model_name=model_config.name,
|
||||
)
|
||||
|
||||
|
||||
class LLMProviderResponse(BaseModel, Generic[T]):
|
||||
providers: list[T]
|
||||
default_text: DefaultModel | None = None
|
||||
default_vision: DefaultModel | None = None
|
||||
|
||||
@classmethod
|
||||
def from_models(
|
||||
cls,
|
||||
providers: list[T],
|
||||
default_text: DefaultModel | None = None,
|
||||
default_vision: DefaultModel | None = None,
|
||||
) -> LLMProviderResponse[T]:
|
||||
return cls(
|
||||
providers=providers,
|
||||
default_text=default_text,
|
||||
default_vision=default_vision,
|
||||
)
|
||||
|
||||
@@ -198,7 +198,6 @@ def patch_slack_channel_config(
|
||||
channel_name=channel_config["channel_name"],
|
||||
document_set_ids=slack_channel_config_creation_request.document_sets,
|
||||
existing_persona_id=existing_persona_id,
|
||||
enable_auto_filters=slack_channel_config_creation_request.enable_auto_filters,
|
||||
).id
|
||||
|
||||
slack_channel_config_model = update_slack_channel_config(
|
||||
|
||||
@@ -13,13 +13,13 @@ from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from redis.client import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.api_key import get_hashed_api_key_from_request
|
||||
from onyx.auth.pat import get_hashed_pat_from_request
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.chat.chat_processing_checker import is_chat_session_processing
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_utils import convert_chat_history_basic
|
||||
@@ -67,7 +67,6 @@ from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.llm.factory import get_llm_token_counter
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.secondary_llm_flows.chat_session_naming import generate_chat_session_name
|
||||
from onyx.server.api_key_usage import check_api_key_usage
|
||||
from onyx.server.query_and_chat.models import ChatFeedbackRequest
|
||||
@@ -152,10 +151,20 @@ def get_user_chat_sessions(
|
||||
project_id: int | None = None,
|
||||
only_non_project_chats: bool = True,
|
||||
include_failed_chats: bool = False,
|
||||
page_size: int = Query(default=50, ge=1, le=100),
|
||||
before: str | None = Query(default=None),
|
||||
) -> ChatSessionsResponse:
|
||||
user_id = user.id
|
||||
|
||||
try:
|
||||
before_dt = (
|
||||
datetime.datetime.fromisoformat(before) if before is not None else None
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=422, detail="Invalid 'before' timestamp format")
|
||||
|
||||
try:
|
||||
# Fetch one extra to determine if there are more results
|
||||
chat_sessions = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
@@ -163,11 +172,16 @@ def get_user_chat_sessions(
|
||||
project_id=project_id,
|
||||
only_non_project_chats=only_non_project_chats,
|
||||
include_failed_chats=include_failed_chats,
|
||||
limit=page_size + 1,
|
||||
before=before_dt,
|
||||
)
|
||||
|
||||
except ValueError:
|
||||
raise ValueError("Chat session does not exist or has been deleted")
|
||||
|
||||
has_more = len(chat_sessions) > page_size
|
||||
chat_sessions = chat_sessions[:page_size]
|
||||
|
||||
return ChatSessionsResponse(
|
||||
sessions=[
|
||||
ChatSessionDetails(
|
||||
@@ -181,7 +195,8 @@ def get_user_chat_sessions(
|
||||
current_temperature_override=chat.temperature_override,
|
||||
)
|
||||
for chat in chat_sessions
|
||||
]
|
||||
],
|
||||
has_more=has_more,
|
||||
)
|
||||
|
||||
|
||||
@@ -314,7 +329,7 @@ def get_chat_session(
|
||||
]
|
||||
|
||||
try:
|
||||
is_processing = is_chat_session_processing(session_id, get_redis_client())
|
||||
is_processing = is_chat_session_processing(session_id, get_cache_backend())
|
||||
# Edit the last message to indicate loading (Overriding default message value)
|
||||
if is_processing and chat_message_details:
|
||||
last_msg = chat_message_details[-1]
|
||||
@@ -911,11 +926,10 @@ async def search_chats(
|
||||
def stop_chat_session(
|
||||
chat_session_id: UUID,
|
||||
user: User = Depends(current_user), # noqa: ARG001
|
||||
redis_client: Redis = Depends(get_redis_client),
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Stop a chat session by setting a stop signal in Redis.
|
||||
Stop a chat session by setting a stop signal.
|
||||
This endpoint is called by the frontend when the user clicks the stop button.
|
||||
"""
|
||||
set_fence(chat_session_id, redis_client, True)
|
||||
set_fence(chat_session_id, get_cache_backend(), True)
|
||||
return {"message": "Chat session stopped"}
|
||||
|
||||
@@ -192,6 +192,7 @@ class ChatSessionDetails(BaseModel):
|
||||
|
||||
class ChatSessionsResponse(BaseModel):
|
||||
sessions: list[ChatSessionDetails]
|
||||
has_more: bool = False
|
||||
|
||||
|
||||
class ChatMessageDetail(BaseModel):
|
||||
|
||||
@@ -19,6 +19,7 @@ class ApplicationStatus(str, Enum):
|
||||
PAYMENT_REMINDER = "payment_reminder"
|
||||
GRACE_PERIOD = "grace_period"
|
||||
GATED_ACCESS = "gated_access"
|
||||
SEAT_LIMIT_EXCEEDED = "seat_limit_exceeded"
|
||||
|
||||
|
||||
class Notification(BaseModel):
|
||||
@@ -82,6 +83,10 @@ class Settings(BaseModel):
|
||||
# Default Assistant settings
|
||||
disable_default_assistant: bool | None = False
|
||||
|
||||
# Seat usage - populated by license enforcement when seat limit is exceeded
|
||||
seat_count: int | None = None
|
||||
used_seats: int | None = None
|
||||
|
||||
# OpenSearch migration
|
||||
opensearch_indexing_enabled: bool = False
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.app_configs import DISABLE_USER_KNOWLEDGE
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
@@ -6,11 +7,8 @@ from onyx.configs.constants import KV_SETTINGS_KEY
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -33,30 +31,22 @@ def load_settings() -> Settings:
|
||||
logger.error(f"Error loading settings from KV store: {str(e)}")
|
||||
settings = Settings()
|
||||
|
||||
tenant_id = get_current_tenant_id() if MULTI_TENANT else None
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
cache = get_cache_backend()
|
||||
|
||||
try:
|
||||
value = redis_client.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
|
||||
value = cache.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
|
||||
if value is not None:
|
||||
assert isinstance(value, bytes)
|
||||
anonymous_user_enabled = int(value.decode("utf-8")) == 1
|
||||
else:
|
||||
# Default to False
|
||||
anonymous_user_enabled = False
|
||||
# Optionally store the default back to Redis
|
||||
redis_client.set(
|
||||
OnyxRedisLocks.ANONYMOUS_USER_ENABLED, "0", ex=SETTINGS_TTL
|
||||
)
|
||||
cache.set(OnyxRedisLocks.ANONYMOUS_USER_ENABLED, "0", ex=SETTINGS_TTL)
|
||||
except Exception as e:
|
||||
# Log the error and reset to default
|
||||
logger.error(f"Error loading anonymous user setting from Redis: {str(e)}")
|
||||
logger.error(f"Error loading anonymous user setting from cache: {str(e)}")
|
||||
anonymous_user_enabled = False
|
||||
|
||||
settings.anonymous_user_enabled = anonymous_user_enabled
|
||||
settings.query_history_type = ONYX_QUERY_HISTORY_TYPE
|
||||
|
||||
# Override user knowledge setting if disabled via environment variable
|
||||
if DISABLE_USER_KNOWLEDGE:
|
||||
settings.user_knowledge_enabled = False
|
||||
|
||||
@@ -66,11 +56,10 @@ def load_settings() -> Settings:
|
||||
|
||||
|
||||
def store_settings(settings: Settings) -> None:
|
||||
tenant_id = get_current_tenant_id() if MULTI_TENANT else None
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
cache = get_cache_backend()
|
||||
|
||||
if settings.anonymous_user_enabled is not None:
|
||||
redis_client.set(
|
||||
cache.set(
|
||||
OnyxRedisLocks.ANONYMOUS_USER_ENABLED,
|
||||
"1" if settings.anonymous_user_enabled else "0",
|
||||
ex=SETTINGS_TTL,
|
||||
|
||||
@@ -25,6 +25,7 @@ from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from onyx.db.index_attempt import expire_index_attempts
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
@@ -254,14 +255,18 @@ def setup_postgres(db_session: Session) -> None:
|
||||
logger.notice("Setting up default OpenAI LLM for dev.")
|
||||
|
||||
llm_model = GEN_AI_MODEL_VERSION or "gpt-4o-mini"
|
||||
provider_name = "DevEnvPresetOpenAI"
|
||||
existing = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
model_req = LLMProviderUpsertRequest(
|
||||
name="DevEnvPresetOpenAI",
|
||||
id=existing.id if existing else None,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=GEN_AI_API_KEY,
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
custom_config=None,
|
||||
default_model_name=llm_model,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
model_configurations=[
|
||||
@@ -273,7 +278,9 @@ def setup_postgres(db_session: Session) -> None:
|
||||
new_llm_provider = upsert_llm_provider(
|
||||
llm_provider_upsert_request=model_req, db_session=db_session
|
||||
)
|
||||
update_default_provider(provider_id=new_llm_provider.id, db_session=db_session)
|
||||
update_default_provider(
|
||||
provider_id=new_llm_provider.id, model_name=llm_model, db_session=db_session
|
||||
)
|
||||
|
||||
|
||||
def update_default_multipass_indexing(db_session: Session) -> None:
|
||||
|
||||
@@ -8,37 +8,3 @@ dependencies = [
|
||||
|
||||
[tool.uv.sources]
|
||||
onyx = { workspace = true }
|
||||
|
||||
[tool.mypy]
|
||||
plugins = "sqlalchemy.ext.mypy.plugin"
|
||||
mypy_path = "backend"
|
||||
explicit_package_bases = true
|
||||
disallow_untyped_defs = true
|
||||
warn_unused_ignores = true
|
||||
enable_error_code = ["possibly-undefined"]
|
||||
strict_equality = true
|
||||
# Patterns match paths whether mypy is run from backend/ (CI) or repo root (e.g. VS Code extension with target ./backend)
|
||||
exclude = [
|
||||
"(?:^|/)generated/",
|
||||
"(?:^|/)\\.venv/",
|
||||
"(?:^|/)onyx/server/features/build/sandbox/kubernetes/docker/skills/",
|
||||
"(?:^|/)onyx/server/features/build/sandbox/kubernetes/docker/templates/",
|
||||
]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "alembic.versions.*"
|
||||
disable_error_code = ["var-annotated"]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "alembic_tenants.versions.*"
|
||||
disable_error_code = ["var-annotated"]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "generated.*"
|
||||
follow_imports = "silent"
|
||||
ignore_errors = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "transformers.*"
|
||||
follow_imports = "skip"
|
||||
ignore_errors = true
|
||||
|
||||
@@ -528,7 +528,7 @@ lxml==5.3.0
|
||||
# unstructured
|
||||
# xmlsec
|
||||
# zeep
|
||||
lxml-html-clean==0.4.3
|
||||
lxml-html-clean==0.4.4
|
||||
# via lxml
|
||||
magika==0.6.3
|
||||
# via markitdown
|
||||
@@ -809,7 +809,7 @@ pypandoc-binary==1.16.2
|
||||
# via onyx
|
||||
pyparsing==3.2.5
|
||||
# via httplib2
|
||||
pypdf==6.7.3
|
||||
pypdf==6.7.5
|
||||
# via
|
||||
# onyx
|
||||
# unstructured-client
|
||||
|
||||
@@ -17,7 +17,7 @@ def test_bedrock_llm_configuration(client: TestClient) -> None:
|
||||
# Prepare the test request payload
|
||||
test_request: dict[str, Any] = {
|
||||
"provider": LlmProviderNames.BEDROCK,
|
||||
"default_model_name": _DEFAULT_BEDROCK_MODEL,
|
||||
"model": _DEFAULT_BEDROCK_MODEL,
|
||||
"api_key": None,
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
@@ -44,7 +44,7 @@ def test_bedrock_llm_configuration_invalid_key(client: TestClient) -> None:
|
||||
# Prepare the test request payload with invalid credentials
|
||||
test_request: dict[str, Any] = {
|
||||
"provider": LlmProviderNames.BEDROCK,
|
||||
"default_model_name": _DEFAULT_BEDROCK_MODEL,
|
||||
"model": _DEFAULT_BEDROCK_MODEL,
|
||||
"api_key": None,
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
|
||||
@@ -28,7 +28,6 @@ def ensure_default_llm_provider(db_session: Session) -> None:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=os.environ.get("OPENAI_API_KEY", "test"),
|
||||
is_public=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini",
|
||||
@@ -41,7 +40,7 @@ def ensure_default_llm_provider(db_session: Session) -> None:
|
||||
llm_provider_upsert_request=llm_provider_request,
|
||||
db_session=db_session,
|
||||
)
|
||||
update_default_provider(provider.id, db_session)
|
||||
update_default_provider(provider.id, "gpt-4o-mini", db_session)
|
||||
except Exception as exc: # pragma: no cover - only hits on duplicate setup issues
|
||||
# Rollback to clear the pending transaction state
|
||||
db_session.rollback()
|
||||
|
||||
@@ -47,7 +47,6 @@ def test_answer_with_only_anthropic_provider(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.ANTHROPIC,
|
||||
api_key=anthropic_api_key,
|
||||
default_model_name=anthropic_model,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
model_configurations=[
|
||||
@@ -59,7 +58,7 @@ def test_answer_with_only_anthropic_provider(
|
||||
)
|
||||
|
||||
try:
|
||||
update_default_provider(anthropic_provider.id, db_session)
|
||||
update_default_provider(anthropic_provider.id, anthropic_model, db_session)
|
||||
|
||||
test_user = create_test_user(db_session, email_prefix="anthropic_only")
|
||||
chat_session = create_chat_session(
|
||||
|
||||
@@ -9,7 +9,6 @@ from onyx.chat.models import AnswerStreamPart
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.process_message import handle_stream_message_objects
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.models import RecencyBiasSetting
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
@@ -74,10 +73,6 @@ def test_stream_chat_message_objects_without_web_search(
|
||||
user=None, # System persona
|
||||
name=f"Test Persona {uuid.uuid4()}",
|
||||
description="Test persona with no tools for web search test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=True,
|
||||
recency_bias=RecencyBiasSetting.BASE_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
|
||||
@@ -0,0 +1,257 @@
|
||||
"""External dependency unit tests for periodic task claiming.
|
||||
|
||||
Tests ``_try_claim_task`` and ``_try_run_periodic_task`` against real
|
||||
PostgreSQL, verifying happy-path behavior and concurrent-access safety.
|
||||
|
||||
The claim mechanism uses a transaction-scoped advisory lock + a KVStore
|
||||
timestamp for cross-instance dedup. The DB session is released before
|
||||
the task runs, so long-running tasks don't hold connections.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.background.periodic_poller import _PeriodicTaskDef
|
||||
from onyx.background.periodic_poller import _try_claim_task
|
||||
from onyx.background.periodic_poller import _try_run_periodic_task
|
||||
from onyx.background.periodic_poller import PERIODIC_TASK_KV_PREFIX
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.models import KVStore
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
_TEST_LOCK_BASE = 90_000
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def _init_engine() -> None:
|
||||
SqlEngine.init_engine(pool_size=10, max_overflow=5)
|
||||
|
||||
|
||||
def _make_task(
|
||||
*,
|
||||
name: str | None = None,
|
||||
interval: float = 3600,
|
||||
lock_id: int | None = None,
|
||||
run_fn: MagicMock | None = None,
|
||||
) -> _PeriodicTaskDef:
|
||||
return _PeriodicTaskDef(
|
||||
name=name if name is not None else f"test-{uuid4().hex[:8]}",
|
||||
interval_seconds=interval,
|
||||
lock_id=lock_id if lock_id is not None else _TEST_LOCK_BASE,
|
||||
run_fn=run_fn if run_fn is not None else MagicMock(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _cleanup_kv(
|
||||
tenant_context: None, # noqa: ARG001
|
||||
) -> Generator[None, None, None]:
|
||||
yield
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_session.query(KVStore).filter(
|
||||
KVStore.key.like(f"{PERIODIC_TASK_KV_PREFIX}test-%")
|
||||
).delete(synchronize_session=False)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Happy-path: _try_claim_task
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClaimHappyPath:
|
||||
def test_first_claim_succeeds(self) -> None:
|
||||
assert _try_claim_task(_make_task()) is True
|
||||
|
||||
def test_first_claim_creates_kv_row(self) -> None:
|
||||
task = _make_task()
|
||||
_try_claim_task(task)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
row = (
|
||||
db_session.query(KVStore)
|
||||
.filter_by(key=PERIODIC_TASK_KV_PREFIX + task.name)
|
||||
.first()
|
||||
)
|
||||
assert row is not None
|
||||
assert row.value is not None
|
||||
|
||||
def test_second_claim_within_interval_fails(self) -> None:
|
||||
task = _make_task(interval=3600)
|
||||
assert _try_claim_task(task) is True
|
||||
assert _try_claim_task(task) is False
|
||||
|
||||
def test_claim_after_interval_succeeds(self) -> None:
|
||||
task = _make_task(interval=1)
|
||||
assert _try_claim_task(task) is True
|
||||
|
||||
kv_key = PERIODIC_TASK_KV_PREFIX + task.name
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
row = db_session.query(KVStore).filter_by(key=kv_key).first()
|
||||
assert row is not None
|
||||
row.value = (datetime.now(timezone.utc) - timedelta(seconds=10)).isoformat()
|
||||
db_session.commit()
|
||||
|
||||
assert _try_claim_task(task) is True
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Happy-path: _try_run_periodic_task
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunHappyPath:
|
||||
def test_runs_task_and_updates_last_run_at(self) -> None:
|
||||
mock_fn = MagicMock()
|
||||
task = _make_task(run_fn=mock_fn)
|
||||
|
||||
_try_run_periodic_task(task)
|
||||
|
||||
mock_fn.assert_called_once()
|
||||
assert task.last_run_at > 0
|
||||
|
||||
def test_skips_when_in_memory_interval_not_elapsed(self) -> None:
|
||||
mock_fn = MagicMock()
|
||||
task = _make_task(run_fn=mock_fn, interval=3600)
|
||||
task.last_run_at = time.monotonic()
|
||||
|
||||
_try_run_periodic_task(task)
|
||||
|
||||
mock_fn.assert_not_called()
|
||||
|
||||
def test_skips_when_db_claim_blocked(self) -> None:
|
||||
name = f"test-{uuid4().hex[:8]}"
|
||||
lock_id = _TEST_LOCK_BASE + 10
|
||||
|
||||
_try_claim_task(_make_task(name=name, lock_id=lock_id, interval=3600))
|
||||
|
||||
mock_fn = MagicMock()
|
||||
task = _make_task(name=name, lock_id=lock_id, interval=3600, run_fn=mock_fn)
|
||||
_try_run_periodic_task(task)
|
||||
|
||||
mock_fn.assert_not_called()
|
||||
|
||||
def test_task_exception_does_not_propagate(self) -> None:
|
||||
task = _make_task(run_fn=MagicMock(side_effect=RuntimeError("boom")))
|
||||
_try_run_periodic_task(task)
|
||||
|
||||
def test_claim_committed_before_task_runs(self) -> None:
|
||||
"""The KV claim must be visible in the DB when run_fn executes."""
|
||||
task_name = f"test-order-{uuid4().hex[:8]}"
|
||||
kv_key = PERIODIC_TASK_KV_PREFIX + task_name
|
||||
claim_visible: list[bool] = []
|
||||
|
||||
def check_claim() -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
row = db_session.query(KVStore).filter_by(key=kv_key).first()
|
||||
claim_visible.append(row is not None and row.value is not None)
|
||||
|
||||
task = _PeriodicTaskDef(
|
||||
name=task_name,
|
||||
interval_seconds=3600,
|
||||
lock_id=_TEST_LOCK_BASE + 11,
|
||||
run_fn=check_claim,
|
||||
)
|
||||
|
||||
_try_run_periodic_task(task)
|
||||
|
||||
assert claim_visible == [True]
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Concurrency: only one claimer should win
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClaimConcurrency:
|
||||
def test_concurrent_claims_single_winner(self) -> None:
|
||||
"""Many threads claim the same task — exactly one should succeed."""
|
||||
num_threads = 20
|
||||
task_name = f"test-race-{uuid4().hex[:8]}"
|
||||
lock_id = _TEST_LOCK_BASE + 20
|
||||
|
||||
def claim() -> bool:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
|
||||
return _try_claim_task(
|
||||
_PeriodicTaskDef(
|
||||
name=task_name,
|
||||
interval_seconds=3600,
|
||||
lock_id=lock_id,
|
||||
run_fn=lambda: None,
|
||||
)
|
||||
)
|
||||
|
||||
results: list[bool] = []
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = [executor.submit(claim) for _ in range(num_threads)]
|
||||
for future in as_completed(futures):
|
||||
results.append(future.result())
|
||||
|
||||
winners = sum(1 for r in results if r)
|
||||
assert winners == 1, f"Expected 1 winner, got {winners}"
|
||||
|
||||
def test_concurrent_run_single_execution(self) -> None:
|
||||
"""Many threads run the same task — run_fn fires exactly once."""
|
||||
num_threads = 20
|
||||
task_name = f"test-run-race-{uuid4().hex[:8]}"
|
||||
lock_id = _TEST_LOCK_BASE + 21
|
||||
counter = MagicMock()
|
||||
|
||||
def run() -> None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
|
||||
_try_run_periodic_task(
|
||||
_PeriodicTaskDef(
|
||||
name=task_name,
|
||||
interval_seconds=3600,
|
||||
lock_id=lock_id,
|
||||
run_fn=counter,
|
||||
)
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = [executor.submit(run) for _ in range(num_threads)]
|
||||
for future in as_completed(futures):
|
||||
future.result()
|
||||
|
||||
assert (
|
||||
counter.call_count == 1
|
||||
), f"Expected run_fn called once, got {counter.call_count}"
|
||||
|
||||
def test_no_errors_under_contention(self) -> None:
|
||||
"""All threads complete without exceptions under high contention."""
|
||||
num_threads = 30
|
||||
task_name = f"test-err-{uuid4().hex[:8]}"
|
||||
lock_id = _TEST_LOCK_BASE + 22
|
||||
errors: list[Exception] = []
|
||||
|
||||
def claim() -> bool:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
|
||||
return _try_claim_task(
|
||||
_PeriodicTaskDef(
|
||||
name=task_name,
|
||||
interval_seconds=3600,
|
||||
lock_id=lock_id,
|
||||
run_fn=lambda: None,
|
||||
)
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = [executor.submit(claim) for _ in range(num_threads)]
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
future.result()
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
assert errors == [], f"Got {len(errors)} errors: {errors}"
|
||||
@@ -0,0 +1,352 @@
|
||||
"""External dependency unit tests for startup recovery (Step 10g).
|
||||
|
||||
Seeds ``UserFile`` records in stuck states (PROCESSING, DELETING,
|
||||
needs_project_sync) then calls ``recover_stuck_user_files`` and verifies
|
||||
the drain loops pick them up via ``FOR UPDATE SKIP LOCKED``.
|
||||
|
||||
Uses real PostgreSQL (via ``db_session`` / ``tenant_context`` fixtures).
|
||||
The per-file ``*_impl`` functions are mocked so no real file store or
|
||||
connector is needed — we only verify that recovery finds and dispatches
|
||||
the correct files.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.periodic_poller import recover_stuck_user_files
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import UserFile
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_IMPL_MODULE = "onyx.background.celery.tasks.user_file_processing.tasks"
|
||||
|
||||
|
||||
def _create_user_file(
|
||||
db_session: Session,
|
||||
user_id: object,
|
||||
*,
|
||||
status: UserFileStatus = UserFileStatus.PROCESSING,
|
||||
needs_project_sync: bool = False,
|
||||
needs_persona_sync: bool = False,
|
||||
) -> UserFile:
|
||||
uf = UserFile(
|
||||
id=uuid4(),
|
||||
user_id=user_id,
|
||||
file_id=f"test_file_{uuid4().hex[:8]}",
|
||||
name=f"test_{uuid4().hex[:8]}.txt",
|
||||
file_type="text/plain",
|
||||
status=status,
|
||||
needs_project_sync=needs_project_sync,
|
||||
needs_persona_sync=needs_persona_sync,
|
||||
)
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
db_session.refresh(uf)
|
||||
return uf
|
||||
|
||||
|
||||
def _fake_delete_impl(
|
||||
user_file_id: str, tenant_id: str, redis_locking: bool # noqa: ARG001
|
||||
) -> None:
|
||||
"""Mock side-effect: delete the row so the drain loop terminates."""
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as session:
|
||||
session.execute(sa.delete(UserFile).where(UserFile.id == UUID(user_file_id)))
|
||||
session.commit()
|
||||
|
||||
|
||||
def _fake_sync_impl(
|
||||
user_file_id: str, tenant_id: str, redis_locking: bool # noqa: ARG001
|
||||
) -> None:
|
||||
"""Mock side-effect: clear sync flags so the drain loop terminates."""
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as session:
|
||||
session.execute(
|
||||
sa.update(UserFile)
|
||||
.where(UserFile.id == UUID(user_file_id))
|
||||
.values(needs_project_sync=False, needs_persona_sync=False)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _cleanup_user_files(db_session: Session) -> Generator[list[UserFile], None, None]:
|
||||
"""Track created UserFile rows and delete them after each test."""
|
||||
created: list[UserFile] = []
|
||||
yield created
|
||||
for uf in created:
|
||||
existing = db_session.get(UserFile, uf.id)
|
||||
if existing:
|
||||
db_session.delete(existing)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRecoverProcessingFiles:
|
||||
"""Files in PROCESSING status are re-processed via the processing drain loop."""
|
||||
|
||||
def test_processing_files_recovered(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_proc")
|
||||
uf = _create_user_file(db_session, user.id, status=UserFileStatus.PROCESSING)
|
||||
_cleanup_user_files.append(uf)
|
||||
|
||||
mock_impl = MagicMock()
|
||||
with patch(f"{_IMPL_MODULE}.process_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert (
|
||||
str(uf.id) in called_ids
|
||||
), f"Expected file {uf.id} to be recovered but got: {called_ids}"
|
||||
|
||||
def test_completed_files_not_recovered(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_comp")
|
||||
uf = _create_user_file(db_session, user.id, status=UserFileStatus.COMPLETED)
|
||||
_cleanup_user_files.append(uf)
|
||||
|
||||
mock_impl = MagicMock()
|
||||
with patch(f"{_IMPL_MODULE}.process_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert (
|
||||
str(uf.id) not in called_ids
|
||||
), f"COMPLETED file {uf.id} should not have been recovered"
|
||||
|
||||
|
||||
class TestRecoverDeletingFiles:
|
||||
"""Files in DELETING status are recovered via the delete drain loop."""
|
||||
|
||||
def test_deleting_files_recovered(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_del")
|
||||
uf = _create_user_file(db_session, user.id, status=UserFileStatus.DELETING)
|
||||
# Row is deleted by _fake_delete_impl, so no cleanup needed.
|
||||
|
||||
mock_impl = MagicMock(side_effect=_fake_delete_impl)
|
||||
with patch(f"{_IMPL_MODULE}.delete_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert (
|
||||
str(uf.id) in called_ids
|
||||
), f"Expected file {uf.id} to be recovered for deletion but got: {called_ids}"
|
||||
|
||||
|
||||
class TestRecoverSyncFiles:
|
||||
"""Files needing project/persona sync are recovered via the sync drain loop."""
|
||||
|
||||
def test_needs_project_sync_recovered(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_sync")
|
||||
uf = _create_user_file(
|
||||
db_session,
|
||||
user.id,
|
||||
status=UserFileStatus.COMPLETED,
|
||||
needs_project_sync=True,
|
||||
)
|
||||
_cleanup_user_files.append(uf)
|
||||
|
||||
mock_impl = MagicMock(side_effect=_fake_sync_impl)
|
||||
with patch(f"{_IMPL_MODULE}.project_sync_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert (
|
||||
str(uf.id) in called_ids
|
||||
), f"Expected file {uf.id} to be recovered for sync but got: {called_ids}"
|
||||
|
||||
def test_needs_persona_sync_recovered(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_psync")
|
||||
uf = _create_user_file(
|
||||
db_session,
|
||||
user.id,
|
||||
status=UserFileStatus.COMPLETED,
|
||||
needs_persona_sync=True,
|
||||
)
|
||||
_cleanup_user_files.append(uf)
|
||||
|
||||
mock_impl = MagicMock(side_effect=_fake_sync_impl)
|
||||
with patch(f"{_IMPL_MODULE}.project_sync_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert (
|
||||
str(uf.id) in called_ids
|
||||
), f"Expected file {uf.id} to be recovered for persona sync but got: {called_ids}"
|
||||
|
||||
|
||||
class TestRecoveryMultipleFiles:
|
||||
"""Recovery processes all stuck files in one pass, not just the first."""
|
||||
|
||||
def test_multiple_processing_files(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_multi")
|
||||
files = []
|
||||
for _ in range(3):
|
||||
uf = _create_user_file(
|
||||
db_session, user.id, status=UserFileStatus.PROCESSING
|
||||
)
|
||||
_cleanup_user_files.append(uf)
|
||||
files.append(uf)
|
||||
|
||||
mock_impl = MagicMock()
|
||||
with patch(f"{_IMPL_MODULE}.process_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = {call.kwargs["user_file_id"] for call in mock_impl.call_args_list}
|
||||
expected_ids = {str(uf.id) for uf in files}
|
||||
assert expected_ids.issubset(called_ids), (
|
||||
f"Expected all {len(files)} files to be recovered. "
|
||||
f"Missing: {expected_ids - called_ids}"
|
||||
)
|
||||
|
||||
|
||||
class TestTransientFailures:
|
||||
"""Drain loops skip failed files, process the rest, and terminate."""
|
||||
|
||||
def test_processing_failure_skips_and_continues(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "fail_proc")
|
||||
uf_fail = _create_user_file(
|
||||
db_session, user.id, status=UserFileStatus.PROCESSING
|
||||
)
|
||||
uf_ok = _create_user_file(db_session, user.id, status=UserFileStatus.PROCESSING)
|
||||
_cleanup_user_files.extend([uf_fail, uf_ok])
|
||||
|
||||
fail_id = str(uf_fail.id)
|
||||
|
||||
def side_effect(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool # noqa: ARG001
|
||||
) -> None:
|
||||
if user_file_id == fail_id:
|
||||
raise RuntimeError("transient failure")
|
||||
|
||||
mock_impl = MagicMock(side_effect=side_effect)
|
||||
with patch(f"{_IMPL_MODULE}.process_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert fail_id in called_ids, "Failed file should have been attempted"
|
||||
assert str(uf_ok.id) in called_ids, "Healthy file should have been processed"
|
||||
assert called_ids.count(fail_id) == 1, "Failed file retried — infinite loop"
|
||||
assert called_ids.count(str(uf_ok.id)) == 1
|
||||
|
||||
def test_delete_failure_skips_and_continues(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "fail_del")
|
||||
uf_fail = _create_user_file(db_session, user.id, status=UserFileStatus.DELETING)
|
||||
uf_ok = _create_user_file(db_session, user.id, status=UserFileStatus.DELETING)
|
||||
_cleanup_user_files.append(uf_fail)
|
||||
|
||||
fail_id = str(uf_fail.id)
|
||||
|
||||
def side_effect(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool
|
||||
) -> None:
|
||||
if user_file_id == fail_id:
|
||||
raise RuntimeError("transient failure")
|
||||
_fake_delete_impl(user_file_id, tenant_id, redis_locking)
|
||||
|
||||
mock_impl = MagicMock(side_effect=side_effect)
|
||||
with patch(f"{_IMPL_MODULE}.delete_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert fail_id in called_ids, "Failed file should have been attempted"
|
||||
assert str(uf_ok.id) in called_ids, "Healthy file should have been deleted"
|
||||
assert called_ids.count(fail_id) == 1, "Failed file retried — infinite loop"
|
||||
assert called_ids.count(str(uf_ok.id)) == 1
|
||||
|
||||
def test_sync_failure_skips_and_continues(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "fail_sync")
|
||||
uf_fail = _create_user_file(
|
||||
db_session,
|
||||
user.id,
|
||||
status=UserFileStatus.COMPLETED,
|
||||
needs_project_sync=True,
|
||||
)
|
||||
uf_ok = _create_user_file(
|
||||
db_session,
|
||||
user.id,
|
||||
status=UserFileStatus.COMPLETED,
|
||||
needs_persona_sync=True,
|
||||
)
|
||||
_cleanup_user_files.extend([uf_fail, uf_ok])
|
||||
|
||||
fail_id = str(uf_fail.id)
|
||||
|
||||
def side_effect(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool
|
||||
) -> None:
|
||||
if user_file_id == fail_id:
|
||||
raise RuntimeError("transient failure")
|
||||
_fake_sync_impl(user_file_id, tenant_id, redis_locking)
|
||||
|
||||
mock_impl = MagicMock(side_effect=side_effect)
|
||||
with patch(f"{_IMPL_MODULE}.project_sync_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert fail_id in called_ids, "Failed file should have been attempted"
|
||||
assert str(uf_ok.id) in called_ids, "Healthy file should have been synced"
|
||||
assert called_ids.count(fail_id) == 1, "Failed file retried — infinite loop"
|
||||
assert called_ids.count(str(uf_ok.id)) == 1
|
||||
57
backend/tests/external_dependency_unit/cache/conftest.py
vendored
Normal file
57
backend/tests/external_dependency_unit/cache/conftest.py
vendored
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Fixtures for cache backend tests.
|
||||
|
||||
Requires a running PostgreSQL instance (and Redis for parity tests).
|
||||
Run with::
|
||||
|
||||
python -m dotenv -f .vscode/.env run -- pytest tests/external_dependency_unit/cache/
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.postgres_backend import PostgresCacheBackend
|
||||
from onyx.cache.redis_backend import RedisCacheBackend
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def _init_db() -> Generator[None, None, None]:
|
||||
"""Initialize DB engine. Assumes Postgres has migrations applied (e.g. via docker compose)."""
|
||||
SqlEngine.init_engine(pool_size=5, max_overflow=2)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _tenant_context() -> Generator[None, None, None]:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pg_cache() -> PostgresCacheBackend:
|
||||
return PostgresCacheBackend(TEST_TENANT_ID)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def redis_cache() -> RedisCacheBackend:
|
||||
from onyx.redis.redis_pool import redis_pool
|
||||
|
||||
return RedisCacheBackend(redis_pool.get_client(TEST_TENANT_ID))
|
||||
|
||||
|
||||
@pytest.fixture(params=["postgres", "redis"], ids=["postgres", "redis"])
|
||||
def cache(
|
||||
request: pytest.FixtureRequest,
|
||||
pg_cache: PostgresCacheBackend,
|
||||
redis_cache: RedisCacheBackend,
|
||||
) -> CacheBackend:
|
||||
if request.param == "postgres":
|
||||
return pg_cache
|
||||
return redis_cache
|
||||
100
backend/tests/external_dependency_unit/cache/test_cache_backend_parity.py
vendored
Normal file
100
backend/tests/external_dependency_unit/cache/test_cache_backend_parity.py
vendored
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Parameterized tests that run the same CacheBackend operations against
|
||||
both Redis and PostgreSQL, asserting identical return values.
|
||||
|
||||
Each test runs twice (once per backend) via the ``cache`` fixture defined
|
||||
in conftest.py.
|
||||
"""
|
||||
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.interface import TTL_KEY_NOT_FOUND
|
||||
from onyx.cache.interface import TTL_NO_EXPIRY
|
||||
|
||||
|
||||
def _key() -> str:
|
||||
return f"parity_{uuid4().hex[:12]}"
|
||||
|
||||
|
||||
class TestKVParity:
|
||||
def test_get_missing(self, cache: CacheBackend) -> None:
|
||||
assert cache.get(_key()) is None
|
||||
|
||||
def test_get_set(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"value")
|
||||
assert cache.get(k) == b"value"
|
||||
|
||||
def test_overwrite(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"a")
|
||||
cache.set(k, b"b")
|
||||
assert cache.get(k) == b"b"
|
||||
|
||||
def test_set_string(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, "hello")
|
||||
assert cache.get(k) == b"hello"
|
||||
|
||||
def test_set_int(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, 42)
|
||||
assert cache.get(k) == b"42"
|
||||
|
||||
def test_delete(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"x")
|
||||
cache.delete(k)
|
||||
assert cache.get(k) is None
|
||||
|
||||
def test_exists(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
assert not cache.exists(k)
|
||||
cache.set(k, b"x")
|
||||
assert cache.exists(k)
|
||||
|
||||
|
||||
class TestTTLParity:
|
||||
def test_ttl_missing(self, cache: CacheBackend) -> None:
|
||||
assert cache.ttl(_key()) == TTL_KEY_NOT_FOUND
|
||||
|
||||
def test_ttl_no_expiry(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"x")
|
||||
assert cache.ttl(k) == TTL_NO_EXPIRY
|
||||
|
||||
def test_ttl_remaining(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"x", ex=10)
|
||||
remaining = cache.ttl(k)
|
||||
assert 8 <= remaining <= 10
|
||||
|
||||
def test_set_with_ttl_expires(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"x", ex=1)
|
||||
assert cache.get(k) == b"x"
|
||||
time.sleep(1.5)
|
||||
assert cache.get(k) is None
|
||||
|
||||
|
||||
class TestLockParity:
|
||||
def test_acquire_release(self, cache: CacheBackend) -> None:
|
||||
lock = cache.lock(f"parity_lock_{uuid4().hex[:8]}")
|
||||
assert lock.acquire(blocking=False)
|
||||
assert lock.owned()
|
||||
lock.release()
|
||||
assert not lock.owned()
|
||||
|
||||
|
||||
class TestListParity:
|
||||
def test_rpush_blpop(self, cache: CacheBackend) -> None:
|
||||
k = f"parity_list_{uuid4().hex[:8]}"
|
||||
cache.rpush(k, b"item")
|
||||
result = cache.blpop([k], timeout=1)
|
||||
assert result is not None
|
||||
assert result[1] == b"item"
|
||||
|
||||
def test_blpop_timeout(self, cache: CacheBackend) -> None:
|
||||
result = cache.blpop([f"parity_empty_{uuid4().hex[:8]}"], timeout=1)
|
||||
assert result is None
|
||||
129
backend/tests/external_dependency_unit/cache/test_kv_store_cache_layer.py
vendored
Normal file
129
backend/tests/external_dependency_unit/cache/test_kv_store_cache_layer.py
vendored
Normal file
@@ -0,0 +1,129 @@
|
||||
"""Tests for PgRedisKVStore's cache layer integration with CacheBackend.
|
||||
|
||||
Verifies that the KV store correctly uses the CacheBackend for caching
|
||||
in front of PostgreSQL: cache hits, cache misses falling through to PG,
|
||||
cache population after PG reads, cache invalidation on delete, and
|
||||
graceful degradation when the cache backend raises.
|
||||
|
||||
Requires running PostgreSQL.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.postgres_backend import PostgresCacheBackend
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.models import CacheStore
|
||||
from onyx.db.models import KVStore
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.key_value_store.store import PgRedisKVStore
|
||||
from onyx.key_value_store.store import REDIS_KEY_PREFIX
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clean_kv() -> Generator[None, None, None]:
|
||||
yield
|
||||
with get_session_with_tenant(tenant_id=TEST_TENANT_ID) as session:
|
||||
session.execute(delete(KVStore))
|
||||
session.execute(delete(CacheStore))
|
||||
session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def kv_store(pg_cache: PostgresCacheBackend) -> PgRedisKVStore:
|
||||
return PgRedisKVStore(cache=pg_cache)
|
||||
|
||||
|
||||
class TestStoreAndLoad:
|
||||
def test_store_populates_cache_and_pg(
|
||||
self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend
|
||||
) -> None:
|
||||
kv_store.store("k1", {"hello": "world"})
|
||||
|
||||
cached = pg_cache.get(REDIS_KEY_PREFIX + "k1")
|
||||
assert cached is not None
|
||||
assert json.loads(cached) == {"hello": "world"}
|
||||
|
||||
loaded = kv_store.load("k1")
|
||||
assert loaded == {"hello": "world"}
|
||||
|
||||
def test_load_returns_cached_value_without_pg_hit(
|
||||
self, pg_cache: PostgresCacheBackend
|
||||
) -> None:
|
||||
"""If the cache already has the value, PG should not be queried."""
|
||||
pg_cache.set(REDIS_KEY_PREFIX + "cached_only", json.dumps({"from": "cache"}))
|
||||
kv = PgRedisKVStore(cache=pg_cache)
|
||||
assert kv.load("cached_only") == {"from": "cache"}
|
||||
|
||||
def test_load_falls_through_to_pg_on_cache_miss(
|
||||
self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend
|
||||
) -> None:
|
||||
kv_store.store("k2", [1, 2, 3])
|
||||
|
||||
pg_cache.delete(REDIS_KEY_PREFIX + "k2")
|
||||
assert pg_cache.get(REDIS_KEY_PREFIX + "k2") is None
|
||||
|
||||
loaded = kv_store.load("k2")
|
||||
assert loaded == [1, 2, 3]
|
||||
|
||||
repopulated = pg_cache.get(REDIS_KEY_PREFIX + "k2")
|
||||
assert repopulated is not None
|
||||
assert json.loads(repopulated) == [1, 2, 3]
|
||||
|
||||
def test_load_with_refresh_cache_skips_cache(
|
||||
self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend
|
||||
) -> None:
|
||||
kv_store.store("k3", "original")
|
||||
|
||||
pg_cache.set(REDIS_KEY_PREFIX + "k3", json.dumps("stale"))
|
||||
|
||||
loaded = kv_store.load("k3", refresh_cache=True)
|
||||
assert loaded == "original"
|
||||
|
||||
|
||||
class TestDelete:
|
||||
def test_delete_removes_from_cache_and_pg(
|
||||
self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend
|
||||
) -> None:
|
||||
kv_store.store("del_me", "bye")
|
||||
kv_store.delete("del_me")
|
||||
|
||||
assert pg_cache.get(REDIS_KEY_PREFIX + "del_me") is None
|
||||
|
||||
with pytest.raises(KvKeyNotFoundError):
|
||||
kv_store.load("del_me")
|
||||
|
||||
def test_delete_missing_key_raises(self, kv_store: PgRedisKVStore) -> None:
|
||||
with pytest.raises(KvKeyNotFoundError):
|
||||
kv_store.delete("nonexistent")
|
||||
|
||||
|
||||
class TestCacheFailureGracefulDegradation:
|
||||
def test_store_succeeds_when_cache_set_raises(self) -> None:
|
||||
failing_cache = MagicMock(spec=CacheBackend)
|
||||
failing_cache.set.side_effect = ConnectionError("cache down")
|
||||
|
||||
kv = PgRedisKVStore(cache=failing_cache)
|
||||
kv.store("resilient", {"data": True})
|
||||
|
||||
working_cache = MagicMock(spec=CacheBackend)
|
||||
working_cache.get.return_value = None
|
||||
kv_reader = PgRedisKVStore(cache=working_cache)
|
||||
loaded = kv_reader.load("resilient")
|
||||
assert loaded == {"data": True}
|
||||
|
||||
def test_load_falls_through_when_cache_get_raises(self) -> None:
|
||||
failing_cache = MagicMock(spec=CacheBackend)
|
||||
failing_cache.get.side_effect = ConnectionError("cache down")
|
||||
failing_cache.set.side_effect = ConnectionError("cache down")
|
||||
|
||||
kv = PgRedisKVStore(cache=failing_cache)
|
||||
kv.store("survive", 42)
|
||||
loaded = kv.load("survive")
|
||||
assert loaded == 42
|
||||
229
backend/tests/external_dependency_unit/cache/test_postgres_cache_backend.py
vendored
Normal file
229
backend/tests/external_dependency_unit/cache/test_postgres_cache_backend.py
vendored
Normal file
@@ -0,0 +1,229 @@
|
||||
"""Tests for PostgresCacheBackend against real PostgreSQL.
|
||||
|
||||
Covers every method on the backend: KV CRUD, TTL behaviour, advisory
|
||||
locks (acquire / release / contention), list operations (rpush / blpop),
|
||||
and the periodic cleanup function.
|
||||
"""
|
||||
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from onyx.cache.interface import TTL_KEY_NOT_FOUND
|
||||
from onyx.cache.interface import TTL_NO_EXPIRY
|
||||
from onyx.cache.postgres_backend import cleanup_expired_cache_entries
|
||||
from onyx.cache.postgres_backend import PostgresCacheBackend
|
||||
from onyx.db.models import CacheStore
|
||||
|
||||
|
||||
def _key() -> str:
|
||||
return f"test_{uuid4().hex[:12]}"
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Basic KV
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestKV:
|
||||
def test_get_set(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"hello")
|
||||
assert pg_cache.get(k) == b"hello"
|
||||
|
||||
def test_get_missing(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
assert pg_cache.get(_key()) is None
|
||||
|
||||
def test_set_overwrite(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"first")
|
||||
pg_cache.set(k, b"second")
|
||||
assert pg_cache.get(k) == b"second"
|
||||
|
||||
def test_set_string_value(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, "string_val")
|
||||
assert pg_cache.get(k) == b"string_val"
|
||||
|
||||
def test_set_int_value(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, 42)
|
||||
assert pg_cache.get(k) == b"42"
|
||||
|
||||
def test_delete(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"to_delete")
|
||||
pg_cache.delete(k)
|
||||
assert pg_cache.get(k) is None
|
||||
|
||||
def test_delete_missing_is_noop(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
pg_cache.delete(_key())
|
||||
|
||||
def test_exists(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
assert not pg_cache.exists(k)
|
||||
pg_cache.set(k, b"x")
|
||||
assert pg_cache.exists(k)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# TTL
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTTL:
|
||||
def test_set_with_ttl_expires(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"ephemeral", ex=1)
|
||||
assert pg_cache.get(k) == b"ephemeral"
|
||||
time.sleep(1.5)
|
||||
assert pg_cache.get(k) is None
|
||||
|
||||
def test_ttl_no_expiry(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"forever")
|
||||
assert pg_cache.ttl(k) == TTL_NO_EXPIRY
|
||||
|
||||
def test_ttl_missing_key(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
assert pg_cache.ttl(_key()) == TTL_KEY_NOT_FOUND
|
||||
|
||||
def test_ttl_remaining(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"x", ex=10)
|
||||
remaining = pg_cache.ttl(k)
|
||||
assert 8 <= remaining <= 10
|
||||
|
||||
def test_ttl_expired_key(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"x", ex=1)
|
||||
time.sleep(1.5)
|
||||
assert pg_cache.ttl(k) == TTL_KEY_NOT_FOUND
|
||||
|
||||
def test_expire_adds_ttl(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"x")
|
||||
assert pg_cache.ttl(k) == TTL_NO_EXPIRY
|
||||
pg_cache.expire(k, 10)
|
||||
assert 8 <= pg_cache.ttl(k) <= 10
|
||||
|
||||
def test_exists_respects_ttl(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"x", ex=1)
|
||||
assert pg_cache.exists(k)
|
||||
time.sleep(1.5)
|
||||
assert not pg_cache.exists(k)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Locks
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLock:
|
||||
def test_acquire_release(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
lock = pg_cache.lock(f"lock_{uuid4().hex[:8]}")
|
||||
assert lock.acquire(blocking=False)
|
||||
assert lock.owned()
|
||||
lock.release()
|
||||
assert not lock.owned()
|
||||
|
||||
def test_contention(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
name = f"contention_{uuid4().hex[:8]}"
|
||||
lock1 = pg_cache.lock(name)
|
||||
lock2 = pg_cache.lock(name)
|
||||
|
||||
assert lock1.acquire(blocking=False)
|
||||
assert not lock2.acquire(blocking=False)
|
||||
|
||||
lock1.release()
|
||||
assert lock2.acquire(blocking=False)
|
||||
lock2.release()
|
||||
|
||||
def test_context_manager(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
with pg_cache.lock(f"ctx_{uuid4().hex[:8]}") as lock:
|
||||
assert lock.owned()
|
||||
assert not lock.owned()
|
||||
|
||||
def test_blocking_timeout(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
name = f"timeout_{uuid4().hex[:8]}"
|
||||
holder = pg_cache.lock(name)
|
||||
holder.acquire(blocking=False)
|
||||
|
||||
waiter = pg_cache.lock(name, timeout=0.3)
|
||||
start = time.monotonic()
|
||||
assert not waiter.acquire(blocking=True, blocking_timeout=0.3)
|
||||
elapsed = time.monotonic() - start
|
||||
assert elapsed >= 0.25
|
||||
|
||||
holder.release()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# List (rpush / blpop)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestList:
|
||||
def test_rpush_blpop(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = f"list_{uuid4().hex[:8]}"
|
||||
pg_cache.rpush(k, b"item1")
|
||||
result = pg_cache.blpop([k], timeout=1)
|
||||
assert result is not None
|
||||
assert result == (k.encode(), b"item1")
|
||||
|
||||
def test_blpop_timeout(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
result = pg_cache.blpop([f"empty_{uuid4().hex[:8]}"], timeout=1)
|
||||
assert result is None
|
||||
|
||||
def test_fifo_order(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = f"fifo_{uuid4().hex[:8]}"
|
||||
pg_cache.rpush(k, b"first")
|
||||
time.sleep(0.01)
|
||||
pg_cache.rpush(k, b"second")
|
||||
|
||||
r1 = pg_cache.blpop([k], timeout=1)
|
||||
r2 = pg_cache.blpop([k], timeout=1)
|
||||
assert r1 is not None and r1[1] == b"first"
|
||||
assert r2 is not None and r2[1] == b"second"
|
||||
|
||||
def test_multiple_keys(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k1 = f"mk1_{uuid4().hex[:8]}"
|
||||
k2 = f"mk2_{uuid4().hex[:8]}"
|
||||
pg_cache.rpush(k2, b"from_k2")
|
||||
|
||||
result = pg_cache.blpop([k1, k2], timeout=1)
|
||||
assert result is not None
|
||||
assert result == (k2.encode(), b"from_k2")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Cleanup
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCleanup:
|
||||
def test_removes_expired_rows(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
k = _key()
|
||||
pg_cache.set(k, b"stale", ex=1)
|
||||
time.sleep(1.5)
|
||||
cleanup_expired_cache_entries()
|
||||
|
||||
stmt = select(CacheStore.key).where(CacheStore.key == k)
|
||||
with get_session_with_current_tenant() as session:
|
||||
row = session.execute(stmt).first()
|
||||
assert row is None, "expired row should be physically deleted"
|
||||
|
||||
def test_preserves_unexpired_rows(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"fresh", ex=300)
|
||||
cleanup_expired_cache_entries()
|
||||
assert pg_cache.get(k) == b"fresh"
|
||||
|
||||
def test_preserves_no_ttl_rows(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"permanent")
|
||||
cleanup_expired_cache_entries()
|
||||
assert pg_cache.get(k) == b"permanent"
|
||||
@@ -36,7 +36,6 @@ from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
user_file_project_sync_lock_key,
|
||||
)
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserFile
|
||||
@@ -86,12 +85,6 @@ def _create_test_persona(
|
||||
persona = Persona(
|
||||
name=f"Test Persona {uuid4().hex[:8]}",
|
||||
description="Test persona",
|
||||
num_chunks=10.0,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
system_prompt="You are a test assistant",
|
||||
task_prompt="Answer the question",
|
||||
tools=[],
|
||||
@@ -410,10 +403,6 @@ class TestUpsertPersonaMarksSyncFlag:
|
||||
user=user,
|
||||
name=f"persona-{uuid4().hex[:8]}",
|
||||
description="test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
@@ -442,10 +431,6 @@ class TestUpsertPersonaMarksSyncFlag:
|
||||
user=user,
|
||||
name=f"persona-{uuid4().hex[:8]}",
|
||||
description="test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
@@ -461,16 +446,11 @@ class TestUpsertPersonaMarksSyncFlag:
|
||||
uf_old.needs_persona_sync = False
|
||||
db_session.commit()
|
||||
|
||||
assert persona.num_chunks is not None
|
||||
# Now update the persona to swap files
|
||||
upsert_persona(
|
||||
user=user,
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=persona.num_chunks,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=persona.recency_bias,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
@@ -501,10 +481,6 @@ class TestUpsertPersonaMarksSyncFlag:
|
||||
user=user,
|
||||
name=f"persona-{uuid4().hex[:8]}",
|
||||
description="test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
@@ -519,15 +495,10 @@ class TestUpsertPersonaMarksSyncFlag:
|
||||
uf.needs_persona_sync = False
|
||||
db_session.commit()
|
||||
|
||||
assert persona.num_chunks is not None
|
||||
upsert_persona(
|
||||
user=user,
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=persona.num_chunks,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=persona.recency_bias,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
|
||||
@@ -18,7 +18,6 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserFile
|
||||
@@ -58,12 +57,6 @@ def _create_persona(db_session: Session, user: User) -> Persona:
|
||||
persona = Persona(
|
||||
name=f"Test Persona {uuid4().hex[:8]}",
|
||||
description="Test persona",
|
||||
num_chunks=10.0,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
system_prompt="test",
|
||||
task_prompt="test",
|
||||
tools=[],
|
||||
|
||||
@@ -29,6 +29,7 @@ from onyx.server.manage.llm.api import (
|
||||
test_llm_configuration as run_test_llm_configuration,
|
||||
)
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.manage.llm.models import TestLLMRequest as LLMTestRequest
|
||||
|
||||
@@ -44,15 +45,14 @@ def _create_test_provider(
|
||||
db_session: Session,
|
||||
name: str,
|
||||
api_key: str = "sk-test-key-00000000000000000000000000000000000",
|
||||
) -> None:
|
||||
) -> LLMProviderView:
|
||||
"""Helper to create a test LLM provider in the database."""
|
||||
upsert_llm_provider(
|
||||
return upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=api_key,
|
||||
api_key_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o-mini", is_visible=True)
|
||||
],
|
||||
@@ -102,17 +102,11 @@ class TestLLMConfigurationEndpoint:
|
||||
# This should complete without exception
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
name=None, # New provider (not in DB)
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-new-test-key-0000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=False,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
model="gpt-4o-mini",
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -152,17 +146,11 @@ class TestLLMConfigurationEndpoint:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
name=None,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-invalid-key-00000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=False,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
model="gpt-4o-mini",
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -194,7 +182,9 @@ class TestLLMConfigurationEndpoint:
|
||||
|
||||
try:
|
||||
# First, create the provider in the database
|
||||
_create_test_provider(db_session, provider_name, api_key=original_api_key)
|
||||
provider = _create_test_provider(
|
||||
db_session, provider_name, api_key=original_api_key
|
||||
)
|
||||
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
|
||||
@@ -202,17 +192,12 @@ class TestLLMConfigurationEndpoint:
|
||||
# Test with api_key_changed=False - should use stored key
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
name=provider_name, # Existing provider
|
||||
id=provider.id,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None, # Not providing a new key
|
||||
api_key_changed=False, # Using existing key
|
||||
custom_config_changed=False,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
model="gpt-4o-mini",
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -246,7 +231,9 @@ class TestLLMConfigurationEndpoint:
|
||||
|
||||
try:
|
||||
# First, create the provider in the database
|
||||
_create_test_provider(db_session, provider_name, api_key=original_api_key)
|
||||
provider = _create_test_provider(
|
||||
db_session, provider_name, api_key=original_api_key
|
||||
)
|
||||
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
|
||||
@@ -254,17 +241,12 @@ class TestLLMConfigurationEndpoint:
|
||||
# Test with api_key_changed=True - should use new key
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
name=provider_name, # Existing provider
|
||||
id=provider.id,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=new_api_key, # Providing a new key
|
||||
api_key_changed=True, # Key is being changed
|
||||
custom_config_changed=False,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
model="gpt-4o-mini",
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -297,7 +279,7 @@ class TestLLMConfigurationEndpoint:
|
||||
|
||||
try:
|
||||
# First, create the provider in the database with custom_config
|
||||
upsert_llm_provider(
|
||||
provider = upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
@@ -305,7 +287,6 @@ class TestLLMConfigurationEndpoint:
|
||||
api_key_changed=True,
|
||||
custom_config=original_custom_config,
|
||||
custom_config_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
@@ -321,18 +302,13 @@ class TestLLMConfigurationEndpoint:
|
||||
# Test with custom_config_changed=False - should use stored config
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
name=provider_name,
|
||||
id=provider.id,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None,
|
||||
api_key_changed=False,
|
||||
custom_config=None, # Not providing new config
|
||||
custom_config_changed=False, # Using existing config
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
model="gpt-4o-mini",
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -368,17 +344,11 @@ class TestLLMConfigurationEndpoint:
|
||||
for model_name in test_models:
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
name=None,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=False,
|
||||
default_model_name=model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
model=model_name,
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -442,7 +412,6 @@ class TestDefaultProviderEndpoint:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=provider_1_api_key,
|
||||
api_key_changed=True,
|
||||
default_model_name=provider_1_initial_model,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True),
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
|
||||
@@ -452,7 +421,7 @@ class TestDefaultProviderEndpoint:
|
||||
)
|
||||
|
||||
# Set provider 1 as the default provider explicitly
|
||||
update_default_provider(provider_1.id, db_session)
|
||||
update_default_provider(provider_1.id, provider_1_initial_model, db_session)
|
||||
|
||||
# Step 2: Call run_test_default_provider - should use provider 1's default model
|
||||
with patch(
|
||||
@@ -472,7 +441,6 @@ class TestDefaultProviderEndpoint:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=provider_2_api_key,
|
||||
api_key_changed=True,
|
||||
default_model_name=provider_2_default_model,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
@@ -499,11 +467,11 @@ class TestDefaultProviderEndpoint:
|
||||
# Step 5: Update provider 1's default model
|
||||
upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
id=provider_1.id,
|
||||
name=provider_1_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=provider_1_api_key,
|
||||
api_key_changed=True,
|
||||
default_model_name=provider_1_updated_model, # Changed
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True),
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
|
||||
@@ -512,6 +480,9 @@ class TestDefaultProviderEndpoint:
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Set provider 1's default model to the updated model
|
||||
update_default_provider(provider_1.id, provider_1_updated_model, db_session)
|
||||
|
||||
# Step 6: Call run_test_default_provider - should use new model on provider 1
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
|
||||
@@ -524,7 +495,7 @@ class TestDefaultProviderEndpoint:
|
||||
captured_llms.clear()
|
||||
|
||||
# Step 7: Change the default provider to provider 2
|
||||
update_default_provider(provider_2.id, db_session)
|
||||
update_default_provider(provider_2.id, provider_2_default_model, db_session)
|
||||
|
||||
# Step 8: Call run_test_default_provider - should use provider 2
|
||||
with patch(
|
||||
@@ -596,7 +567,6 @@ class TestDefaultProviderEndpoint:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
@@ -605,7 +575,7 @@ class TestDefaultProviderEndpoint:
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
update_default_provider(provider.id, db_session)
|
||||
update_default_provider(provider.id, "gpt-4o-mini", db_session)
|
||||
|
||||
# Test should fail
|
||||
with patch(
|
||||
|
||||
@@ -49,7 +49,6 @@ def _create_test_provider(
|
||||
api_key_changed=True,
|
||||
api_base=api_base,
|
||||
custom_config=custom_config,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o-mini", is_visible=True)
|
||||
],
|
||||
@@ -91,14 +90,14 @@ class TestLLMProviderChanges:
|
||||
the API key should be blocked.
|
||||
"""
|
||||
try:
|
||||
_create_test_provider(db_session, provider_name)
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base="https://attacker.example.com",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -125,16 +124,16 @@ class TestLLMProviderChanges:
|
||||
Changing api_base IS allowed when the API key is also being changed.
|
||||
"""
|
||||
try:
|
||||
_create_test_provider(db_session, provider_name)
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-new-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
api_base="https://custom-endpoint.example.com/v1",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -159,14 +158,16 @@ class TestLLMProviderChanges:
|
||||
original_api_base = "https://original.example.com/v1"
|
||||
|
||||
try:
|
||||
_create_test_provider(db_session, provider_name, api_base=original_api_base)
|
||||
provider = _create_test_provider(
|
||||
db_session, provider_name, api_base=original_api_base
|
||||
)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base=original_api_base,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -190,14 +191,14 @@ class TestLLMProviderChanges:
|
||||
changes. This allows model-only updates when provider has no custom base URL.
|
||||
"""
|
||||
try:
|
||||
_create_test_provider(db_session, provider_name, api_base=None)
|
||||
view = _create_test_provider(db_session, provider_name, api_base=None)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=view.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base="",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -223,14 +224,16 @@ class TestLLMProviderChanges:
|
||||
original_api_base = "https://original.example.com/v1"
|
||||
|
||||
try:
|
||||
_create_test_provider(db_session, provider_name, api_base=original_api_base)
|
||||
provider = _create_test_provider(
|
||||
db_session, provider_name, api_base=original_api_base
|
||||
)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base=None,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -259,14 +262,14 @@ class TestLLMProviderChanges:
|
||||
users have full control over their deployment.
|
||||
"""
|
||||
try:
|
||||
_create_test_provider(db_session, provider_name)
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base="https://custom.example.com/v1",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -297,7 +300,6 @@ class TestLLMProviderChanges:
|
||||
api_key="sk-new-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
api_base="https://custom.example.com/v1",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -322,7 +324,7 @@ class TestLLMProviderChanges:
|
||||
redirect LLM API requests).
|
||||
"""
|
||||
try:
|
||||
_create_test_provider(
|
||||
provider = _create_test_provider(
|
||||
db_session,
|
||||
provider_name,
|
||||
custom_config={"SOME_CONFIG": "original_value"},
|
||||
@@ -330,11 +332,11 @@ class TestLLMProviderChanges:
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
custom_config={"OPENAI_API_BASE": "https://attacker.example.com"},
|
||||
custom_config_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -362,15 +364,15 @@ class TestLLMProviderChanges:
|
||||
without changing the API key.
|
||||
"""
|
||||
try:
|
||||
_create_test_provider(db_session, provider_name)
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
custom_config={"OPENAI_API_BASE": "https://attacker.example.com"},
|
||||
custom_config_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -399,7 +401,7 @@ class TestLLMProviderChanges:
|
||||
new_config = {"AWS_REGION_NAME": "us-west-2"}
|
||||
|
||||
try:
|
||||
_create_test_provider(
|
||||
provider = _create_test_provider(
|
||||
db_session,
|
||||
provider_name,
|
||||
custom_config={"AWS_REGION_NAME": "us-east-1"},
|
||||
@@ -407,13 +409,13 @@ class TestLLMProviderChanges:
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-new-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=True,
|
||||
custom_config=new_config,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -438,17 +440,17 @@ class TestLLMProviderChanges:
|
||||
original_config = {"AWS_REGION_NAME": "us-east-1"}
|
||||
|
||||
try:
|
||||
_create_test_provider(
|
||||
provider = _create_test_provider(
|
||||
db_session, provider_name, custom_config=original_config
|
||||
)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
custom_config=original_config,
|
||||
custom_config_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -474,7 +476,7 @@ class TestLLMProviderChanges:
|
||||
new_config = {"AWS_REGION_NAME": "eu-west-1"}
|
||||
|
||||
try:
|
||||
_create_test_provider(
|
||||
provider = _create_test_provider(
|
||||
db_session,
|
||||
provider_name,
|
||||
custom_config={"AWS_REGION_NAME": "us-east-1"},
|
||||
@@ -482,10 +484,10 @@ class TestLLMProviderChanges:
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
custom_config=new_config,
|
||||
default_model_name="gpt-4o-mini",
|
||||
custom_config_changed=True,
|
||||
)
|
||||
|
||||
@@ -530,14 +532,8 @@ def test_upload_with_custom_config_then_change(
|
||||
with patch("onyx.server.manage.llm.api.test_llm", side_effect=capture_test_llm):
|
||||
run_llm_config_test(
|
||||
LLMTestRequest(
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
model=default_model_name,
|
||||
api_key_changed=False,
|
||||
custom_config_changed=True,
|
||||
custom_config=custom_config,
|
||||
@@ -546,11 +542,10 @@ def test_upload_with_custom_config_then_change(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
put_llm_provider(
|
||||
provider = put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
default_model_name=default_model_name,
|
||||
custom_config=custom_config,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -569,14 +564,9 @@ def test_upload_with_custom_config_then_change(
|
||||
# Turn auto mode off
|
||||
run_llm_config_test(
|
||||
LLMTestRequest(
|
||||
name=name,
|
||||
id=provider.id,
|
||||
provider=provider_name,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
model=default_model_name,
|
||||
api_key_changed=False,
|
||||
custom_config_changed=False,
|
||||
),
|
||||
@@ -586,9 +576,9 @@ def test_upload_with_custom_config_then_change(
|
||||
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
@@ -616,13 +606,13 @@ def test_upload_with_custom_config_then_change(
|
||||
)
|
||||
|
||||
# Check inside the database and check that custom_config is the same as the original
|
||||
provider = fetch_existing_llm_provider(name=name, db_session=db_session)
|
||||
if not provider:
|
||||
db_provider = fetch_existing_llm_provider(name=name, db_session=db_session)
|
||||
if not db_provider:
|
||||
assert False, "Provider not found in the database"
|
||||
|
||||
assert provider.custom_config == custom_config, (
|
||||
assert db_provider.custom_config == custom_config, (
|
||||
f"Expected custom_config {custom_config}, "
|
||||
f"but got {provider.custom_config}"
|
||||
f"but got {db_provider.custom_config}"
|
||||
)
|
||||
finally:
|
||||
db_session.rollback()
|
||||
@@ -642,11 +632,10 @@ def test_preserves_masked_sensitive_custom_config_on_provider_update(
|
||||
}
|
||||
|
||||
try:
|
||||
put_llm_provider(
|
||||
view = put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=provider,
|
||||
default_model_name=default_model_name,
|
||||
custom_config=original_custom_config,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -665,9 +654,9 @@ def test_preserves_masked_sensitive_custom_config_on_provider_update(
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
id=view.id,
|
||||
name=name,
|
||||
provider=provider,
|
||||
default_model_name=default_model_name,
|
||||
custom_config={
|
||||
"vertex_credentials": _mask_string(
|
||||
original_custom_config["vertex_credentials"]
|
||||
@@ -706,7 +695,7 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
|
||||
) -> None:
|
||||
"""LLM test should restore masked sensitive custom config values before invocation."""
|
||||
name = f"test-provider-vertex-test-{uuid4().hex[:8]}"
|
||||
provider = LlmProviderNames.VERTEX_AI.value
|
||||
provider_name = LlmProviderNames.VERTEX_AI.value
|
||||
default_model_name = "gemini-2.5-pro"
|
||||
original_custom_config = {
|
||||
"vertex_credentials": '{"type":"service_account","private_key":"REAL_PRIVATE_KEY"}',
|
||||
@@ -719,11 +708,10 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
|
||||
return ""
|
||||
|
||||
try:
|
||||
put_llm_provider(
|
||||
provider = put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=provider,
|
||||
default_model_name=default_model_name,
|
||||
provider=provider_name,
|
||||
custom_config=original_custom_config,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -742,14 +730,9 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
|
||||
with patch("onyx.server.manage.llm.api.test_llm", side_effect=capture_test_llm):
|
||||
run_llm_config_test(
|
||||
LLMTestRequest(
|
||||
name=name,
|
||||
provider=provider,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
id=provider.id,
|
||||
provider=provider_name,
|
||||
model=default_model_name,
|
||||
api_key_changed=False,
|
||||
custom_config_changed=True,
|
||||
custom_config={
|
||||
|
||||
@@ -15,9 +15,11 @@ import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.llm import fetch_auto_mode_providers
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_llm_provider_view
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import sync_auto_mode_models
|
||||
from onyx.db.llm import update_default_provider
|
||||
@@ -135,7 +137,6 @@ class TestAutoModeSyncFeature:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name=expected_default_model,
|
||||
model_configurations=[], # No model configs provided
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -163,13 +164,8 @@ class TestAutoModeSyncFeature:
|
||||
if mc.name in all_expected_models:
|
||||
assert mc.is_visible is True, f"Model '{mc.name}' should be visible"
|
||||
|
||||
# Verify the default model was set correctly
|
||||
assert (
|
||||
provider.default_model_name == expected_default_model
|
||||
), f"Default model should be '{expected_default_model}'"
|
||||
|
||||
# Step 4: Set the provider as default
|
||||
update_default_provider(provider.id, db_session)
|
||||
update_default_provider(provider.id, expected_default_model, db_session)
|
||||
|
||||
# Step 5: Fetch the default provider and verify
|
||||
default_model = fetch_default_llm_model(db_session)
|
||||
@@ -238,7 +234,6 @@ class TestAutoModeSyncFeature:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name="gpt-4o",
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -317,7 +312,6 @@ class TestAutoModeSyncFeature:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=False, # Not in auto mode initially
|
||||
default_model_name="gpt-4",
|
||||
model_configurations=initial_models,
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -326,13 +320,13 @@ class TestAutoModeSyncFeature:
|
||||
)
|
||||
|
||||
# Verify initial state: all models are visible
|
||||
provider = fetch_existing_llm_provider(
|
||||
initial_provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
assert provider.is_auto_mode is False
|
||||
assert initial_provider is not None
|
||||
assert initial_provider.is_auto_mode is False
|
||||
|
||||
for mc in provider.model_configurations:
|
||||
for mc in initial_provider.model_configurations:
|
||||
assert (
|
||||
mc.is_visible is True
|
||||
), f"Initial model '{mc.name}' should be visible"
|
||||
@@ -344,12 +338,12 @@ class TestAutoModeSyncFeature:
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
id=initial_provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None, # Not changing API key
|
||||
api_key_changed=False,
|
||||
is_auto_mode=True, # Now enabling auto mode
|
||||
default_model_name=auto_mode_default,
|
||||
model_configurations=[], # Auto mode will sync from config
|
||||
),
|
||||
is_creation=False, # This is an update
|
||||
@@ -360,15 +354,15 @@ class TestAutoModeSyncFeature:
|
||||
# Step 3: Verify model visibility after auto mode transition
|
||||
# Expire session cache to force fresh fetch after sync_auto_mode_models committed
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
provider_view = fetch_llm_provider_view(
|
||||
provider_name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
assert provider.is_auto_mode is True
|
||||
assert provider_view is not None
|
||||
assert provider_view.is_auto_mode is True
|
||||
|
||||
# Build a map of model name -> visibility
|
||||
model_visibility = {
|
||||
mc.name: mc.is_visible for mc in provider.model_configurations
|
||||
mc.name: mc.is_visible for mc in provider_view.model_configurations
|
||||
}
|
||||
|
||||
# Models in auto mode config should be visible
|
||||
@@ -388,9 +382,6 @@ class TestAutoModeSyncFeature:
|
||||
model_visibility[model_name] is False
|
||||
), f"Model '{model_name}' not in auto config should NOT be visible"
|
||||
|
||||
# Verify the default model was updated
|
||||
assert provider.default_model_name == auto_mode_default
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
@@ -432,8 +423,12 @@ class TestAutoModeSyncFeature:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name="gpt-4o",
|
||||
model_configurations=[],
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o",
|
||||
is_visible=True,
|
||||
)
|
||||
],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
@@ -535,7 +530,6 @@ class TestAutoModeSyncFeature:
|
||||
api_key=provider_1_api_key,
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name=provider_1_default_model,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -549,7 +543,7 @@ class TestAutoModeSyncFeature:
|
||||
name=provider_1_name, db_session=db_session
|
||||
)
|
||||
assert provider_1 is not None
|
||||
update_default_provider(provider_1.id, db_session)
|
||||
update_default_provider(provider_1.id, provider_1_default_model, db_session)
|
||||
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
@@ -563,7 +557,6 @@ class TestAutoModeSyncFeature:
|
||||
api_key=provider_2_api_key,
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name=provider_2_default_model,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -584,7 +577,7 @@ class TestAutoModeSyncFeature:
|
||||
name=provider_2_name, db_session=db_session
|
||||
)
|
||||
assert provider_2 is not None
|
||||
update_default_provider(provider_2.id, db_session)
|
||||
update_default_provider(provider_2.id, provider_2_default_model, db_session)
|
||||
|
||||
# Step 5: Verify provider 2 is now the default
|
||||
db_session.expire_all()
|
||||
@@ -644,7 +637,6 @@ class TestAutoModeMissingFlows:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name="gpt-4o",
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -701,3 +693,364 @@ class TestAutoModeMissingFlows:
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
|
||||
class TestAutoModeTransitionsAndResync:
|
||||
"""Tests for auto/manual transitions, config evolution, and sync idempotency."""
|
||||
|
||||
def test_auto_to_manual_mode_preserves_models_and_stops_syncing(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""Disabling auto mode should preserve the current model list and
|
||||
prevent future syncs from altering visibility.
|
||||
|
||||
Steps:
|
||||
1. Create provider in auto mode — models synced from config.
|
||||
2. Update provider to manual mode (is_auto_mode=False).
|
||||
3. Verify all models remain with unchanged visibility.
|
||||
4. Call sync_auto_mode_models with a *different* config.
|
||||
5. Verify fetch_auto_mode_providers excludes this provider, so the
|
||||
periodic task would never call sync on it.
|
||||
"""
|
||||
initial_config = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o",
|
||||
additional_models=["gpt-4o-mini"],
|
||||
)
|
||||
|
||||
try:
|
||||
# Step 1: Create in auto mode
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
return_value=initial_config,
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
visibility_before = {
|
||||
mc.name: mc.is_visible for mc in provider.model_configurations
|
||||
}
|
||||
assert visibility_before == {"gpt-4o": True, "gpt-4o-mini": True}
|
||||
|
||||
# Step 2: Switch to manual mode
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None,
|
||||
api_key_changed=False,
|
||||
is_auto_mode=False,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
),
|
||||
],
|
||||
),
|
||||
is_creation=False,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Step 3: Models unchanged
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
assert provider.is_auto_mode is False
|
||||
visibility_after = {
|
||||
mc.name: mc.is_visible for mc in provider.model_configurations
|
||||
}
|
||||
assert visibility_after == visibility_before
|
||||
|
||||
# Step 4-5: Provider excluded from auto mode queries
|
||||
auto_providers = fetch_auto_mode_providers(db_session)
|
||||
auto_provider_ids = {p.id for p in auto_providers}
|
||||
assert provider.id not in auto_provider_ids
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
def test_resync_adds_new_and_hides_removed_models(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""When the GitHub config changes between syncs, a subsequent sync
|
||||
should add newly listed models and hide models that were removed.
|
||||
|
||||
Steps:
|
||||
1. Create provider in auto mode with config v1: [gpt-4o, gpt-4o-mini].
|
||||
2. Sync with config v2: [gpt-4o, gpt-4-turbo] (gpt-4o-mini removed,
|
||||
gpt-4-turbo added).
|
||||
3. Verify gpt-4o still visible, gpt-4o-mini hidden, gpt-4-turbo added
|
||||
and visible.
|
||||
"""
|
||||
config_v1 = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o",
|
||||
additional_models=["gpt-4o-mini"],
|
||||
)
|
||||
config_v2 = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o",
|
||||
additional_models=["gpt-4-turbo"],
|
||||
)
|
||||
|
||||
try:
|
||||
# Step 1: Create with config v1
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
return_value=config_v1,
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Step 2: Re-sync with config v2
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
|
||||
changes = sync_auto_mode_models(
|
||||
db_session=db_session,
|
||||
provider=provider,
|
||||
llm_recommendations=config_v2,
|
||||
)
|
||||
assert changes > 0
|
||||
|
||||
# Step 3: Verify
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
|
||||
visibility = {
|
||||
mc.name: mc.is_visible for mc in provider.model_configurations
|
||||
}
|
||||
|
||||
# gpt-4o: still in config -> visible
|
||||
assert visibility["gpt-4o"] is True
|
||||
# gpt-4o-mini: removed from config -> hidden (not deleted)
|
||||
assert "gpt-4o-mini" in visibility, "Removed model should still exist in DB"
|
||||
assert visibility["gpt-4o-mini"] is False
|
||||
# gpt-4-turbo: newly added -> visible
|
||||
assert visibility["gpt-4-turbo"] is True
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
def test_sync_is_idempotent(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""Running sync twice with the same config should produce zero
|
||||
changes on the second call."""
|
||||
config = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o",
|
||||
additional_models=["gpt-4o-mini", "gpt-4-turbo"],
|
||||
)
|
||||
|
||||
try:
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
return_value=config,
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
|
||||
# First explicit sync (may report changes if creation already synced)
|
||||
sync_auto_mode_models(
|
||||
db_session=db_session,
|
||||
provider=provider,
|
||||
llm_recommendations=config,
|
||||
)
|
||||
|
||||
# Snapshot state after first sync
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
snapshot = {
|
||||
mc.name: (mc.is_visible, mc.display_name)
|
||||
for mc in provider.model_configurations
|
||||
}
|
||||
|
||||
# Second sync — should be a no-op
|
||||
changes = sync_auto_mode_models(
|
||||
db_session=db_session,
|
||||
provider=provider,
|
||||
llm_recommendations=config,
|
||||
)
|
||||
assert (
|
||||
changes == 0
|
||||
), f"Expected 0 changes on idempotent re-sync, got {changes}"
|
||||
|
||||
# State should be identical
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
current = {
|
||||
mc.name: (mc.is_visible, mc.display_name)
|
||||
for mc in provider.model_configurations
|
||||
}
|
||||
assert current == snapshot
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
def test_default_model_hidden_when_removed_from_config(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""When the current default model is removed from the config, sync
|
||||
should hide it. The default model flow row should still exist (it
|
||||
points at the ModelConfiguration), but the model is no longer visible.
|
||||
|
||||
Steps:
|
||||
1. Create provider with config: default=gpt-4o, additional=[gpt-4o-mini].
|
||||
2. Set gpt-4o as the global default.
|
||||
3. Re-sync with config: default=gpt-4o-mini (gpt-4o removed entirely).
|
||||
4. Verify gpt-4o is hidden, gpt-4o-mini is visible, and
|
||||
fetch_default_llm_model still returns a result (the flow row persists).
|
||||
"""
|
||||
config_v1 = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o",
|
||||
additional_models=["gpt-4o-mini"],
|
||||
)
|
||||
config_v2 = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o-mini",
|
||||
additional_models=[],
|
||||
)
|
||||
|
||||
try:
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
return_value=config_v1,
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Step 2: Set gpt-4o as global default
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
default_before = fetch_default_llm_model(db_session)
|
||||
assert default_before is not None
|
||||
assert default_before.name == "gpt-4o"
|
||||
|
||||
# Step 3: Re-sync with config v2 (gpt-4o removed)
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
|
||||
changes = sync_auto_mode_models(
|
||||
db_session=db_session,
|
||||
provider=provider,
|
||||
llm_recommendations=config_v2,
|
||||
)
|
||||
assert changes > 0
|
||||
|
||||
# Step 4: Verify visibility
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
|
||||
visibility = {
|
||||
mc.name: mc.is_visible for mc in provider.model_configurations
|
||||
}
|
||||
assert visibility["gpt-4o"] is False, "Removed default should be hidden"
|
||||
assert visibility["gpt-4o-mini"] is True, "New default should be visible"
|
||||
|
||||
# The LLMModelFlow row for gpt-4o still exists (is_default=True),
|
||||
# but the model is hidden. fetch_default_llm_model filters on
|
||||
# is_visible=True, so it should NOT return gpt-4o.
|
||||
db_session.expire_all()
|
||||
default_after = fetch_default_llm_model(db_session)
|
||||
assert (
|
||||
default_after is None or default_after.name != "gpt-4o"
|
||||
), "Hidden model should not be returned as the default"
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
@@ -64,7 +64,6 @@ def _create_provider(
|
||||
name=name,
|
||||
provider=provider,
|
||||
api_key="sk-ant-api03-...",
|
||||
default_model_name="claude-3-5-sonnet-20240620",
|
||||
is_public=is_public,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -154,7 +153,9 @@ def test_user_sends_message_to_private_provider(
|
||||
)
|
||||
_create_provider(db_session, LlmProviderNames.GOOGLE, "private-provider", False)
|
||||
|
||||
update_default_provider(public_provider_id, db_session)
|
||||
update_default_provider(
|
||||
public_provider_id, "claude-3-5-sonnet-20240620", db_session
|
||||
)
|
||||
|
||||
try:
|
||||
# Create chat session
|
||||
|
||||
@@ -42,7 +42,6 @@ def _create_llm_provider_and_model(
|
||||
name=provider_name,
|
||||
provider="openai",
|
||||
api_key="test-api-key",
|
||||
default_model_name=model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=model_name,
|
||||
|
||||
@@ -22,7 +22,6 @@ from sqlalchemy.orm import Session
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
from onyx.configs.constants import FederatedConnectorSource
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.context.search.federated.slack_search import fetch_and_cache_channel_metadata
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import FederatedConnector
|
||||
@@ -55,11 +54,6 @@ def _create_test_persona_with_slack_config(db_session: Session) -> Persona | Non
|
||||
persona = Persona(
|
||||
name=f"test_slack_persona_{unique_id}",
|
||||
description="Test persona for Slack federated search",
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=True,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
system_prompt="You are a helpful assistant.",
|
||||
task_prompt="Answer the user's question based on the provided context.",
|
||||
)
|
||||
@@ -434,7 +428,6 @@ class TestSlackBotFederatedSearch:
|
||||
name=f"test-llm-provider-{uuid4().hex[:8]}",
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=api_key,
|
||||
default_model_name="gpt-4o",
|
||||
is_public=True,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -448,7 +441,7 @@ class TestSlackBotFederatedSearch:
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
update_default_provider(provider_view.id, db_session)
|
||||
update_default_provider(provider_view.id, "gpt-4o", db_session)
|
||||
|
||||
def _teardown_common_mocks(self, patches: list) -> None:
|
||||
"""Stop all patches"""
|
||||
@@ -819,11 +812,6 @@ def test_slack_channel_config_eager_loads_persona(db_session: Session) -> None:
|
||||
persona = Persona(
|
||||
name=f"test_eager_load_persona_{unique_id}",
|
||||
description="Test persona for eager loading test",
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=True,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
system_prompt="You are a helpful assistant.",
|
||||
task_prompt="Answer the user's question.",
|
||||
)
|
||||
|
||||
@@ -21,7 +21,6 @@ import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import MCPAuthenticationPerformer
|
||||
from onyx.db.enums import MCPAuthenticationType
|
||||
from onyx.db.enums import MCPTransport
|
||||
@@ -47,12 +46,6 @@ def _create_test_persona_with_mcp_tool(
|
||||
persona = Persona(
|
||||
name=f"Test MCP Persona {uuid4().hex[:8]}",
|
||||
description="Test persona with MCP tools",
|
||||
num_chunks=10.0,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
system_prompt="You are a helpful assistant",
|
||||
task_prompt="Answer the user's question",
|
||||
tools=tools,
|
||||
|
||||
@@ -17,7 +17,6 @@ import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import OAuthConfig
|
||||
from onyx.db.models import Persona
|
||||
@@ -57,12 +56,6 @@ def _create_test_persona(db_session: Session, user: User, tools: list[Tool]) ->
|
||||
persona = Persona(
|
||||
name=f"Test Persona {uuid4().hex[:8]}",
|
||||
description="Test persona",
|
||||
num_chunks=10.0,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
system_prompt="You are a helpful assistant",
|
||||
task_prompt="Answer the user's question",
|
||||
tools=tools,
|
||||
|
||||
@@ -933,6 +933,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi import UploadFile
|
||||
from fastapi.background import BackgroundTasks
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.datastructures import Headers
|
||||
|
||||
@@ -1139,6 +1140,7 @@ def test_code_interpreter_receives_chat_files(
|
||||
# Upload a test CSV
|
||||
csv_content = b"name,age,city\nAlice,30,NYC\nBob,25,SF\n"
|
||||
result = upload_user_files(
|
||||
bg_tasks=BackgroundTasks(),
|
||||
files=[
|
||||
UploadFile(
|
||||
file=io.BytesIO(csv_content),
|
||||
|
||||
@@ -4,10 +4,12 @@ from uuid import uuid4
|
||||
import requests
|
||||
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.server.manage.llm.models import DefaultModel
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -32,7 +34,6 @@ class LLMProviderManager:
|
||||
llm_provider = LLMProviderUpsertRequest(
|
||||
name=name or f"test-provider-{uuid4()}",
|
||||
provider=provider or LlmProviderNames.OPENAI,
|
||||
default_model_name=default_model_name or "gpt-4o-mini",
|
||||
api_key=api_key or os.environ["OPENAI_API_KEY"],
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
@@ -65,7 +66,7 @@ class LLMProviderManager:
|
||||
name=response_data["name"],
|
||||
provider=response_data["provider"],
|
||||
api_key=response_data["api_key"],
|
||||
default_model_name=response_data["default_model_name"],
|
||||
default_model_name=default_model_name or "gpt-4o-mini",
|
||||
is_public=response_data["is_public"],
|
||||
is_auto_mode=response_data.get("is_auto_mode", False),
|
||||
groups=response_data["groups"],
|
||||
@@ -75,9 +76,19 @@ class LLMProviderManager:
|
||||
)
|
||||
|
||||
if set_as_default:
|
||||
if default_model_name is None:
|
||||
default_model_name = "gpt-4o-mini"
|
||||
set_default_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{llm_response.json()['id']}/default",
|
||||
headers=user_performing_action.headers,
|
||||
f"{API_SERVER_URL}/admin/llm/default",
|
||||
json={
|
||||
"provider_id": response_data["id"],
|
||||
"model_name": default_model_name,
|
||||
},
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
)
|
||||
set_default_response.raise_for_status()
|
||||
|
||||
@@ -104,7 +115,7 @@ class LLMProviderManager:
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [LLMProviderView(**ug) for ug in response.json()]
|
||||
return [LLMProviderView(**p) for p in response.json()["providers"]]
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
@@ -113,7 +124,11 @@ class LLMProviderManager:
|
||||
verify_deleted: bool = False,
|
||||
) -> None:
|
||||
all_llm_providers = LLMProviderManager.get_all(user_performing_action)
|
||||
default_model = LLMProviderManager.get_default_model(user_performing_action)
|
||||
for fetched_llm_provider in all_llm_providers:
|
||||
model_names = [
|
||||
model.name for model in fetched_llm_provider.model_configurations
|
||||
]
|
||||
if llm_provider.id == fetched_llm_provider.id:
|
||||
if verify_deleted:
|
||||
raise ValueError(
|
||||
@@ -126,11 +141,30 @@ class LLMProviderManager:
|
||||
if (
|
||||
fetched_llm_groups == llm_provider_groups
|
||||
and llm_provider.provider == fetched_llm_provider.provider
|
||||
and llm_provider.default_model_name
|
||||
== fetched_llm_provider.default_model_name
|
||||
and (
|
||||
default_model is None or default_model.model_name in model_names
|
||||
)
|
||||
and llm_provider.is_public == fetched_llm_provider.is_public
|
||||
and set(fetched_llm_provider.personas) == set(llm_provider.personas)
|
||||
):
|
||||
return
|
||||
if not verify_deleted:
|
||||
raise ValueError(f"LLM Provider {llm_provider.id} not found")
|
||||
|
||||
@staticmethod
|
||||
def get_default_model(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DefaultModel | None:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
default_text = response.json().get("default_text")
|
||||
if default_text is None:
|
||||
return None
|
||||
return DefaultModel(**default_text)
|
||||
|
||||
@@ -3,7 +3,6 @@ from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.server.features.persona.models import FullPersonaSnapshot
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
@@ -20,11 +19,7 @@ class PersonaManager:
|
||||
description: str | None = None,
|
||||
system_prompt: str | None = None,
|
||||
task_prompt: str | None = None,
|
||||
num_chunks: float = 5,
|
||||
llm_relevance_filter: bool = True,
|
||||
is_public: bool = True,
|
||||
llm_filter_extraction: bool = True,
|
||||
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO,
|
||||
datetime_aware: bool = False,
|
||||
document_set_ids: list[int] | None = None,
|
||||
tool_ids: list[int] | None = None,
|
||||
@@ -35,6 +30,7 @@ class PersonaManager:
|
||||
label_ids: list[int] | None = None,
|
||||
user_file_ids: list[str] | None = None,
|
||||
display_priority: int | None = None,
|
||||
featured: bool = False,
|
||||
) -> DATestPersona:
|
||||
name = name or f"test-persona-{uuid4()}"
|
||||
description = description or f"Description for {name}"
|
||||
@@ -47,11 +43,7 @@ class PersonaManager:
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
datetime_aware=datetime_aware,
|
||||
num_chunks=num_chunks,
|
||||
llm_relevance_filter=llm_relevance_filter,
|
||||
is_public=is_public,
|
||||
llm_filter_extraction=llm_filter_extraction,
|
||||
recency_bias=recency_bias,
|
||||
document_set_ids=document_set_ids or [],
|
||||
tool_ids=tool_ids or [],
|
||||
llm_model_provider_override=llm_model_provider_override,
|
||||
@@ -61,6 +53,7 @@ class PersonaManager:
|
||||
label_ids=label_ids or [],
|
||||
user_file_ids=user_file_ids or [],
|
||||
display_priority=display_priority,
|
||||
featured=featured,
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
@@ -75,11 +68,7 @@ class PersonaManager:
|
||||
id=persona_data["id"],
|
||||
name=name,
|
||||
description=description,
|
||||
num_chunks=num_chunks,
|
||||
llm_relevance_filter=llm_relevance_filter,
|
||||
is_public=is_public,
|
||||
llm_filter_extraction=llm_filter_extraction,
|
||||
recency_bias=recency_bias,
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
datetime_aware=datetime_aware,
|
||||
@@ -90,6 +79,7 @@ class PersonaManager:
|
||||
users=users or [],
|
||||
groups=groups or [],
|
||||
label_ids=label_ids or [],
|
||||
featured=featured,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -100,11 +90,7 @@ class PersonaManager:
|
||||
description: str | None = None,
|
||||
system_prompt: str | None = None,
|
||||
task_prompt: str | None = None,
|
||||
num_chunks: float | None = None,
|
||||
llm_relevance_filter: bool | None = None,
|
||||
is_public: bool | None = None,
|
||||
llm_filter_extraction: bool | None = None,
|
||||
recency_bias: RecencyBiasSetting | None = None,
|
||||
datetime_aware: bool = False,
|
||||
document_set_ids: list[int] | None = None,
|
||||
tool_ids: list[int] | None = None,
|
||||
@@ -113,6 +99,7 @@ class PersonaManager:
|
||||
users: list[str] | None = None,
|
||||
groups: list[int] | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
featured: bool | None = None,
|
||||
) -> DATestPersona:
|
||||
system_prompt = system_prompt or f"System prompt for {persona.name}"
|
||||
task_prompt = task_prompt or f"Task prompt for {persona.name}"
|
||||
@@ -123,13 +110,7 @@ class PersonaManager:
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
datetime_aware=datetime_aware,
|
||||
num_chunks=num_chunks or persona.num_chunks,
|
||||
llm_relevance_filter=llm_relevance_filter or persona.llm_relevance_filter,
|
||||
is_public=persona.is_public if is_public is None else is_public,
|
||||
llm_filter_extraction=(
|
||||
llm_filter_extraction or persona.llm_filter_extraction
|
||||
),
|
||||
recency_bias=recency_bias or persona.recency_bias,
|
||||
document_set_ids=document_set_ids or persona.document_set_ids,
|
||||
tool_ids=tool_ids or persona.tool_ids,
|
||||
llm_model_provider_override=(
|
||||
@@ -141,6 +122,7 @@ class PersonaManager:
|
||||
users=[UUID(user) for user in (users or persona.users)],
|
||||
groups=groups or persona.groups,
|
||||
label_ids=label_ids or persona.label_ids,
|
||||
featured=featured if featured is not None else persona.featured,
|
||||
)
|
||||
|
||||
response = requests.patch(
|
||||
@@ -155,16 +137,12 @@ class PersonaManager:
|
||||
id=updated_persona_data["id"],
|
||||
name=updated_persona_data["name"],
|
||||
description=updated_persona_data["description"],
|
||||
num_chunks=updated_persona_data["num_chunks"],
|
||||
llm_relevance_filter=updated_persona_data["llm_relevance_filter"],
|
||||
is_public=updated_persona_data["is_public"],
|
||||
llm_filter_extraction=updated_persona_data["llm_filter_extraction"],
|
||||
recency_bias=recency_bias or persona.recency_bias,
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
datetime_aware=datetime_aware,
|
||||
document_set_ids=updated_persona_data["document_sets"],
|
||||
tool_ids=updated_persona_data["tools"],
|
||||
document_set_ids=[ds["id"] for ds in updated_persona_data["document_sets"]],
|
||||
tool_ids=[t["id"] for t in updated_persona_data["tools"]],
|
||||
llm_model_provider_override=updated_persona_data[
|
||||
"llm_model_provider_override"
|
||||
],
|
||||
@@ -173,7 +151,8 @@ class PersonaManager:
|
||||
],
|
||||
users=[user["email"] for user in updated_persona_data["users"]],
|
||||
groups=updated_persona_data["groups"],
|
||||
label_ids=updated_persona_data["labels"],
|
||||
label_ids=[label["id"] for label in updated_persona_data["labels"]],
|
||||
featured=updated_persona_data["featured"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -222,32 +201,13 @@ class PersonaManager:
|
||||
fetched_persona.description,
|
||||
)
|
||||
)
|
||||
if fetched_persona.num_chunks != persona.num_chunks:
|
||||
mismatches.append(
|
||||
("num_chunks", persona.num_chunks, fetched_persona.num_chunks)
|
||||
)
|
||||
if fetched_persona.llm_relevance_filter != persona.llm_relevance_filter:
|
||||
mismatches.append(
|
||||
(
|
||||
"llm_relevance_filter",
|
||||
persona.llm_relevance_filter,
|
||||
fetched_persona.llm_relevance_filter,
|
||||
)
|
||||
)
|
||||
if fetched_persona.is_public != persona.is_public:
|
||||
mismatches.append(
|
||||
("is_public", persona.is_public, fetched_persona.is_public)
|
||||
)
|
||||
if (
|
||||
fetched_persona.llm_filter_extraction
|
||||
!= persona.llm_filter_extraction
|
||||
):
|
||||
if fetched_persona.featured != persona.featured:
|
||||
mismatches.append(
|
||||
(
|
||||
"llm_filter_extraction",
|
||||
persona.llm_filter_extraction,
|
||||
fetched_persona.llm_filter_extraction,
|
||||
)
|
||||
("featured", persona.featured, fetched_persona.featured)
|
||||
)
|
||||
if (
|
||||
fetched_persona.llm_model_provider_override
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
|
||||
|
||||
class ScimClient:
|
||||
"""HTTP client for making authenticated SCIM v2 requests."""
|
||||
|
||||
@staticmethod
|
||||
def _headers(raw_token: str) -> dict[str, str]:
|
||||
return {
|
||||
**GENERAL_HEADERS,
|
||||
"Authorization": f"Bearer {raw_token}",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get(path: str, raw_token: str) -> requests.Response:
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=ScimClient._headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def post(path: str, raw_token: str, json: dict) -> requests.Response:
|
||||
return requests.post(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
json=json,
|
||||
headers=ScimClient._headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def put(path: str, raw_token: str, json: dict) -> requests.Response:
|
||||
return requests.put(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
json=json,
|
||||
headers=ScimClient._headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def patch(path: str, raw_token: str, json: dict) -> requests.Response:
|
||||
return requests.patch(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
json=json,
|
||||
headers=ScimClient._headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def delete(path: str, raw_token: str) -> requests.Response:
|
||||
return requests.delete(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=ScimClient._headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_no_auth(path: str) -> requests.Response:
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=GENERAL_HEADERS,
|
||||
timeout=60,
|
||||
)
|
||||
@@ -1,7 +1,6 @@
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestScimToken
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -51,29 +50,3 @@ class ScimTokenManager:
|
||||
created_at=data["created_at"],
|
||||
last_used_at=data.get("last_used_at"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_scim_headers(raw_token: str) -> dict[str, str]:
|
||||
return {
|
||||
**GENERAL_HEADERS,
|
||||
"Authorization": f"Bearer {raw_token}",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def scim_get(
|
||||
path: str,
|
||||
raw_token: str,
|
||||
) -> requests.Response:
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=ScimTokenManager.get_scim_headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def scim_get_no_auth(path: str) -> requests.Response:
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=GENERAL_HEADERS,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@@ -10,7 +10,6 @@ from pydantic import Field
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.db.enums import AccessType
|
||||
@@ -128,7 +127,7 @@ class DATestLLMProvider(BaseModel):
|
||||
name: str
|
||||
provider: str
|
||||
api_key: str
|
||||
default_model_name: str
|
||||
default_model_name: str | None = None
|
||||
is_public: bool
|
||||
is_auto_mode: bool = False
|
||||
groups: list[int]
|
||||
@@ -162,11 +161,7 @@ class DATestPersona(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str
|
||||
num_chunks: float
|
||||
llm_relevance_filter: bool
|
||||
is_public: bool
|
||||
llm_filter_extraction: bool
|
||||
recency_bias: RecencyBiasSetting
|
||||
document_set_ids: list[int]
|
||||
tool_ids: list[int]
|
||||
llm_model_provider_override: str | None
|
||||
@@ -174,6 +169,7 @@ class DATestPersona(BaseModel):
|
||||
users: list[str]
|
||||
groups: list[int]
|
||||
label_ids: list[int]
|
||||
featured: bool = False
|
||||
|
||||
# Embedded prompt fields (no longer separate prompt_ids)
|
||||
system_prompt: str | None = None
|
||||
|
||||
@@ -8,7 +8,6 @@ from collections.abc import Generator
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.discord_bot import bulk_create_channel_configs
|
||||
from onyx.db.discord_bot import create_discord_bot_config
|
||||
from onyx.db.discord_bot import create_guild_config
|
||||
@@ -36,14 +35,8 @@ def _create_test_persona(db_session: Session, persona_id: int, name: str) -> Per
|
||||
id=persona_id,
|
||||
name=name,
|
||||
description="Test persona for Discord bot tests",
|
||||
num_chunks=5.0,
|
||||
chunks_above=1,
|
||||
chunks_below=1,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.FAVOR_RECENT,
|
||||
is_visible=True,
|
||||
is_default_persona=False,
|
||||
featured=False,
|
||||
deleted=False,
|
||||
builtin_persona=False,
|
||||
)
|
||||
|
||||
@@ -414,6 +414,24 @@ def test_mock_connector_checkpoint_recovery(
|
||||
)
|
||||
assert finished_index_attempt.status == IndexingStatus.FAILED
|
||||
|
||||
# Pause the connector immediately to prevent check_for_indexing from
|
||||
# creating automatic retry attempts while we reset the mock server.
|
||||
# Without this, the INITIAL_INDEXING status causes immediate retries
|
||||
# that would consume (or fail against) the mock server before we can
|
||||
# set up the recovery behavior.
|
||||
CCPairManager.pause_cc_pair(cc_pair, user_performing_action=admin_user)
|
||||
|
||||
# Collect all index attempt IDs created so far (the initial one plus
|
||||
# any automatic retries that may have started before the pause took effect).
|
||||
all_prior_attempt_ids: list[int] = []
|
||||
index_attempts_page = IndexAttemptManager.get_index_attempt_page(
|
||||
cc_pair_id=cc_pair.id,
|
||||
page=0,
|
||||
page_size=100,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
all_prior_attempt_ids = [ia.id for ia in index_attempts_page.items]
|
||||
|
||||
# Verify initial state: both docs should be indexed
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
documents = DocumentManager.fetch_documents_for_cc_pair(
|
||||
@@ -465,17 +483,14 @@ def test_mock_connector_checkpoint_recovery(
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# After the failure, the connector is in repeated error state and paused.
|
||||
# Set the manual indexing trigger first (while paused), then unpause.
|
||||
# This ensures the trigger is set before CHECK_FOR_INDEXING runs, which will
|
||||
# prevent the connector from being re-paused when repeated error state is detected.
|
||||
# Set the manual indexing trigger, then unpause to allow the recovery run.
|
||||
CCPairManager.run_once(
|
||||
cc_pair, from_beginning=False, user_performing_action=admin_user
|
||||
)
|
||||
CCPairManager.unpause_cc_pair(cc_pair, user_performing_action=admin_user)
|
||||
recovery_index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempts_to_ignore=[initial_index_attempt.id],
|
||||
index_attempts_to_ignore=all_prior_attempt_ids,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
IndexAttemptManager.wait_for_index_attempt_completion(
|
||||
|
||||
@@ -130,8 +130,8 @@ def test_repeated_error_state_detection_and_recovery(
|
||||
# )
|
||||
break
|
||||
|
||||
if time.monotonic() - start_time > 30:
|
||||
assert False, "CC pair did not enter repeated error state within 30 seconds"
|
||||
if time.monotonic() - start_time > 90:
|
||||
assert False, "CC pair did not enter repeated error state within 90 seconds"
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
@@ -42,12 +42,10 @@ def _create_provider_with_api(
|
||||
llm_provider_data = {
|
||||
"name": name,
|
||||
"provider": provider_type,
|
||||
"default_model_name": default_model,
|
||||
"api_key": "test-api-key-for-auto-mode-testing",
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
"custom_config": None,
|
||||
"fast_default_model_name": default_model,
|
||||
"is_public": True,
|
||||
"is_auto_mode": is_auto_mode,
|
||||
"groups": [],
|
||||
@@ -72,7 +70,7 @@ def _get_provider_by_id(admin_user: DATestUser, provider_id: int) -> dict:
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
for provider in response.json():
|
||||
for provider in response.json()["providers"]:
|
||||
if provider["id"] == provider_id:
|
||||
return provider
|
||||
raise ValueError(f"Provider with id {provider_id} not found")
|
||||
@@ -219,15 +217,6 @@ def test_auto_mode_provider_gets_synced_from_github_config(
|
||||
"is_visible"
|
||||
], "Outdated model should not be visible after sync"
|
||||
|
||||
# Verify default model was set from GitHub config
|
||||
expected_default = (
|
||||
default_model["name"] if isinstance(default_model, dict) else default_model
|
||||
)
|
||||
assert synced_provider["default_model_name"] == expected_default, (
|
||||
f"Default model should be {expected_default}, "
|
||||
f"got {synced_provider['default_model_name']}"
|
||||
)
|
||||
|
||||
|
||||
def test_manual_mode_provider_not_affected_by_auto_sync(
|
||||
reset: None, # noqa: ARG001
|
||||
@@ -273,7 +262,3 @@ def test_manual_mode_provider_not_affected_by_auto_sync(
|
||||
f"Manual mode provider models should not change. "
|
||||
f"Initial: {initial_models}, Current: {current_models}"
|
||||
)
|
||||
|
||||
assert (
|
||||
updated_provider["default_model_name"] == custom_model
|
||||
), f"Manual mode default model should remain {custom_model}"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,22 +4,22 @@ import pytest
|
||||
import requests
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.llm import can_user_access_llm_provider
|
||||
from onyx.db.llm import fetch_user_group_ids
|
||||
from onyx.db.models import LLMModelFlow
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import LLMProvider as LLMProviderModel
|
||||
from onyx.db.models import LLMProvider__Persona
|
||||
from onyx.db.models import LLMProvider__UserGroup
|
||||
from onyx.db.models import ModelConfiguration
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.persona import PersonaManager
|
||||
@@ -41,24 +41,30 @@ def _create_llm_provider(
|
||||
is_public: bool,
|
||||
is_default: bool,
|
||||
) -> LLMProviderModel:
|
||||
provider = LLMProviderModel(
|
||||
name=name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None,
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
custom_config=None,
|
||||
default_model_name=default_model_name,
|
||||
deployment_name=None,
|
||||
is_public=is_public,
|
||||
# Use None instead of False to avoid unique constraint violation
|
||||
# The is_default_provider column has unique=True, so only one True and one False allowed
|
||||
is_default_provider=is_default if is_default else None,
|
||||
is_default_vision_provider=False,
|
||||
default_vision_model=None,
|
||||
_provider = upsert_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None,
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
custom_config=None,
|
||||
is_public=is_public,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name,
|
||||
is_visible=True,
|
||||
)
|
||||
],
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.add(provider)
|
||||
db_session.flush()
|
||||
if is_default:
|
||||
update_default_provider(_provider.id, default_model_name, db_session)
|
||||
|
||||
provider = db_session.get(LLMProviderModel, _provider.id)
|
||||
if not provider:
|
||||
raise ValueError(f"Provider {name} not found")
|
||||
return provider
|
||||
|
||||
|
||||
@@ -71,12 +77,6 @@ def _create_persona(
|
||||
persona = Persona(
|
||||
name=name,
|
||||
description=f"{name} description",
|
||||
num_chunks=5,
|
||||
chunks_above=2,
|
||||
chunks_below=2,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=True,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
llm_model_provider_override=provider_name,
|
||||
llm_model_version_override="gpt-4o-mini",
|
||||
system_prompt="System prompt",
|
||||
@@ -243,6 +243,116 @@ def test_can_user_access_llm_provider_or_logic(
|
||||
)
|
||||
|
||||
|
||||
def test_public_provider_with_persona_restrictions(
|
||||
users: tuple[DATestUser, DATestUser],
|
||||
) -> None:
|
||||
"""Public providers should still enforce persona restrictions.
|
||||
|
||||
Regression test for the bug where is_public=True caused
|
||||
can_user_access_llm_provider() to return True immediately,
|
||||
bypassing persona whitelist checks entirely.
|
||||
"""
|
||||
admin_user, _basic_user = users
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Public provider with persona restrictions
|
||||
public_restricted = _create_llm_provider(
|
||||
db_session,
|
||||
name="public-persona-restricted",
|
||||
default_model_name="gpt-4o",
|
||||
is_public=True,
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
whitelisted_persona = _create_persona(
|
||||
db_session,
|
||||
name="whitelisted-persona",
|
||||
provider_name=public_restricted.name,
|
||||
)
|
||||
non_whitelisted_persona = _create_persona(
|
||||
db_session,
|
||||
name="non-whitelisted-persona",
|
||||
provider_name=public_restricted.name,
|
||||
)
|
||||
|
||||
# Only whitelist one persona
|
||||
db_session.add(
|
||||
LLMProvider__Persona(
|
||||
llm_provider_id=public_restricted.id,
|
||||
persona_id=whitelisted_persona.id,
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
db_session.refresh(public_restricted)
|
||||
|
||||
admin_model = db_session.get(User, admin_user.id)
|
||||
assert admin_model is not None
|
||||
admin_group_ids = fetch_user_group_ids(db_session, admin_model)
|
||||
|
||||
# Whitelisted persona — should be allowed
|
||||
assert can_user_access_llm_provider(
|
||||
public_restricted,
|
||||
admin_group_ids,
|
||||
whitelisted_persona,
|
||||
)
|
||||
|
||||
# Non-whitelisted persona — should be denied despite is_public=True
|
||||
assert not can_user_access_llm_provider(
|
||||
public_restricted,
|
||||
admin_group_ids,
|
||||
non_whitelisted_persona,
|
||||
)
|
||||
|
||||
# No persona context (e.g. global provider list) — should be denied
|
||||
# because provider has persona restrictions set
|
||||
assert not can_user_access_llm_provider(
|
||||
public_restricted,
|
||||
admin_group_ids,
|
||||
persona=None,
|
||||
)
|
||||
|
||||
|
||||
def test_public_provider_without_persona_restrictions(
|
||||
users: tuple[DATestUser, DATestUser],
|
||||
) -> None:
|
||||
"""Public providers with no persona restrictions remain accessible to all."""
|
||||
admin_user, basic_user = users
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
public_unrestricted = _create_llm_provider(
|
||||
db_session,
|
||||
name="public-unrestricted",
|
||||
default_model_name="gpt-4o",
|
||||
is_public=True,
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
any_persona = _create_persona(
|
||||
db_session,
|
||||
name="any-persona",
|
||||
provider_name=public_unrestricted.name,
|
||||
)
|
||||
|
||||
admin_model = db_session.get(User, admin_user.id)
|
||||
basic_model = db_session.get(User, basic_user.id)
|
||||
assert admin_model is not None
|
||||
assert basic_model is not None
|
||||
|
||||
admin_group_ids = fetch_user_group_ids(db_session, admin_model)
|
||||
basic_group_ids = fetch_user_group_ids(db_session, basic_model)
|
||||
|
||||
# Any user, any persona — all allowed
|
||||
assert can_user_access_llm_provider(
|
||||
public_unrestricted, admin_group_ids, any_persona
|
||||
)
|
||||
assert can_user_access_llm_provider(
|
||||
public_unrestricted, basic_group_ids, any_persona
|
||||
)
|
||||
assert can_user_access_llm_provider(
|
||||
public_unrestricted, admin_group_ids, persona=None
|
||||
)
|
||||
|
||||
|
||||
def test_get_llm_for_persona_falls_back_when_access_denied(
|
||||
users: tuple[DATestUser, DATestUser],
|
||||
) -> None:
|
||||
@@ -270,24 +380,6 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
|
||||
provider_name=restricted_provider.name,
|
||||
)
|
||||
|
||||
# Set up ModelConfiguration + LLMModelFlow so get_default_llm() can
|
||||
# resolve the default provider when the fallback path is triggered.
|
||||
default_model_config = ModelConfiguration(
|
||||
llm_provider_id=default_provider.id,
|
||||
name=default_provider.default_model_name,
|
||||
is_visible=True,
|
||||
)
|
||||
db_session.add(default_model_config)
|
||||
db_session.flush()
|
||||
db_session.add(
|
||||
LLMModelFlow(
|
||||
model_configuration_id=default_model_config.id,
|
||||
llm_model_flow_type=LLMModelFlowType.CHAT,
|
||||
is_default=True,
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
access_group = UserGroup(name="persona-group")
|
||||
db_session.add(access_group)
|
||||
db_session.flush()
|
||||
@@ -321,13 +413,19 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
|
||||
persona=persona,
|
||||
user=admin_model,
|
||||
)
|
||||
assert allowed_llm.config.model_name == restricted_provider.default_model_name
|
||||
assert (
|
||||
allowed_llm.config.model_name
|
||||
== restricted_provider.model_configurations[0].name
|
||||
)
|
||||
|
||||
fallback_llm = get_llm_for_persona(
|
||||
persona=persona,
|
||||
user=basic_model,
|
||||
)
|
||||
assert fallback_llm.config.model_name == default_provider.default_model_name
|
||||
assert (
|
||||
fallback_llm.config.model_name
|
||||
== default_provider.model_configurations[0].name
|
||||
)
|
||||
|
||||
|
||||
def test_list_llm_provider_basics_excludes_non_public_unrestricted(
|
||||
@@ -346,6 +444,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
|
||||
name="public-provider",
|
||||
is_public=True,
|
||||
set_as_default=True,
|
||||
default_model_name="gpt-4o",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
@@ -365,7 +464,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
|
||||
headers=basic_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
providers = response.json()
|
||||
providers = response.json()["providers"]
|
||||
provider_names = [p["name"] for p in providers]
|
||||
|
||||
# Public provider should be visible
|
||||
@@ -380,7 +479,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert admin_response.status_code == 200
|
||||
admin_providers = admin_response.json()
|
||||
admin_providers = admin_response.json()["providers"]
|
||||
admin_provider_names = [p["name"] for p in admin_providers]
|
||||
|
||||
assert public_provider.name in admin_provider_names
|
||||
@@ -396,6 +495,7 @@ def test_provider_delete_clears_persona_references(reset: None) -> None: # noqa
|
||||
name="default-provider",
|
||||
is_public=True,
|
||||
set_as_default=True,
|
||||
default_model_name="gpt-4o",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
|
||||
@@ -107,7 +107,7 @@ def test_authorized_persona_access_returns_filtered_providers(
|
||||
|
||||
# Should succeed
|
||||
assert response.status_code == 200
|
||||
providers = response.json()
|
||||
providers = response.json()["providers"]
|
||||
|
||||
# Should include the restricted provider since basic_user can access the persona
|
||||
provider_names = [p["name"] for p in providers]
|
||||
@@ -140,7 +140,7 @@ def test_persona_id_zero_applies_rbac(
|
||||
|
||||
# Should succeed (persona_id=0 refers to default persona, which is public)
|
||||
assert response.status_code == 200
|
||||
providers = response.json()
|
||||
providers = response.json()["providers"]
|
||||
|
||||
# Should NOT include the restricted provider since basic_user is not in group2
|
||||
provider_names = [p["name"] for p in providers]
|
||||
@@ -182,7 +182,7 @@ def test_admin_can_query_any_persona(
|
||||
|
||||
# Should succeed - admins can access any persona
|
||||
assert response.status_code == 200
|
||||
providers = response.json()
|
||||
providers = response.json()["providers"]
|
||||
|
||||
# Should include the restricted provider
|
||||
provider_names = [p["name"] for p in providers]
|
||||
@@ -223,7 +223,7 @@ def test_public_persona_accessible_to_all(
|
||||
|
||||
# Should succeed
|
||||
assert response.status_code == 200
|
||||
providers = response.json()
|
||||
providers = response.json()["providers"]
|
||||
|
||||
# Should return the public provider
|
||||
assert len(providers) > 0
|
||||
|
||||
@@ -25,7 +25,7 @@ def test_cold_startup_default_assistant() -> None:
|
||||
result = db_session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT id, name, builtin_persona, is_default_persona, deleted
|
||||
SELECT id, name, builtin_persona, featured, deleted
|
||||
FROM persona
|
||||
WHERE builtin_persona = true
|
||||
ORDER BY id
|
||||
@@ -40,7 +40,7 @@ def test_cold_startup_default_assistant() -> None:
|
||||
assert default[0] == 0, "Default assistant should have ID 0"
|
||||
assert default[1] == "Assistant", "Should be named 'Assistant'"
|
||||
assert default[2] is True, "Should be builtin"
|
||||
assert default[3] is True, "Should be default"
|
||||
assert default[3] is True, "Should be featured"
|
||||
assert default[4] is False, "Should not be deleted"
|
||||
|
||||
# Check tools are properly associated
|
||||
|
||||
@@ -195,11 +195,7 @@ def _base_persona_body(**overrides: object) -> dict:
|
||||
"description": "test",
|
||||
"system_prompt": "test",
|
||||
"task_prompt": "",
|
||||
"num_chunks": 5,
|
||||
"is_public": True,
|
||||
"recency_bias": "auto",
|
||||
"llm_filter_extraction": False,
|
||||
"llm_relevance_filter": False,
|
||||
"datetime_aware": False,
|
||||
"document_set_ids": [],
|
||||
"tool_ids": [],
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user