mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-03 22:55:46 +00:00
Compare commits
55 Commits
content-re
...
model_depr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
21ca9715dd | ||
|
|
3700bee34f | ||
|
|
55117fccf2 | ||
|
|
1de522f9ae | ||
|
|
60fe3e9ad6 | ||
|
|
6aa56821d6 | ||
|
|
eda436de01 | ||
|
|
07915a6c01 | ||
|
|
2c3e9aecd1 | ||
|
|
fa29cc3849 | ||
|
|
24ac8b37d3 | ||
|
|
be8b108ae4 | ||
|
|
f380a75df3 | ||
|
|
21ec93663b | ||
|
|
d789c74024 | ||
|
|
fe014776f7 | ||
|
|
700ca0e0fc | ||
|
|
a84f8238ec | ||
|
|
4fc802e19d | ||
|
|
6cfd49439a | ||
|
|
71a1faa47e | ||
|
|
1a65217baf | ||
|
|
30fa43b5fc | ||
|
|
28332fa24b | ||
|
|
1f5050f9f6 | ||
|
|
3c1d29d3cf | ||
|
|
709e3f4ca7 | ||
|
|
dfa27c08ef | ||
|
|
13d60dcb0e | ||
|
|
30704f427f | ||
|
|
4f3c54f282 | ||
|
|
580d41dc23 | ||
|
|
897e181d67 | ||
|
|
fd322a8a10 | ||
|
|
11c54bafb5 | ||
|
|
c93617df5d | ||
|
|
0cdd438f46 | ||
|
|
31aef36f78 | ||
|
|
0c35dfc0e4 | ||
|
|
a9769757fe | ||
|
|
15d8946f40 | ||
|
|
ba79539d6d | ||
|
|
59d3725fc6 | ||
|
|
9c05bd215d | ||
|
|
4d2aa09654 | ||
|
|
16c07c8756 | ||
|
|
3fb4f5d6e6 | ||
|
|
14fab7fcdf | ||
|
|
22a335fffa | ||
|
|
b0f7466eba | ||
|
|
b1d42726b1 | ||
|
|
7d922bffc1 | ||
|
|
de7fc36fc5 | ||
|
|
7f9e37450d | ||
|
|
c7ef85b733 |
@@ -54,6 +54,7 @@ runs:
|
||||
shell: bash
|
||||
env:
|
||||
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
|
||||
INTEGRATION_REPOSITORY: ${{ inputs.runs-on-ecr-cache }}
|
||||
TAG: nightly-llm-it-${{ inputs.run-id }}
|
||||
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
|
||||
HEAD_SHA: ${{ inputs.github-sha }}
|
||||
|
||||
26
.github/workflows/deployment.yml
vendored
26
.github/workflows/deployment.yml
vendored
@@ -426,8 +426,9 @@ jobs:
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64,mode=max
|
||||
@@ -499,8 +500,9 @@ jobs:
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64,mode=max
|
||||
@@ -646,8 +648,8 @@ jobs:
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64,mode=max
|
||||
@@ -728,8 +730,8 @@ jobs:
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64,mode=max
|
||||
@@ -862,8 +864,9 @@ jobs:
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64,mode=max
|
||||
@@ -934,8 +937,9 @@ jobs:
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64,mode=max
|
||||
@@ -1072,8 +1076,8 @@ jobs:
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
ENABLE_CRAFT=true
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64,mode=max
|
||||
@@ -1145,8 +1149,8 @@ jobs:
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
ENABLE_CRAFT=true
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64,mode=max
|
||||
@@ -1287,8 +1291,9 @@ jobs:
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64,mode=max
|
||||
@@ -1366,8 +1371,9 @@ jobs:
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64,mode=max
|
||||
|
||||
13
.github/workflows/nightly-llm-provider-chat.yml
vendored
13
.github/workflows/nightly-llm-provider-chat.yml
vendored
@@ -15,6 +15,9 @@ permissions:
|
||||
jobs:
|
||||
provider-chat-test:
|
||||
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
with:
|
||||
openai_models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
|
||||
anthropic_models: ${{ vars.NIGHTLY_LLM_ANTHROPIC_MODELS }}
|
||||
@@ -25,16 +28,6 @@ jobs:
|
||||
ollama_models: ${{ vars.NIGHTLY_LLM_OLLAMA_MODELS }}
|
||||
openrouter_models: ${{ vars.NIGHTLY_LLM_OPENROUTER_MODELS }}
|
||||
strict: true
|
||||
secrets:
|
||||
openai_api_key: ${{ secrets.OPENAI_API_KEY }}
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
bedrock_api_key: ${{ secrets.BEDROCK_API_KEY }}
|
||||
vertex_ai_custom_config_json: ${{ secrets.NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON }}
|
||||
azure_api_key: ${{ secrets.AZURE_API_KEY }}
|
||||
ollama_api_key: ${{ secrets.OLLAMA_API_KEY }}
|
||||
openrouter_api_key: ${{ secrets.OPENROUTER_API_KEY }}
|
||||
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs: [provider-chat-test]
|
||||
|
||||
@@ -160,7 +160,7 @@ jobs:
|
||||
cd deployment/docker_compose
|
||||
|
||||
# Get list of running containers
|
||||
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml -f docker-compose.opensearch.yml ps -q)
|
||||
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml ps -q)
|
||||
|
||||
# Collect logs from each container
|
||||
for container in $containers; do
|
||||
|
||||
21
.github/workflows/pr-python-checks.yml
vendored
21
.github/workflows/pr-python-checks.yml
vendored
@@ -8,7 +8,7 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
- "release/**"
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
@@ -21,7 +21,13 @@ jobs:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
# Note: Mypy seems quite optimized for x64 compared to arm64.
|
||||
# Similarly, mypy is single-threaded and incremental, so 2cpu is sufficient.
|
||||
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}-mypy-check", "extras=s3-cache"]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=2cpu-linux-x64,
|
||||
"run-id=${{ github.run_id }}-mypy-check",
|
||||
"extras=s3-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
|
||||
steps:
|
||||
@@ -52,21 +58,14 @@ jobs:
|
||||
if: ${{ vars.DISABLE_MYPY_CACHE != 'true' }}
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: backend/.mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
path: .mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'pyproject.toml') }}
|
||||
restore-keys: |
|
||||
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
|
||||
mypy-${{ runner.os }}-
|
||||
|
||||
- name: Run MyPy
|
||||
working-directory: ./backend
|
||||
env:
|
||||
MYPY_FORCE_COLOR: 1
|
||||
TERM: xterm-256color
|
||||
run: mypy .
|
||||
|
||||
- name: Run MyPy (tools/)
|
||||
env:
|
||||
MYPY_FORCE_COLOR: 1
|
||||
TERM: xterm-256color
|
||||
run: mypy tools/
|
||||
|
||||
@@ -48,28 +48,10 @@ on:
|
||||
required: false
|
||||
default: true
|
||||
type: boolean
|
||||
secrets:
|
||||
openai_api_key:
|
||||
required: false
|
||||
anthropic_api_key:
|
||||
required: false
|
||||
bedrock_api_key:
|
||||
required: false
|
||||
vertex_ai_custom_config_json:
|
||||
required: false
|
||||
azure_api_key:
|
||||
required: false
|
||||
ollama_api_key:
|
||||
required: false
|
||||
openrouter_api_key:
|
||||
required: false
|
||||
DOCKER_USERNAME:
|
||||
required: true
|
||||
DOCKER_TOKEN:
|
||||
required: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
|
||||
jobs:
|
||||
build-backend-image:
|
||||
@@ -81,6 +63,7 @@ jobs:
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
environment: ci-protected
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
@@ -89,6 +72,19 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, test/docker-username
|
||||
DOCKER_TOKEN, test/docker-token
|
||||
|
||||
- name: Build backend image
|
||||
uses: ./.github/actions/build-backend-image
|
||||
with:
|
||||
@@ -97,8 +93,8 @@ jobs:
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
docker-username: ${{ env.DOCKER_USERNAME }}
|
||||
docker-token: ${{ env.DOCKER_TOKEN }}
|
||||
docker-no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' && 'true' || 'false' }}
|
||||
|
||||
build-model-server-image:
|
||||
@@ -110,6 +106,7 @@ jobs:
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
environment: ci-protected
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
@@ -118,6 +115,19 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, test/docker-username
|
||||
DOCKER_TOKEN, test/docker-token
|
||||
|
||||
- name: Build model server image
|
||||
uses: ./.github/actions/build-model-server-image
|
||||
with:
|
||||
@@ -126,8 +136,8 @@ jobs:
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
docker-username: ${{ env.DOCKER_USERNAME }}
|
||||
docker-token: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
build-integration-image:
|
||||
runs-on:
|
||||
@@ -138,6 +148,7 @@ jobs:
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
environment: ci-protected
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
@@ -146,6 +157,19 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, test/docker-username
|
||||
DOCKER_TOKEN, test/docker-token
|
||||
|
||||
- name: Build integration image
|
||||
uses: ./.github/actions/build-integration-image
|
||||
with:
|
||||
@@ -154,8 +178,8 @@ jobs:
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
docker-username: ${{ env.DOCKER_USERNAME }}
|
||||
docker-token: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
provider-chat-test:
|
||||
needs:
|
||||
@@ -170,56 +194,56 @@ jobs:
|
||||
include:
|
||||
- provider: openai
|
||||
models: ${{ inputs.openai_models }}
|
||||
api_key_secret: openai_api_key
|
||||
custom_config_secret: ""
|
||||
api_key_env: OPENAI_API_KEY
|
||||
custom_config_env: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: true
|
||||
- provider: anthropic
|
||||
models: ${{ inputs.anthropic_models }}
|
||||
api_key_secret: anthropic_api_key
|
||||
custom_config_secret: ""
|
||||
api_key_env: ANTHROPIC_API_KEY
|
||||
custom_config_env: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: true
|
||||
- provider: bedrock
|
||||
models: ${{ inputs.bedrock_models }}
|
||||
api_key_secret: bedrock_api_key
|
||||
custom_config_secret: ""
|
||||
api_key_env: BEDROCK_API_KEY
|
||||
custom_config_env: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: vertex_ai
|
||||
models: ${{ inputs.vertex_ai_models }}
|
||||
api_key_secret: ""
|
||||
custom_config_secret: vertex_ai_custom_config_json
|
||||
api_key_env: ""
|
||||
custom_config_env: NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: azure
|
||||
models: ${{ inputs.azure_models }}
|
||||
api_key_secret: azure_api_key
|
||||
custom_config_secret: ""
|
||||
api_key_env: AZURE_API_KEY
|
||||
custom_config_env: ""
|
||||
api_base: ${{ inputs.azure_api_base }}
|
||||
api_version: "2025-04-01-preview"
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: ollama_chat
|
||||
models: ${{ inputs.ollama_models }}
|
||||
api_key_secret: ollama_api_key
|
||||
custom_config_secret: ""
|
||||
api_key_env: OLLAMA_API_KEY
|
||||
custom_config_env: ""
|
||||
api_base: "https://ollama.com"
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: openrouter
|
||||
models: ${{ inputs.openrouter_models }}
|
||||
api_key_secret: openrouter_api_key
|
||||
custom_config_secret: ""
|
||||
api_key_env: OPENROUTER_API_KEY
|
||||
custom_config_env: ""
|
||||
api_base: "https://openrouter.ai/api/v1"
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
@@ -230,6 +254,7 @@ jobs:
|
||||
- "run-id=${{ github.run_id }}-nightly-${{ matrix.provider }}-provider-chat-test"
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 45
|
||||
environment: ci-protected
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
@@ -238,21 +263,43 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
# Keep JSON values unparsed so vertex custom config is passed as raw JSON.
|
||||
parse-json-secrets: false
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, test/docker-username
|
||||
DOCKER_TOKEN, test/docker-token
|
||||
OPENAI_API_KEY, test/openai-api-key
|
||||
ANTHROPIC_API_KEY, test/anthropic-api-key
|
||||
BEDROCK_API_KEY, test/bedrock-api-key
|
||||
NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON, test/nightly-llm-vertex-ai-custom-config-json
|
||||
AZURE_API_KEY, test/azure-api-key
|
||||
OLLAMA_API_KEY, test/ollama-api-key
|
||||
OPENROUTER_API_KEY, test/openrouter-api-key
|
||||
|
||||
- name: Run nightly provider chat test
|
||||
uses: ./.github/actions/run-nightly-provider-chat-test
|
||||
with:
|
||||
provider: ${{ matrix.provider }}
|
||||
models: ${{ matrix.models }}
|
||||
provider-api-key: ${{ matrix.api_key_secret && secrets[matrix.api_key_secret] || '' }}
|
||||
provider-api-key: ${{ matrix.api_key_env && env[matrix.api_key_env] || '' }}
|
||||
strict: ${{ inputs.strict && 'true' || 'false' }}
|
||||
api-base: ${{ matrix.api_base }}
|
||||
api-version: ${{ matrix.api_version }}
|
||||
deployment-name: ${{ matrix.deployment_name }}
|
||||
custom-config-json: ${{ matrix.custom_config_secret && secrets[matrix.custom_config_secret] || '' }}
|
||||
custom-config-json: ${{ matrix.custom_config_env && env[matrix.custom_config_env] || '' }}
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
docker-username: ${{ env.DOCKER_USERNAME }}
|
||||
docker-token: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Dump API server logs
|
||||
if: always()
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
"""add cache_store table
|
||||
|
||||
Revision ID: 2664261bfaab
|
||||
Revises: 4a1e4b1c89d2
|
||||
Create Date: 2026-02-27 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2664261bfaab"
|
||||
down_revision = "4a1e4b1c89d2"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"cache_store",
|
||||
sa.Column("key", sa.String(), nullable=False),
|
||||
sa.Column("value", sa.LargeBinary(), nullable=True),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint("key"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_cache_store_expires",
|
||||
"cache_store",
|
||||
["expires_at"],
|
||||
postgresql_where=sa.text("expires_at IS NOT NULL"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_cache_store_expires", table_name="cache_store")
|
||||
op.drop_table("cache_store")
|
||||
@@ -0,0 +1,51 @@
|
||||
"""Add INDEXING to UserFileStatus
|
||||
|
||||
Revision ID: 4a1e4b1c89d2
|
||||
Revises: 6b3b4083c5aa
|
||||
Create Date: 2026-02-28 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = "4a1e4b1c89d2"
|
||||
down_revision = "6b3b4083c5aa"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
TABLE = "user_file"
|
||||
COLUMN = "status"
|
||||
CONSTRAINT_NAME = "ck_user_file_status"
|
||||
|
||||
OLD_VALUES = ("PROCESSING", "COMPLETED", "FAILED", "CANCELED", "DELETING")
|
||||
NEW_VALUES = ("PROCESSING", "INDEXING", "COMPLETED", "FAILED", "CANCELED", "DELETING")
|
||||
|
||||
|
||||
def _drop_status_check_constraint() -> None:
|
||||
"""Drop the existing CHECK constraint on user_file.status.
|
||||
|
||||
The constraint name is auto-generated by SQLAlchemy and unknown,
|
||||
so we look it up via the inspector.
|
||||
"""
|
||||
inspector = sa.inspect(op.get_bind())
|
||||
for constraint in inspector.get_check_constraints(TABLE):
|
||||
if COLUMN in constraint.get("sqltext", ""):
|
||||
constraint_name = constraint["name"]
|
||||
if constraint_name is not None:
|
||||
op.drop_constraint(constraint_name, TABLE, type_="check")
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
_drop_status_check_constraint()
|
||||
in_clause = ", ".join(f"'{v}'" for v in NEW_VALUES)
|
||||
op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
f"UPDATE {TABLE} SET {COLUMN} = 'PROCESSING' WHERE {COLUMN} = 'INDEXING'"
|
||||
)
|
||||
op.drop_constraint(CONSTRAINT_NAME, TABLE, type_="check")
|
||||
in_clause = ", ".join(f"'{v}'" for v in OLD_VALUES)
|
||||
op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})")
|
||||
@@ -0,0 +1,112 @@
|
||||
"""persona cleanup and featured
|
||||
|
||||
Revision ID: 6b3b4083c5aa
|
||||
Revises: 57122d037335
|
||||
Create Date: 2026-02-26 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "6b3b4083c5aa"
|
||||
down_revision = "57122d037335"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add featured column with nullable=True first
|
||||
op.add_column("persona", sa.Column("featured", sa.Boolean(), nullable=True))
|
||||
|
||||
# Migrate data from is_default_persona to featured
|
||||
op.execute("UPDATE persona SET featured = is_default_persona")
|
||||
|
||||
# Make featured non-nullable with default=False
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"featured",
|
||||
existing_type=sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
)
|
||||
|
||||
# Drop is_default_persona column
|
||||
op.drop_column("persona", "is_default_persona")
|
||||
|
||||
# Drop unused columns
|
||||
op.drop_column("persona", "num_chunks")
|
||||
op.drop_column("persona", "chunks_above")
|
||||
op.drop_column("persona", "chunks_below")
|
||||
op.drop_column("persona", "llm_relevance_filter")
|
||||
op.drop_column("persona", "llm_filter_extraction")
|
||||
op.drop_column("persona", "recency_bias")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Add back recency_bias column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"recency_bias",
|
||||
sa.VARCHAR(),
|
||||
nullable=False,
|
||||
server_default="base_decay",
|
||||
),
|
||||
)
|
||||
|
||||
# Add back llm_filter_extraction column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"llm_filter_extraction",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
# Add back llm_relevance_filter column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"llm_relevance_filter",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
# Add back chunks_below column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("chunks_below", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
|
||||
# Add back chunks_above column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("chunks_above", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
|
||||
# Add back num_chunks column
|
||||
op.add_column("persona", sa.Column("num_chunks", sa.Float(), nullable=True))
|
||||
|
||||
# Add back is_default_persona column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"is_default_persona",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
# Migrate data from featured to is_default_persona
|
||||
op.execute("UPDATE persona SET is_default_persona = featured")
|
||||
|
||||
# Drop featured column
|
||||
op.drop_column("persona", "featured")
|
||||
@@ -0,0 +1,34 @@
|
||||
"""make scim_user_mapping.external_id nullable
|
||||
|
||||
Revision ID: a3b8d9e2f1c4
|
||||
Revises: 2664261bfaab
|
||||
Create Date: 2026-03-02
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a3b8d9e2f1c4"
|
||||
down_revision = "2664261bfaab"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column(
|
||||
"scim_user_mapping",
|
||||
"external_id",
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Delete any rows where external_id is NULL before re-applying NOT NULL
|
||||
op.execute("DELETE FROM scim_user_mapping WHERE external_id IS NULL")
|
||||
op.alter_column(
|
||||
"scim_user_mapping",
|
||||
"external_id",
|
||||
nullable=False,
|
||||
)
|
||||
@@ -126,12 +126,16 @@ class ScimDAL(DAL):
|
||||
|
||||
def create_user_mapping(
|
||||
self,
|
||||
external_id: str,
|
||||
external_id: str | None,
|
||||
user_id: UUID,
|
||||
scim_username: str | None = None,
|
||||
fields: ScimMappingFields | None = None,
|
||||
) -> ScimUserMapping:
|
||||
"""Create a mapping between a SCIM externalId and an Onyx user."""
|
||||
"""Create a SCIM mapping for a user.
|
||||
|
||||
``external_id`` may be ``None`` when the IdP omits it (RFC 7643
|
||||
allows this). The mapping still marks the user as SCIM-managed.
|
||||
"""
|
||||
f = fields or ScimMappingFields()
|
||||
mapping = ScimUserMapping(
|
||||
external_id=external_id,
|
||||
@@ -270,8 +274,13 @@ class ScimDAL(DAL):
|
||||
Raises:
|
||||
ValueError: If the filter uses an unsupported attribute.
|
||||
"""
|
||||
query = select(User).where(
|
||||
User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER])
|
||||
# Inner-join with ScimUserMapping so only SCIM-managed users appear.
|
||||
# Pre-existing system accounts (anonymous, admin, etc.) are excluded
|
||||
# unless they were explicitly linked via SCIM provisioning.
|
||||
query = (
|
||||
select(User)
|
||||
.join(ScimUserMapping, ScimUserMapping.user_id == User.id)
|
||||
.where(User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER]))
|
||||
)
|
||||
|
||||
if scim_filter:
|
||||
@@ -321,34 +330,37 @@ class ScimDAL(DAL):
|
||||
scim_username: str | None = None,
|
||||
fields: ScimMappingFields | None = None,
|
||||
) -> None:
|
||||
"""Create, update, or delete the external ID mapping for a user.
|
||||
"""Sync the SCIM mapping for a user.
|
||||
|
||||
If a mapping already exists, its fields are updated (including
|
||||
setting ``external_id`` to ``None`` when the IdP omits it).
|
||||
If no mapping exists and ``new_external_id`` is provided, a new
|
||||
mapping is created. A mapping is never deleted here — SCIM-managed
|
||||
users must retain their mapping to remain visible in ``GET /Users``.
|
||||
|
||||
When *fields* is provided, all mapping fields are written
|
||||
unconditionally — including ``None`` values — so that a caller can
|
||||
clear a previously-set field (e.g. removing a department).
|
||||
"""
|
||||
mapping = self.get_user_mapping_by_user_id(user_id)
|
||||
if new_external_id:
|
||||
if mapping:
|
||||
if mapping.external_id != new_external_id:
|
||||
mapping.external_id = new_external_id
|
||||
if scim_username is not None:
|
||||
mapping.scim_username = scim_username
|
||||
if fields is not None:
|
||||
mapping.department = fields.department
|
||||
mapping.manager = fields.manager
|
||||
mapping.given_name = fields.given_name
|
||||
mapping.family_name = fields.family_name
|
||||
mapping.scim_emails_json = fields.scim_emails_json
|
||||
else:
|
||||
self.create_user_mapping(
|
||||
external_id=new_external_id,
|
||||
user_id=user_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
elif mapping:
|
||||
self.delete_user_mapping(mapping.id)
|
||||
if mapping:
|
||||
if mapping.external_id != new_external_id:
|
||||
mapping.external_id = new_external_id
|
||||
if scim_username is not None:
|
||||
mapping.scim_username = scim_username
|
||||
if fields is not None:
|
||||
mapping.department = fields.department
|
||||
mapping.manager = fields.manager
|
||||
mapping.given_name = fields.given_name
|
||||
mapping.family_name = fields.family_name
|
||||
mapping.scim_emails_json = fields.scim_emails_json
|
||||
elif new_external_id:
|
||||
self.create_user_mapping(
|
||||
external_id=new_external_id,
|
||||
user_id=user_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
def _get_user_mappings_batch(
|
||||
self, user_ids: list[UUID]
|
||||
|
||||
@@ -31,6 +31,7 @@ from ee.onyx.server.query_and_chat.query_backend import (
|
||||
from ee.onyx.server.query_and_chat.search_backend import router as search_router
|
||||
from ee.onyx.server.query_history.api import router as query_history_router
|
||||
from ee.onyx.server.reporting.usage_export_api import router as usage_export_router
|
||||
from ee.onyx.server.scim.api import register_scim_exception_handlers
|
||||
from ee.onyx.server.scim.api import scim_router
|
||||
from ee.onyx.server.seeding import seed_db
|
||||
from ee.onyx.server.tenants.api import router as tenants_router
|
||||
@@ -167,6 +168,7 @@ def get_application() -> FastAPI:
|
||||
# they use their own SCIM bearer token auth).
|
||||
# Not behind APP_API_PREFIX because IdPs expect /scim/v2/... directly.
|
||||
application.include_router(scim_router)
|
||||
register_scim_exception_handlers(application)
|
||||
|
||||
# Ensure all routes have auth enabled or are explicitly marked as public
|
||||
check_ee_router_auth(application)
|
||||
|
||||
@@ -15,7 +15,9 @@ from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import FastAPI
|
||||
from fastapi import Query
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi_users.password import PasswordHelper
|
||||
@@ -24,6 +26,7 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.scim import ScimDAL
|
||||
from ee.onyx.server.scim.auth import ScimAuthError
|
||||
from ee.onyx.server.scim.auth import verify_scim_token
|
||||
from ee.onyx.server.scim.filtering import parse_scim_filter
|
||||
from ee.onyx.server.scim.models import SCIM_LIST_RESPONSE_SCHEMA
|
||||
@@ -77,6 +80,22 @@ scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
|
||||
_pw_helper = PasswordHelper()
|
||||
|
||||
|
||||
def register_scim_exception_handlers(app: FastAPI) -> None:
|
||||
"""Register SCIM-specific exception handlers on the FastAPI app.
|
||||
|
||||
Call this after ``app.include_router(scim_router)`` so that auth
|
||||
failures from ``verify_scim_token`` return RFC 7644 §3.12 error
|
||||
envelopes (with ``schemas`` and ``status`` fields) instead of
|
||||
FastAPI's default ``{"detail": "..."}`` format.
|
||||
"""
|
||||
|
||||
@app.exception_handler(ScimAuthError)
|
||||
async def _handle_scim_auth_error(
|
||||
_request: Request, exc: ScimAuthError
|
||||
) -> ScimJSONResponse:
|
||||
return _scim_error_response(exc.status_code, exc.detail)
|
||||
|
||||
|
||||
def _get_provider(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
) -> ScimProvider:
|
||||
@@ -404,21 +423,63 @@ def create_user(
|
||||
|
||||
email = user_resource.userName.strip()
|
||||
|
||||
# externalId is how the IdP correlates this user on subsequent requests.
|
||||
# Without it, the IdP can't find the user and will try to re-create,
|
||||
# hitting a 409 conflict — so we require it up front.
|
||||
if not user_resource.externalId:
|
||||
return _scim_error_response(400, "externalId is required")
|
||||
# Check for existing user — if they exist but aren't SCIM-managed yet,
|
||||
# link them to the IdP rather than rejecting with 409.
|
||||
external_id: str | None = user_resource.externalId
|
||||
scim_username: str = user_resource.userName.strip()
|
||||
fields: ScimMappingFields = _fields_from_resource(user_resource)
|
||||
|
||||
# Enforce seat limit
|
||||
existing_user = dal.get_user_by_email(email)
|
||||
if existing_user:
|
||||
existing_mapping = dal.get_user_mapping_by_user_id(existing_user.id)
|
||||
if existing_mapping:
|
||||
return _scim_error_response(409, f"User with email {email} already exists")
|
||||
|
||||
# Adopt pre-existing user into SCIM management.
|
||||
# Reactivating a deactivated user consumes a seat, so enforce the
|
||||
# seat limit the same way replace_user does.
|
||||
if user_resource.active and not existing_user.is_active:
|
||||
seat_error = _check_seat_availability(dal)
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
|
||||
personal_name = _scim_name_to_str(user_resource.name)
|
||||
dal.update_user(
|
||||
existing_user,
|
||||
is_active=user_resource.active,
|
||||
**({"personal_name": personal_name} if personal_name else {}),
|
||||
)
|
||||
|
||||
try:
|
||||
dal.create_user_mapping(
|
||||
external_id=external_id,
|
||||
user_id=existing_user.id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
dal.commit()
|
||||
except IntegrityError:
|
||||
dal.rollback()
|
||||
return _scim_error_response(
|
||||
409, f"User with email {email} already has a SCIM mapping"
|
||||
)
|
||||
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
existing_user,
|
||||
external_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
),
|
||||
status_code=201,
|
||||
)
|
||||
|
||||
# Only enforce seat limit for net-new users — adopting a pre-existing
|
||||
# user doesn't consume a new seat.
|
||||
seat_error = _check_seat_availability(dal)
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
|
||||
# Check for existing user
|
||||
if dal.get_user_by_email(email):
|
||||
return _scim_error_response(409, f"User with email {email} already exists")
|
||||
|
||||
# Create user with a random password (SCIM users authenticate via IdP)
|
||||
personal_name = _scim_name_to_str(user_resource.name)
|
||||
user = User(
|
||||
@@ -436,18 +497,21 @@ def create_user(
|
||||
dal.rollback()
|
||||
return _scim_error_response(409, f"User with email {email} already exists")
|
||||
|
||||
# Create SCIM mapping (externalId is validated above, always present)
|
||||
external_id = user_resource.externalId
|
||||
scim_username = user_resource.userName.strip()
|
||||
fields = _fields_from_resource(user_resource)
|
||||
dal.create_user_mapping(
|
||||
external_id=external_id,
|
||||
user_id=user.id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
# Always create a SCIM mapping so that the user is marked as
|
||||
# SCIM-managed. externalId may be None (RFC 7643 says it's optional).
|
||||
try:
|
||||
dal.create_user_mapping(
|
||||
external_id=external_id,
|
||||
user_id=user.id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
dal.commit()
|
||||
except IntegrityError:
|
||||
dal.rollback()
|
||||
return _scim_error_response(
|
||||
409, f"User with email {email} already has a SCIM mapping"
|
||||
)
|
||||
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
|
||||
@@ -19,7 +19,6 @@ import hashlib
|
||||
import secrets
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -28,6 +27,21 @@ from onyx.auth.utils import get_hashed_bearer_token_from_request
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import ScimToken
|
||||
|
||||
|
||||
class ScimAuthError(Exception):
|
||||
"""Raised when SCIM bearer token authentication fails.
|
||||
|
||||
Unlike HTTPException, this carries the status and detail so the SCIM
|
||||
exception handler can wrap them in an RFC 7644 §3.12 error envelope
|
||||
with ``schemas`` and ``status`` fields.
|
||||
"""
|
||||
|
||||
def __init__(self, status_code: int, detail: str) -> None:
|
||||
self.status_code = status_code
|
||||
self.detail = detail
|
||||
super().__init__(detail)
|
||||
|
||||
|
||||
SCIM_TOKEN_PREFIX = "onyx_scim_"
|
||||
SCIM_TOKEN_LENGTH = 48
|
||||
|
||||
@@ -82,23 +96,14 @@ def verify_scim_token(
|
||||
"""
|
||||
hashed = _get_hashed_scim_token_from_request(request)
|
||||
if not hashed:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing or invalid SCIM bearer token",
|
||||
)
|
||||
raise ScimAuthError(401, "Missing or invalid SCIM bearer token")
|
||||
|
||||
token = dal.get_token_by_hash(hashed)
|
||||
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid SCIM bearer token",
|
||||
)
|
||||
raise ScimAuthError(401, "Invalid SCIM bearer token")
|
||||
|
||||
if not token.is_active:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="SCIM token has been revoked",
|
||||
)
|
||||
raise ScimAuthError(401, "SCIM token has been revoked")
|
||||
|
||||
return token
|
||||
|
||||
@@ -153,26 +153,31 @@ class ScimProvider(ABC):
|
||||
self,
|
||||
user: User,
|
||||
fields: ScimMappingFields,
|
||||
) -> ScimName | None:
|
||||
) -> ScimName:
|
||||
"""Build SCIM name components for the response.
|
||||
|
||||
Round-trips stored ``given_name``/``family_name`` when available (so
|
||||
the IdP gets back what it sent). Falls back to splitting
|
||||
``personal_name`` for users provisioned before we stored components.
|
||||
Always returns a ScimName — Okta's spec tests expect ``name``
|
||||
(with ``givenName``/``familyName``) on every user resource.
|
||||
Providers may override for custom behavior.
|
||||
"""
|
||||
if fields.given_name is not None or fields.family_name is not None:
|
||||
return ScimName(
|
||||
givenName=fields.given_name,
|
||||
familyName=fields.family_name,
|
||||
formatted=user.personal_name,
|
||||
givenName=fields.given_name or "",
|
||||
familyName=fields.family_name or "",
|
||||
formatted=user.personal_name or "",
|
||||
)
|
||||
if not user.personal_name:
|
||||
return None
|
||||
# Derive a reasonable name from the email so that SCIM spec tests
|
||||
# see non-empty givenName / familyName for every user resource.
|
||||
local = user.email.split("@")[0] if user.email else ""
|
||||
return ScimName(givenName=local, familyName="", formatted=local)
|
||||
parts = user.personal_name.split(" ", 1)
|
||||
return ScimName(
|
||||
givenName=parts[0],
|
||||
familyName=parts[1] if len(parts) > 1 else None,
|
||||
familyName=parts[1] if len(parts) > 1 else "",
|
||||
formatted=user.personal_name,
|
||||
)
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ from ee.onyx.server.enterprise_settings.store import (
|
||||
store_settings as store_ee_settings,
|
||||
)
|
||||
from ee.onyx.server.enterprise_settings.store import upload_logo
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
@@ -161,12 +160,6 @@ def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) ->
|
||||
user=None, # Seeding is done as admin
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=(
|
||||
persona.num_chunks if persona.num_chunks is not None else 0.0
|
||||
),
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
document_set_ids=persona.document_set_ids,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
@@ -178,6 +171,7 @@ def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) ->
|
||||
system_prompt=persona.system_prompt,
|
||||
task_prompt=persona.task_prompt,
|
||||
datetime_aware=persona.datetime_aware,
|
||||
featured=persona.featured,
|
||||
commit=False,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
@@ -543,7 +543,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
result = await db_session.execute(
|
||||
select(Persona.id)
|
||||
.where(
|
||||
Persona.is_default_persona.is_(True),
|
||||
Persona.featured.is_(True),
|
||||
Persona.is_public.is_(True),
|
||||
Persona.is_visible.is_(True),
|
||||
Persona.deleted.is_(False),
|
||||
@@ -725,11 +725,19 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
if user_by_session:
|
||||
user = user_by_session
|
||||
|
||||
# If the user is inactive, check seat availability before
|
||||
# upgrading role — otherwise they'd become an inactive BASIC
|
||||
# user who still can't log in.
|
||||
if not user.is_active:
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
enforce_seat_limit(sync_db)
|
||||
|
||||
await self.user_db.update(
|
||||
user,
|
||||
{
|
||||
"is_verified": is_verified_by_default,
|
||||
"role": UserRole.BASIC,
|
||||
**({"is_active": True} if not user.is_active else {}),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -241,8 +241,7 @@ _VECTOR_DB_BEAT_TASK_NAMES: set[str] = {
|
||||
"check-for-index-attempt-cleanup",
|
||||
"check-for-doc-permissions-sync",
|
||||
"check-for-external-group-sync",
|
||||
"check-for-documents-for-opensearch-migration",
|
||||
"migrate-documents-from-vespa-to-opensearch",
|
||||
"migrate-chunks-from-vespa-to-opensearch",
|
||||
}
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
|
||||
@@ -414,34 +414,31 @@ def _process_user_file_with_indexing(
|
||||
raise RuntimeError(f"Indexing pipeline failed for user file {user_file_id}")
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
def process_user_file_impl(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool
|
||||
) -> None:
|
||||
task_logger.info(f"process_single_user_file - Starting id={user_file_id}")
|
||||
"""Core implementation for processing a single user file.
|
||||
|
||||
When redis_locking=True, acquires a per-file Redis lock and clears the
|
||||
queued-key guard (Celery path). When redis_locking=False, skips all Redis
|
||||
operations (BackgroundTask path).
|
||||
"""
|
||||
task_logger.info(f"process_user_file_impl - Starting id={user_file_id}")
|
||||
start = time.monotonic()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Clear the "queued" guard set by the beat generator so that the next beat
|
||||
# cycle can re-enqueue this file if it is still in PROCESSING state after
|
||||
# this task completes or fails.
|
||||
redis_client.delete(_user_file_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
if not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"process_single_user_file - Lock held, skipping user_file_id={user_file_id}"
|
||||
file_lock: RedisLock | None = None
|
||||
if redis_locking:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client.delete(_user_file_queued_key(user_file_id))
|
||||
file_lock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
|
||||
)
|
||||
return None
|
||||
if file_lock is not None and not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"process_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
return
|
||||
|
||||
documents: list[Document] = []
|
||||
try:
|
||||
@@ -449,15 +446,18 @@ def process_single_user_file(
|
||||
uf = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if not uf:
|
||||
task_logger.warning(
|
||||
f"process_single_user_file - UserFile not found id={user_file_id}"
|
||||
f"process_user_file_impl - UserFile not found id={user_file_id}"
|
||||
)
|
||||
return None
|
||||
return
|
||||
|
||||
if uf.status != UserFileStatus.PROCESSING:
|
||||
if uf.status not in (
|
||||
UserFileStatus.PROCESSING,
|
||||
UserFileStatus.INDEXING,
|
||||
):
|
||||
task_logger.info(
|
||||
f"process_single_user_file - Skipping id={user_file_id} status={uf.status}"
|
||||
f"process_user_file_impl - Skipping id={user_file_id} status={uf.status}"
|
||||
)
|
||||
return None
|
||||
return
|
||||
|
||||
connector = LocalFileConnector(
|
||||
file_locations=[uf.file_id],
|
||||
@@ -471,7 +471,6 @@ def process_single_user_file(
|
||||
[doc for doc in batch if not isinstance(doc, HierarchyNode)]
|
||||
)
|
||||
|
||||
# update the document id to userfile id in the documents
|
||||
for document in documents:
|
||||
document.id = str(user_file_id)
|
||||
document.source = DocumentSource.USER_FILE
|
||||
@@ -493,9 +492,8 @@ def process_single_user_file(
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"process_single_user_file - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
f"process_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
# don't update the status if the user file is being deleted
|
||||
current_user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if (
|
||||
current_user_file
|
||||
@@ -504,33 +502,43 @@ def process_single_user_file(
|
||||
uf.status = UserFileStatus.FAILED
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
return None
|
||||
return
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"process_single_user_file - Finished id={user_file_id} docs={len(documents)} elapsed={elapsed:.2f}s"
|
||||
f"process_user_file_impl - Finished id={user_file_id} docs={len(documents)} elapsed={elapsed:.2f}s"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
# Attempt to mark the file as failed
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
uf = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if uf:
|
||||
# don't update the status if the user file is being deleted
|
||||
if uf.status != UserFileStatus.DELETING:
|
||||
uf.status = UserFileStatus.FAILED
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
|
||||
task_logger.exception(
|
||||
f"process_single_user_file - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
f"process_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
return None
|
||||
raise
|
||||
finally:
|
||||
if file_lock.owned():
|
||||
if file_lock is not None and file_lock.owned():
|
||||
file_lock.release()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
process_user_file_impl(
|
||||
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
|
||||
)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_USER_FILE_DELETE,
|
||||
soft_time_limit=300,
|
||||
@@ -581,36 +589,38 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file_delete(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
def delete_user_file_impl(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool
|
||||
) -> None:
|
||||
"""Process a single user file delete."""
|
||||
task_logger.info(f"process_single_user_file_delete - Starting id={user_file_id}")
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_delete_lock_key(user_file_id),
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
if not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"process_single_user_file_delete - Lock held, skipping user_file_id={user_file_id}"
|
||||
"""Core implementation for deleting a single user file.
|
||||
|
||||
When redis_locking=True, acquires a per-file Redis lock (Celery path).
|
||||
When redis_locking=False, skips Redis operations (BackgroundTask path).
|
||||
"""
|
||||
task_logger.info(f"delete_user_file_impl - Starting id={user_file_id}")
|
||||
|
||||
file_lock: RedisLock | None = None
|
||||
if redis_locking:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
file_lock = redis_client.lock(
|
||||
_user_file_delete_lock_key(user_file_id),
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
return None
|
||||
if file_lock is not None and not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"delete_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if not user_file:
|
||||
task_logger.info(
|
||||
f"process_single_user_file_delete - User file not found id={user_file_id}"
|
||||
f"delete_user_file_impl - User file not found id={user_file_id}"
|
||||
)
|
||||
return None
|
||||
return
|
||||
|
||||
# 1) Delete vector DB chunks (skip when disabled)
|
||||
if not DISABLE_VECTOR_DB:
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
@@ -648,7 +658,6 @@ def process_single_user_file_delete(
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
|
||||
# 2) Delete the user-uploaded file content from filestore (blob + metadata)
|
||||
file_store = get_default_file_store()
|
||||
try:
|
||||
file_store.delete_file(user_file.file_id)
|
||||
@@ -656,26 +665,34 @@ def process_single_user_file_delete(
|
||||
user_file_id_to_plaintext_file_name(user_file.id)
|
||||
)
|
||||
except Exception as e:
|
||||
# This block executed only if the file is not found in the filestore
|
||||
task_logger.exception(
|
||||
f"process_single_user_file_delete - Error deleting file id={user_file.id} - {e.__class__.__name__}"
|
||||
f"delete_user_file_impl - Error deleting file id={user_file.id} - {e.__class__.__name__}"
|
||||
)
|
||||
|
||||
# 3) Finally, delete the UserFile row
|
||||
db_session.delete(user_file)
|
||||
db_session.commit()
|
||||
task_logger.info(
|
||||
f"process_single_user_file_delete - Completed id={user_file_id}"
|
||||
)
|
||||
task_logger.info(f"delete_user_file_impl - Completed id={user_file_id}")
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"process_single_user_file_delete - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
f"delete_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
return None
|
||||
raise
|
||||
finally:
|
||||
if file_lock.owned():
|
||||
if file_lock is not None and file_lock.owned():
|
||||
file_lock.release()
|
||||
return None
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file_delete(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
delete_user_file_impl(
|
||||
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
|
||||
)
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -747,32 +764,30 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file_project_sync(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
def project_sync_user_file_impl(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool
|
||||
) -> None:
|
||||
"""Process a single user file project sync."""
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - Starting id={user_file_id}"
|
||||
)
|
||||
"""Core implementation for syncing a user file's project/persona metadata.
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
|
||||
When redis_locking=True, acquires a per-file Redis lock and clears the
|
||||
queued-key guard (Celery path). When redis_locking=False, skips Redis
|
||||
operations (BackgroundTask path).
|
||||
"""
|
||||
task_logger.info(f"project_sync_user_file_impl - Starting id={user_file_id}")
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
user_file_project_sync_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
if not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - Lock held, skipping user_file_id={user_file_id}"
|
||||
file_lock: RedisLock | None = None
|
||||
if redis_locking:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
|
||||
file_lock = redis_client.lock(
|
||||
user_file_project_sync_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
return None
|
||||
if file_lock is not None and not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"project_sync_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -783,11 +798,10 @@ def process_single_user_file_project_sync(
|
||||
).scalar_one_or_none()
|
||||
if not user_file:
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - User file not found id={user_file_id}"
|
||||
f"project_sync_user_file_impl - User file not found id={user_file_id}"
|
||||
)
|
||||
return None
|
||||
return
|
||||
|
||||
# Sync project metadata to vector DB (skip when disabled)
|
||||
if not DISABLE_VECTOR_DB:
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
@@ -822,7 +836,7 @@ def process_single_user_file_project_sync(
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - User file id={user_file_id}"
|
||||
f"project_sync_user_file_impl - User file id={user_file_id}"
|
||||
)
|
||||
|
||||
user_file.needs_project_sync = False
|
||||
@@ -835,11 +849,22 @@ def process_single_user_file_project_sync(
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"process_single_user_file_project_sync - Error syncing project for file id={user_file_id} - {e.__class__.__name__}"
|
||||
f"project_sync_user_file_impl - Error syncing project for file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
return None
|
||||
raise
|
||||
finally:
|
||||
if file_lock.owned():
|
||||
if file_lock is not None and file_lock.owned():
|
||||
file_lock.release()
|
||||
|
||||
return None
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file_project_sync(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
|
||||
)
|
||||
|
||||
307
backend/onyx/background/periodic_poller.py
Normal file
307
backend/onyx/background/periodic_poller.py
Normal file
@@ -0,0 +1,307 @@
|
||||
"""Periodic poller for NO_VECTOR_DB deployments.
|
||||
|
||||
Replaces Celery Beat and background workers with a lightweight daemon thread
|
||||
that runs from the API server process. Two responsibilities:
|
||||
|
||||
1. Recovery polling (every 30 s): re-processes user files stuck in
|
||||
PROCESSING / DELETING / needs_sync states via the drain loops defined
|
||||
in ``task_utils.py``.
|
||||
|
||||
2. Periodic task execution (configurable intervals): runs LLM model updates
|
||||
and scheduled evals at their configured cadences, with Postgres advisory
|
||||
lock deduplication across multiple API server instances.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
RECOVERY_INTERVAL_SECONDS = 30
|
||||
PERIODIC_TASK_LOCK_BASE = 20_000
|
||||
PERIODIC_TASK_KV_PREFIX = "periodic_poller:last_claimed:"
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Periodic task definitions
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
_NEVER_RAN: float = -1e18
|
||||
|
||||
|
||||
@dataclass
|
||||
class _PeriodicTaskDef:
|
||||
name: str
|
||||
interval_seconds: float
|
||||
lock_id: int
|
||||
run_fn: Callable[[], None]
|
||||
last_run_at: float = field(default=_NEVER_RAN)
|
||||
|
||||
|
||||
def _run_auto_llm_update() -> None:
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
|
||||
if not AUTO_LLM_CONFIG_URL:
|
||||
return
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.llm.well_known_providers.auto_update_service import (
|
||||
sync_llm_models_from_github,
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
sync_llm_models_from_github(db_session)
|
||||
|
||||
|
||||
def _run_cache_cleanup() -> None:
|
||||
from onyx.cache.postgres_backend import cleanup_expired_cache_entries
|
||||
|
||||
cleanup_expired_cache_entries()
|
||||
|
||||
|
||||
def _run_scheduled_eval() -> None:
|
||||
from onyx.configs.app_configs import BRAINTRUST_API_KEY
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_PERMISSIONS_EMAIL
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_PROJECT
|
||||
|
||||
if not all(
|
||||
[
|
||||
BRAINTRUST_API_KEY,
|
||||
SCHEDULED_EVAL_PROJECT,
|
||||
SCHEDULED_EVAL_DATASET_NAMES,
|
||||
SCHEDULED_EVAL_PERMISSIONS_EMAIL,
|
||||
]
|
||||
):
|
||||
return
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from onyx.evals.eval import run_eval
|
||||
from onyx.evals.models import EvalConfigurationOptions
|
||||
|
||||
run_timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
for dataset_name in SCHEDULED_EVAL_DATASET_NAMES:
|
||||
try:
|
||||
run_eval(
|
||||
configuration=EvalConfigurationOptions(
|
||||
search_permissions_email=SCHEDULED_EVAL_PERMISSIONS_EMAIL,
|
||||
dataset_name=dataset_name,
|
||||
no_send_logs=False,
|
||||
braintrust_project=SCHEDULED_EVAL_PROJECT,
|
||||
experiment_name=f"{dataset_name} - {run_timestamp}",
|
||||
),
|
||||
remote_dataset_name=dataset_name,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Periodic poller - Failed scheduled eval for dataset {dataset_name}"
|
||||
)
|
||||
|
||||
|
||||
_CACHE_CLEANUP_INTERVAL_SECONDS = 300
|
||||
|
||||
|
||||
def _build_periodic_tasks() -> list[_PeriodicTaskDef]:
|
||||
from onyx.cache.interface import CacheBackendType
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
|
||||
from onyx.configs.app_configs import CACHE_BACKEND
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
|
||||
tasks: list[_PeriodicTaskDef] = []
|
||||
if CACHE_BACKEND == CacheBackendType.POSTGRES:
|
||||
tasks.append(
|
||||
_PeriodicTaskDef(
|
||||
name="cache-cleanup",
|
||||
interval_seconds=_CACHE_CLEANUP_INTERVAL_SECONDS,
|
||||
lock_id=PERIODIC_TASK_LOCK_BASE + 2,
|
||||
run_fn=_run_cache_cleanup,
|
||||
)
|
||||
)
|
||||
if AUTO_LLM_CONFIG_URL:
|
||||
tasks.append(
|
||||
_PeriodicTaskDef(
|
||||
name="auto-llm-update",
|
||||
interval_seconds=AUTO_LLM_UPDATE_INTERVAL_SECONDS,
|
||||
lock_id=PERIODIC_TASK_LOCK_BASE,
|
||||
run_fn=_run_auto_llm_update,
|
||||
)
|
||||
)
|
||||
if SCHEDULED_EVAL_DATASET_NAMES:
|
||||
tasks.append(
|
||||
_PeriodicTaskDef(
|
||||
name="scheduled-eval",
|
||||
interval_seconds=7 * 24 * 3600,
|
||||
lock_id=PERIODIC_TASK_LOCK_BASE + 1,
|
||||
run_fn=_run_scheduled_eval,
|
||||
)
|
||||
)
|
||||
return tasks
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Periodic task runner with advisory-lock-guarded claim
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def _try_claim_task(task_def: _PeriodicTaskDef) -> bool:
|
||||
"""Atomically check whether *task_def* should run and record a claim.
|
||||
|
||||
Uses a transaction-scoped advisory lock for atomicity combined with a
|
||||
``KVStore`` timestamp for cross-instance dedup. The DB session is held
|
||||
only for this brief claim transaction, not during task execution.
|
||||
"""
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import KVStore
|
||||
|
||||
kv_key = PERIODIC_TASK_KV_PREFIX + task_def.name
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
acquired = db_session.execute(
|
||||
text("SELECT pg_try_advisory_xact_lock(:id)"),
|
||||
{"id": task_def.lock_id},
|
||||
).scalar()
|
||||
if not acquired:
|
||||
return False
|
||||
|
||||
row = db_session.query(KVStore).filter_by(key=kv_key).first()
|
||||
if row and row.value is not None:
|
||||
last_claimed = datetime.fromisoformat(str(row.value))
|
||||
elapsed = (datetime.now(timezone.utc) - last_claimed).total_seconds()
|
||||
if elapsed < task_def.interval_seconds:
|
||||
return False
|
||||
|
||||
now_ts = datetime.now(timezone.utc).isoformat()
|
||||
if row:
|
||||
row.value = now_ts
|
||||
else:
|
||||
db_session.add(KVStore(key=kv_key, value=now_ts))
|
||||
db_session.commit()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _try_run_periodic_task(task_def: _PeriodicTaskDef) -> None:
|
||||
"""Run *task_def* if its interval has elapsed and no peer holds the lock."""
|
||||
now = time.monotonic()
|
||||
if now - task_def.last_run_at < task_def.interval_seconds:
|
||||
return
|
||||
|
||||
if not _try_claim_task(task_def):
|
||||
return
|
||||
|
||||
try:
|
||||
task_def.run_fn()
|
||||
task_def.last_run_at = now
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Periodic poller - Error running periodic task {task_def.name}"
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Recovery / drain loop runner
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def _run_drain_loops(tenant_id: str) -> None:
|
||||
from onyx.background.task_utils import drain_delete_loop
|
||||
from onyx.background.task_utils import drain_processing_loop
|
||||
from onyx.background.task_utils import drain_project_sync_loop
|
||||
|
||||
drain_processing_loop(tenant_id)
|
||||
drain_delete_loop(tenant_id)
|
||||
drain_project_sync_loop(tenant_id)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Startup recovery (10g)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def recover_stuck_user_files(tenant_id: str) -> None:
|
||||
"""Run all drain loops once to re-process files left in intermediate states.
|
||||
|
||||
Called from ``lifespan()`` on startup when ``DISABLE_VECTOR_DB`` is set.
|
||||
"""
|
||||
logger.info("recover_stuck_user_files - Checking for stuck user files")
|
||||
try:
|
||||
_run_drain_loops(tenant_id)
|
||||
except Exception:
|
||||
logger.exception("recover_stuck_user_files - Error during recovery")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Daemon thread (10f)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
_shutdown_event = threading.Event()
|
||||
_poller_thread: threading.Thread | None = None
|
||||
|
||||
|
||||
def _poller_loop(tenant_id: str) -> None:
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
periodic_tasks = _build_periodic_tasks()
|
||||
logger.info(
|
||||
f"Periodic poller started with {len(periodic_tasks)} periodic task(s): "
|
||||
f"{[t.name for t in periodic_tasks]}"
|
||||
)
|
||||
|
||||
while not _shutdown_event.is_set():
|
||||
try:
|
||||
_run_drain_loops(tenant_id)
|
||||
except Exception:
|
||||
logger.exception("Periodic poller - Error in recovery polling")
|
||||
|
||||
for task_def in periodic_tasks:
|
||||
try:
|
||||
_try_run_periodic_task(task_def)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Periodic poller - Unhandled error checking task {task_def.name}"
|
||||
)
|
||||
|
||||
_shutdown_event.wait(RECOVERY_INTERVAL_SECONDS)
|
||||
|
||||
|
||||
def start_periodic_poller(tenant_id: str) -> None:
|
||||
"""Start the periodic poller daemon thread."""
|
||||
global _poller_thread # noqa: PLW0603
|
||||
_shutdown_event.clear()
|
||||
_poller_thread = threading.Thread(
|
||||
target=_poller_loop,
|
||||
args=(tenant_id,),
|
||||
daemon=True,
|
||||
name="no-vectordb-periodic-poller",
|
||||
)
|
||||
_poller_thread.start()
|
||||
logger.info("Periodic poller thread started")
|
||||
|
||||
|
||||
def stop_periodic_poller() -> None:
|
||||
"""Signal the periodic poller to stop and wait for it to exit."""
|
||||
global _poller_thread # noqa: PLW0603
|
||||
if _poller_thread is None:
|
||||
return
|
||||
_shutdown_event.set()
|
||||
_poller_thread.join(timeout=10)
|
||||
if _poller_thread.is_alive():
|
||||
logger.warning("Periodic poller thread did not stop within timeout")
|
||||
_poller_thread = None
|
||||
logger.info("Periodic poller thread stopped")
|
||||
@@ -1,3 +1,33 @@
|
||||
"""Background task utilities.
|
||||
|
||||
Contains query-history report helpers (used by all deployment modes) and
|
||||
in-process background task execution helpers for NO_VECTOR_DB mode:
|
||||
|
||||
- Atomic claim-and-mark helpers that prevent duplicate processing
|
||||
- Drain loops that process all pending user file work
|
||||
|
||||
Each claim function runs a short-lived transaction: SELECT ... FOR UPDATE
|
||||
SKIP LOCKED, UPDATE the row to remove it from future queries, COMMIT.
|
||||
After the commit the row lock is released, but the row is no longer
|
||||
eligible for re-claiming. No long-lived sessions or advisory locks.
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Query-history report helpers (pre-existing, used by all modes)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
QUERY_REPORT_NAME_PREFIX = "query-history"
|
||||
|
||||
|
||||
@@ -9,3 +39,168 @@ def construct_query_history_report_name(
|
||||
|
||||
def extract_task_id_from_query_history_report_name(name: str) -> str:
|
||||
return name.removeprefix(f"{QUERY_REPORT_NAME_PREFIX}-").removesuffix(".csv")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Atomic claim-and-mark helpers
|
||||
# ------------------------------------------------------------------
|
||||
# Each function runs inside a single short-lived session/transaction:
|
||||
# 1. SELECT ... FOR UPDATE SKIP LOCKED (locks one eligible row)
|
||||
# 2. UPDATE the row so it is no longer eligible
|
||||
# 3. COMMIT (releases the row lock)
|
||||
# After the commit, no other drain loop can claim the same row.
|
||||
|
||||
|
||||
def _claim_next_processing_file(db_session: Session) -> UUID | None:
|
||||
"""Claim the next PROCESSING file by transitioning it to INDEXING.
|
||||
|
||||
Returns the file id, or None when no eligible files remain.
|
||||
"""
|
||||
file_id = db_session.execute(
|
||||
select(UserFile.id)
|
||||
.where(UserFile.status == UserFileStatus.PROCESSING)
|
||||
.order_by(UserFile.created_at)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
).scalar_one_or_none()
|
||||
if file_id is None:
|
||||
return None
|
||||
|
||||
db_session.execute(
|
||||
sa.update(UserFile)
|
||||
.where(UserFile.id == file_id)
|
||||
.values(status=UserFileStatus.INDEXING)
|
||||
)
|
||||
db_session.commit()
|
||||
return file_id
|
||||
|
||||
|
||||
def _claim_next_deleting_file(
|
||||
db_session: Session,
|
||||
exclude_ids: set[UUID] | None = None,
|
||||
) -> UUID | None:
|
||||
"""Claim the next DELETING file.
|
||||
|
||||
No status transition needed — the impl deletes the row on success.
|
||||
The short-lived FOR UPDATE lock prevents concurrent claims.
|
||||
*exclude_ids* prevents re-processing the same file if the impl fails.
|
||||
"""
|
||||
stmt = (
|
||||
select(UserFile.id)
|
||||
.where(UserFile.status == UserFileStatus.DELETING)
|
||||
.order_by(UserFile.created_at)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
)
|
||||
if exclude_ids:
|
||||
stmt = stmt.where(UserFile.id.notin_(exclude_ids))
|
||||
file_id = db_session.execute(stmt).scalar_one_or_none()
|
||||
db_session.commit()
|
||||
return file_id
|
||||
|
||||
|
||||
def _claim_next_sync_file(
|
||||
db_session: Session,
|
||||
exclude_ids: set[UUID] | None = None,
|
||||
) -> UUID | None:
|
||||
"""Claim the next file needing project/persona sync.
|
||||
|
||||
No status transition needed — the impl clears the sync flags on
|
||||
success. The short-lived FOR UPDATE lock prevents concurrent claims.
|
||||
*exclude_ids* prevents re-processing the same file if the impl fails.
|
||||
"""
|
||||
stmt = (
|
||||
select(UserFile.id)
|
||||
.where(
|
||||
sa.and_(
|
||||
sa.or_(
|
||||
UserFile.needs_project_sync.is_(True),
|
||||
UserFile.needs_persona_sync.is_(True),
|
||||
),
|
||||
UserFile.status == UserFileStatus.COMPLETED,
|
||||
)
|
||||
)
|
||||
.order_by(UserFile.created_at)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
)
|
||||
if exclude_ids:
|
||||
stmt = stmt.where(UserFile.id.notin_(exclude_ids))
|
||||
file_id = db_session.execute(stmt).scalar_one_or_none()
|
||||
db_session.commit()
|
||||
return file_id
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Drain loops — process *all* pending work of each type
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def drain_processing_loop(tenant_id: str) -> None:
|
||||
"""Process all pending PROCESSING user files."""
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
process_user_file_impl,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
while True:
|
||||
with get_session_with_current_tenant() as session:
|
||||
file_id = _claim_next_processing_file(session)
|
||||
if file_id is None:
|
||||
break
|
||||
try:
|
||||
process_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to process user file {file_id}")
|
||||
|
||||
|
||||
def drain_delete_loop(tenant_id: str) -> None:
|
||||
"""Delete all pending DELETING user files."""
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
delete_user_file_impl,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
failed: set[UUID] = set()
|
||||
while True:
|
||||
with get_session_with_current_tenant() as session:
|
||||
file_id = _claim_next_deleting_file(session, exclude_ids=failed)
|
||||
if file_id is None:
|
||||
break
|
||||
try:
|
||||
delete_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to delete user file {file_id}")
|
||||
failed.add(file_id)
|
||||
|
||||
|
||||
def drain_project_sync_loop(tenant_id: str) -> None:
|
||||
"""Sync all pending project/persona metadata for user files."""
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
project_sync_user_file_impl,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
failed: set[UUID] = set()
|
||||
while True:
|
||||
with get_session_with_current_tenant() as session:
|
||||
file_id = _claim_next_sync_file(session, exclude_ids=failed)
|
||||
if file_id is None:
|
||||
break
|
||||
try:
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to sync user file {file_id}")
|
||||
failed.add(file_id)
|
||||
|
||||
51
backend/onyx/cache/factory.py
vendored
Normal file
51
backend/onyx/cache/factory.py
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.interface import CacheBackendType
|
||||
from onyx.configs.app_configs import CACHE_BACKEND
|
||||
|
||||
|
||||
def _build_redis_backend(tenant_id: str) -> CacheBackend:
|
||||
from onyx.cache.redis_backend import RedisCacheBackend
|
||||
from onyx.redis.redis_pool import redis_pool
|
||||
|
||||
return RedisCacheBackend(redis_pool.get_client(tenant_id))
|
||||
|
||||
|
||||
def _build_postgres_backend(tenant_id: str) -> CacheBackend:
|
||||
from onyx.cache.postgres_backend import PostgresCacheBackend
|
||||
|
||||
return PostgresCacheBackend(tenant_id)
|
||||
|
||||
|
||||
_BACKEND_BUILDERS: dict[CacheBackendType, Callable[[str], CacheBackend]] = {
|
||||
CacheBackendType.REDIS: _build_redis_backend,
|
||||
CacheBackendType.POSTGRES: _build_postgres_backend,
|
||||
}
|
||||
|
||||
|
||||
def get_cache_backend(*, tenant_id: str | None = None) -> CacheBackend:
|
||||
"""Return a tenant-aware ``CacheBackend``.
|
||||
|
||||
If *tenant_id* is ``None``, the current tenant is read from the
|
||||
thread-local context variable (same behaviour as ``get_redis_client``).
|
||||
"""
|
||||
if tenant_id is None:
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
builder = _BACKEND_BUILDERS.get(CACHE_BACKEND)
|
||||
if builder is None:
|
||||
raise ValueError(
|
||||
f"Unsupported CACHE_BACKEND={CACHE_BACKEND!r}. "
|
||||
f"Supported values: {[t.value for t in CacheBackendType]}"
|
||||
)
|
||||
return builder(tenant_id)
|
||||
|
||||
|
||||
def get_shared_cache_backend() -> CacheBackend:
|
||||
"""Return a ``CacheBackend`` in the shared (cross-tenant) namespace."""
|
||||
from shared_configs.configs import DEFAULT_REDIS_PREFIX
|
||||
|
||||
return get_cache_backend(tenant_id=DEFAULT_REDIS_PREFIX)
|
||||
104
backend/onyx/cache/interface.py
vendored
Normal file
104
backend/onyx/cache/interface.py
vendored
Normal file
@@ -0,0 +1,104 @@
|
||||
import abc
|
||||
from enum import Enum
|
||||
|
||||
TTL_KEY_NOT_FOUND = -2
|
||||
TTL_NO_EXPIRY = -1
|
||||
|
||||
|
||||
class CacheBackendType(str, Enum):
|
||||
REDIS = "redis"
|
||||
POSTGRES = "postgres"
|
||||
|
||||
|
||||
class CacheLock(abc.ABC):
|
||||
"""Abstract distributed lock returned by CacheBackend.lock()."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def acquire(
|
||||
self,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: float | None = None,
|
||||
) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def release(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def owned(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def __enter__(self) -> "CacheLock":
|
||||
if not self.acquire():
|
||||
raise RuntimeError("Failed to acquire lock")
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: object) -> None:
|
||||
self.release()
|
||||
|
||||
|
||||
class CacheBackend(abc.ABC):
|
||||
"""Thin abstraction over a key-value cache with TTL, locks, and blocking lists.
|
||||
|
||||
Covers the subset of Redis operations used outside of Celery. When
|
||||
CACHE_BACKEND=postgres, a PostgreSQL-backed implementation is used instead.
|
||||
"""
|
||||
|
||||
# -- basic key/value ---------------------------------------------------
|
||||
|
||||
@abc.abstractmethod
|
||||
def get(self, key: str) -> bytes | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str | bytes | int | float,
|
||||
ex: int | None = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete(self, key: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def exists(self, key: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
# -- TTL ---------------------------------------------------------------
|
||||
|
||||
@abc.abstractmethod
|
||||
def expire(self, key: str, seconds: int) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def ttl(self, key: str) -> int:
|
||||
"""Return remaining TTL in seconds.
|
||||
|
||||
Returns ``TTL_NO_EXPIRY`` (-1) if key exists without expiry,
|
||||
``TTL_KEY_NOT_FOUND`` (-2) if key is missing or expired.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# -- distributed lock --------------------------------------------------
|
||||
|
||||
@abc.abstractmethod
|
||||
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
|
||||
raise NotImplementedError
|
||||
|
||||
# -- blocking list (used by MCP OAuth BLPOP pattern) -------------------
|
||||
|
||||
@abc.abstractmethod
|
||||
def rpush(self, key: str, value: str | bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
|
||||
"""Block until a value is available on one of *keys*, or *timeout* expires.
|
||||
|
||||
Returns ``(key, value)`` or ``None`` on timeout.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
323
backend/onyx/cache/postgres_backend.py
vendored
Normal file
323
backend/onyx/cache/postgres_backend.py
vendored
Normal file
@@ -0,0 +1,323 @@
|
||||
"""PostgreSQL-backed ``CacheBackend`` for NO_VECTOR_DB deployments.
|
||||
|
||||
Uses the ``cache_store`` table for key-value storage, PostgreSQL advisory locks
|
||||
for distributed locking, and a polling loop for the BLPOP pattern.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import struct
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import AbstractContextManager
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.interface import CacheLock
|
||||
from onyx.cache.interface import TTL_KEY_NOT_FOUND
|
||||
from onyx.cache.interface import TTL_NO_EXPIRY
|
||||
from onyx.db.models import CacheStore
|
||||
|
||||
_LIST_KEY_PREFIX = "_q:"
|
||||
# ASCII: ':' (0x3A) < ';' (0x3B). Upper bound for range queries so [prefix+, prefix;)
|
||||
# captures all list-item keys (e.g. _q:mylist:123:uuid) without including other
|
||||
# lists whose names share a prefix (e.g. _q:mylist2:...).
|
||||
_LIST_KEY_RANGE_TERMINATOR = ";"
|
||||
_LIST_ITEM_TTL_SECONDS = 3600
|
||||
_LOCK_POLL_INTERVAL = 0.1
|
||||
_BLPOP_POLL_INTERVAL = 0.25
|
||||
|
||||
|
||||
def _list_item_key(key: str) -> str:
|
||||
"""Unique key for a list item. Timestamp for FIFO ordering; UUID prevents
|
||||
collision when concurrent rpush calls occur within the same nanosecond.
|
||||
"""
|
||||
return f"{_LIST_KEY_PREFIX}{key}:{time.time_ns()}:{uuid.uuid4().hex}"
|
||||
|
||||
|
||||
def _to_bytes(value: str | bytes | int | float) -> bytes:
|
||||
if isinstance(value, bytes):
|
||||
return value
|
||||
return str(value).encode()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lock
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class PostgresCacheLock(CacheLock):
|
||||
"""Advisory-lock-based distributed lock.
|
||||
|
||||
Uses ``get_session_with_tenant`` for connection lifecycle. The lock is tied
|
||||
to the session's connection; releasing or closing the session frees it.
|
||||
|
||||
NOTE: Unlike Redis locks, advisory locks do not auto-expire after
|
||||
``timeout`` seconds. They are released when ``release()`` is
|
||||
called or when the session is closed.
|
||||
"""
|
||||
|
||||
def __init__(self, lock_id: int, timeout: float | None, tenant_id: str) -> None:
|
||||
self._lock_id = lock_id
|
||||
self._timeout = timeout
|
||||
self._tenant_id = tenant_id
|
||||
self._session_cm: AbstractContextManager[Session] | None = None
|
||||
self._session: Session | None = None
|
||||
self._acquired = False
|
||||
|
||||
def acquire(
|
||||
self,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: float | None = None,
|
||||
) -> bool:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
self._session_cm = get_session_with_tenant(tenant_id=self._tenant_id)
|
||||
self._session = self._session_cm.__enter__()
|
||||
try:
|
||||
if not blocking:
|
||||
return self._try_lock()
|
||||
|
||||
effective_timeout = blocking_timeout or self._timeout
|
||||
deadline = (
|
||||
(time.monotonic() + effective_timeout) if effective_timeout else None
|
||||
)
|
||||
while True:
|
||||
if self._try_lock():
|
||||
return True
|
||||
if deadline is not None and time.monotonic() >= deadline:
|
||||
return False
|
||||
time.sleep(_LOCK_POLL_INTERVAL)
|
||||
finally:
|
||||
if not self._acquired:
|
||||
self._close_session()
|
||||
|
||||
def release(self) -> None:
|
||||
if not self._acquired or self._session is None:
|
||||
return
|
||||
try:
|
||||
self._session.execute(select(func.pg_advisory_unlock(self._lock_id)))
|
||||
finally:
|
||||
self._acquired = False
|
||||
self._close_session()
|
||||
|
||||
def owned(self) -> bool:
|
||||
return self._acquired
|
||||
|
||||
def _close_session(self) -> None:
|
||||
if self._session_cm is not None:
|
||||
try:
|
||||
self._session_cm.__exit__(None, None, None)
|
||||
finally:
|
||||
self._session_cm = None
|
||||
self._session = None
|
||||
|
||||
def _try_lock(self) -> bool:
|
||||
assert self._session is not None
|
||||
result = self._session.execute(
|
||||
select(func.pg_try_advisory_lock(self._lock_id))
|
||||
).scalar()
|
||||
if result:
|
||||
self._acquired = True
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Backend
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class PostgresCacheBackend(CacheBackend):
|
||||
"""``CacheBackend`` backed by the ``cache_store`` table in PostgreSQL.
|
||||
|
||||
Each operation opens and closes its own database session so the backend
|
||||
is safe to share across threads. Tenant isolation is handled by
|
||||
SQLAlchemy's ``schema_translate_map`` (set by ``get_session_with_tenant``).
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_id: str) -> None:
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
# -- basic key/value ---------------------------------------------------
|
||||
|
||||
def get(self, key: str) -> bytes | None:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
stmt = select(CacheStore.value).where(
|
||||
CacheStore.key == key,
|
||||
or_(CacheStore.expires_at.is_(None), CacheStore.expires_at > func.now()),
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
value = session.execute(stmt).scalar_one_or_none()
|
||||
if value is None:
|
||||
return None
|
||||
return bytes(value)
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str | bytes | int | float,
|
||||
ex: int | None = None,
|
||||
) -> None:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
value_bytes = _to_bytes(value)
|
||||
expires_at = (
|
||||
datetime.now(timezone.utc) + timedelta(seconds=ex)
|
||||
if ex is not None
|
||||
else None
|
||||
)
|
||||
stmt = (
|
||||
pg_insert(CacheStore)
|
||||
.values(key=key, value=value_bytes, expires_at=expires_at)
|
||||
.on_conflict_do_update(
|
||||
index_elements=[CacheStore.key],
|
||||
set_={"value": value_bytes, "expires_at": expires_at},
|
||||
)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
session.execute(delete(CacheStore).where(CacheStore.key == key))
|
||||
session.commit()
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
stmt = (
|
||||
select(CacheStore.key)
|
||||
.where(
|
||||
CacheStore.key == key,
|
||||
or_(
|
||||
CacheStore.expires_at.is_(None),
|
||||
CacheStore.expires_at > func.now(),
|
||||
),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
return session.execute(stmt).first() is not None
|
||||
|
||||
# -- TTL ---------------------------------------------------------------
|
||||
|
||||
def expire(self, key: str, seconds: int) -> None:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
new_exp = datetime.now(timezone.utc) + timedelta(seconds=seconds)
|
||||
stmt = (
|
||||
update(CacheStore).where(CacheStore.key == key).values(expires_at=new_exp)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
def ttl(self, key: str) -> int:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
stmt = select(CacheStore.expires_at).where(CacheStore.key == key)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
result = session.execute(stmt).first()
|
||||
if result is None:
|
||||
return TTL_KEY_NOT_FOUND
|
||||
expires_at: datetime | None = result[0]
|
||||
if expires_at is None:
|
||||
return TTL_NO_EXPIRY
|
||||
remaining = (expires_at - datetime.now(timezone.utc)).total_seconds()
|
||||
if remaining <= 0:
|
||||
return TTL_KEY_NOT_FOUND
|
||||
return int(remaining)
|
||||
|
||||
# -- distributed lock --------------------------------------------------
|
||||
|
||||
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
|
||||
return PostgresCacheLock(
|
||||
self._lock_id_for(name), timeout, tenant_id=self._tenant_id
|
||||
)
|
||||
|
||||
# -- blocking list (MCP OAuth BLPOP pattern) ---------------------------
|
||||
|
||||
def rpush(self, key: str, value: str | bytes) -> None:
|
||||
self.set(_list_item_key(key), value, ex=_LIST_ITEM_TTL_SECONDS)
|
||||
|
||||
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
|
||||
if timeout <= 0:
|
||||
raise ValueError(
|
||||
"PostgresCacheBackend.blpop requires timeout > 0. "
|
||||
"timeout=0 would block the calling thread indefinitely "
|
||||
"with no way to interrupt short of process termination."
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
deadline = time.monotonic() + timeout
|
||||
while True:
|
||||
for key in keys:
|
||||
lower = f"{_LIST_KEY_PREFIX}{key}:"
|
||||
upper = f"{_LIST_KEY_PREFIX}{key}{_LIST_KEY_RANGE_TERMINATOR}"
|
||||
stmt = (
|
||||
select(CacheStore)
|
||||
.where(
|
||||
CacheStore.key >= lower,
|
||||
CacheStore.key < upper,
|
||||
or_(
|
||||
CacheStore.expires_at.is_(None),
|
||||
CacheStore.expires_at > func.now(),
|
||||
),
|
||||
)
|
||||
.order_by(CacheStore.key)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
row = session.execute(stmt).scalars().first()
|
||||
if row is not None:
|
||||
value = bytes(row.value) if row.value else b""
|
||||
session.delete(row)
|
||||
session.commit()
|
||||
return (key.encode(), value)
|
||||
if time.monotonic() >= deadline:
|
||||
return None
|
||||
time.sleep(_BLPOP_POLL_INTERVAL)
|
||||
|
||||
# -- helpers -----------------------------------------------------------
|
||||
|
||||
def _lock_id_for(self, name: str) -> int:
|
||||
"""Map *name* to a 64-bit signed int for ``pg_advisory_lock``."""
|
||||
h = hashlib.md5(f"{self._tenant_id}:{name}".encode()).digest()
|
||||
return struct.unpack("q", h[:8])[0]
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Periodic cleanup
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def cleanup_expired_cache_entries() -> None:
|
||||
"""Delete rows whose ``expires_at`` is in the past.
|
||||
|
||||
Called by the periodic poller every 5 minutes.
|
||||
"""
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as session:
|
||||
session.execute(
|
||||
delete(CacheStore).where(
|
||||
CacheStore.expires_at.is_not(None),
|
||||
CacheStore.expires_at < func.now(),
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
92
backend/onyx/cache/redis_backend.py
vendored
Normal file
92
backend/onyx/cache/redis_backend.py
vendored
Normal file
@@ -0,0 +1,92 @@
|
||||
from typing import cast
|
||||
|
||||
from redis.client import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.interface import CacheLock
|
||||
|
||||
|
||||
class RedisCacheLock(CacheLock):
|
||||
"""Wraps ``redis.lock.Lock`` behind the ``CacheLock`` interface."""
|
||||
|
||||
def __init__(self, lock: RedisLock) -> None:
|
||||
self._lock = lock
|
||||
|
||||
def acquire(
|
||||
self,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: float | None = None,
|
||||
) -> bool:
|
||||
return bool(
|
||||
self._lock.acquire(
|
||||
blocking=blocking,
|
||||
blocking_timeout=blocking_timeout,
|
||||
)
|
||||
)
|
||||
|
||||
def release(self) -> None:
|
||||
self._lock.release()
|
||||
|
||||
def owned(self) -> bool:
|
||||
return bool(self._lock.owned())
|
||||
|
||||
|
||||
class RedisCacheBackend(CacheBackend):
|
||||
"""``CacheBackend`` implementation that delegates to a ``redis.Redis`` client.
|
||||
|
||||
This is a thin pass-through — every method maps 1-to-1 to the underlying
|
||||
Redis command. ``TenantRedis`` key-prefixing is handled by the client
|
||||
itself (provided by ``get_redis_client``).
|
||||
"""
|
||||
|
||||
def __init__(self, redis_client: Redis) -> None:
|
||||
self._r = redis_client
|
||||
|
||||
# -- basic key/value ---------------------------------------------------
|
||||
|
||||
def get(self, key: str) -> bytes | None:
|
||||
val = self._r.get(key)
|
||||
if val is None:
|
||||
return None
|
||||
if isinstance(val, bytes):
|
||||
return val
|
||||
return str(val).encode()
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str | bytes | int | float,
|
||||
ex: int | None = None,
|
||||
) -> None:
|
||||
self._r.set(key, value, ex=ex)
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
self._r.delete(key)
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
return bool(self._r.exists(key))
|
||||
|
||||
# -- TTL ---------------------------------------------------------------
|
||||
|
||||
def expire(self, key: str, seconds: int) -> None:
|
||||
self._r.expire(key, seconds)
|
||||
|
||||
def ttl(self, key: str) -> int:
|
||||
return cast(int, self._r.ttl(key))
|
||||
|
||||
# -- distributed lock --------------------------------------------------
|
||||
|
||||
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
|
||||
return RedisCacheLock(self._r.lock(name, timeout=timeout))
|
||||
|
||||
# -- blocking list (MCP OAuth BLPOP pattern) ---------------------------
|
||||
|
||||
def rpush(self, key: str, value: str | bytes) -> None:
|
||||
self._r.rpush(key, value)
|
||||
|
||||
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
|
||||
result = cast(list[bytes] | None, self._r.blpop(keys, timeout=timeout))
|
||||
if result is None:
|
||||
return None
|
||||
return (result[0], result[1])
|
||||
@@ -1,57 +1,52 @@
|
||||
from uuid import UUID
|
||||
|
||||
from redis.client import Redis
|
||||
from onyx.cache.interface import CacheBackend
|
||||
|
||||
# Redis key prefixes for chat message processing
|
||||
PREFIX = "chatprocessing"
|
||||
FENCE_PREFIX = f"{PREFIX}_fence"
|
||||
FENCE_TTL = 30 * 60 # 30 minutes
|
||||
|
||||
|
||||
def _get_fence_key(chat_session_id: UUID) -> str:
|
||||
"""
|
||||
Generate the Redis key for a chat session processing a message.
|
||||
"""Generate the cache key for a chat session processing fence.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
|
||||
Returns:
|
||||
The fence key string (tenant_id is automatically added by the Redis client)
|
||||
The fence key string. Tenant isolation is handled automatically
|
||||
by the cache backend (Redis key-prefixing or Postgres schema routing).
|
||||
"""
|
||||
return f"{FENCE_PREFIX}_{chat_session_id}"
|
||||
|
||||
|
||||
def set_processing_status(
|
||||
chat_session_id: UUID, redis_client: Redis, value: bool
|
||||
chat_session_id: UUID, cache: CacheBackend, value: bool
|
||||
) -> None:
|
||||
"""
|
||||
Set or clear the fence for a chat session processing a message.
|
||||
"""Set or clear the fence for a chat session processing a message.
|
||||
|
||||
If the key exists, we are processing a message. If the key does not exist, we are not processing a message.
|
||||
If the key exists, a message is being processed.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
redis_client: The Redis client to use
|
||||
cache: Tenant-aware cache backend
|
||||
value: True to set the fence, False to clear it
|
||||
"""
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
|
||||
if value:
|
||||
redis_client.set(fence_key, 0, ex=FENCE_TTL)
|
||||
cache.set(fence_key, 0, ex=FENCE_TTL)
|
||||
else:
|
||||
redis_client.delete(fence_key)
|
||||
cache.delete(fence_key)
|
||||
|
||||
|
||||
def is_chat_session_processing(chat_session_id: UUID, redis_client: Redis) -> bool:
|
||||
"""
|
||||
Check if the chat session is processing a message.
|
||||
def is_chat_session_processing(chat_session_id: UUID, cache: CacheBackend) -> bool:
|
||||
"""Check if the chat session is processing a message.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
redis_client: The Redis client to use
|
||||
cache: Tenant-aware cache backend
|
||||
|
||||
Returns:
|
||||
True if the chat session is processing a message, False otherwise
|
||||
"""
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
return bool(redis_client.exists(fence_key))
|
||||
return cache.exists(_get_fence_key(chat_session_id))
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -530,11 +531,13 @@ def _create_file_tool_metadata_message(
|
||||
"""
|
||||
lines = [
|
||||
"You have access to the following files. Use the read_file tool to "
|
||||
"read sections of any file:"
|
||||
"read sections of any file. You MUST pass the file_id UUID (not the "
|
||||
"filename) to read_file:"
|
||||
]
|
||||
for meta in file_metadata:
|
||||
lines.append(
|
||||
f'- {meta.file_id}: "{meta.filename}" (~{meta.approx_char_count:,} chars)'
|
||||
f'- file_id="{meta.file_id}" filename="{meta.filename}" '
|
||||
f"(~{meta.approx_char_count:,} chars)"
|
||||
)
|
||||
|
||||
message_content = "\n".join(lines)
|
||||
@@ -558,12 +561,16 @@ def _create_context_files_message(
|
||||
# Format as documents JSON as described in README
|
||||
documents_list = []
|
||||
for idx, file_text in enumerate(context_files.file_texts, start=1):
|
||||
documents_list.append(
|
||||
{
|
||||
"document": idx,
|
||||
"contents": file_text,
|
||||
}
|
||||
title = (
|
||||
context_files.file_metadata[idx - 1].filename
|
||||
if idx - 1 < len(context_files.file_metadata)
|
||||
else None
|
||||
)
|
||||
entry: dict[str, Any] = {"document": idx}
|
||||
if title:
|
||||
entry["title"] = title
|
||||
entry["contents"] = file_text
|
||||
documents_list.append(entry)
|
||||
|
||||
documents_json = json.dumps({"documents": documents_list}, indent=2)
|
||||
message_content = f"Here are some documents provided for context, they may not all be relevant:\n{documents_json}"
|
||||
|
||||
@@ -11,9 +11,10 @@ from contextvars import Token
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from redis.client import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.chat.chat_processing_checker import set_processing_status
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_state import run_chat_loop_with_state_containers
|
||||
@@ -79,7 +80,6 @@ from onyx.llm.request_context import reset_llm_mock_response
|
||||
from onyx.llm.request_context import set_llm_mock_response
|
||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
@@ -448,7 +448,7 @@ def handle_stream_message_objects(
|
||||
|
||||
llm: LLM | None = None
|
||||
chat_session: ChatSession | None = None
|
||||
redis_client: Redis | None = None
|
||||
cache: CacheBackend | None = None
|
||||
|
||||
user_id = user.id
|
||||
if user.is_anonymous:
|
||||
@@ -809,19 +809,19 @@ def handle_stream_message_objects(
|
||||
)
|
||||
simple_chat_history.insert(0, summary_simple)
|
||||
|
||||
redis_client = get_redis_client()
|
||||
cache = get_cache_backend()
|
||||
|
||||
reset_cancel_status(
|
||||
chat_session.id,
|
||||
redis_client,
|
||||
cache,
|
||||
)
|
||||
|
||||
def check_is_connected() -> bool:
|
||||
return check_stop_signal(chat_session.id, redis_client)
|
||||
return check_stop_signal(chat_session.id, cache)
|
||||
|
||||
set_processing_status(
|
||||
chat_session_id=chat_session.id,
|
||||
redis_client=redis_client,
|
||||
cache=cache,
|
||||
value=True,
|
||||
)
|
||||
|
||||
@@ -968,10 +968,10 @@ def handle_stream_message_objects(
|
||||
reset_llm_mock_response(mock_response_token)
|
||||
|
||||
try:
|
||||
if redis_client is not None and chat_session is not None:
|
||||
if cache is not None and chat_session is not None:
|
||||
set_processing_status(
|
||||
chat_session_id=chat_session.id,
|
||||
redis_client=redis_client,
|
||||
cache=cache,
|
||||
value=False,
|
||||
)
|
||||
except Exception:
|
||||
|
||||
@@ -1,65 +1,58 @@
|
||||
from uuid import UUID
|
||||
|
||||
from redis.client import Redis
|
||||
from onyx.cache.interface import CacheBackend
|
||||
|
||||
# Redis key prefixes for chat session stop signals
|
||||
PREFIX = "chatsessionstop"
|
||||
FENCE_PREFIX = f"{PREFIX}_fence"
|
||||
FENCE_TTL = 10 * 60 # 10 minutes - defensive TTL to prevent memory leaks
|
||||
FENCE_TTL = 10 * 60 # 10 minutes
|
||||
|
||||
|
||||
def _get_fence_key(chat_session_id: UUID) -> str:
|
||||
"""
|
||||
Generate the Redis key for a chat session stop signal fence.
|
||||
"""Generate the cache key for a chat session stop signal fence.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
|
||||
Returns:
|
||||
The fence key string (tenant_id is automatically added by the Redis client)
|
||||
The fence key string. Tenant isolation is handled automatically
|
||||
by the cache backend (Redis key-prefixing or Postgres schema routing).
|
||||
"""
|
||||
return f"{FENCE_PREFIX}_{chat_session_id}"
|
||||
|
||||
|
||||
def set_fence(chat_session_id: UUID, redis_client: Redis, value: bool) -> None:
|
||||
"""
|
||||
Set or clear the stop signal fence for a chat session.
|
||||
def set_fence(chat_session_id: UUID, cache: CacheBackend, value: bool) -> None:
|
||||
"""Set or clear the stop signal fence for a chat session.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
redis_client: Redis client to use (tenant-aware client that auto-prefixes keys)
|
||||
cache: Tenant-aware cache backend
|
||||
value: True to set the fence (stop signal), False to clear it
|
||||
"""
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
if not value:
|
||||
redis_client.delete(fence_key)
|
||||
cache.delete(fence_key)
|
||||
return
|
||||
|
||||
redis_client.set(fence_key, 0, ex=FENCE_TTL)
|
||||
cache.set(fence_key, 0, ex=FENCE_TTL)
|
||||
|
||||
|
||||
def is_connected(chat_session_id: UUID, redis_client: Redis) -> bool:
|
||||
"""
|
||||
Check if the chat session should continue (not stopped).
|
||||
def is_connected(chat_session_id: UUID, cache: CacheBackend) -> bool:
|
||||
"""Check if the chat session should continue (not stopped).
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session to check
|
||||
redis_client: Redis client to use for checking the stop signal (tenant-aware client that auto-prefixes keys)
|
||||
cache: Tenant-aware cache backend
|
||||
|
||||
Returns:
|
||||
True if the session should continue, False if it should stop
|
||||
"""
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
return not bool(redis_client.exists(fence_key))
|
||||
return not cache.exists(_get_fence_key(chat_session_id))
|
||||
|
||||
|
||||
def reset_cancel_status(chat_session_id: UUID, redis_client: Redis) -> None:
|
||||
"""
|
||||
Clear the stop signal for a chat session.
|
||||
def reset_cancel_status(chat_session_id: UUID, cache: CacheBackend) -> None:
|
||||
"""Clear the stop signal for a chat session.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
redis_client: Redis client to use (tenant-aware client that auto-prefixes keys)
|
||||
cache: Tenant-aware cache backend
|
||||
"""
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
redis_client.delete(fence_key)
|
||||
cache.delete(_get_fence_key(chat_session_id))
|
||||
|
||||
@@ -6,6 +6,7 @@ from datetime import timezone
|
||||
from typing import cast
|
||||
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
from onyx.cache.interface import CacheBackendType
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
|
||||
@@ -54,6 +55,12 @@ DISABLE_USER_KNOWLEDGE = os.environ.get("DISABLE_USER_KNOWLEDGE", "").lower() ==
|
||||
# are disabled but core chat, tools, user file uploads, and Projects still work.
|
||||
DISABLE_VECTOR_DB = os.environ.get("DISABLE_VECTOR_DB", "").lower() == "true"
|
||||
|
||||
# Which backend to use for caching, locks, and ephemeral state.
|
||||
# "redis" (default) or "postgres" (only valid when DISABLE_VECTOR_DB=true).
|
||||
CACHE_BACKEND = CacheBackendType(
|
||||
os.environ.get("CACHE_BACKEND", CacheBackendType.REDIS)
|
||||
)
|
||||
|
||||
# Maximum token count for a single uploaded file. Files exceeding this are rejected.
|
||||
# Defaults to 100k tokens (or 10M when vector DB is disabled).
|
||||
_DEFAULT_FILE_TOKEN_LIMIT = 10_000_000 if DISABLE_VECTOR_DB else 100_000
|
||||
|
||||
@@ -98,6 +98,7 @@ def get_chat_sessions_by_user(
|
||||
db_session: Session,
|
||||
include_onyxbot_flows: bool = False,
|
||||
limit: int = 50,
|
||||
before: datetime | None = None,
|
||||
project_id: int | None = None,
|
||||
only_non_project_chats: bool = False,
|
||||
include_failed_chats: bool = False,
|
||||
@@ -112,6 +113,9 @@ def get_chat_sessions_by_user(
|
||||
if deleted is not None:
|
||||
stmt = stmt.where(ChatSession.deleted == deleted)
|
||||
|
||||
if before is not None:
|
||||
stmt = stmt.where(ChatSession.time_updated < before)
|
||||
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
|
||||
@@ -186,6 +186,7 @@ class EmbeddingPrecision(str, PyEnum):
|
||||
|
||||
class UserFileStatus(str, PyEnum):
|
||||
PROCESSING = "PROCESSING"
|
||||
INDEXING = "INDEXING"
|
||||
COMPLETED = "COMPLETED"
|
||||
FAILED = "FAILED"
|
||||
CANCELED = "CANCELED"
|
||||
|
||||
@@ -109,45 +109,38 @@ def can_user_access_llm_provider(
|
||||
is_admin: If True, bypass user group restrictions but still respect persona restrictions
|
||||
|
||||
Access logic:
|
||||
1. If is_public=True → everyone has access (public override)
|
||||
2. If is_public=False:
|
||||
- Both groups AND personas set → must satisfy BOTH (AND logic, admins bypass group check)
|
||||
- Only groups set → must be in one of the groups (OR across groups, admins bypass)
|
||||
- Only personas set → must use one of the personas (OR across personas, applies to admins)
|
||||
- Neither set → NOBODY has access unless admin (locked, admin-only)
|
||||
- is_public controls USER access (group bypass): when True, all users can access
|
||||
regardless of group membership. When False, user must be in a whitelisted group
|
||||
(or be admin).
|
||||
- Persona restrictions are ALWAYS enforced when set, regardless of is_public.
|
||||
This allows admins to make a provider available to all users while still
|
||||
restricting which personas (assistants) can use it.
|
||||
|
||||
Decision matrix:
|
||||
1. is_public=True, no personas set → everyone has access
|
||||
2. is_public=True, personas set → all users, but only whitelisted personas
|
||||
3. is_public=False, groups+personas set → must satisfy BOTH (admins bypass groups)
|
||||
4. is_public=False, only groups set → must be in group (admins bypass)
|
||||
5. is_public=False, only personas set → must use whitelisted persona
|
||||
6. is_public=False, neither set → admin-only (locked)
|
||||
"""
|
||||
# Public override - everyone has access
|
||||
if provider.is_public:
|
||||
return True
|
||||
|
||||
# Extract IDs once to avoid multiple iterations
|
||||
provider_group_ids = (
|
||||
{group.id for group in provider.groups} if provider.groups else set()
|
||||
)
|
||||
provider_persona_ids = (
|
||||
{p.id for p in provider.personas} if provider.personas else set()
|
||||
)
|
||||
|
||||
provider_group_ids = {g.id for g in (provider.groups or [])}
|
||||
provider_persona_ids = {p.id for p in (provider.personas or [])}
|
||||
has_groups = bool(provider_group_ids)
|
||||
has_personas = bool(provider_persona_ids)
|
||||
|
||||
# Both groups AND personas set → AND logic (must satisfy both)
|
||||
if has_groups and has_personas:
|
||||
# Admins bypass group check but still must satisfy persona restrictions
|
||||
user_in_group = is_admin or bool(user_group_ids & provider_group_ids)
|
||||
persona_allowed = persona.id in provider_persona_ids if persona else False
|
||||
return user_in_group and persona_allowed
|
||||
# Persona restrictions are always enforced when set, regardless of is_public
|
||||
if has_personas and not (persona and persona.id in provider_persona_ids):
|
||||
return False
|
||||
|
||||
if provider.is_public:
|
||||
return True
|
||||
|
||||
# Only groups set → user must be in one of the groups (admins bypass)
|
||||
if has_groups:
|
||||
return is_admin or bool(user_group_ids & provider_group_ids)
|
||||
|
||||
# Only personas set → persona must be in allowed list (applies to admins too)
|
||||
if has_personas:
|
||||
return persona.id in provider_persona_ids if persona else False
|
||||
|
||||
# Neither groups nor personas set, and not public → admins can access
|
||||
return is_admin
|
||||
# No groups: either persona-whitelisted (already passed) or admin-only if locked
|
||||
return has_personas or is_admin
|
||||
|
||||
|
||||
def validate_persona_ids_exist(
|
||||
@@ -872,7 +865,6 @@ def insert_new_model_configuration__no_commit(
|
||||
is_visible=is_visible,
|
||||
max_input_tokens=max_input_tokens,
|
||||
display_name=display_name,
|
||||
supports_image_input=LLMModelFlowType.VISION in supported_flows,
|
||||
)
|
||||
.on_conflict_do_nothing()
|
||||
.returning(ModelConfiguration.id)
|
||||
@@ -907,7 +899,6 @@ def update_model_configuration__no_commit(
|
||||
is_visible=is_visible,
|
||||
max_input_tokens=max_input_tokens,
|
||||
display_name=display_name,
|
||||
supports_image_input=LLMModelFlowType.VISION in supported_flows,
|
||||
)
|
||||
.where(ModelConfiguration.id == model_configuration_id)
|
||||
.returning(ModelConfiguration)
|
||||
|
||||
@@ -103,7 +103,6 @@ from onyx.utils.encryption import encrypt_string_to_bytes
|
||||
from onyx.utils.sensitive import SensitiveValue
|
||||
from onyx.utils.headers import HeaderItemDict
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
|
||||
# TODO: After anonymous user migration has been deployed, make user_id columns NOT NULL
|
||||
# and update Mapped[User | None] relationships to Mapped[User] where needed.
|
||||
@@ -2883,9 +2882,6 @@ class ModelConfiguration(Base):
|
||||
# - The end-user is configuring a model and chooses not to set a max-input-tokens limit.
|
||||
max_input_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Deprecated: use LLMModelFlow with VISION flow type instead
|
||||
supports_image_input: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
|
||||
# Human-readable display name for the model.
|
||||
# For dynamic providers (OpenRouter, Bedrock, Ollama), this comes from the source API.
|
||||
# For static providers (OpenAI, Anthropic), this may be null and will fall back to LiteLLM.
|
||||
@@ -3265,19 +3261,6 @@ class Persona(Base):
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String)
|
||||
description: Mapped[str] = mapped_column(String)
|
||||
# Number of chunks to pass to the LLM for generation.
|
||||
num_chunks: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
chunks_above: Mapped[int] = mapped_column(Integer)
|
||||
chunks_below: Mapped[int] = mapped_column(Integer)
|
||||
# Pass every chunk through LLM for evaluation, fairly expensive
|
||||
# Can be turned off globally by admin, in which case, this setting is ignored
|
||||
llm_relevance_filter: Mapped[bool] = mapped_column(Boolean)
|
||||
# Enables using LLM to extract time and source type filters
|
||||
# Can also be admin disabled globally
|
||||
llm_filter_extraction: Mapped[bool] = mapped_column(Boolean)
|
||||
recency_bias: Mapped[RecencyBiasSetting] = mapped_column(
|
||||
Enum(RecencyBiasSetting, native_enum=False)
|
||||
)
|
||||
|
||||
# Allows the persona to specify a specific default LLM model
|
||||
# NOTE: only is applied on the actual response generation - is not used for things like
|
||||
@@ -3304,11 +3287,8 @@ class Persona(Base):
|
||||
# Treated specially (cannot be user edited etc.)
|
||||
builtin_persona: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
# Default personas are personas created by admins and are automatically added
|
||||
# to all users' assistants list.
|
||||
is_default_persona: Mapped[bool] = mapped_column(
|
||||
Boolean, default=False, nullable=False
|
||||
)
|
||||
# Featured personas are highlighted in the UI
|
||||
featured: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
# controls whether the persona is available to be selected by users
|
||||
is_visible: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
# controls the ordering of personas in the UI
|
||||
@@ -4943,7 +4923,9 @@ class ScimUserMapping(Base):
|
||||
__tablename__ = "scim_user_mapping"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
external_id: Mapped[str] = mapped_column(String, unique=True, index=True)
|
||||
external_id: Mapped[str | None] = mapped_column(
|
||||
String, unique=True, index=True, nullable=True
|
||||
)
|
||||
user_id: Mapped[UUID] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), unique=True, nullable=False
|
||||
)
|
||||
@@ -5000,3 +4982,25 @@ class CodeInterpreterServer(Base):
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
server_enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
|
||||
|
||||
class CacheStore(Base):
|
||||
"""Key-value cache table used by ``PostgresCacheBackend``.
|
||||
|
||||
Replaces Redis for simple KV caching, locks, and list operations
|
||||
when ``CACHE_BACKEND=postgres`` (NO_VECTOR_DB deployments).
|
||||
|
||||
Intentionally separate from ``KVStore``:
|
||||
- Stores raw bytes (LargeBinary) vs JSONB, matching Redis semantics.
|
||||
- Has ``expires_at`` for TTL; rows are periodically garbage-collected.
|
||||
- Holds ephemeral data (tokens, stop signals, lock state) not
|
||||
persistent application config, so cleanup can be aggressive.
|
||||
"""
|
||||
|
||||
__tablename__ = "cache_store"
|
||||
|
||||
key: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
value: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
|
||||
expires_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
@@ -18,11 +18,8 @@ from sqlalchemy.orm import Session
|
||||
from onyx.access.hierarchy_access import get_user_external_group_ids
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS
|
||||
from onyx.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
|
||||
from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX
|
||||
from onyx.db.document_access import get_accessible_documents_by_ids
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
@@ -254,13 +251,15 @@ def create_update_persona(
|
||||
# Permission to actually use these is checked later
|
||||
|
||||
try:
|
||||
# Default persona validation
|
||||
if create_persona_request.is_default_persona:
|
||||
# Curators can edit default personas, but not make them
|
||||
# Featured persona validation
|
||||
if create_persona_request.featured:
|
||||
|
||||
# Curators can edit featured personas, but not make them
|
||||
# TODO this will be reworked soon with RBAC permissions feature
|
||||
if user.role == UserRole.CURATOR or user.role == UserRole.GLOBAL_CURATOR:
|
||||
pass
|
||||
elif user.role != UserRole.ADMIN:
|
||||
raise ValueError("Only admins can make a default persona")
|
||||
raise ValueError("Only admins can make a featured persona")
|
||||
|
||||
# Convert incoming string UUIDs to UUID objects for DB operations
|
||||
converted_user_file_ids = None
|
||||
@@ -281,7 +280,6 @@ def create_update_persona(
|
||||
document_set_ids=create_persona_request.document_set_ids,
|
||||
tool_ids=create_persona_request.tool_ids,
|
||||
is_public=create_persona_request.is_public,
|
||||
recency_bias=create_persona_request.recency_bias,
|
||||
llm_model_provider_override=create_persona_request.llm_model_provider_override,
|
||||
llm_model_version_override=create_persona_request.llm_model_version_override,
|
||||
starter_messages=create_persona_request.starter_messages,
|
||||
@@ -295,10 +293,7 @@ def create_update_persona(
|
||||
remove_image=create_persona_request.remove_image,
|
||||
search_start_date=create_persona_request.search_start_date,
|
||||
label_ids=create_persona_request.label_ids,
|
||||
num_chunks=create_persona_request.num_chunks,
|
||||
llm_relevance_filter=create_persona_request.llm_relevance_filter,
|
||||
llm_filter_extraction=create_persona_request.llm_filter_extraction,
|
||||
is_default_persona=create_persona_request.is_default_persona,
|
||||
featured=create_persona_request.featured,
|
||||
user_file_ids=converted_user_file_ids,
|
||||
commit=False,
|
||||
hierarchy_node_ids=create_persona_request.hierarchy_node_ids,
|
||||
@@ -874,10 +869,6 @@ def upsert_persona(
|
||||
user: User | None,
|
||||
name: str,
|
||||
description: str,
|
||||
num_chunks: float,
|
||||
llm_relevance_filter: bool,
|
||||
llm_filter_extraction: bool,
|
||||
recency_bias: RecencyBiasSetting,
|
||||
llm_model_provider_override: str | None,
|
||||
llm_model_version_override: str | None,
|
||||
starter_messages: list[StarterMessage] | None,
|
||||
@@ -898,13 +889,11 @@ def upsert_persona(
|
||||
remove_image: bool | None = None,
|
||||
search_start_date: datetime | None = None,
|
||||
builtin_persona: bool = False,
|
||||
is_default_persona: bool | None = None,
|
||||
featured: bool | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
user_file_ids: list[UUID] | None = None,
|
||||
hierarchy_node_ids: list[int] | None = None,
|
||||
document_ids: list[str] | None = None,
|
||||
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
|
||||
chunks_below: int = CONTEXT_CHUNKS_BELOW,
|
||||
replace_base_system_prompt: bool = False,
|
||||
) -> Persona:
|
||||
"""
|
||||
@@ -1015,12 +1004,6 @@ def upsert_persona(
|
||||
# `default` and `built-in` properties can only be set when creating a persona.
|
||||
existing_persona.name = name
|
||||
existing_persona.description = description
|
||||
existing_persona.num_chunks = num_chunks
|
||||
existing_persona.chunks_above = chunks_above
|
||||
existing_persona.chunks_below = chunks_below
|
||||
existing_persona.llm_relevance_filter = llm_relevance_filter
|
||||
existing_persona.llm_filter_extraction = llm_filter_extraction
|
||||
existing_persona.recency_bias = recency_bias
|
||||
existing_persona.llm_model_provider_override = llm_model_provider_override
|
||||
existing_persona.llm_model_version_override = llm_model_version_override
|
||||
existing_persona.starter_messages = starter_messages
|
||||
@@ -1034,10 +1017,8 @@ def upsert_persona(
|
||||
if label_ids is not None:
|
||||
existing_persona.labels.clear()
|
||||
existing_persona.labels = labels or []
|
||||
existing_persona.is_default_persona = (
|
||||
is_default_persona
|
||||
if is_default_persona is not None
|
||||
else existing_persona.is_default_persona
|
||||
existing_persona.featured = (
|
||||
featured if featured is not None else existing_persona.featured
|
||||
)
|
||||
# Update embedded prompt fields if provided
|
||||
if system_prompt is not None:
|
||||
@@ -1090,12 +1071,6 @@ def upsert_persona(
|
||||
is_public=is_public,
|
||||
name=name,
|
||||
description=description,
|
||||
num_chunks=num_chunks,
|
||||
chunks_above=chunks_above,
|
||||
chunks_below=chunks_below,
|
||||
llm_relevance_filter=llm_relevance_filter,
|
||||
llm_filter_extraction=llm_filter_extraction,
|
||||
recency_bias=recency_bias,
|
||||
builtin_persona=builtin_persona,
|
||||
system_prompt=system_prompt or "",
|
||||
task_prompt=task_prompt or "",
|
||||
@@ -1111,9 +1086,7 @@ def upsert_persona(
|
||||
display_priority=display_priority,
|
||||
is_visible=is_visible,
|
||||
search_start_date=search_start_date,
|
||||
is_default_persona=(
|
||||
is_default_persona if is_default_persona is not None else False
|
||||
),
|
||||
featured=(featured if featured is not None else False),
|
||||
user_files=user_files or [],
|
||||
labels=labels or [],
|
||||
hierarchy_nodes=hierarchy_nodes or [],
|
||||
@@ -1158,9 +1131,9 @@ def delete_old_default_personas(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_persona_is_default(
|
||||
def update_persona_featured(
|
||||
persona_id: int,
|
||||
is_default: bool,
|
||||
featured: bool,
|
||||
db_session: Session,
|
||||
user: User,
|
||||
) -> None:
|
||||
@@ -1168,7 +1141,7 @@ def update_persona_is_default(
|
||||
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
|
||||
)
|
||||
|
||||
persona.is_default_persona = is_default
|
||||
persona.featured = featured
|
||||
db_session.commit()
|
||||
|
||||
|
||||
|
||||
@@ -9,8 +9,9 @@ from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.background import BackgroundTasks
|
||||
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
@@ -105,8 +106,8 @@ def upload_files_to_user_files_with_indexing(
|
||||
user: User,
|
||||
temp_id_map: dict[str, str] | None,
|
||||
db_session: Session,
|
||||
background_tasks: BackgroundTasks | None = None,
|
||||
) -> CategorizedFilesResult:
|
||||
# Validate project ownership if a project_id is provided
|
||||
if project_id is not None and user is not None:
|
||||
if not check_project_ownership(project_id, user.id, db_session):
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
@@ -127,16 +128,27 @@ def upload_files_to_user_files_with_indexing(
|
||||
logger.warning(
|
||||
f"File {rejected_file.filename} rejected for {rejected_file.reason}"
|
||||
)
|
||||
for user_file in user_files:
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered indexing for user_file_id={user_file.id} with task_id={task.id}"
|
||||
)
|
||||
|
||||
if DISABLE_VECTOR_DB and background_tasks is not None:
|
||||
from onyx.background.task_utils import drain_processing_loop
|
||||
|
||||
background_tasks.add_task(drain_processing_loop, tenant_id)
|
||||
for user_file in user_files:
|
||||
logger.info(f"Queued in-process processing for user_file_id={user_file.id}")
|
||||
else:
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
|
||||
for user_file in user_files:
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered indexing for user_file_id={user_file.id} "
|
||||
f"with task_id={task.id}"
|
||||
)
|
||||
|
||||
return CategorizedFilesResult(
|
||||
user_files=user_files,
|
||||
|
||||
@@ -5,8 +5,6 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.constants import DEFAULT_PERSONA_SLACK_CHANNEL_NAME
|
||||
from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX
|
||||
from onyx.db.models import ChannelConfig
|
||||
@@ -45,8 +43,6 @@ def create_slack_channel_persona(
|
||||
channel_name: str | None,
|
||||
document_set_ids: list[int],
|
||||
existing_persona_id: int | None = None,
|
||||
num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
||||
enable_auto_filters: bool = False,
|
||||
) -> Persona:
|
||||
"""NOTE: does not commit changes"""
|
||||
|
||||
@@ -73,17 +69,13 @@ def create_slack_channel_persona(
|
||||
system_prompt="",
|
||||
task_prompt="",
|
||||
datetime_aware=True,
|
||||
num_chunks=num_chunks,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=enable_auto_filters,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
tool_ids=[search_tool.id],
|
||||
document_set_ids=document_set_ids,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
is_public=True,
|
||||
is_default_persona=False,
|
||||
featured=False,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
|
||||
@@ -6,7 +6,6 @@ import httpx
|
||||
from opensearchpy import NotFoundError
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH
|
||||
from onyx.configs.app_configs import VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT
|
||||
from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
@@ -563,12 +562,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
)
|
||||
|
||||
if not self._client.index_exists():
|
||||
if USING_AWS_MANAGED_OPENSEARCH:
|
||||
index_settings = (
|
||||
DocumentSchema.get_index_settings_for_aws_managed_opensearch()
|
||||
)
|
||||
else:
|
||||
index_settings = DocumentSchema.get_index_settings()
|
||||
index_settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
self._client.create_index(
|
||||
mappings=expected_mappings,
|
||||
settings=index_settings,
|
||||
|
||||
@@ -12,6 +12,7 @@ from pydantic import model_validator
|
||||
from pydantic import SerializerFunctionWrapHandler
|
||||
|
||||
from onyx.configs.app_configs import OPENSEARCH_TEXT_ANALYZER
|
||||
from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
|
||||
from onyx.document_index.opensearch.constants import EF_CONSTRUCTION
|
||||
@@ -525,7 +526,7 @@ class DocumentSchema:
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_index_settings_for_aws_managed_opensearch() -> dict[str, Any]:
|
||||
def get_index_settings_for_aws_managed_opensearch_st_dev() -> dict[str, Any]:
|
||||
"""
|
||||
Settings for AWS-managed OpenSearch.
|
||||
|
||||
@@ -546,3 +547,41 @@ class DocumentSchema:
|
||||
"knn.algo_param.ef_search": EF_SEARCH,
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_index_settings_for_aws_managed_opensearch_mt_cloud() -> dict[str, Any]:
|
||||
"""
|
||||
Settings for AWS-managed OpenSearch in multi-tenant cloud.
|
||||
|
||||
324 shards very roughly targets a storage load of ~30Gb per shard, which
|
||||
according to AWS OpenSearch documentation is within a good target range.
|
||||
|
||||
As documented above we need 2 replicas for a total of 3 copies of the
|
||||
data because the cluster is configured with 3-AZ awareness.
|
||||
"""
|
||||
return {
|
||||
"index": {
|
||||
"number_of_shards": 324,
|
||||
"number_of_replicas": 2,
|
||||
# Required for vector search.
|
||||
"knn": True,
|
||||
"knn.algo_param.ef_search": EF_SEARCH,
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_index_settings_based_on_environment() -> dict[str, Any]:
|
||||
"""
|
||||
Returns the index settings based on the environment.
|
||||
"""
|
||||
if USING_AWS_MANAGED_OPENSEARCH:
|
||||
if MULTI_TENANT:
|
||||
return (
|
||||
DocumentSchema.get_index_settings_for_aws_managed_opensearch_mt_cloud()
|
||||
)
|
||||
else:
|
||||
return (
|
||||
DocumentSchema.get_index_settings_for_aws_managed_opensearch_st_dev()
|
||||
)
|
||||
else:
|
||||
return DocumentSchema.get_index_settings()
|
||||
|
||||
@@ -4,39 +4,33 @@ import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Redis key prefix for OAuth state
|
||||
OAUTH_STATE_PREFIX = "federated_oauth"
|
||||
# Default TTL for OAuth state (5 minutes)
|
||||
OAUTH_STATE_TTL = 300
|
||||
OAUTH_STATE_TTL = 300 # 5 minutes
|
||||
|
||||
|
||||
class OAuthSession:
|
||||
"""Represents an OAuth session stored in Redis."""
|
||||
"""Represents an OAuth session stored in the cache backend."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
federated_connector_id: int,
|
||||
user_id: str,
|
||||
redirect_uri: Optional[str] = None,
|
||||
additional_data: Optional[Dict[str, Any]] = None,
|
||||
redirect_uri: str | None = None,
|
||||
additional_data: dict[str, Any] | None = None,
|
||||
):
|
||||
self.federated_connector_id = federated_connector_id
|
||||
self.user_id = user_id
|
||||
self.redirect_uri = redirect_uri
|
||||
self.additional_data = additional_data or {}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for Redis storage."""
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"federated_connector_id": self.federated_connector_id,
|
||||
"user_id": self.user_id,
|
||||
@@ -45,8 +39,7 @@ class OAuthSession:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "OAuthSession":
|
||||
"""Create from dictionary retrieved from Redis."""
|
||||
def from_dict(cls, data: dict[str, Any]) -> "OAuthSession":
|
||||
return cls(
|
||||
federated_connector_id=data["federated_connector_id"],
|
||||
user_id=data["user_id"],
|
||||
@@ -58,31 +51,27 @@ class OAuthSession:
|
||||
def generate_oauth_state(
|
||||
federated_connector_id: int,
|
||||
user_id: str,
|
||||
redirect_uri: Optional[str] = None,
|
||||
additional_data: Optional[Dict[str, Any]] = None,
|
||||
redirect_uri: str | None = None,
|
||||
additional_data: dict[str, Any] | None = None,
|
||||
ttl: int = OAUTH_STATE_TTL,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a secure state parameter and store session data in Redis.
|
||||
Generate a secure state parameter and store session data in the cache backend.
|
||||
|
||||
Args:
|
||||
federated_connector_id: ID of the federated connector
|
||||
user_id: ID of the user initiating OAuth
|
||||
redirect_uri: Optional redirect URI after OAuth completion
|
||||
additional_data: Any additional data to store with the session
|
||||
ttl: Time-to-live in seconds for the Redis key
|
||||
ttl: Time-to-live in seconds for the cache key
|
||||
|
||||
Returns:
|
||||
Base64-encoded state parameter
|
||||
"""
|
||||
# Generate a random UUID for the state
|
||||
state_uuid = uuid.uuid4()
|
||||
state_b64 = base64.urlsafe_b64encode(state_uuid.bytes).decode("utf-8").rstrip("=")
|
||||
|
||||
# Convert UUID to base64 for URL-safe state parameter
|
||||
state_bytes = state_uuid.bytes
|
||||
state_b64 = base64.urlsafe_b64encode(state_bytes).decode("utf-8").rstrip("=")
|
||||
|
||||
# Create session object
|
||||
session = OAuthSession(
|
||||
federated_connector_id=federated_connector_id,
|
||||
user_id=user_id,
|
||||
@@ -90,15 +79,9 @@ def generate_oauth_state(
|
||||
additional_data=additional_data,
|
||||
)
|
||||
|
||||
# Store in Redis with TTL
|
||||
redis_client = get_redis_client()
|
||||
redis_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
|
||||
|
||||
redis_client.set(
|
||||
redis_key,
|
||||
json.dumps(session.to_dict()),
|
||||
ex=ttl,
|
||||
)
|
||||
cache = get_cache_backend()
|
||||
cache_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
|
||||
cache.set(cache_key, json.dumps(session.to_dict()), ex=ttl)
|
||||
|
||||
logger.info(
|
||||
f"Generated OAuth state for federated_connector_id={federated_connector_id}, "
|
||||
@@ -125,18 +108,15 @@ def verify_oauth_state(state: str) -> OAuthSession:
|
||||
state_bytes = base64.urlsafe_b64decode(padded_state)
|
||||
state_uuid = uuid.UUID(bytes=state_bytes)
|
||||
|
||||
# Look up in Redis
|
||||
redis_client = get_redis_client()
|
||||
redis_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
|
||||
cache = get_cache_backend()
|
||||
cache_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
|
||||
|
||||
session_data = cast(bytes, redis_client.get(redis_key))
|
||||
session_data = cache.get(cache_key)
|
||||
if not session_data:
|
||||
raise ValueError(f"OAuth state not found in Redis: {state}")
|
||||
raise ValueError(f"OAuth state not found: {state}")
|
||||
|
||||
# Delete the key after retrieval (one-time use)
|
||||
redis_client.delete(redis_key)
|
||||
cache.delete(cache_key)
|
||||
|
||||
# Parse and return session
|
||||
session_dict = json.loads(session_data)
|
||||
return OAuthSession.from_dict(session_dict)
|
||||
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import json
|
||||
from typing import cast
|
||||
|
||||
from redis.client import Redis
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import KVStore
|
||||
from onyx.key_value_store.interface import KeyValueStore
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
|
||||
@@ -20,22 +18,27 @@ KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day
|
||||
|
||||
|
||||
class PgRedisKVStore(KeyValueStore):
|
||||
def __init__(self, redis_client: Redis | None = None) -> None:
|
||||
# If no redis_client is provided, fall back to the context var
|
||||
if redis_client is not None:
|
||||
self.redis_client = redis_client
|
||||
else:
|
||||
self.redis_client = get_redis_client()
|
||||
def __init__(self, cache: CacheBackend | None = None) -> None:
|
||||
self._cache = cache
|
||||
|
||||
def _get_cache(self) -> CacheBackend:
|
||||
if self._cache is None:
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
|
||||
self._cache = get_cache_backend()
|
||||
return self._cache
|
||||
|
||||
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
|
||||
# Not encrypted in Redis, but encrypted in Postgres
|
||||
# Not encrypted in Cache backend (typically Redis), but encrypted in Postgres
|
||||
try:
|
||||
self.redis_client.set(
|
||||
self._get_cache().set(
|
||||
REDIS_KEY_PREFIX + key, json.dumps(val), ex=KV_REDIS_KEY_EXPIRATION
|
||||
)
|
||||
except Exception as e:
|
||||
# Fallback gracefully to Postgres if Redis fails
|
||||
logger.error(f"Failed to set value in Redis for key '{key}': {str(e)}")
|
||||
# Fallback gracefully to Postgres if Cache backend fails
|
||||
logger.error(
|
||||
f"Failed to set value in Cache backend for key '{key}': {str(e)}"
|
||||
)
|
||||
|
||||
encrypted_val = val if encrypt else None
|
||||
plain_val = val if not encrypt else None
|
||||
@@ -53,16 +56,12 @@ class PgRedisKVStore(KeyValueStore):
|
||||
def load(self, key: str, refresh_cache: bool = False) -> JSON_ro:
|
||||
if not refresh_cache:
|
||||
try:
|
||||
redis_value = self.redis_client.get(REDIS_KEY_PREFIX + key)
|
||||
if redis_value:
|
||||
if not isinstance(redis_value, bytes):
|
||||
raise ValueError(
|
||||
f"Redis value for key '{key}' is not a bytes object"
|
||||
)
|
||||
return json.loads(redis_value.decode("utf-8"))
|
||||
cached = self._get_cache().get(REDIS_KEY_PREFIX + key)
|
||||
if cached is not None:
|
||||
return json.loads(cached.decode("utf-8"))
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to get value from Redis for key '{key}': {str(e)}"
|
||||
f"Failed to get value from cache for key '{key}': {str(e)}"
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -79,21 +78,21 @@ class PgRedisKVStore(KeyValueStore):
|
||||
value = None
|
||||
|
||||
try:
|
||||
self.redis_client.set(
|
||||
self._get_cache().set(
|
||||
REDIS_KEY_PREFIX + key,
|
||||
json.dumps(value),
|
||||
ex=KV_REDIS_KEY_EXPIRATION,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to set value in Redis for key '{key}': {str(e)}")
|
||||
logger.error(f"Failed to set value in cache for key '{key}': {str(e)}")
|
||||
|
||||
return cast(JSON_ro, value)
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
try:
|
||||
self.redis_client.delete(REDIS_KEY_PREFIX + key)
|
||||
self._get_cache().delete(REDIS_KEY_PREFIX + key)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete value from Redis for key '{key}': {str(e)}")
|
||||
logger.error(f"Failed to delete value from cache for key '{key}': {str(e)}")
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
result = db_session.query(KVStore).filter_by(key=key).delete()
|
||||
|
||||
@@ -67,6 +67,18 @@ Status checked against LiteLLM v1.81.6-nightly (2026-02-02):
|
||||
STATUS: STILL NEEDED - litellm_core_utils/litellm_logging.py lines 3185-3199 set
|
||||
usage as a dict with chat completion format instead of keeping it as
|
||||
ResponseAPIUsage. Our patch creates a deep copy before modification.
|
||||
|
||||
7. Responses API metadata=None TypeError (_patch_responses_metadata_none):
|
||||
- LiteLLM's @client decorator wrapper in utils.py uses kwargs.get("metadata", {})
|
||||
to check for router calls, but when metadata is explicitly None (key exists with
|
||||
value None), the default {} is not used
|
||||
- This causes "argument of type 'NoneType' is not iterable" TypeError which swallows
|
||||
the real exception (e.g. AuthenticationError for wrong API key)
|
||||
- Surfaces as: APIConnectionError: OpenAIException - argument of type 'NoneType' is
|
||||
not iterable
|
||||
STATUS: STILL NEEDED - litellm/utils.py wrapper function (line 1721) does not guard
|
||||
against metadata being explicitly None. Triggered when Responses API bridge
|
||||
passes **litellm_params containing metadata=None.
|
||||
"""
|
||||
|
||||
import time
|
||||
@@ -725,6 +737,44 @@ def _patch_logging_assembled_streaming_response() -> None:
|
||||
LiteLLMLoggingObj._get_assembled_streaming_response = _patched_get_assembled_streaming_response # type: ignore[method-assign]
|
||||
|
||||
|
||||
def _patch_responses_metadata_none() -> None:
|
||||
"""
|
||||
Patches litellm.responses to normalize metadata=None to metadata={} in kwargs.
|
||||
|
||||
LiteLLM's @client decorator wrapper in utils.py (line 1721) does:
|
||||
_is_litellm_router_call = "model_group" in kwargs.get("metadata", {})
|
||||
When metadata is explicitly None in kwargs, kwargs.get("metadata", {}) returns
|
||||
None (the key exists, so the default is not used), causing:
|
||||
TypeError: argument of type 'NoneType' is not iterable
|
||||
|
||||
This swallows the real exception (e.g. AuthenticationError) and surfaces as:
|
||||
APIConnectionError: OpenAIException - argument of type 'NoneType' is not iterable
|
||||
|
||||
This happens when the Responses API bridge calls litellm.responses() with
|
||||
**litellm_params which may contain metadata=None.
|
||||
|
||||
STATUS: STILL NEEDED - litellm/utils.py wrapper function uses kwargs.get("metadata", {})
|
||||
which does not guard against metadata being explicitly None. Same pattern exists
|
||||
on line 1407 for async path.
|
||||
"""
|
||||
import litellm as _litellm
|
||||
from functools import wraps
|
||||
|
||||
original_responses = _litellm.responses
|
||||
|
||||
if getattr(original_responses, "_metadata_patched", False):
|
||||
return
|
||||
|
||||
@wraps(original_responses)
|
||||
def _patched_responses(*args: Any, **kwargs: Any) -> Any:
|
||||
if kwargs.get("metadata") is None:
|
||||
kwargs["metadata"] = {}
|
||||
return original_responses(*args, **kwargs)
|
||||
|
||||
_patched_responses._metadata_patched = True # type: ignore[attr-defined]
|
||||
_litellm.responses = _patched_responses
|
||||
|
||||
|
||||
def apply_monkey_patches() -> None:
|
||||
"""
|
||||
Apply all necessary monkey patches to LiteLLM for compatibility.
|
||||
@@ -736,6 +786,7 @@ def apply_monkey_patches() -> None:
|
||||
- Patching AzureOpenAIResponsesAPIConfig.should_fake_stream to enable native streaming
|
||||
- Patching ResponsesAPIResponse.model_construct to fix usage format in all code paths
|
||||
- Patching LiteLLMLoggingObj._get_assembled_streaming_response to avoid mutating original response
|
||||
- Patching litellm.responses to fix metadata=None causing TypeError in error handling
|
||||
"""
|
||||
_patch_ollama_chunk_parser()
|
||||
_patch_openai_responses_parallel_tool_calls()
|
||||
@@ -743,3 +794,4 @@ def apply_monkey_patches() -> None:
|
||||
_patch_azure_responses_should_fake_stream()
|
||||
_patch_responses_api_usage_format()
|
||||
_patch_logging_assembled_streaming_response()
|
||||
_patch_responses_metadata_none()
|
||||
|
||||
@@ -13,44 +13,38 @@ from datetime import datetime
|
||||
import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
from onyx.db.llm import fetch_auto_mode_providers
|
||||
from onyx.db.llm import sync_auto_mode_models
|
||||
from onyx.llm.well_known_providers.auto_update_models import LLMRecommendations
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Redis key for caching the last updated timestamp (per-tenant)
|
||||
_REDIS_KEY_LAST_UPDATED_AT = "auto_llm_update:last_updated_at"
|
||||
_CACHE_KEY_LAST_UPDATED_AT = "auto_llm_update:last_updated_at"
|
||||
_CACHE_TTL_SECONDS = 60 * 60 * 24 # 24 hours
|
||||
|
||||
|
||||
def _get_cached_last_updated_at() -> datetime | None:
|
||||
"""Get the cached last_updated_at timestamp from Redis."""
|
||||
try:
|
||||
redis_client = get_redis_client()
|
||||
value = redis_client.get(_REDIS_KEY_LAST_UPDATED_AT)
|
||||
if value and isinstance(value, bytes):
|
||||
# Value is bytes, decode to string then parse as ISO format
|
||||
value = get_cache_backend().get(_CACHE_KEY_LAST_UPDATED_AT)
|
||||
if value is not None:
|
||||
return datetime.fromisoformat(value.decode("utf-8"))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get cached last_updated_at from Redis: {e}")
|
||||
logger.warning(f"Failed to get cached last_updated_at: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _set_cached_last_updated_at(updated_at: datetime) -> None:
|
||||
"""Set the cached last_updated_at timestamp in Redis."""
|
||||
try:
|
||||
redis_client = get_redis_client()
|
||||
# Store as ISO format string, with 24 hour expiration
|
||||
redis_client.set(
|
||||
_REDIS_KEY_LAST_UPDATED_AT,
|
||||
get_cache_backend().set(
|
||||
_CACHE_KEY_LAST_UPDATED_AT,
|
||||
updated_at.isoformat(),
|
||||
ex=60 * 60 * 24, # 24 hours
|
||||
ex=_CACHE_TTL_SECONDS,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to set cached last_updated_at in Redis: {e}")
|
||||
logger.warning(f"Failed to set cached last_updated_at: {e}")
|
||||
|
||||
|
||||
def fetch_llm_recommendations_from_github(
|
||||
@@ -148,9 +142,8 @@ def sync_llm_models_from_github(
|
||||
|
||||
|
||||
def reset_cache() -> None:
|
||||
"""Reset the cache timestamp in Redis. Useful for testing."""
|
||||
"""Reset the cache timestamp. Useful for testing."""
|
||||
try:
|
||||
redis_client = get_redis_client()
|
||||
redis_client.delete(_REDIS_KEY_LAST_UPDATED_AT)
|
||||
get_cache_backend().delete(_CACHE_KEY_LAST_UPDATED_AT)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to reset cache in Redis: {e}")
|
||||
logger.warning(f"Failed to reset cache: {e}")
|
||||
|
||||
@@ -32,11 +32,14 @@ from onyx.auth.schemas import UserUpdate
|
||||
from onyx.auth.users import auth_backend
|
||||
from onyx.auth.users import create_onyx_oauth_router
|
||||
from onyx.auth.users import fastapi_users
|
||||
from onyx.cache.interface import CacheBackendType
|
||||
from onyx.configs.app_configs import APP_API_PREFIX
|
||||
from onyx.configs.app_configs import APP_HOST
|
||||
from onyx.configs.app_configs import APP_PORT
|
||||
from onyx.configs.app_configs import AUTH_RATE_LIMITING_ENABLED
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import CACHE_BACKEND
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import LOG_ENDPOINT_LATENCY
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_ID
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
@@ -254,8 +257,53 @@ def include_auth_router_with_prefix(
|
||||
)
|
||||
|
||||
|
||||
def validate_cache_backend_settings() -> None:
|
||||
"""Validate that CACHE_BACKEND=postgres is only used with DISABLE_VECTOR_DB.
|
||||
|
||||
The Postgres cache backend eliminates the Redis dependency, but only works
|
||||
when Celery is not running (which requires DISABLE_VECTOR_DB=true).
|
||||
"""
|
||||
if CACHE_BACKEND == CacheBackendType.POSTGRES and not DISABLE_VECTOR_DB:
|
||||
raise RuntimeError(
|
||||
"CACHE_BACKEND=postgres requires DISABLE_VECTOR_DB=true. "
|
||||
"The Postgres cache backend is only supported in no-vector-DB "
|
||||
"deployments where Celery is replaced by the in-process task runner."
|
||||
)
|
||||
|
||||
|
||||
def validate_no_vector_db_settings() -> None:
|
||||
"""Validate that DISABLE_VECTOR_DB is not combined with incompatible settings.
|
||||
|
||||
Raises RuntimeError if DISABLE_VECTOR_DB is set alongside MULTI_TENANT or ENABLE_CRAFT,
|
||||
since these modes require infrastructure that is removed in no-vector-DB deployments.
|
||||
"""
|
||||
if not DISABLE_VECTOR_DB:
|
||||
return
|
||||
|
||||
if MULTI_TENANT:
|
||||
raise RuntimeError(
|
||||
"DISABLE_VECTOR_DB cannot be used with MULTI_TENANT. "
|
||||
"Multi-tenant deployments require the vector database for "
|
||||
"per-tenant document indexing and search. Run in single-tenant "
|
||||
"mode when disabling the vector database."
|
||||
)
|
||||
|
||||
from onyx.server.features.build.configs import ENABLE_CRAFT
|
||||
|
||||
if ENABLE_CRAFT:
|
||||
raise RuntimeError(
|
||||
"DISABLE_VECTOR_DB cannot be used with ENABLE_CRAFT. "
|
||||
"Onyx Craft requires background workers for sandbox lifecycle "
|
||||
"management, which are removed in no-vector-DB deployments. "
|
||||
"Disable Craft (ENABLE_CRAFT=false) when disabling the vector database."
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
|
||||
validate_no_vector_db_settings()
|
||||
validate_cache_backend_settings()
|
||||
|
||||
# Set recursion limit
|
||||
if SYSTEM_RECURSION_LIMIT is not None:
|
||||
sys.setrecursionlimit(SYSTEM_RECURSION_LIMIT)
|
||||
@@ -324,8 +372,20 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
|
||||
if AUTH_RATE_LIMITING_ENABLED:
|
||||
await setup_auth_limiter()
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
from onyx.background.periodic_poller import recover_stuck_user_files
|
||||
from onyx.background.periodic_poller import start_periodic_poller
|
||||
|
||||
recover_stuck_user_files(POSTGRES_DEFAULT_SCHEMA)
|
||||
start_periodic_poller(POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
yield
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
from onyx.background.periodic_poller import stop_periodic_poller
|
||||
|
||||
stop_periodic_poller()
|
||||
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
if AUTH_RATE_LIMITING_ENABLED:
|
||||
|
||||
@@ -3,10 +3,12 @@ import datetime
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_FEEDBACK_REMINDER
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.user_preferences import activate_user
|
||||
from onyx.db.users import add_slack_user_if_not_exists
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.onyxbot.slack.blocks import get_feedback_reminder_blocks
|
||||
@@ -243,6 +245,44 @@ def handle_message(
|
||||
)
|
||||
return False
|
||||
|
||||
elif (
|
||||
not existing_user.is_active
|
||||
and existing_user.role == UserRole.SLACK_USER
|
||||
):
|
||||
check_seat_fn = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license",
|
||||
"check_seat_availability",
|
||||
None,
|
||||
)
|
||||
seat_result = check_seat_fn(db_session=db_session)
|
||||
if seat_result is not None and not seat_result.available:
|
||||
logger.info(
|
||||
f"Blocked inactive Slack user {message_info.email}: "
|
||||
f"{seat_result.error_message}"
|
||||
)
|
||||
respond_in_thread_or_channel(
|
||||
client=client,
|
||||
channel=channel,
|
||||
thread_ts=message_info.msg_to_respond,
|
||||
text=(
|
||||
"We weren't able to respond because your organization "
|
||||
"has reached its user seat limit. Your account is "
|
||||
"currently deactivated and cannot be reactivated "
|
||||
"until more seats are available. Please contact "
|
||||
"your Onyx administrator."
|
||||
),
|
||||
)
|
||||
return False
|
||||
|
||||
activate_user(existing_user, db_session)
|
||||
invalidate_license_cache_fn = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license",
|
||||
"invalidate_license_cache",
|
||||
None,
|
||||
)
|
||||
invalidate_license_cache_fn()
|
||||
logger.info(f"Reactivated inactive Slack user {message_info.email}")
|
||||
|
||||
add_slack_user_if_not_exists(db_session, message_info.email)
|
||||
|
||||
# first check if we need to respond with a standard answer
|
||||
|
||||
@@ -1133,7 +1133,8 @@ done
|
||||
# Already deleted
|
||||
service_deleted = True
|
||||
else:
|
||||
logger.warning(f"Error deleting Service {service_name}: {e}")
|
||||
logger.error(f"Error deleting Service {service_name}: {e}")
|
||||
raise
|
||||
|
||||
pod_deleted = False
|
||||
try:
|
||||
@@ -1148,7 +1149,8 @@ done
|
||||
# Already deleted
|
||||
pod_deleted = True
|
||||
else:
|
||||
logger.warning(f"Error deleting Pod {pod_name}: {e}")
|
||||
logger.error(f"Error deleting Pod {pod_name}: {e}")
|
||||
raise
|
||||
|
||||
# Wait for resources to be fully deleted to prevent 409 conflicts
|
||||
# on immediate re-provisioning
|
||||
|
||||
@@ -80,7 +80,7 @@ def cleanup_idle_sandboxes_task(self: Task, *, tenant_id: str) -> None: # noqa:
|
||||
|
||||
# Prevent overlapping runs of this task
|
||||
if not lock.acquire(blocking=False):
|
||||
task_logger.debug("cleanup_idle_sandboxes_task - lock not acquired, skipping")
|
||||
task_logger.info("cleanup_idle_sandboxes_task - lock not acquired, skipping")
|
||||
return
|
||||
|
||||
try:
|
||||
|
||||
@@ -32,7 +32,7 @@ from onyx.db.persona import get_persona_snapshots_for_user
|
||||
from onyx.db.persona import get_persona_snapshots_paginated
|
||||
from onyx.db.persona import mark_persona_as_deleted
|
||||
from onyx.db.persona import mark_persona_as_not_deleted
|
||||
from onyx.db.persona import update_persona_is_default
|
||||
from onyx.db.persona import update_persona_featured
|
||||
from onyx.db.persona import update_persona_label
|
||||
from onyx.db.persona import update_persona_public_status
|
||||
from onyx.db.persona import update_persona_shared
|
||||
@@ -130,8 +130,8 @@ class IsPublicRequest(BaseModel):
|
||||
is_public: bool
|
||||
|
||||
|
||||
class IsDefaultRequest(BaseModel):
|
||||
is_default_persona: bool
|
||||
class IsFeaturedRequest(BaseModel):
|
||||
featured: bool
|
||||
|
||||
|
||||
@admin_router.patch("/{persona_id}/visible")
|
||||
@@ -168,22 +168,22 @@ def patch_user_persona_public_status(
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.patch("/{persona_id}/default")
|
||||
def patch_persona_default_status(
|
||||
@admin_router.patch("/{persona_id}/featured")
|
||||
def patch_persona_featured_status(
|
||||
persona_id: int,
|
||||
is_default_request: IsDefaultRequest,
|
||||
is_featured_request: IsFeaturedRequest,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
update_persona_is_default(
|
||||
update_persona_featured(
|
||||
persona_id=persona_id,
|
||||
is_default=is_default_request.is_default_persona,
|
||||
featured=is_featured_request.featured,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to update persona default status")
|
||||
logger.exception("Failed to update persona featured status")
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import HierarchyNodeType
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import HierarchyNode
|
||||
@@ -108,11 +107,7 @@ class PersonaUpsertRequest(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
document_set_ids: list[int]
|
||||
num_chunks: float
|
||||
is_public: bool
|
||||
recency_bias: RecencyBiasSetting
|
||||
llm_filter_extraction: bool
|
||||
llm_relevance_filter: bool
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
starter_messages: list[StarterMessage] | None = None
|
||||
@@ -128,7 +123,7 @@ class PersonaUpsertRequest(BaseModel):
|
||||
)
|
||||
search_start_date: datetime | None = None
|
||||
label_ids: list[int] | None = None
|
||||
is_default_persona: bool = False
|
||||
featured: bool = False
|
||||
display_priority: int | None = None
|
||||
# Accept string UUIDs from frontend
|
||||
user_file_ids: list[str] | None = None
|
||||
@@ -155,9 +150,6 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
tools: list[ToolSnapshot]
|
||||
starter_messages: list[StarterMessage] | None
|
||||
|
||||
llm_relevance_filter: bool
|
||||
llm_filter_extraction: bool
|
||||
|
||||
# only show document sets in the UI that the assistant has access to
|
||||
document_sets: list[DocumentSetSummary]
|
||||
# Counts for knowledge sources (used to determine if search tool should be enabled)
|
||||
@@ -175,7 +167,7 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
is_public: bool
|
||||
is_visible: bool
|
||||
display_priority: int | None
|
||||
is_default_persona: bool
|
||||
featured: bool
|
||||
builtin_persona: bool
|
||||
|
||||
# Used for filtering
|
||||
@@ -214,8 +206,6 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
if should_expose_tool_to_fe(tool)
|
||||
],
|
||||
starter_messages=persona.starter_messages,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
document_sets=[
|
||||
DocumentSetSummary.from_model(document_set)
|
||||
for document_set in persona.document_sets
|
||||
@@ -230,7 +220,7 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
is_public=persona.is_public,
|
||||
is_visible=persona.is_visible,
|
||||
display_priority=persona.display_priority,
|
||||
is_default_persona=persona.is_default_persona,
|
||||
featured=persona.featured,
|
||||
builtin_persona=persona.builtin_persona,
|
||||
labels=[PersonaLabelSnapshot.from_model(label) for label in persona.labels],
|
||||
owner=(
|
||||
@@ -252,11 +242,9 @@ class PersonaSnapshot(BaseModel):
|
||||
# Return string UUIDs to frontend for consistency
|
||||
user_file_ids: list[str]
|
||||
display_priority: int | None
|
||||
is_default_persona: bool
|
||||
featured: bool
|
||||
builtin_persona: bool
|
||||
starter_messages: list[StarterMessage] | None
|
||||
llm_relevance_filter: bool
|
||||
llm_filter_extraction: bool
|
||||
tools: list[ToolSnapshot]
|
||||
labels: list["PersonaLabelSnapshot"]
|
||||
owner: MinimalUserSnapshot | None
|
||||
@@ -265,7 +253,6 @@ class PersonaSnapshot(BaseModel):
|
||||
document_sets: list[DocumentSetSummary]
|
||||
llm_model_provider_override: str | None
|
||||
llm_model_version_override: str | None
|
||||
num_chunks: float | None
|
||||
# Hierarchy nodes attached for scoped search
|
||||
hierarchy_nodes: list[HierarchyNodeSnapshot] = Field(default_factory=list)
|
||||
# Individual documents attached for scoped search
|
||||
@@ -289,11 +276,9 @@ class PersonaSnapshot(BaseModel):
|
||||
icon_name=persona.icon_name,
|
||||
user_file_ids=[str(file.id) for file in persona.user_files],
|
||||
display_priority=persona.display_priority,
|
||||
is_default_persona=persona.is_default_persona,
|
||||
featured=persona.featured,
|
||||
builtin_persona=persona.builtin_persona,
|
||||
starter_messages=persona.starter_messages,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
tools=[
|
||||
ToolSnapshot.from_model(tool)
|
||||
for tool in persona.tools
|
||||
@@ -324,7 +309,6 @@ class PersonaSnapshot(BaseModel):
|
||||
],
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
num_chunks=persona.num_chunks,
|
||||
system_prompt=persona.system_prompt,
|
||||
replace_base_system_prompt=persona.replace_base_system_prompt,
|
||||
task_prompt=persona.task_prompt,
|
||||
@@ -332,12 +316,10 @@ class PersonaSnapshot(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
# Model with full context on perona's internal settings
|
||||
# Model with full context on persona's internal settings
|
||||
# This is used for flows which need to know all settings
|
||||
class FullPersonaSnapshot(PersonaSnapshot):
|
||||
search_start_date: datetime | None = None
|
||||
llm_relevance_filter: bool = False
|
||||
llm_filter_extraction: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
@@ -360,7 +342,7 @@ class FullPersonaSnapshot(PersonaSnapshot):
|
||||
icon_name=persona.icon_name,
|
||||
user_file_ids=[str(file.id) for file in persona.user_files],
|
||||
display_priority=persona.display_priority,
|
||||
is_default_persona=persona.is_default_persona,
|
||||
featured=persona.featured,
|
||||
builtin_persona=persona.builtin_persona,
|
||||
starter_messages=persona.starter_messages,
|
||||
users=[
|
||||
@@ -391,10 +373,7 @@ class FullPersonaSnapshot(PersonaSnapshot):
|
||||
DocumentSetSummary.from_model(document_set_model)
|
||||
for document_set_model in persona.document_sets
|
||||
],
|
||||
num_chunks=persona.num_chunks,
|
||||
search_start_date=persona.search_start_date,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
system_prompt=persona.system_prompt,
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import BackgroundTasks
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import Form
|
||||
@@ -12,13 +13,7 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
enqueue_user_file_project_sync_task,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
get_user_file_project_sync_queue_depth,
|
||||
)
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -34,7 +29,6 @@ from onyx.db.models import UserProject
|
||||
from onyx.db.persona import get_personas_by_ids
|
||||
from onyx.db.projects import get_project_token_count
|
||||
from onyx.db.projects import upload_files_to_user_files_with_indexing
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.features.projects.models import CategorizedFilesSnapshot
|
||||
from onyx.server.features.projects.models import ChatSessionRequest
|
||||
from onyx.server.features.projects.models import TokenCountResponse
|
||||
@@ -55,7 +49,27 @@ class UserFileDeleteResult(BaseModel):
|
||||
assistant_names: list[str] = []
|
||||
|
||||
|
||||
def _trigger_user_file_project_sync(user_file_id: UUID, tenant_id: str) -> None:
|
||||
def _trigger_user_file_project_sync(
|
||||
user_file_id: UUID,
|
||||
tenant_id: str,
|
||||
background_tasks: BackgroundTasks | None = None,
|
||||
) -> None:
|
||||
if DISABLE_VECTOR_DB and background_tasks is not None:
|
||||
from onyx.background.task_utils import drain_project_sync_loop
|
||||
|
||||
background_tasks.add_task(drain_project_sync_loop, tenant_id)
|
||||
logger.info(f"Queued in-process project sync for user_file_id={user_file_id}")
|
||||
return
|
||||
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
enqueue_user_file_project_sync_task,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
get_user_file_project_sync_queue_depth,
|
||||
)
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
|
||||
queue_depth = get_user_file_project_sync_queue_depth(client_app)
|
||||
if queue_depth > USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH:
|
||||
logger.warning(
|
||||
@@ -111,6 +125,7 @@ def create_project(
|
||||
|
||||
@router.post("/file/upload", tags=PUBLIC_API_TAGS)
|
||||
def upload_user_files(
|
||||
bg_tasks: BackgroundTasks,
|
||||
files: list[UploadFile] = File(...),
|
||||
project_id: int | None = Form(None),
|
||||
temp_id_map: str | None = Form(None), # JSON string mapping hashed key -> temp_id
|
||||
@@ -137,12 +152,12 @@ def upload_user_files(
|
||||
user=user,
|
||||
temp_id_map=parsed_temp_id_map,
|
||||
db_session=db_session,
|
||||
background_tasks=bg_tasks if DISABLE_VECTOR_DB else None,
|
||||
)
|
||||
|
||||
return CategorizedFilesSnapshot.from_result(categorized_files_result)
|
||||
|
||||
except Exception as e:
|
||||
# Log error with type, message, and stack for easier debugging
|
||||
logger.exception(f"Error uploading files - {type(e).__name__}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
@@ -192,6 +207,7 @@ def get_files_in_project(
|
||||
def unlink_user_file_from_project(
|
||||
project_id: int,
|
||||
file_id: UUID,
|
||||
bg_tasks: BackgroundTasks,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
@@ -208,7 +224,6 @@ def unlink_user_file_from_project(
|
||||
if project is None:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
user_id = user.id
|
||||
user_file = (
|
||||
db_session.query(UserFile)
|
||||
.filter(UserFile.id == file_id, UserFile.user_id == user_id)
|
||||
@@ -224,7 +239,7 @@ def unlink_user_file_from_project(
|
||||
db_session.commit()
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
_trigger_user_file_project_sync(user_file.id, tenant_id)
|
||||
_trigger_user_file_project_sync(user_file.id, tenant_id, bg_tasks)
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
@@ -237,6 +252,7 @@ def unlink_user_file_from_project(
|
||||
def link_user_file_to_project(
|
||||
project_id: int,
|
||||
file_id: UUID,
|
||||
bg_tasks: BackgroundTasks,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserFileSnapshot:
|
||||
@@ -268,7 +284,7 @@ def link_user_file_to_project(
|
||||
db_session.commit()
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
_trigger_user_file_project_sync(user_file.id, tenant_id)
|
||||
_trigger_user_file_project_sync(user_file.id, tenant_id, bg_tasks)
|
||||
|
||||
return UserFileSnapshot.from_model(user_file)
|
||||
|
||||
@@ -335,7 +351,7 @@ def upsert_project_instructions(
|
||||
class ProjectPayload(BaseModel):
|
||||
project: UserProjectSnapshot
|
||||
files: list[UserFileSnapshot] | None = None
|
||||
persona_id_to_is_default: dict[int, bool] | None = None
|
||||
persona_id_to_featured: dict[int, bool] | None = None
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -354,13 +370,11 @@ def get_project_details(
|
||||
if session.persona_id is not None
|
||||
]
|
||||
personas = get_personas_by_ids(persona_ids, db_session)
|
||||
persona_id_to_is_default = {
|
||||
persona.id: persona.is_default_persona for persona in personas
|
||||
}
|
||||
persona_id_to_featured = {persona.id: persona.featured for persona in personas}
|
||||
return ProjectPayload(
|
||||
project=project,
|
||||
files=files,
|
||||
persona_id_to_is_default=persona_id_to_is_default,
|
||||
persona_id_to_featured=persona_id_to_featured,
|
||||
)
|
||||
|
||||
|
||||
@@ -426,6 +440,7 @@ def delete_project(
|
||||
@router.delete("/file/{file_id}", tags=PUBLIC_API_TAGS)
|
||||
def delete_user_file(
|
||||
file_id: UUID,
|
||||
bg_tasks: BackgroundTasks,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserFileDeleteResult:
|
||||
@@ -458,15 +473,25 @@ def delete_user_file(
|
||||
db_session.commit()
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file.id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_DELETE,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered delete for user_file_id={user_file.id} with task_id={task.id}"
|
||||
)
|
||||
if DISABLE_VECTOR_DB:
|
||||
from onyx.background.task_utils import drain_delete_loop
|
||||
|
||||
bg_tasks.add_task(drain_delete_loop, tenant_id)
|
||||
logger.info(f"Queued in-process delete for user_file_id={user_file.id}")
|
||||
else:
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file.id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_DELETE,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered delete for user_file_id={user_file.id} "
|
||||
f"with task_id={task.id}"
|
||||
)
|
||||
|
||||
return UserFileDeleteResult(
|
||||
has_associations=False, project_names=[], assistant_names=[]
|
||||
)
|
||||
|
||||
@@ -8,10 +8,10 @@ import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx import __version__
|
||||
from onyx.cache.factory import get_shared_cache_backend
|
||||
from onyx.configs.app_configs import INSTANCE_TYPE
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.release_notes import create_release_notifications_for_versions
|
||||
from onyx.redis.redis_pool import get_shared_redis_client
|
||||
from onyx.server.features.release_notes.constants import AUTO_REFRESH_THRESHOLD_SECONDS
|
||||
from onyx.server.features.release_notes.constants import FETCH_TIMEOUT
|
||||
from onyx.server.features.release_notes.constants import GITHUB_CHANGELOG_RAW_URL
|
||||
@@ -113,60 +113,46 @@ def parse_mdx_to_release_note_entries(mdx_content: str) -> list[ReleaseNoteEntry
|
||||
|
||||
|
||||
def get_cached_etag() -> str | None:
|
||||
"""Get the cached GitHub ETag from Redis."""
|
||||
redis_client = get_shared_redis_client()
|
||||
cache = get_shared_cache_backend()
|
||||
try:
|
||||
etag = redis_client.get(REDIS_KEY_ETAG)
|
||||
etag = cache.get(REDIS_KEY_ETAG)
|
||||
if etag:
|
||||
return etag.decode("utf-8") if isinstance(etag, bytes) else str(etag)
|
||||
return etag.decode("utf-8")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get cached etag from Redis: {e}")
|
||||
logger.error(f"Failed to get cached etag: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_last_fetch_time() -> datetime | None:
|
||||
"""Get the last fetch timestamp from Redis."""
|
||||
redis_client = get_shared_redis_client()
|
||||
cache = get_shared_cache_backend()
|
||||
try:
|
||||
fetched_at_str = redis_client.get(REDIS_KEY_FETCHED_AT)
|
||||
if not fetched_at_str:
|
||||
raw = cache.get(REDIS_KEY_FETCHED_AT)
|
||||
if not raw:
|
||||
return None
|
||||
|
||||
decoded = (
|
||||
fetched_at_str.decode("utf-8")
|
||||
if isinstance(fetched_at_str, bytes)
|
||||
else str(fetched_at_str)
|
||||
)
|
||||
|
||||
last_fetch = datetime.fromisoformat(decoded)
|
||||
|
||||
# Defensively ensure timezone awareness
|
||||
# fromisoformat() returns naive datetime if input lacks timezone
|
||||
last_fetch = datetime.fromisoformat(raw.decode("utf-8"))
|
||||
if last_fetch.tzinfo is None:
|
||||
# Assume UTC for naive datetimes
|
||||
last_fetch = last_fetch.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
# Convert to UTC if timezone-aware
|
||||
last_fetch = last_fetch.astimezone(timezone.utc)
|
||||
|
||||
return last_fetch
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get last fetch time from Redis: {e}")
|
||||
logger.error(f"Failed to get last fetch time from cache: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def save_fetch_metadata(etag: str | None) -> None:
|
||||
"""Save ETag and fetch timestamp to Redis."""
|
||||
redis_client = get_shared_redis_client()
|
||||
cache = get_shared_cache_backend()
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
try:
|
||||
redis_client.set(REDIS_KEY_FETCHED_AT, now.isoformat(), ex=REDIS_CACHE_TTL)
|
||||
cache.set(REDIS_KEY_FETCHED_AT, now.isoformat(), ex=REDIS_CACHE_TTL)
|
||||
if etag:
|
||||
redis_client.set(REDIS_KEY_ETAG, etag, ex=REDIS_CACHE_TTL)
|
||||
cache.set(REDIS_KEY_ETAG, etag, ex=REDIS_CACHE_TTL)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save fetch metadata to Redis: {e}")
|
||||
logger.error(f"Failed to save fetch metadata to cache: {e}")
|
||||
|
||||
|
||||
def is_cache_stale() -> bool:
|
||||
@@ -196,11 +182,10 @@ def ensure_release_notes_fresh_and_notify(db_session: Session) -> None:
|
||||
if not is_cache_stale():
|
||||
return
|
||||
|
||||
# Acquire lock to prevent concurrent fetches
|
||||
redis_client = get_shared_redis_client()
|
||||
lock = redis_client.lock(
|
||||
cache = get_shared_cache_backend()
|
||||
lock = cache.lock(
|
||||
OnyxRedisLocks.RELEASE_NOTES_FETCH_LOCK,
|
||||
timeout=90, # 90 second timeout for the lock
|
||||
timeout=90,
|
||||
)
|
||||
|
||||
# Non-blocking acquire - if we can't get the lock, another request is handling it
|
||||
|
||||
@@ -5,7 +5,6 @@ from sqlalchemy.orm import Session
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
|
||||
from onyx.configs.kg_configs import KG_BETA_ASSISTANT_DESCRIPTION
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.entities import get_entity_stats_by_grounded_source_name
|
||||
from onyx.db.entity_type import get_configured_entity_types
|
||||
@@ -134,11 +133,7 @@ def enable_or_disable_kg(
|
||||
system_prompt=KG_BETA_ASSISTANT_SYSTEM_PROMPT,
|
||||
task_prompt=KG_BETA_ASSISTANT_TASK_PROMPT,
|
||||
datetime_aware=False,
|
||||
num_chunks=25,
|
||||
llm_relevance_filter=False,
|
||||
is_public=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
document_set_ids=[],
|
||||
tool_ids=[search_tool.id, kg_tool.id],
|
||||
llm_model_provider_override=None,
|
||||
@@ -147,7 +142,7 @@ def enable_or_disable_kg(
|
||||
users=[user.id],
|
||||
groups=[],
|
||||
label_ids=[],
|
||||
is_default_persona=False,
|
||||
featured=False,
|
||||
display_priority=0,
|
||||
user_file_ids=[],
|
||||
)
|
||||
|
||||
@@ -479,10 +479,20 @@ def put_llm_provider(
|
||||
@admin_router.delete("/provider/{provider_id}")
|
||||
def delete_llm_provider(
|
||||
provider_id: int,
|
||||
force: bool = Query(False),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
if not force:
|
||||
model = fetch_default_llm_model(db_session)
|
||||
|
||||
if model and model.llm_provider_id == provider_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot delete the default LLM provider",
|
||||
)
|
||||
|
||||
remove_llm_provider(db_session, provider_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@@ -603,9 +613,9 @@ def list_llm_provider_basics(
|
||||
for provider in all_providers:
|
||||
# Use centralized access control logic with persona=None since we're
|
||||
# listing providers without a specific persona context. This correctly:
|
||||
# - Includes all public providers
|
||||
# - Includes public providers WITHOUT persona restrictions
|
||||
# - Includes providers user can access via group membership
|
||||
# - Excludes persona-only restricted providers (requires specific persona)
|
||||
# - Excludes providers with persona restrictions (requires specific persona)
|
||||
# - Excludes non-public providers with no restrictions (admin-only)
|
||||
if can_user_access_llm_provider(
|
||||
provider, user_group_ids, persona=None, is_admin=is_admin
|
||||
@@ -638,7 +648,7 @@ def get_valid_model_names_for_persona(
|
||||
|
||||
Returns a list of model names (e.g., ["gpt-4o", "claude-3-5-sonnet"]) that are
|
||||
available to the user when using this persona, respecting all RBAC restrictions.
|
||||
Public providers are always included.
|
||||
Public providers are included unless they have persona restrictions that exclude this persona.
|
||||
"""
|
||||
persona = fetch_persona_with_groups(db_session, persona_id)
|
||||
if not persona:
|
||||
@@ -652,7 +662,7 @@ def get_valid_model_names_for_persona(
|
||||
|
||||
valid_models = []
|
||||
for llm_provider_model in all_providers:
|
||||
# Public providers always included, restricted checked via RBAC
|
||||
# Check access with persona context — respects all RBAC restrictions
|
||||
if can_user_access_llm_provider(
|
||||
llm_provider_model, user_group_ids, persona, is_admin=is_admin
|
||||
):
|
||||
@@ -673,7 +683,7 @@ def list_llm_providers_for_persona(
|
||||
"""Get LLM providers for a specific persona.
|
||||
|
||||
Returns providers that the user can access when using this persona:
|
||||
- All public providers (is_public=True) - ALWAYS included
|
||||
- Public providers (respecting persona restrictions if set)
|
||||
- Restricted providers user can access via group/persona restrictions
|
||||
|
||||
This endpoint is used for background fetching of restricted providers
|
||||
@@ -702,7 +712,7 @@ def list_llm_providers_for_persona(
|
||||
llm_provider_list: list[LLMProviderDescriptor] = []
|
||||
|
||||
for llm_provider_model in all_providers:
|
||||
# Use simplified access check - public providers always included
|
||||
# Check access with persona context — respects persona restrictions
|
||||
if can_user_access_llm_provider(
|
||||
llm_provider_model, user_group_ids, persona, is_admin=is_admin
|
||||
):
|
||||
|
||||
@@ -168,18 +168,6 @@ class ModelConfigurationUpsertRequest(BaseModel):
|
||||
supports_image_input: bool | None = None
|
||||
display_name: str | None = None # For dynamic providers, from source API
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls, model_configuration_model: "ModelConfigurationModel"
|
||||
) -> "ModelConfigurationUpsertRequest":
|
||||
return cls(
|
||||
name=model_configuration_model.name,
|
||||
is_visible=model_configuration_model.is_visible,
|
||||
max_input_tokens=model_configuration_model.max_input_tokens,
|
||||
supports_image_input=model_configuration_model.supports_image_input,
|
||||
display_name=model_configuration_model.display_name,
|
||||
)
|
||||
|
||||
|
||||
class ModelConfigurationView(BaseModel):
|
||||
name: str
|
||||
|
||||
@@ -198,7 +198,6 @@ def patch_slack_channel_config(
|
||||
channel_name=channel_config["channel_name"],
|
||||
document_set_ids=slack_channel_config_creation_request.document_sets,
|
||||
existing_persona_id=existing_persona_id,
|
||||
enable_auto_filters=slack_channel_config_creation_request.enable_auto_filters,
|
||||
).id
|
||||
|
||||
slack_channel_config_model = update_slack_channel_config(
|
||||
|
||||
@@ -13,13 +13,13 @@ from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from redis.client import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.api_key import get_hashed_api_key_from_request
|
||||
from onyx.auth.pat import get_hashed_pat_from_request
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.chat.chat_processing_checker import is_chat_session_processing
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_utils import convert_chat_history_basic
|
||||
@@ -67,7 +67,6 @@ from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.llm.factory import get_llm_token_counter
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.secondary_llm_flows.chat_session_naming import generate_chat_session_name
|
||||
from onyx.server.api_key_usage import check_api_key_usage
|
||||
from onyx.server.query_and_chat.models import ChatFeedbackRequest
|
||||
@@ -152,10 +151,20 @@ def get_user_chat_sessions(
|
||||
project_id: int | None = None,
|
||||
only_non_project_chats: bool = True,
|
||||
include_failed_chats: bool = False,
|
||||
page_size: int = Query(default=50, ge=1, le=100),
|
||||
before: str | None = Query(default=None),
|
||||
) -> ChatSessionsResponse:
|
||||
user_id = user.id
|
||||
|
||||
try:
|
||||
before_dt = (
|
||||
datetime.datetime.fromisoformat(before) if before is not None else None
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=422, detail="Invalid 'before' timestamp format")
|
||||
|
||||
try:
|
||||
# Fetch one extra to determine if there are more results
|
||||
chat_sessions = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
@@ -163,11 +172,16 @@ def get_user_chat_sessions(
|
||||
project_id=project_id,
|
||||
only_non_project_chats=only_non_project_chats,
|
||||
include_failed_chats=include_failed_chats,
|
||||
limit=page_size + 1,
|
||||
before=before_dt,
|
||||
)
|
||||
|
||||
except ValueError:
|
||||
raise ValueError("Chat session does not exist or has been deleted")
|
||||
|
||||
has_more = len(chat_sessions) > page_size
|
||||
chat_sessions = chat_sessions[:page_size]
|
||||
|
||||
return ChatSessionsResponse(
|
||||
sessions=[
|
||||
ChatSessionDetails(
|
||||
@@ -181,7 +195,8 @@ def get_user_chat_sessions(
|
||||
current_temperature_override=chat.temperature_override,
|
||||
)
|
||||
for chat in chat_sessions
|
||||
]
|
||||
],
|
||||
has_more=has_more,
|
||||
)
|
||||
|
||||
|
||||
@@ -314,7 +329,7 @@ def get_chat_session(
|
||||
]
|
||||
|
||||
try:
|
||||
is_processing = is_chat_session_processing(session_id, get_redis_client())
|
||||
is_processing = is_chat_session_processing(session_id, get_cache_backend())
|
||||
# Edit the last message to indicate loading (Overriding default message value)
|
||||
if is_processing and chat_message_details:
|
||||
last_msg = chat_message_details[-1]
|
||||
@@ -911,11 +926,10 @@ async def search_chats(
|
||||
def stop_chat_session(
|
||||
chat_session_id: UUID,
|
||||
user: User = Depends(current_user), # noqa: ARG001
|
||||
redis_client: Redis = Depends(get_redis_client),
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Stop a chat session by setting a stop signal in Redis.
|
||||
Stop a chat session by setting a stop signal.
|
||||
This endpoint is called by the frontend when the user clicks the stop button.
|
||||
"""
|
||||
set_fence(chat_session_id, redis_client, True)
|
||||
set_fence(chat_session_id, get_cache_backend(), True)
|
||||
return {"message": "Chat session stopped"}
|
||||
|
||||
@@ -192,6 +192,7 @@ class ChatSessionDetails(BaseModel):
|
||||
|
||||
class ChatSessionsResponse(BaseModel):
|
||||
sessions: list[ChatSessionDetails]
|
||||
has_more: bool = False
|
||||
|
||||
|
||||
class ChatMessageDetail(BaseModel):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.app_configs import DISABLE_USER_KNOWLEDGE
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
@@ -6,11 +7,8 @@ from onyx.configs.constants import KV_SETTINGS_KEY
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -33,30 +31,22 @@ def load_settings() -> Settings:
|
||||
logger.error(f"Error loading settings from KV store: {str(e)}")
|
||||
settings = Settings()
|
||||
|
||||
tenant_id = get_current_tenant_id() if MULTI_TENANT else None
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
cache = get_cache_backend()
|
||||
|
||||
try:
|
||||
value = redis_client.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
|
||||
value = cache.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
|
||||
if value is not None:
|
||||
assert isinstance(value, bytes)
|
||||
anonymous_user_enabled = int(value.decode("utf-8")) == 1
|
||||
else:
|
||||
# Default to False
|
||||
anonymous_user_enabled = False
|
||||
# Optionally store the default back to Redis
|
||||
redis_client.set(
|
||||
OnyxRedisLocks.ANONYMOUS_USER_ENABLED, "0", ex=SETTINGS_TTL
|
||||
)
|
||||
cache.set(OnyxRedisLocks.ANONYMOUS_USER_ENABLED, "0", ex=SETTINGS_TTL)
|
||||
except Exception as e:
|
||||
# Log the error and reset to default
|
||||
logger.error(f"Error loading anonymous user setting from Redis: {str(e)}")
|
||||
logger.error(f"Error loading anonymous user setting from cache: {str(e)}")
|
||||
anonymous_user_enabled = False
|
||||
|
||||
settings.anonymous_user_enabled = anonymous_user_enabled
|
||||
settings.query_history_type = ONYX_QUERY_HISTORY_TYPE
|
||||
|
||||
# Override user knowledge setting if disabled via environment variable
|
||||
if DISABLE_USER_KNOWLEDGE:
|
||||
settings.user_knowledge_enabled = False
|
||||
|
||||
@@ -66,11 +56,10 @@ def load_settings() -> Settings:
|
||||
|
||||
|
||||
def store_settings(settings: Settings) -> None:
|
||||
tenant_id = get_current_tenant_id() if MULTI_TENANT else None
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
cache = get_cache_backend()
|
||||
|
||||
if settings.anonymous_user_enabled is not None:
|
||||
redis_client.set(
|
||||
cache.set(
|
||||
OnyxRedisLocks.ANONYMOUS_USER_ENABLED,
|
||||
"1" if settings.anonymous_user_enabled else "0",
|
||||
ex=SETTINGS_TTL,
|
||||
|
||||
@@ -8,37 +8,3 @@ dependencies = [
|
||||
|
||||
[tool.uv.sources]
|
||||
onyx = { workspace = true }
|
||||
|
||||
[tool.mypy]
|
||||
plugins = "sqlalchemy.ext.mypy.plugin"
|
||||
mypy_path = "backend"
|
||||
explicit_package_bases = true
|
||||
disallow_untyped_defs = true
|
||||
warn_unused_ignores = true
|
||||
enable_error_code = ["possibly-undefined"]
|
||||
strict_equality = true
|
||||
# Patterns match paths whether mypy is run from backend/ (CI) or repo root (e.g. VS Code extension with target ./backend)
|
||||
exclude = [
|
||||
"(?:^|/)generated/",
|
||||
"(?:^|/)\\.venv/",
|
||||
"(?:^|/)onyx/server/features/build/sandbox/kubernetes/docker/skills/",
|
||||
"(?:^|/)onyx/server/features/build/sandbox/kubernetes/docker/templates/",
|
||||
]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "alembic.versions.*"
|
||||
disable_error_code = ["var-annotated"]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "alembic_tenants.versions.*"
|
||||
disable_error_code = ["var-annotated"]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "generated.*"
|
||||
follow_imports = "silent"
|
||||
ignore_errors = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "transformers.*"
|
||||
follow_imports = "skip"
|
||||
ignore_errors = true
|
||||
|
||||
@@ -528,7 +528,7 @@ lxml==5.3.0
|
||||
# unstructured
|
||||
# xmlsec
|
||||
# zeep
|
||||
lxml-html-clean==0.4.3
|
||||
lxml-html-clean==0.4.4
|
||||
# via lxml
|
||||
magika==0.6.3
|
||||
# via markitdown
|
||||
@@ -809,7 +809,7 @@ pypandoc-binary==1.16.2
|
||||
# via onyx
|
||||
pyparsing==3.2.5
|
||||
# via httplib2
|
||||
pypdf==6.7.3
|
||||
pypdf==6.7.5
|
||||
# via
|
||||
# onyx
|
||||
# unstructured-client
|
||||
|
||||
@@ -9,7 +9,6 @@ from onyx.chat.models import AnswerStreamPart
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.process_message import handle_stream_message_objects
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.models import RecencyBiasSetting
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
@@ -74,10 +73,6 @@ def test_stream_chat_message_objects_without_web_search(
|
||||
user=None, # System persona
|
||||
name=f"Test Persona {uuid.uuid4()}",
|
||||
description="Test persona with no tools for web search test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=True,
|
||||
recency_bias=RecencyBiasSetting.BASE_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
|
||||
@@ -0,0 +1,257 @@
|
||||
"""External dependency unit tests for periodic task claiming.
|
||||
|
||||
Tests ``_try_claim_task`` and ``_try_run_periodic_task`` against real
|
||||
PostgreSQL, verifying happy-path behavior and concurrent-access safety.
|
||||
|
||||
The claim mechanism uses a transaction-scoped advisory lock + a KVStore
|
||||
timestamp for cross-instance dedup. The DB session is released before
|
||||
the task runs, so long-running tasks don't hold connections.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.background.periodic_poller import _PeriodicTaskDef
|
||||
from onyx.background.periodic_poller import _try_claim_task
|
||||
from onyx.background.periodic_poller import _try_run_periodic_task
|
||||
from onyx.background.periodic_poller import PERIODIC_TASK_KV_PREFIX
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.models import KVStore
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
_TEST_LOCK_BASE = 90_000
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def _init_engine() -> None:
|
||||
SqlEngine.init_engine(pool_size=10, max_overflow=5)
|
||||
|
||||
|
||||
def _make_task(
|
||||
*,
|
||||
name: str | None = None,
|
||||
interval: float = 3600,
|
||||
lock_id: int | None = None,
|
||||
run_fn: MagicMock | None = None,
|
||||
) -> _PeriodicTaskDef:
|
||||
return _PeriodicTaskDef(
|
||||
name=name if name is not None else f"test-{uuid4().hex[:8]}",
|
||||
interval_seconds=interval,
|
||||
lock_id=lock_id if lock_id is not None else _TEST_LOCK_BASE,
|
||||
run_fn=run_fn if run_fn is not None else MagicMock(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _cleanup_kv(
|
||||
tenant_context: None, # noqa: ARG001
|
||||
) -> Generator[None, None, None]:
|
||||
yield
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_session.query(KVStore).filter(
|
||||
KVStore.key.like(f"{PERIODIC_TASK_KV_PREFIX}test-%")
|
||||
).delete(synchronize_session=False)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Happy-path: _try_claim_task
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClaimHappyPath:
|
||||
def test_first_claim_succeeds(self) -> None:
|
||||
assert _try_claim_task(_make_task()) is True
|
||||
|
||||
def test_first_claim_creates_kv_row(self) -> None:
|
||||
task = _make_task()
|
||||
_try_claim_task(task)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
row = (
|
||||
db_session.query(KVStore)
|
||||
.filter_by(key=PERIODIC_TASK_KV_PREFIX + task.name)
|
||||
.first()
|
||||
)
|
||||
assert row is not None
|
||||
assert row.value is not None
|
||||
|
||||
def test_second_claim_within_interval_fails(self) -> None:
|
||||
task = _make_task(interval=3600)
|
||||
assert _try_claim_task(task) is True
|
||||
assert _try_claim_task(task) is False
|
||||
|
||||
def test_claim_after_interval_succeeds(self) -> None:
|
||||
task = _make_task(interval=1)
|
||||
assert _try_claim_task(task) is True
|
||||
|
||||
kv_key = PERIODIC_TASK_KV_PREFIX + task.name
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
row = db_session.query(KVStore).filter_by(key=kv_key).first()
|
||||
assert row is not None
|
||||
row.value = (datetime.now(timezone.utc) - timedelta(seconds=10)).isoformat()
|
||||
db_session.commit()
|
||||
|
||||
assert _try_claim_task(task) is True
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Happy-path: _try_run_periodic_task
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunHappyPath:
|
||||
def test_runs_task_and_updates_last_run_at(self) -> None:
|
||||
mock_fn = MagicMock()
|
||||
task = _make_task(run_fn=mock_fn)
|
||||
|
||||
_try_run_periodic_task(task)
|
||||
|
||||
mock_fn.assert_called_once()
|
||||
assert task.last_run_at > 0
|
||||
|
||||
def test_skips_when_in_memory_interval_not_elapsed(self) -> None:
|
||||
mock_fn = MagicMock()
|
||||
task = _make_task(run_fn=mock_fn, interval=3600)
|
||||
task.last_run_at = time.monotonic()
|
||||
|
||||
_try_run_periodic_task(task)
|
||||
|
||||
mock_fn.assert_not_called()
|
||||
|
||||
def test_skips_when_db_claim_blocked(self) -> None:
|
||||
name = f"test-{uuid4().hex[:8]}"
|
||||
lock_id = _TEST_LOCK_BASE + 10
|
||||
|
||||
_try_claim_task(_make_task(name=name, lock_id=lock_id, interval=3600))
|
||||
|
||||
mock_fn = MagicMock()
|
||||
task = _make_task(name=name, lock_id=lock_id, interval=3600, run_fn=mock_fn)
|
||||
_try_run_periodic_task(task)
|
||||
|
||||
mock_fn.assert_not_called()
|
||||
|
||||
def test_task_exception_does_not_propagate(self) -> None:
|
||||
task = _make_task(run_fn=MagicMock(side_effect=RuntimeError("boom")))
|
||||
_try_run_periodic_task(task)
|
||||
|
||||
def test_claim_committed_before_task_runs(self) -> None:
|
||||
"""The KV claim must be visible in the DB when run_fn executes."""
|
||||
task_name = f"test-order-{uuid4().hex[:8]}"
|
||||
kv_key = PERIODIC_TASK_KV_PREFIX + task_name
|
||||
claim_visible: list[bool] = []
|
||||
|
||||
def check_claim() -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
row = db_session.query(KVStore).filter_by(key=kv_key).first()
|
||||
claim_visible.append(row is not None and row.value is not None)
|
||||
|
||||
task = _PeriodicTaskDef(
|
||||
name=task_name,
|
||||
interval_seconds=3600,
|
||||
lock_id=_TEST_LOCK_BASE + 11,
|
||||
run_fn=check_claim,
|
||||
)
|
||||
|
||||
_try_run_periodic_task(task)
|
||||
|
||||
assert claim_visible == [True]
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Concurrency: only one claimer should win
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClaimConcurrency:
|
||||
def test_concurrent_claims_single_winner(self) -> None:
|
||||
"""Many threads claim the same task — exactly one should succeed."""
|
||||
num_threads = 20
|
||||
task_name = f"test-race-{uuid4().hex[:8]}"
|
||||
lock_id = _TEST_LOCK_BASE + 20
|
||||
|
||||
def claim() -> bool:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
|
||||
return _try_claim_task(
|
||||
_PeriodicTaskDef(
|
||||
name=task_name,
|
||||
interval_seconds=3600,
|
||||
lock_id=lock_id,
|
||||
run_fn=lambda: None,
|
||||
)
|
||||
)
|
||||
|
||||
results: list[bool] = []
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = [executor.submit(claim) for _ in range(num_threads)]
|
||||
for future in as_completed(futures):
|
||||
results.append(future.result())
|
||||
|
||||
winners = sum(1 for r in results if r)
|
||||
assert winners == 1, f"Expected 1 winner, got {winners}"
|
||||
|
||||
def test_concurrent_run_single_execution(self) -> None:
|
||||
"""Many threads run the same task — run_fn fires exactly once."""
|
||||
num_threads = 20
|
||||
task_name = f"test-run-race-{uuid4().hex[:8]}"
|
||||
lock_id = _TEST_LOCK_BASE + 21
|
||||
counter = MagicMock()
|
||||
|
||||
def run() -> None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
|
||||
_try_run_periodic_task(
|
||||
_PeriodicTaskDef(
|
||||
name=task_name,
|
||||
interval_seconds=3600,
|
||||
lock_id=lock_id,
|
||||
run_fn=counter,
|
||||
)
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = [executor.submit(run) for _ in range(num_threads)]
|
||||
for future in as_completed(futures):
|
||||
future.result()
|
||||
|
||||
assert (
|
||||
counter.call_count == 1
|
||||
), f"Expected run_fn called once, got {counter.call_count}"
|
||||
|
||||
def test_no_errors_under_contention(self) -> None:
|
||||
"""All threads complete without exceptions under high contention."""
|
||||
num_threads = 30
|
||||
task_name = f"test-err-{uuid4().hex[:8]}"
|
||||
lock_id = _TEST_LOCK_BASE + 22
|
||||
errors: list[Exception] = []
|
||||
|
||||
def claim() -> bool:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
|
||||
return _try_claim_task(
|
||||
_PeriodicTaskDef(
|
||||
name=task_name,
|
||||
interval_seconds=3600,
|
||||
lock_id=lock_id,
|
||||
run_fn=lambda: None,
|
||||
)
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = [executor.submit(claim) for _ in range(num_threads)]
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
future.result()
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
assert errors == [], f"Got {len(errors)} errors: {errors}"
|
||||
@@ -0,0 +1,352 @@
|
||||
"""External dependency unit tests for startup recovery (Step 10g).
|
||||
|
||||
Seeds ``UserFile`` records in stuck states (PROCESSING, DELETING,
|
||||
needs_project_sync) then calls ``recover_stuck_user_files`` and verifies
|
||||
the drain loops pick them up via ``FOR UPDATE SKIP LOCKED``.
|
||||
|
||||
Uses real PostgreSQL (via ``db_session`` / ``tenant_context`` fixtures).
|
||||
The per-file ``*_impl`` functions are mocked so no real file store or
|
||||
connector is needed — we only verify that recovery finds and dispatches
|
||||
the correct files.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.periodic_poller import recover_stuck_user_files
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import UserFile
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_IMPL_MODULE = "onyx.background.celery.tasks.user_file_processing.tasks"
|
||||
|
||||
|
||||
def _create_user_file(
|
||||
db_session: Session,
|
||||
user_id: object,
|
||||
*,
|
||||
status: UserFileStatus = UserFileStatus.PROCESSING,
|
||||
needs_project_sync: bool = False,
|
||||
needs_persona_sync: bool = False,
|
||||
) -> UserFile:
|
||||
uf = UserFile(
|
||||
id=uuid4(),
|
||||
user_id=user_id,
|
||||
file_id=f"test_file_{uuid4().hex[:8]}",
|
||||
name=f"test_{uuid4().hex[:8]}.txt",
|
||||
file_type="text/plain",
|
||||
status=status,
|
||||
needs_project_sync=needs_project_sync,
|
||||
needs_persona_sync=needs_persona_sync,
|
||||
)
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
db_session.refresh(uf)
|
||||
return uf
|
||||
|
||||
|
||||
def _fake_delete_impl(
|
||||
user_file_id: str, tenant_id: str, redis_locking: bool # noqa: ARG001
|
||||
) -> None:
|
||||
"""Mock side-effect: delete the row so the drain loop terminates."""
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as session:
|
||||
session.execute(sa.delete(UserFile).where(UserFile.id == UUID(user_file_id)))
|
||||
session.commit()
|
||||
|
||||
|
||||
def _fake_sync_impl(
|
||||
user_file_id: str, tenant_id: str, redis_locking: bool # noqa: ARG001
|
||||
) -> None:
|
||||
"""Mock side-effect: clear sync flags so the drain loop terminates."""
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as session:
|
||||
session.execute(
|
||||
sa.update(UserFile)
|
||||
.where(UserFile.id == UUID(user_file_id))
|
||||
.values(needs_project_sync=False, needs_persona_sync=False)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _cleanup_user_files(db_session: Session) -> Generator[list[UserFile], None, None]:
|
||||
"""Track created UserFile rows and delete them after each test."""
|
||||
created: list[UserFile] = []
|
||||
yield created
|
||||
for uf in created:
|
||||
existing = db_session.get(UserFile, uf.id)
|
||||
if existing:
|
||||
db_session.delete(existing)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRecoverProcessingFiles:
|
||||
"""Files in PROCESSING status are re-processed via the processing drain loop."""
|
||||
|
||||
def test_processing_files_recovered(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_proc")
|
||||
uf = _create_user_file(db_session, user.id, status=UserFileStatus.PROCESSING)
|
||||
_cleanup_user_files.append(uf)
|
||||
|
||||
mock_impl = MagicMock()
|
||||
with patch(f"{_IMPL_MODULE}.process_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert (
|
||||
str(uf.id) in called_ids
|
||||
), f"Expected file {uf.id} to be recovered but got: {called_ids}"
|
||||
|
||||
def test_completed_files_not_recovered(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_comp")
|
||||
uf = _create_user_file(db_session, user.id, status=UserFileStatus.COMPLETED)
|
||||
_cleanup_user_files.append(uf)
|
||||
|
||||
mock_impl = MagicMock()
|
||||
with patch(f"{_IMPL_MODULE}.process_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert (
|
||||
str(uf.id) not in called_ids
|
||||
), f"COMPLETED file {uf.id} should not have been recovered"
|
||||
|
||||
|
||||
class TestRecoverDeletingFiles:
|
||||
"""Files in DELETING status are recovered via the delete drain loop."""
|
||||
|
||||
def test_deleting_files_recovered(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_del")
|
||||
uf = _create_user_file(db_session, user.id, status=UserFileStatus.DELETING)
|
||||
# Row is deleted by _fake_delete_impl, so no cleanup needed.
|
||||
|
||||
mock_impl = MagicMock(side_effect=_fake_delete_impl)
|
||||
with patch(f"{_IMPL_MODULE}.delete_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert (
|
||||
str(uf.id) in called_ids
|
||||
), f"Expected file {uf.id} to be recovered for deletion but got: {called_ids}"
|
||||
|
||||
|
||||
class TestRecoverSyncFiles:
|
||||
"""Files needing project/persona sync are recovered via the sync drain loop."""
|
||||
|
||||
def test_needs_project_sync_recovered(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_sync")
|
||||
uf = _create_user_file(
|
||||
db_session,
|
||||
user.id,
|
||||
status=UserFileStatus.COMPLETED,
|
||||
needs_project_sync=True,
|
||||
)
|
||||
_cleanup_user_files.append(uf)
|
||||
|
||||
mock_impl = MagicMock(side_effect=_fake_sync_impl)
|
||||
with patch(f"{_IMPL_MODULE}.project_sync_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert (
|
||||
str(uf.id) in called_ids
|
||||
), f"Expected file {uf.id} to be recovered for sync but got: {called_ids}"
|
||||
|
||||
def test_needs_persona_sync_recovered(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_psync")
|
||||
uf = _create_user_file(
|
||||
db_session,
|
||||
user.id,
|
||||
status=UserFileStatus.COMPLETED,
|
||||
needs_persona_sync=True,
|
||||
)
|
||||
_cleanup_user_files.append(uf)
|
||||
|
||||
mock_impl = MagicMock(side_effect=_fake_sync_impl)
|
||||
with patch(f"{_IMPL_MODULE}.project_sync_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert (
|
||||
str(uf.id) in called_ids
|
||||
), f"Expected file {uf.id} to be recovered for persona sync but got: {called_ids}"
|
||||
|
||||
|
||||
class TestRecoveryMultipleFiles:
|
||||
"""Recovery processes all stuck files in one pass, not just the first."""
|
||||
|
||||
def test_multiple_processing_files(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_multi")
|
||||
files = []
|
||||
for _ in range(3):
|
||||
uf = _create_user_file(
|
||||
db_session, user.id, status=UserFileStatus.PROCESSING
|
||||
)
|
||||
_cleanup_user_files.append(uf)
|
||||
files.append(uf)
|
||||
|
||||
mock_impl = MagicMock()
|
||||
with patch(f"{_IMPL_MODULE}.process_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = {call.kwargs["user_file_id"] for call in mock_impl.call_args_list}
|
||||
expected_ids = {str(uf.id) for uf in files}
|
||||
assert expected_ids.issubset(called_ids), (
|
||||
f"Expected all {len(files)} files to be recovered. "
|
||||
f"Missing: {expected_ids - called_ids}"
|
||||
)
|
||||
|
||||
|
||||
class TestTransientFailures:
|
||||
"""Drain loops skip failed files, process the rest, and terminate."""
|
||||
|
||||
def test_processing_failure_skips_and_continues(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "fail_proc")
|
||||
uf_fail = _create_user_file(
|
||||
db_session, user.id, status=UserFileStatus.PROCESSING
|
||||
)
|
||||
uf_ok = _create_user_file(db_session, user.id, status=UserFileStatus.PROCESSING)
|
||||
_cleanup_user_files.extend([uf_fail, uf_ok])
|
||||
|
||||
fail_id = str(uf_fail.id)
|
||||
|
||||
def side_effect(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool # noqa: ARG001
|
||||
) -> None:
|
||||
if user_file_id == fail_id:
|
||||
raise RuntimeError("transient failure")
|
||||
|
||||
mock_impl = MagicMock(side_effect=side_effect)
|
||||
with patch(f"{_IMPL_MODULE}.process_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert fail_id in called_ids, "Failed file should have been attempted"
|
||||
assert str(uf_ok.id) in called_ids, "Healthy file should have been processed"
|
||||
assert called_ids.count(fail_id) == 1, "Failed file retried — infinite loop"
|
||||
assert called_ids.count(str(uf_ok.id)) == 1
|
||||
|
||||
def test_delete_failure_skips_and_continues(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "fail_del")
|
||||
uf_fail = _create_user_file(db_session, user.id, status=UserFileStatus.DELETING)
|
||||
uf_ok = _create_user_file(db_session, user.id, status=UserFileStatus.DELETING)
|
||||
_cleanup_user_files.append(uf_fail)
|
||||
|
||||
fail_id = str(uf_fail.id)
|
||||
|
||||
def side_effect(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool
|
||||
) -> None:
|
||||
if user_file_id == fail_id:
|
||||
raise RuntimeError("transient failure")
|
||||
_fake_delete_impl(user_file_id, tenant_id, redis_locking)
|
||||
|
||||
mock_impl = MagicMock(side_effect=side_effect)
|
||||
with patch(f"{_IMPL_MODULE}.delete_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert fail_id in called_ids, "Failed file should have been attempted"
|
||||
assert str(uf_ok.id) in called_ids, "Healthy file should have been deleted"
|
||||
assert called_ids.count(fail_id) == 1, "Failed file retried — infinite loop"
|
||||
assert called_ids.count(str(uf_ok.id)) == 1
|
||||
|
||||
def test_sync_failure_skips_and_continues(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "fail_sync")
|
||||
uf_fail = _create_user_file(
|
||||
db_session,
|
||||
user.id,
|
||||
status=UserFileStatus.COMPLETED,
|
||||
needs_project_sync=True,
|
||||
)
|
||||
uf_ok = _create_user_file(
|
||||
db_session,
|
||||
user.id,
|
||||
status=UserFileStatus.COMPLETED,
|
||||
needs_persona_sync=True,
|
||||
)
|
||||
_cleanup_user_files.extend([uf_fail, uf_ok])
|
||||
|
||||
fail_id = str(uf_fail.id)
|
||||
|
||||
def side_effect(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool
|
||||
) -> None:
|
||||
if user_file_id == fail_id:
|
||||
raise RuntimeError("transient failure")
|
||||
_fake_sync_impl(user_file_id, tenant_id, redis_locking)
|
||||
|
||||
mock_impl = MagicMock(side_effect=side_effect)
|
||||
with patch(f"{_IMPL_MODULE}.project_sync_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert fail_id in called_ids, "Failed file should have been attempted"
|
||||
assert str(uf_ok.id) in called_ids, "Healthy file should have been synced"
|
||||
assert called_ids.count(fail_id) == 1, "Failed file retried — infinite loop"
|
||||
assert called_ids.count(str(uf_ok.id)) == 1
|
||||
57
backend/tests/external_dependency_unit/cache/conftest.py
vendored
Normal file
57
backend/tests/external_dependency_unit/cache/conftest.py
vendored
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Fixtures for cache backend tests.
|
||||
|
||||
Requires a running PostgreSQL instance (and Redis for parity tests).
|
||||
Run with::
|
||||
|
||||
python -m dotenv -f .vscode/.env run -- pytest tests/external_dependency_unit/cache/
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.postgres_backend import PostgresCacheBackend
|
||||
from onyx.cache.redis_backend import RedisCacheBackend
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def _init_db() -> Generator[None, None, None]:
|
||||
"""Initialize DB engine. Assumes Postgres has migrations applied (e.g. via docker compose)."""
|
||||
SqlEngine.init_engine(pool_size=5, max_overflow=2)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _tenant_context() -> Generator[None, None, None]:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pg_cache() -> PostgresCacheBackend:
|
||||
return PostgresCacheBackend(TEST_TENANT_ID)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def redis_cache() -> RedisCacheBackend:
|
||||
from onyx.redis.redis_pool import redis_pool
|
||||
|
||||
return RedisCacheBackend(redis_pool.get_client(TEST_TENANT_ID))
|
||||
|
||||
|
||||
@pytest.fixture(params=["postgres", "redis"], ids=["postgres", "redis"])
|
||||
def cache(
|
||||
request: pytest.FixtureRequest,
|
||||
pg_cache: PostgresCacheBackend,
|
||||
redis_cache: RedisCacheBackend,
|
||||
) -> CacheBackend:
|
||||
if request.param == "postgres":
|
||||
return pg_cache
|
||||
return redis_cache
|
||||
100
backend/tests/external_dependency_unit/cache/test_cache_backend_parity.py
vendored
Normal file
100
backend/tests/external_dependency_unit/cache/test_cache_backend_parity.py
vendored
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Parameterized tests that run the same CacheBackend operations against
|
||||
both Redis and PostgreSQL, asserting identical return values.
|
||||
|
||||
Each test runs twice (once per backend) via the ``cache`` fixture defined
|
||||
in conftest.py.
|
||||
"""
|
||||
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.interface import TTL_KEY_NOT_FOUND
|
||||
from onyx.cache.interface import TTL_NO_EXPIRY
|
||||
|
||||
|
||||
def _key() -> str:
|
||||
return f"parity_{uuid4().hex[:12]}"
|
||||
|
||||
|
||||
class TestKVParity:
|
||||
def test_get_missing(self, cache: CacheBackend) -> None:
|
||||
assert cache.get(_key()) is None
|
||||
|
||||
def test_get_set(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"value")
|
||||
assert cache.get(k) == b"value"
|
||||
|
||||
def test_overwrite(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"a")
|
||||
cache.set(k, b"b")
|
||||
assert cache.get(k) == b"b"
|
||||
|
||||
def test_set_string(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, "hello")
|
||||
assert cache.get(k) == b"hello"
|
||||
|
||||
def test_set_int(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, 42)
|
||||
assert cache.get(k) == b"42"
|
||||
|
||||
def test_delete(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"x")
|
||||
cache.delete(k)
|
||||
assert cache.get(k) is None
|
||||
|
||||
def test_exists(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
assert not cache.exists(k)
|
||||
cache.set(k, b"x")
|
||||
assert cache.exists(k)
|
||||
|
||||
|
||||
class TestTTLParity:
|
||||
def test_ttl_missing(self, cache: CacheBackend) -> None:
|
||||
assert cache.ttl(_key()) == TTL_KEY_NOT_FOUND
|
||||
|
||||
def test_ttl_no_expiry(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"x")
|
||||
assert cache.ttl(k) == TTL_NO_EXPIRY
|
||||
|
||||
def test_ttl_remaining(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"x", ex=10)
|
||||
remaining = cache.ttl(k)
|
||||
assert 8 <= remaining <= 10
|
||||
|
||||
def test_set_with_ttl_expires(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"x", ex=1)
|
||||
assert cache.get(k) == b"x"
|
||||
time.sleep(1.5)
|
||||
assert cache.get(k) is None
|
||||
|
||||
|
||||
class TestLockParity:
|
||||
def test_acquire_release(self, cache: CacheBackend) -> None:
|
||||
lock = cache.lock(f"parity_lock_{uuid4().hex[:8]}")
|
||||
assert lock.acquire(blocking=False)
|
||||
assert lock.owned()
|
||||
lock.release()
|
||||
assert not lock.owned()
|
||||
|
||||
|
||||
class TestListParity:
|
||||
def test_rpush_blpop(self, cache: CacheBackend) -> None:
|
||||
k = f"parity_list_{uuid4().hex[:8]}"
|
||||
cache.rpush(k, b"item")
|
||||
result = cache.blpop([k], timeout=1)
|
||||
assert result is not None
|
||||
assert result[1] == b"item"
|
||||
|
||||
def test_blpop_timeout(self, cache: CacheBackend) -> None:
|
||||
result = cache.blpop([f"parity_empty_{uuid4().hex[:8]}"], timeout=1)
|
||||
assert result is None
|
||||
129
backend/tests/external_dependency_unit/cache/test_kv_store_cache_layer.py
vendored
Normal file
129
backend/tests/external_dependency_unit/cache/test_kv_store_cache_layer.py
vendored
Normal file
@@ -0,0 +1,129 @@
|
||||
"""Tests for PgRedisKVStore's cache layer integration with CacheBackend.
|
||||
|
||||
Verifies that the KV store correctly uses the CacheBackend for caching
|
||||
in front of PostgreSQL: cache hits, cache misses falling through to PG,
|
||||
cache population after PG reads, cache invalidation on delete, and
|
||||
graceful degradation when the cache backend raises.
|
||||
|
||||
Requires running PostgreSQL.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.postgres_backend import PostgresCacheBackend
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.models import CacheStore
|
||||
from onyx.db.models import KVStore
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.key_value_store.store import PgRedisKVStore
|
||||
from onyx.key_value_store.store import REDIS_KEY_PREFIX
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clean_kv() -> Generator[None, None, None]:
|
||||
yield
|
||||
with get_session_with_tenant(tenant_id=TEST_TENANT_ID) as session:
|
||||
session.execute(delete(KVStore))
|
||||
session.execute(delete(CacheStore))
|
||||
session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def kv_store(pg_cache: PostgresCacheBackend) -> PgRedisKVStore:
|
||||
return PgRedisKVStore(cache=pg_cache)
|
||||
|
||||
|
||||
class TestStoreAndLoad:
|
||||
def test_store_populates_cache_and_pg(
|
||||
self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend
|
||||
) -> None:
|
||||
kv_store.store("k1", {"hello": "world"})
|
||||
|
||||
cached = pg_cache.get(REDIS_KEY_PREFIX + "k1")
|
||||
assert cached is not None
|
||||
assert json.loads(cached) == {"hello": "world"}
|
||||
|
||||
loaded = kv_store.load("k1")
|
||||
assert loaded == {"hello": "world"}
|
||||
|
||||
def test_load_returns_cached_value_without_pg_hit(
|
||||
self, pg_cache: PostgresCacheBackend
|
||||
) -> None:
|
||||
"""If the cache already has the value, PG should not be queried."""
|
||||
pg_cache.set(REDIS_KEY_PREFIX + "cached_only", json.dumps({"from": "cache"}))
|
||||
kv = PgRedisKVStore(cache=pg_cache)
|
||||
assert kv.load("cached_only") == {"from": "cache"}
|
||||
|
||||
def test_load_falls_through_to_pg_on_cache_miss(
|
||||
self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend
|
||||
) -> None:
|
||||
kv_store.store("k2", [1, 2, 3])
|
||||
|
||||
pg_cache.delete(REDIS_KEY_PREFIX + "k2")
|
||||
assert pg_cache.get(REDIS_KEY_PREFIX + "k2") is None
|
||||
|
||||
loaded = kv_store.load("k2")
|
||||
assert loaded == [1, 2, 3]
|
||||
|
||||
repopulated = pg_cache.get(REDIS_KEY_PREFIX + "k2")
|
||||
assert repopulated is not None
|
||||
assert json.loads(repopulated) == [1, 2, 3]
|
||||
|
||||
def test_load_with_refresh_cache_skips_cache(
|
||||
self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend
|
||||
) -> None:
|
||||
kv_store.store("k3", "original")
|
||||
|
||||
pg_cache.set(REDIS_KEY_PREFIX + "k3", json.dumps("stale"))
|
||||
|
||||
loaded = kv_store.load("k3", refresh_cache=True)
|
||||
assert loaded == "original"
|
||||
|
||||
|
||||
class TestDelete:
|
||||
def test_delete_removes_from_cache_and_pg(
|
||||
self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend
|
||||
) -> None:
|
||||
kv_store.store("del_me", "bye")
|
||||
kv_store.delete("del_me")
|
||||
|
||||
assert pg_cache.get(REDIS_KEY_PREFIX + "del_me") is None
|
||||
|
||||
with pytest.raises(KvKeyNotFoundError):
|
||||
kv_store.load("del_me")
|
||||
|
||||
def test_delete_missing_key_raises(self, kv_store: PgRedisKVStore) -> None:
|
||||
with pytest.raises(KvKeyNotFoundError):
|
||||
kv_store.delete("nonexistent")
|
||||
|
||||
|
||||
class TestCacheFailureGracefulDegradation:
|
||||
def test_store_succeeds_when_cache_set_raises(self) -> None:
|
||||
failing_cache = MagicMock(spec=CacheBackend)
|
||||
failing_cache.set.side_effect = ConnectionError("cache down")
|
||||
|
||||
kv = PgRedisKVStore(cache=failing_cache)
|
||||
kv.store("resilient", {"data": True})
|
||||
|
||||
working_cache = MagicMock(spec=CacheBackend)
|
||||
working_cache.get.return_value = None
|
||||
kv_reader = PgRedisKVStore(cache=working_cache)
|
||||
loaded = kv_reader.load("resilient")
|
||||
assert loaded == {"data": True}
|
||||
|
||||
def test_load_falls_through_when_cache_get_raises(self) -> None:
|
||||
failing_cache = MagicMock(spec=CacheBackend)
|
||||
failing_cache.get.side_effect = ConnectionError("cache down")
|
||||
failing_cache.set.side_effect = ConnectionError("cache down")
|
||||
|
||||
kv = PgRedisKVStore(cache=failing_cache)
|
||||
kv.store("survive", 42)
|
||||
loaded = kv.load("survive")
|
||||
assert loaded == 42
|
||||
229
backend/tests/external_dependency_unit/cache/test_postgres_cache_backend.py
vendored
Normal file
229
backend/tests/external_dependency_unit/cache/test_postgres_cache_backend.py
vendored
Normal file
@@ -0,0 +1,229 @@
|
||||
"""Tests for PostgresCacheBackend against real PostgreSQL.
|
||||
|
||||
Covers every method on the backend: KV CRUD, TTL behaviour, advisory
|
||||
locks (acquire / release / contention), list operations (rpush / blpop),
|
||||
and the periodic cleanup function.
|
||||
"""
|
||||
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from onyx.cache.interface import TTL_KEY_NOT_FOUND
|
||||
from onyx.cache.interface import TTL_NO_EXPIRY
|
||||
from onyx.cache.postgres_backend import cleanup_expired_cache_entries
|
||||
from onyx.cache.postgres_backend import PostgresCacheBackend
|
||||
from onyx.db.models import CacheStore
|
||||
|
||||
|
||||
def _key() -> str:
|
||||
return f"test_{uuid4().hex[:12]}"
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Basic KV
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestKV:
|
||||
def test_get_set(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"hello")
|
||||
assert pg_cache.get(k) == b"hello"
|
||||
|
||||
def test_get_missing(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
assert pg_cache.get(_key()) is None
|
||||
|
||||
def test_set_overwrite(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"first")
|
||||
pg_cache.set(k, b"second")
|
||||
assert pg_cache.get(k) == b"second"
|
||||
|
||||
def test_set_string_value(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, "string_val")
|
||||
assert pg_cache.get(k) == b"string_val"
|
||||
|
||||
def test_set_int_value(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, 42)
|
||||
assert pg_cache.get(k) == b"42"
|
||||
|
||||
def test_delete(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"to_delete")
|
||||
pg_cache.delete(k)
|
||||
assert pg_cache.get(k) is None
|
||||
|
||||
def test_delete_missing_is_noop(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
pg_cache.delete(_key())
|
||||
|
||||
def test_exists(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
assert not pg_cache.exists(k)
|
||||
pg_cache.set(k, b"x")
|
||||
assert pg_cache.exists(k)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# TTL
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTTL:
|
||||
def test_set_with_ttl_expires(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"ephemeral", ex=1)
|
||||
assert pg_cache.get(k) == b"ephemeral"
|
||||
time.sleep(1.5)
|
||||
assert pg_cache.get(k) is None
|
||||
|
||||
def test_ttl_no_expiry(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"forever")
|
||||
assert pg_cache.ttl(k) == TTL_NO_EXPIRY
|
||||
|
||||
def test_ttl_missing_key(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
assert pg_cache.ttl(_key()) == TTL_KEY_NOT_FOUND
|
||||
|
||||
def test_ttl_remaining(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"x", ex=10)
|
||||
remaining = pg_cache.ttl(k)
|
||||
assert 8 <= remaining <= 10
|
||||
|
||||
def test_ttl_expired_key(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"x", ex=1)
|
||||
time.sleep(1.5)
|
||||
assert pg_cache.ttl(k) == TTL_KEY_NOT_FOUND
|
||||
|
||||
def test_expire_adds_ttl(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"x")
|
||||
assert pg_cache.ttl(k) == TTL_NO_EXPIRY
|
||||
pg_cache.expire(k, 10)
|
||||
assert 8 <= pg_cache.ttl(k) <= 10
|
||||
|
||||
def test_exists_respects_ttl(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"x", ex=1)
|
||||
assert pg_cache.exists(k)
|
||||
time.sleep(1.5)
|
||||
assert not pg_cache.exists(k)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Locks
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLock:
|
||||
def test_acquire_release(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
lock = pg_cache.lock(f"lock_{uuid4().hex[:8]}")
|
||||
assert lock.acquire(blocking=False)
|
||||
assert lock.owned()
|
||||
lock.release()
|
||||
assert not lock.owned()
|
||||
|
||||
def test_contention(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
name = f"contention_{uuid4().hex[:8]}"
|
||||
lock1 = pg_cache.lock(name)
|
||||
lock2 = pg_cache.lock(name)
|
||||
|
||||
assert lock1.acquire(blocking=False)
|
||||
assert not lock2.acquire(blocking=False)
|
||||
|
||||
lock1.release()
|
||||
assert lock2.acquire(blocking=False)
|
||||
lock2.release()
|
||||
|
||||
def test_context_manager(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
with pg_cache.lock(f"ctx_{uuid4().hex[:8]}") as lock:
|
||||
assert lock.owned()
|
||||
assert not lock.owned()
|
||||
|
||||
def test_blocking_timeout(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
name = f"timeout_{uuid4().hex[:8]}"
|
||||
holder = pg_cache.lock(name)
|
||||
holder.acquire(blocking=False)
|
||||
|
||||
waiter = pg_cache.lock(name, timeout=0.3)
|
||||
start = time.monotonic()
|
||||
assert not waiter.acquire(blocking=True, blocking_timeout=0.3)
|
||||
elapsed = time.monotonic() - start
|
||||
assert elapsed >= 0.25
|
||||
|
||||
holder.release()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# List (rpush / blpop)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestList:
|
||||
def test_rpush_blpop(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = f"list_{uuid4().hex[:8]}"
|
||||
pg_cache.rpush(k, b"item1")
|
||||
result = pg_cache.blpop([k], timeout=1)
|
||||
assert result is not None
|
||||
assert result == (k.encode(), b"item1")
|
||||
|
||||
def test_blpop_timeout(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
result = pg_cache.blpop([f"empty_{uuid4().hex[:8]}"], timeout=1)
|
||||
assert result is None
|
||||
|
||||
def test_fifo_order(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = f"fifo_{uuid4().hex[:8]}"
|
||||
pg_cache.rpush(k, b"first")
|
||||
time.sleep(0.01)
|
||||
pg_cache.rpush(k, b"second")
|
||||
|
||||
r1 = pg_cache.blpop([k], timeout=1)
|
||||
r2 = pg_cache.blpop([k], timeout=1)
|
||||
assert r1 is not None and r1[1] == b"first"
|
||||
assert r2 is not None and r2[1] == b"second"
|
||||
|
||||
def test_multiple_keys(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k1 = f"mk1_{uuid4().hex[:8]}"
|
||||
k2 = f"mk2_{uuid4().hex[:8]}"
|
||||
pg_cache.rpush(k2, b"from_k2")
|
||||
|
||||
result = pg_cache.blpop([k1, k2], timeout=1)
|
||||
assert result is not None
|
||||
assert result == (k2.encode(), b"from_k2")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Cleanup
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCleanup:
|
||||
def test_removes_expired_rows(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
k = _key()
|
||||
pg_cache.set(k, b"stale", ex=1)
|
||||
time.sleep(1.5)
|
||||
cleanup_expired_cache_entries()
|
||||
|
||||
stmt = select(CacheStore.key).where(CacheStore.key == k)
|
||||
with get_session_with_current_tenant() as session:
|
||||
row = session.execute(stmt).first()
|
||||
assert row is None, "expired row should be physically deleted"
|
||||
|
||||
def test_preserves_unexpired_rows(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"fresh", ex=300)
|
||||
cleanup_expired_cache_entries()
|
||||
assert pg_cache.get(k) == b"fresh"
|
||||
|
||||
def test_preserves_no_ttl_rows(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"permanent")
|
||||
cleanup_expired_cache_entries()
|
||||
assert pg_cache.get(k) == b"permanent"
|
||||
@@ -36,7 +36,6 @@ from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
user_file_project_sync_lock_key,
|
||||
)
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserFile
|
||||
@@ -86,12 +85,6 @@ def _create_test_persona(
|
||||
persona = Persona(
|
||||
name=f"Test Persona {uuid4().hex[:8]}",
|
||||
description="Test persona",
|
||||
num_chunks=10.0,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
system_prompt="You are a test assistant",
|
||||
task_prompt="Answer the question",
|
||||
tools=[],
|
||||
@@ -410,10 +403,6 @@ class TestUpsertPersonaMarksSyncFlag:
|
||||
user=user,
|
||||
name=f"persona-{uuid4().hex[:8]}",
|
||||
description="test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
@@ -442,10 +431,6 @@ class TestUpsertPersonaMarksSyncFlag:
|
||||
user=user,
|
||||
name=f"persona-{uuid4().hex[:8]}",
|
||||
description="test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
@@ -461,16 +446,11 @@ class TestUpsertPersonaMarksSyncFlag:
|
||||
uf_old.needs_persona_sync = False
|
||||
db_session.commit()
|
||||
|
||||
assert persona.num_chunks is not None
|
||||
# Now update the persona to swap files
|
||||
upsert_persona(
|
||||
user=user,
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=persona.num_chunks,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=persona.recency_bias,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
@@ -501,10 +481,6 @@ class TestUpsertPersonaMarksSyncFlag:
|
||||
user=user,
|
||||
name=f"persona-{uuid4().hex[:8]}",
|
||||
description="test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
@@ -519,15 +495,10 @@ class TestUpsertPersonaMarksSyncFlag:
|
||||
uf.needs_persona_sync = False
|
||||
db_session.commit()
|
||||
|
||||
assert persona.num_chunks is not None
|
||||
upsert_persona(
|
||||
user=user,
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=persona.num_chunks,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=persona.recency_bias,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
|
||||
@@ -18,7 +18,6 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserFile
|
||||
@@ -58,12 +57,6 @@ def _create_persona(db_session: Session, user: User) -> Persona:
|
||||
persona = Persona(
|
||||
name=f"Test Persona {uuid4().hex[:8]}",
|
||||
description="Test persona",
|
||||
num_chunks=10.0,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
system_prompt="test",
|
||||
task_prompt="test",
|
||||
tools=[],
|
||||
|
||||
@@ -22,7 +22,6 @@ from sqlalchemy.orm import Session
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
from onyx.configs.constants import FederatedConnectorSource
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.context.search.federated.slack_search import fetch_and_cache_channel_metadata
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import FederatedConnector
|
||||
@@ -55,11 +54,6 @@ def _create_test_persona_with_slack_config(db_session: Session) -> Persona | Non
|
||||
persona = Persona(
|
||||
name=f"test_slack_persona_{unique_id}",
|
||||
description="Test persona for Slack federated search",
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=True,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
system_prompt="You are a helpful assistant.",
|
||||
task_prompt="Answer the user's question based on the provided context.",
|
||||
)
|
||||
@@ -818,11 +812,6 @@ def test_slack_channel_config_eager_loads_persona(db_session: Session) -> None:
|
||||
persona = Persona(
|
||||
name=f"test_eager_load_persona_{unique_id}",
|
||||
description="Test persona for eager loading test",
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=True,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
system_prompt="You are a helpful assistant.",
|
||||
task_prompt="Answer the user's question.",
|
||||
)
|
||||
|
||||
@@ -21,7 +21,6 @@ import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import MCPAuthenticationPerformer
|
||||
from onyx.db.enums import MCPAuthenticationType
|
||||
from onyx.db.enums import MCPTransport
|
||||
@@ -47,12 +46,6 @@ def _create_test_persona_with_mcp_tool(
|
||||
persona = Persona(
|
||||
name=f"Test MCP Persona {uuid4().hex[:8]}",
|
||||
description="Test persona with MCP tools",
|
||||
num_chunks=10.0,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
system_prompt="You are a helpful assistant",
|
||||
task_prompt="Answer the user's question",
|
||||
tools=tools,
|
||||
|
||||
@@ -17,7 +17,6 @@ import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import OAuthConfig
|
||||
from onyx.db.models import Persona
|
||||
@@ -57,12 +56,6 @@ def _create_test_persona(db_session: Session, user: User, tools: list[Tool]) ->
|
||||
persona = Persona(
|
||||
name=f"Test Persona {uuid4().hex[:8]}",
|
||||
description="Test persona",
|
||||
num_chunks=10.0,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
system_prompt="You are a helpful assistant",
|
||||
task_prompt="Answer the user's question",
|
||||
tools=tools,
|
||||
|
||||
@@ -933,6 +933,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi import UploadFile
|
||||
from fastapi.background import BackgroundTasks
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.datastructures import Headers
|
||||
|
||||
@@ -1139,6 +1140,7 @@ def test_code_interpreter_receives_chat_files(
|
||||
# Upload a test CSV
|
||||
csv_content = b"name,age,city\nAlice,30,NYC\nBob,25,SF\n"
|
||||
result = upload_user_files(
|
||||
bg_tasks=BackgroundTasks(),
|
||||
files=[
|
||||
UploadFile(
|
||||
file=io.BytesIO(csv_content),
|
||||
|
||||
@@ -3,7 +3,6 @@ from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.server.features.persona.models import FullPersonaSnapshot
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
@@ -20,11 +19,7 @@ class PersonaManager:
|
||||
description: str | None = None,
|
||||
system_prompt: str | None = None,
|
||||
task_prompt: str | None = None,
|
||||
num_chunks: float = 5,
|
||||
llm_relevance_filter: bool = True,
|
||||
is_public: bool = True,
|
||||
llm_filter_extraction: bool = True,
|
||||
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO,
|
||||
datetime_aware: bool = False,
|
||||
document_set_ids: list[int] | None = None,
|
||||
tool_ids: list[int] | None = None,
|
||||
@@ -35,6 +30,7 @@ class PersonaManager:
|
||||
label_ids: list[int] | None = None,
|
||||
user_file_ids: list[str] | None = None,
|
||||
display_priority: int | None = None,
|
||||
featured: bool = False,
|
||||
) -> DATestPersona:
|
||||
name = name or f"test-persona-{uuid4()}"
|
||||
description = description or f"Description for {name}"
|
||||
@@ -47,11 +43,7 @@ class PersonaManager:
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
datetime_aware=datetime_aware,
|
||||
num_chunks=num_chunks,
|
||||
llm_relevance_filter=llm_relevance_filter,
|
||||
is_public=is_public,
|
||||
llm_filter_extraction=llm_filter_extraction,
|
||||
recency_bias=recency_bias,
|
||||
document_set_ids=document_set_ids or [],
|
||||
tool_ids=tool_ids or [],
|
||||
llm_model_provider_override=llm_model_provider_override,
|
||||
@@ -61,6 +53,7 @@ class PersonaManager:
|
||||
label_ids=label_ids or [],
|
||||
user_file_ids=user_file_ids or [],
|
||||
display_priority=display_priority,
|
||||
featured=featured,
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
@@ -75,11 +68,7 @@ class PersonaManager:
|
||||
id=persona_data["id"],
|
||||
name=name,
|
||||
description=description,
|
||||
num_chunks=num_chunks,
|
||||
llm_relevance_filter=llm_relevance_filter,
|
||||
is_public=is_public,
|
||||
llm_filter_extraction=llm_filter_extraction,
|
||||
recency_bias=recency_bias,
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
datetime_aware=datetime_aware,
|
||||
@@ -90,6 +79,7 @@ class PersonaManager:
|
||||
users=users or [],
|
||||
groups=groups or [],
|
||||
label_ids=label_ids or [],
|
||||
featured=featured,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -100,11 +90,7 @@ class PersonaManager:
|
||||
description: str | None = None,
|
||||
system_prompt: str | None = None,
|
||||
task_prompt: str | None = None,
|
||||
num_chunks: float | None = None,
|
||||
llm_relevance_filter: bool | None = None,
|
||||
is_public: bool | None = None,
|
||||
llm_filter_extraction: bool | None = None,
|
||||
recency_bias: RecencyBiasSetting | None = None,
|
||||
datetime_aware: bool = False,
|
||||
document_set_ids: list[int] | None = None,
|
||||
tool_ids: list[int] | None = None,
|
||||
@@ -113,6 +99,7 @@ class PersonaManager:
|
||||
users: list[str] | None = None,
|
||||
groups: list[int] | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
featured: bool | None = None,
|
||||
) -> DATestPersona:
|
||||
system_prompt = system_prompt or f"System prompt for {persona.name}"
|
||||
task_prompt = task_prompt or f"Task prompt for {persona.name}"
|
||||
@@ -123,13 +110,7 @@ class PersonaManager:
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
datetime_aware=datetime_aware,
|
||||
num_chunks=num_chunks or persona.num_chunks,
|
||||
llm_relevance_filter=llm_relevance_filter or persona.llm_relevance_filter,
|
||||
is_public=persona.is_public if is_public is None else is_public,
|
||||
llm_filter_extraction=(
|
||||
llm_filter_extraction or persona.llm_filter_extraction
|
||||
),
|
||||
recency_bias=recency_bias or persona.recency_bias,
|
||||
document_set_ids=document_set_ids or persona.document_set_ids,
|
||||
tool_ids=tool_ids or persona.tool_ids,
|
||||
llm_model_provider_override=(
|
||||
@@ -141,6 +122,7 @@ class PersonaManager:
|
||||
users=[UUID(user) for user in (users or persona.users)],
|
||||
groups=groups or persona.groups,
|
||||
label_ids=label_ids or persona.label_ids,
|
||||
featured=featured if featured is not None else persona.featured,
|
||||
)
|
||||
|
||||
response = requests.patch(
|
||||
@@ -155,16 +137,12 @@ class PersonaManager:
|
||||
id=updated_persona_data["id"],
|
||||
name=updated_persona_data["name"],
|
||||
description=updated_persona_data["description"],
|
||||
num_chunks=updated_persona_data["num_chunks"],
|
||||
llm_relevance_filter=updated_persona_data["llm_relevance_filter"],
|
||||
is_public=updated_persona_data["is_public"],
|
||||
llm_filter_extraction=updated_persona_data["llm_filter_extraction"],
|
||||
recency_bias=recency_bias or persona.recency_bias,
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
datetime_aware=datetime_aware,
|
||||
document_set_ids=updated_persona_data["document_sets"],
|
||||
tool_ids=updated_persona_data["tools"],
|
||||
document_set_ids=[ds["id"] for ds in updated_persona_data["document_sets"]],
|
||||
tool_ids=[t["id"] for t in updated_persona_data["tools"]],
|
||||
llm_model_provider_override=updated_persona_data[
|
||||
"llm_model_provider_override"
|
||||
],
|
||||
@@ -173,7 +151,8 @@ class PersonaManager:
|
||||
],
|
||||
users=[user["email"] for user in updated_persona_data["users"]],
|
||||
groups=updated_persona_data["groups"],
|
||||
label_ids=updated_persona_data["labels"],
|
||||
label_ids=[label["id"] for label in updated_persona_data["labels"]],
|
||||
featured=updated_persona_data["featured"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -222,32 +201,13 @@ class PersonaManager:
|
||||
fetched_persona.description,
|
||||
)
|
||||
)
|
||||
if fetched_persona.num_chunks != persona.num_chunks:
|
||||
mismatches.append(
|
||||
("num_chunks", persona.num_chunks, fetched_persona.num_chunks)
|
||||
)
|
||||
if fetched_persona.llm_relevance_filter != persona.llm_relevance_filter:
|
||||
mismatches.append(
|
||||
(
|
||||
"llm_relevance_filter",
|
||||
persona.llm_relevance_filter,
|
||||
fetched_persona.llm_relevance_filter,
|
||||
)
|
||||
)
|
||||
if fetched_persona.is_public != persona.is_public:
|
||||
mismatches.append(
|
||||
("is_public", persona.is_public, fetched_persona.is_public)
|
||||
)
|
||||
if (
|
||||
fetched_persona.llm_filter_extraction
|
||||
!= persona.llm_filter_extraction
|
||||
):
|
||||
if fetched_persona.featured != persona.featured:
|
||||
mismatches.append(
|
||||
(
|
||||
"llm_filter_extraction",
|
||||
persona.llm_filter_extraction,
|
||||
fetched_persona.llm_filter_extraction,
|
||||
)
|
||||
("featured", persona.featured, fetched_persona.featured)
|
||||
)
|
||||
if (
|
||||
fetched_persona.llm_model_provider_override
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
|
||||
|
||||
class ScimClient:
|
||||
"""HTTP client for making authenticated SCIM v2 requests."""
|
||||
|
||||
@staticmethod
|
||||
def _headers(raw_token: str) -> dict[str, str]:
|
||||
return {
|
||||
**GENERAL_HEADERS,
|
||||
"Authorization": f"Bearer {raw_token}",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get(path: str, raw_token: str) -> requests.Response:
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=ScimClient._headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def post(path: str, raw_token: str, json: dict) -> requests.Response:
|
||||
return requests.post(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
json=json,
|
||||
headers=ScimClient._headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def put(path: str, raw_token: str, json: dict) -> requests.Response:
|
||||
return requests.put(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
json=json,
|
||||
headers=ScimClient._headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def patch(path: str, raw_token: str, json: dict) -> requests.Response:
|
||||
return requests.patch(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
json=json,
|
||||
headers=ScimClient._headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def delete(path: str, raw_token: str) -> requests.Response:
|
||||
return requests.delete(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=ScimClient._headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_no_auth(path: str) -> requests.Response:
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=GENERAL_HEADERS,
|
||||
timeout=60,
|
||||
)
|
||||
@@ -1,7 +1,6 @@
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestScimToken
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -51,29 +50,3 @@ class ScimTokenManager:
|
||||
created_at=data["created_at"],
|
||||
last_used_at=data.get("last_used_at"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_scim_headers(raw_token: str) -> dict[str, str]:
|
||||
return {
|
||||
**GENERAL_HEADERS,
|
||||
"Authorization": f"Bearer {raw_token}",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def scim_get(
|
||||
path: str,
|
||||
raw_token: str,
|
||||
) -> requests.Response:
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=ScimTokenManager.get_scim_headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def scim_get_no_auth(path: str) -> requests.Response:
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=GENERAL_HEADERS,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@@ -10,7 +10,6 @@ from pydantic import Field
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.db.enums import AccessType
|
||||
@@ -162,11 +161,7 @@ class DATestPersona(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str
|
||||
num_chunks: float
|
||||
llm_relevance_filter: bool
|
||||
is_public: bool
|
||||
llm_filter_extraction: bool
|
||||
recency_bias: RecencyBiasSetting
|
||||
document_set_ids: list[int]
|
||||
tool_ids: list[int]
|
||||
llm_model_provider_override: str | None
|
||||
@@ -174,6 +169,7 @@ class DATestPersona(BaseModel):
|
||||
users: list[str]
|
||||
groups: list[int]
|
||||
label_ids: list[int]
|
||||
featured: bool = False
|
||||
|
||||
# Embedded prompt fields (no longer separate prompt_ids)
|
||||
system_prompt: str | None = None
|
||||
|
||||
@@ -8,7 +8,6 @@ from collections.abc import Generator
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.discord_bot import bulk_create_channel_configs
|
||||
from onyx.db.discord_bot import create_discord_bot_config
|
||||
from onyx.db.discord_bot import create_guild_config
|
||||
@@ -36,14 +35,8 @@ def _create_test_persona(db_session: Session, persona_id: int, name: str) -> Per
|
||||
id=persona_id,
|
||||
name=name,
|
||||
description="Test persona for Discord bot tests",
|
||||
num_chunks=5.0,
|
||||
chunks_above=1,
|
||||
chunks_below=1,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.FAVOR_RECENT,
|
||||
is_visible=True,
|
||||
is_default_persona=False,
|
||||
featured=False,
|
||||
deleted=False,
|
||||
builtin_persona=False,
|
||||
)
|
||||
|
||||
@@ -414,6 +414,24 @@ def test_mock_connector_checkpoint_recovery(
|
||||
)
|
||||
assert finished_index_attempt.status == IndexingStatus.FAILED
|
||||
|
||||
# Pause the connector immediately to prevent check_for_indexing from
|
||||
# creating automatic retry attempts while we reset the mock server.
|
||||
# Without this, the INITIAL_INDEXING status causes immediate retries
|
||||
# that would consume (or fail against) the mock server before we can
|
||||
# set up the recovery behavior.
|
||||
CCPairManager.pause_cc_pair(cc_pair, user_performing_action=admin_user)
|
||||
|
||||
# Collect all index attempt IDs created so far (the initial one plus
|
||||
# any automatic retries that may have started before the pause took effect).
|
||||
all_prior_attempt_ids: list[int] = []
|
||||
index_attempts_page = IndexAttemptManager.get_index_attempt_page(
|
||||
cc_pair_id=cc_pair.id,
|
||||
page=0,
|
||||
page_size=100,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
all_prior_attempt_ids = [ia.id for ia in index_attempts_page.items]
|
||||
|
||||
# Verify initial state: both docs should be indexed
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
documents = DocumentManager.fetch_documents_for_cc_pair(
|
||||
@@ -465,17 +483,14 @@ def test_mock_connector_checkpoint_recovery(
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# After the failure, the connector is in repeated error state and paused.
|
||||
# Set the manual indexing trigger first (while paused), then unpause.
|
||||
# This ensures the trigger is set before CHECK_FOR_INDEXING runs, which will
|
||||
# prevent the connector from being re-paused when repeated error state is detected.
|
||||
# Set the manual indexing trigger, then unpause to allow the recovery run.
|
||||
CCPairManager.run_once(
|
||||
cc_pair, from_beginning=False, user_performing_action=admin_user
|
||||
)
|
||||
CCPairManager.unpause_cc_pair(cc_pair, user_performing_action=admin_user)
|
||||
recovery_index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempts_to_ignore=[initial_index_attempt.id],
|
||||
index_attempts_to_ignore=all_prior_attempt_ids,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
IndexAttemptManager.wait_for_index_attempt_completion(
|
||||
|
||||
@@ -130,8 +130,8 @@ def test_repeated_error_state_detection_and_recovery(
|
||||
# )
|
||||
break
|
||||
|
||||
if time.monotonic() - start_time > 30:
|
||||
assert False, "CC pair did not enter repeated error state within 30 seconds"
|
||||
if time.monotonic() - start_time > 90:
|
||||
assert False, "CC pair did not enter repeated error state within 90 seconds"
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
@@ -386,6 +386,261 @@ def test_delete_llm_provider(
|
||||
assert provider_data is None
|
||||
|
||||
|
||||
def test_delete_default_llm_provider_rejected(reset: None) -> None: # noqa: ARG001
|
||||
"""Deleting the default LLM provider should return 400."""
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
|
||||
# Create a provider
|
||||
response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"name": "test-provider-default-delete",
|
||||
"provider": LlmProviderNames.OPENAI,
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000000",
|
||||
"model_configurations": [
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
).model_dump()
|
||||
],
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
created_provider = response.json()
|
||||
|
||||
# Set this provider as the default
|
||||
set_default_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/llm/default",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"provider_id": created_provider["id"],
|
||||
"model_name": "gpt-4o-mini",
|
||||
},
|
||||
)
|
||||
assert set_default_response.status_code == 200
|
||||
|
||||
# Attempt to delete the default provider — should be rejected
|
||||
delete_response = requests.delete(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{created_provider['id']}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert delete_response.status_code == 400
|
||||
assert "Cannot delete the default LLM provider" in delete_response.json()["detail"]
|
||||
|
||||
# Verify provider still exists
|
||||
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
|
||||
assert provider_data is not None
|
||||
|
||||
|
||||
def test_delete_non_default_llm_provider_with_default_set(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Deleting a non-default provider should succeed even when a default is set."""
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
|
||||
# Create two providers
|
||||
response_default = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"name": "default-provider",
|
||||
"provider": LlmProviderNames.OPENAI,
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000000",
|
||||
"model_configurations": [
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
).model_dump()
|
||||
],
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
},
|
||||
)
|
||||
assert response_default.status_code == 200
|
||||
default_provider = response_default.json()
|
||||
|
||||
response_other = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"name": "other-provider",
|
||||
"provider": LlmProviderNames.OPENAI,
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000000",
|
||||
"model_configurations": [
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o", is_visible=True
|
||||
).model_dump()
|
||||
],
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
},
|
||||
)
|
||||
assert response_other.status_code == 200
|
||||
other_provider = response_other.json()
|
||||
|
||||
# Set the first provider as default
|
||||
set_default_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/llm/default",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"provider_id": default_provider["id"],
|
||||
"model_name": "gpt-4o-mini",
|
||||
},
|
||||
)
|
||||
assert set_default_response.status_code == 200
|
||||
|
||||
# Delete the non-default provider — should succeed
|
||||
delete_response = requests.delete(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{other_provider['id']}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert delete_response.status_code == 200
|
||||
|
||||
# Verify the non-default provider is gone
|
||||
provider_data = _get_provider_by_id(admin_user, other_provider["id"])
|
||||
assert provider_data is None
|
||||
|
||||
# Verify the default provider still exists
|
||||
default_data = _get_provider_by_id(admin_user, default_provider["id"])
|
||||
assert default_data is not None
|
||||
|
||||
|
||||
def test_force_delete_default_llm_provider(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Force-deleting the default LLM provider should succeed."""
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
|
||||
# Create a provider
|
||||
response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"name": "test-provider-force-delete",
|
||||
"provider": LlmProviderNames.OPENAI,
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000000",
|
||||
"model_configurations": [
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
).model_dump()
|
||||
],
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
created_provider = response.json()
|
||||
|
||||
# Set this provider as the default
|
||||
set_default_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/llm/default",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"provider_id": created_provider["id"],
|
||||
"model_name": "gpt-4o-mini",
|
||||
},
|
||||
)
|
||||
assert set_default_response.status_code == 200
|
||||
|
||||
# Attempt to delete without force — should be rejected
|
||||
delete_response = requests.delete(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{created_provider['id']}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert delete_response.status_code == 400
|
||||
|
||||
# Force delete — should succeed
|
||||
force_delete_response = requests.delete(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{created_provider['id']}?force=true",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert force_delete_response.status_code == 200
|
||||
|
||||
# Verify provider is gone
|
||||
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
|
||||
assert provider_data is None
|
||||
|
||||
|
||||
def test_delete_default_vision_provider_clears_vision_default(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Deleting the default vision provider should succeed and clear the vision default."""
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
|
||||
# Create a text provider and set it as default (so we have a default text provider)
|
||||
text_response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"name": "text-provider",
|
||||
"provider": LlmProviderNames.OPENAI,
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000001",
|
||||
"model_configurations": [
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
).model_dump()
|
||||
],
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
},
|
||||
)
|
||||
assert text_response.status_code == 200
|
||||
text_provider = text_response.json()
|
||||
_set_default_provider(admin_user, text_provider["id"], "gpt-4o-mini")
|
||||
|
||||
# Create a vision provider and set it as default vision
|
||||
vision_response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"name": "vision-provider",
|
||||
"provider": LlmProviderNames.OPENAI,
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000002",
|
||||
"model_configurations": [
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o",
|
||||
is_visible=True,
|
||||
supports_image_input=True,
|
||||
).model_dump()
|
||||
],
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
},
|
||||
)
|
||||
assert vision_response.status_code == 200
|
||||
vision_provider = vision_response.json()
|
||||
_set_default_vision_provider(admin_user, vision_provider["id"], "gpt-4o")
|
||||
|
||||
# Verify vision default is set
|
||||
data = _get_providers_admin(admin_user)
|
||||
assert data is not None
|
||||
_, _, vision_default = _unpack_data(data)
|
||||
assert vision_default is not None
|
||||
assert vision_default["provider_id"] == vision_provider["id"]
|
||||
|
||||
# Delete the vision provider — should succeed (only text default is protected)
|
||||
delete_response = requests.delete(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{vision_provider['id']}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert delete_response.status_code == 200
|
||||
|
||||
# Verify the vision provider is gone
|
||||
provider_data = _get_provider_by_id(admin_user, vision_provider["id"])
|
||||
assert provider_data is None
|
||||
|
||||
# Verify there is no default vision provider
|
||||
data = _get_providers_admin(admin_user)
|
||||
assert data is not None
|
||||
_, text_default, vision_default = _unpack_data(data)
|
||||
assert vision_default is None
|
||||
|
||||
# Verify the text default is still intact
|
||||
assert text_default is not None
|
||||
assert text_default["provider_id"] == text_provider["id"]
|
||||
|
||||
|
||||
def test_duplicate_provider_name_rejected(reset: None) -> None: # noqa: ARG001
|
||||
"""Creating a provider with a name that already exists should return 400."""
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
|
||||
@@ -4,7 +4,6 @@ import pytest
|
||||
import requests
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.llm import can_user_access_llm_provider
|
||||
from onyx.db.llm import fetch_user_group_ids
|
||||
@@ -78,12 +77,6 @@ def _create_persona(
|
||||
persona = Persona(
|
||||
name=name,
|
||||
description=f"{name} description",
|
||||
num_chunks=5,
|
||||
chunks_above=2,
|
||||
chunks_below=2,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=True,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
llm_model_provider_override=provider_name,
|
||||
llm_model_version_override="gpt-4o-mini",
|
||||
system_prompt="System prompt",
|
||||
@@ -250,6 +243,116 @@ def test_can_user_access_llm_provider_or_logic(
|
||||
)
|
||||
|
||||
|
||||
def test_public_provider_with_persona_restrictions(
|
||||
users: tuple[DATestUser, DATestUser],
|
||||
) -> None:
|
||||
"""Public providers should still enforce persona restrictions.
|
||||
|
||||
Regression test for the bug where is_public=True caused
|
||||
can_user_access_llm_provider() to return True immediately,
|
||||
bypassing persona whitelist checks entirely.
|
||||
"""
|
||||
admin_user, _basic_user = users
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Public provider with persona restrictions
|
||||
public_restricted = _create_llm_provider(
|
||||
db_session,
|
||||
name="public-persona-restricted",
|
||||
default_model_name="gpt-4o",
|
||||
is_public=True,
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
whitelisted_persona = _create_persona(
|
||||
db_session,
|
||||
name="whitelisted-persona",
|
||||
provider_name=public_restricted.name,
|
||||
)
|
||||
non_whitelisted_persona = _create_persona(
|
||||
db_session,
|
||||
name="non-whitelisted-persona",
|
||||
provider_name=public_restricted.name,
|
||||
)
|
||||
|
||||
# Only whitelist one persona
|
||||
db_session.add(
|
||||
LLMProvider__Persona(
|
||||
llm_provider_id=public_restricted.id,
|
||||
persona_id=whitelisted_persona.id,
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
db_session.refresh(public_restricted)
|
||||
|
||||
admin_model = db_session.get(User, admin_user.id)
|
||||
assert admin_model is not None
|
||||
admin_group_ids = fetch_user_group_ids(db_session, admin_model)
|
||||
|
||||
# Whitelisted persona — should be allowed
|
||||
assert can_user_access_llm_provider(
|
||||
public_restricted,
|
||||
admin_group_ids,
|
||||
whitelisted_persona,
|
||||
)
|
||||
|
||||
# Non-whitelisted persona — should be denied despite is_public=True
|
||||
assert not can_user_access_llm_provider(
|
||||
public_restricted,
|
||||
admin_group_ids,
|
||||
non_whitelisted_persona,
|
||||
)
|
||||
|
||||
# No persona context (e.g. global provider list) — should be denied
|
||||
# because provider has persona restrictions set
|
||||
assert not can_user_access_llm_provider(
|
||||
public_restricted,
|
||||
admin_group_ids,
|
||||
persona=None,
|
||||
)
|
||||
|
||||
|
||||
def test_public_provider_without_persona_restrictions(
|
||||
users: tuple[DATestUser, DATestUser],
|
||||
) -> None:
|
||||
"""Public providers with no persona restrictions remain accessible to all."""
|
||||
admin_user, basic_user = users
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
public_unrestricted = _create_llm_provider(
|
||||
db_session,
|
||||
name="public-unrestricted",
|
||||
default_model_name="gpt-4o",
|
||||
is_public=True,
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
any_persona = _create_persona(
|
||||
db_session,
|
||||
name="any-persona",
|
||||
provider_name=public_unrestricted.name,
|
||||
)
|
||||
|
||||
admin_model = db_session.get(User, admin_user.id)
|
||||
basic_model = db_session.get(User, basic_user.id)
|
||||
assert admin_model is not None
|
||||
assert basic_model is not None
|
||||
|
||||
admin_group_ids = fetch_user_group_ids(db_session, admin_model)
|
||||
basic_group_ids = fetch_user_group_ids(db_session, basic_model)
|
||||
|
||||
# Any user, any persona — all allowed
|
||||
assert can_user_access_llm_provider(
|
||||
public_unrestricted, admin_group_ids, any_persona
|
||||
)
|
||||
assert can_user_access_llm_provider(
|
||||
public_unrestricted, basic_group_ids, any_persona
|
||||
)
|
||||
assert can_user_access_llm_provider(
|
||||
public_unrestricted, admin_group_ids, persona=None
|
||||
)
|
||||
|
||||
|
||||
def test_get_llm_for_persona_falls_back_when_access_denied(
|
||||
users: tuple[DATestUser, DATestUser],
|
||||
) -> None:
|
||||
|
||||
@@ -25,7 +25,7 @@ def test_cold_startup_default_assistant() -> None:
|
||||
result = db_session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT id, name, builtin_persona, is_default_persona, deleted
|
||||
SELECT id, name, builtin_persona, featured, deleted
|
||||
FROM persona
|
||||
WHERE builtin_persona = true
|
||||
ORDER BY id
|
||||
@@ -40,7 +40,7 @@ def test_cold_startup_default_assistant() -> None:
|
||||
assert default[0] == 0, "Default assistant should have ID 0"
|
||||
assert default[1] == "Assistant", "Should be named 'Assistant'"
|
||||
assert default[2] is True, "Should be builtin"
|
||||
assert default[3] is True, "Should be default"
|
||||
assert default[3] is True, "Should be featured"
|
||||
assert default[4] is False, "Should not be deleted"
|
||||
|
||||
# Check tools are properly associated
|
||||
|
||||
@@ -195,11 +195,7 @@ def _base_persona_body(**overrides: object) -> dict:
|
||||
"description": "test",
|
||||
"system_prompt": "test",
|
||||
"task_prompt": "",
|
||||
"num_chunks": 5,
|
||||
"is_public": True,
|
||||
"recency_bias": "auto",
|
||||
"llm_filter_extraction": False,
|
||||
"llm_relevance_filter": False,
|
||||
"datetime_aware": False,
|
||||
"document_set_ids": [],
|
||||
"tool_ids": [],
|
||||
|
||||
@@ -0,0 +1,160 @@
|
||||
"""Integration test for the full user-file lifecycle in no-vector-DB mode.
|
||||
|
||||
Covers: upload → COMPLETED → unlink from project → delete → gone.
|
||||
|
||||
The entire lifecycle is handled by FastAPI BackgroundTasks (no Celery workers
|
||||
needed). The conftest-level ``pytestmark`` ensures these tests are skipped
|
||||
when the server is running with vector DB enabled.
|
||||
"""
|
||||
|
||||
import time
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.project import ProjectManager
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
POLL_INTERVAL_SECONDS = 1
|
||||
POLL_TIMEOUT_SECONDS = 30
|
||||
|
||||
|
||||
def _poll_file_status(
|
||||
file_id: UUID,
|
||||
user: DATestUser,
|
||||
target_status: UserFileStatus,
|
||||
timeout: int = POLL_TIMEOUT_SECONDS,
|
||||
) -> None:
|
||||
"""Poll GET /user/projects/file/{file_id} until the file reaches *target_status*."""
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
resp = requests.get(
|
||||
f"{API_SERVER_URL}/user/projects/file/{file_id}",
|
||||
headers=user.headers,
|
||||
)
|
||||
if resp.ok:
|
||||
status = resp.json().get("status")
|
||||
if status == target_status.value:
|
||||
return
|
||||
time.sleep(POLL_INTERVAL_SECONDS)
|
||||
raise TimeoutError(
|
||||
f"File {file_id} did not reach {target_status.value} within {timeout}s"
|
||||
)
|
||||
|
||||
|
||||
def _file_is_gone(file_id: UUID, user: DATestUser, timeout: int = 15) -> None:
|
||||
"""Poll until GET /user/projects/file/{file_id} returns 404."""
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
resp = requests.get(
|
||||
f"{API_SERVER_URL}/user/projects/file/{file_id}",
|
||||
headers=user.headers,
|
||||
)
|
||||
if resp.status_code == 404:
|
||||
return
|
||||
time.sleep(POLL_INTERVAL_SECONDS)
|
||||
raise TimeoutError(
|
||||
f"File {file_id} still accessible after {timeout}s (expected 404)"
|
||||
)
|
||||
|
||||
|
||||
def test_file_upload_process_delete_lifecycle(
|
||||
reset: None, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Full lifecycle: upload → COMPLETED → unlink → delete → 404.
|
||||
|
||||
Validates that the API server handles all background processing
|
||||
(via FastAPI BackgroundTasks) without any Celery workers running.
|
||||
"""
|
||||
project = ProjectManager.create(
|
||||
name="lifecycle-test", user_performing_action=admin_user
|
||||
)
|
||||
|
||||
file_content = b"Integration test file content for lifecycle verification."
|
||||
upload_result = ProjectManager.upload_files(
|
||||
project_id=project.id,
|
||||
files=[("lifecycle.txt", file_content)],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert upload_result.user_files, "Expected at least one file in upload response"
|
||||
|
||||
user_file = upload_result.user_files[0]
|
||||
file_id = user_file.id
|
||||
|
||||
_poll_file_status(file_id, admin_user, UserFileStatus.COMPLETED)
|
||||
|
||||
project_files = ProjectManager.get_project_files(project.id, admin_user)
|
||||
assert any(
|
||||
f.id == file_id for f in project_files
|
||||
), "File should be listed in project files after processing"
|
||||
|
||||
# Unlink the file from the project so the delete endpoint will proceed
|
||||
unlink_resp = requests.delete(
|
||||
f"{API_SERVER_URL}/user/projects/{project.id}/files/{file_id}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert (
|
||||
unlink_resp.status_code == 204
|
||||
), f"Expected 204 on unlink, got {unlink_resp.status_code}: {unlink_resp.text}"
|
||||
|
||||
delete_resp = requests.delete(
|
||||
f"{API_SERVER_URL}/user/projects/file/{file_id}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert (
|
||||
delete_resp.ok
|
||||
), f"Delete request failed: {delete_resp.status_code} {delete_resp.text}"
|
||||
body = delete_resp.json()
|
||||
assert (
|
||||
body["has_associations"] is False
|
||||
), f"File still has associations after unlink: {body}"
|
||||
|
||||
_file_is_gone(file_id, admin_user)
|
||||
|
||||
project_files_after = ProjectManager.get_project_files(project.id, admin_user)
|
||||
assert not any(
|
||||
f.id == file_id for f in project_files_after
|
||||
), "Deleted file should not appear in project files"
|
||||
|
||||
|
||||
def test_delete_blocked_while_associated(
|
||||
reset: None, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Deleting a file that still belongs to a project should return
|
||||
has_associations=True without actually deleting the file."""
|
||||
project = ProjectManager.create(
|
||||
name="assoc-test", user_performing_action=admin_user
|
||||
)
|
||||
|
||||
upload_result = ProjectManager.upload_files(
|
||||
project_id=project.id,
|
||||
files=[("assoc.txt", b"associated file content")],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
file_id = upload_result.user_files[0].id
|
||||
|
||||
_poll_file_status(file_id, admin_user, UserFileStatus.COMPLETED)
|
||||
|
||||
# Attempt to delete while still linked
|
||||
delete_resp = requests.delete(
|
||||
f"{API_SERVER_URL}/user/projects/file/{file_id}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert delete_resp.ok
|
||||
body = delete_resp.json()
|
||||
assert body["has_associations"] is True, "Should report existing associations"
|
||||
assert project.name in body["project_names"]
|
||||
|
||||
# File should still be accessible
|
||||
get_resp = requests.get(
|
||||
f"{API_SERVER_URL}/user/projects/file/{file_id}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert get_resp.status_code == 200, "File should still exist after blocked delete"
|
||||
@@ -40,7 +40,6 @@ def test_persona_create_update_share_delete(
|
||||
expected_persona,
|
||||
name=f"updated-{expected_persona.name}",
|
||||
description=f"updated-{expected_persona.description}",
|
||||
num_chunks=expected_persona.num_chunks + 1,
|
||||
is_public=False,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
@@ -31,11 +31,7 @@ def test_update_persona_with_null_label_ids_preserves_labels(
|
||||
task_prompt=persona.task_prompt or "",
|
||||
datetime_aware=persona.datetime_aware,
|
||||
document_set_ids=persona.document_set_ids,
|
||||
num_chunks=persona.num_chunks,
|
||||
is_public=persona.is_public,
|
||||
recency_bias=persona.recency_bias,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
tool_ids=persona.tool_ids,
|
||||
|
||||
@@ -31,9 +31,8 @@ def test_unified_assistant(reset: None, admin_user: DATestUser) -> None: # noqa
|
||||
"search, web browsing, and image generation"
|
||||
in unified_assistant.description.lower()
|
||||
)
|
||||
assert unified_assistant.is_default_persona is True
|
||||
assert unified_assistant.featured is True
|
||||
assert unified_assistant.is_visible is True
|
||||
assert unified_assistant.num_chunks == 25
|
||||
|
||||
# Verify tools
|
||||
tools = unified_assistant.tools
|
||||
|
||||
552
backend/tests/integration/tests/scim/test_scim_groups.py
Normal file
552
backend/tests/integration/tests/scim/test_scim_groups.py
Normal file
@@ -0,0 +1,552 @@
|
||||
"""Integration tests for SCIM group provisioning endpoints.
|
||||
|
||||
Covers the full group lifecycle as driven by an IdP (Okta / Azure AD):
|
||||
1. Create a group via POST /Groups
|
||||
2. Retrieve a group via GET /Groups/{id}
|
||||
3. List, filter, and paginate groups via GET /Groups
|
||||
4. Replace a group via PUT /Groups/{id}
|
||||
5. Patch a group (add/remove members, rename) via PATCH /Groups/{id}
|
||||
6. Delete a group via DELETE /Groups/{id}
|
||||
7. Error cases: duplicate name, not-found, invalid member IDs
|
||||
|
||||
All tests are parameterized across IdP request styles (Okta sends lowercase
|
||||
PATCH ops; Entra sends capitalized ops like ``"Replace"``). The server
|
||||
normalizes both — these tests verify that.
|
||||
|
||||
Auth tests live in test_scim_tokens.py.
|
||||
User lifecycle tests live in test_scim_users.py.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from tests.integration.common_utils.managers.scim_client import ScimClient
|
||||
from tests.integration.common_utils.managers.scim_token import ScimTokenManager
|
||||
|
||||
|
||||
SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group"
|
||||
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
|
||||
SCIM_PATCH_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:PatchOp"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=["okta", "entra"])
|
||||
def idp_style(request: pytest.FixtureRequest) -> str:
|
||||
"""Parameterized IdP style — runs every test with both Okta and Entra request formats."""
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def scim_token(idp_style: str) -> str:
|
||||
"""Create a single SCIM token shared across all tests in this module.
|
||||
|
||||
Creating a new token revokes the previous one, so we create exactly once
|
||||
per IdP-style run and reuse. Uses UserManager directly to avoid
|
||||
fixture-scope conflicts with the function-scoped admin_user fixture.
|
||||
"""
|
||||
from tests.integration.common_utils.constants import ADMIN_USER_NAME
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
try:
|
||||
admin = UserManager.create(name=ADMIN_USER_NAME)
|
||||
except Exception:
|
||||
admin = UserManager.login_as_user(
|
||||
DATestUser(
|
||||
id="",
|
||||
email=build_email(ADMIN_USER_NAME),
|
||||
password=DEFAULT_PASSWORD,
|
||||
headers=GENERAL_HEADERS,
|
||||
role=UserRole.ADMIN,
|
||||
is_active=True,
|
||||
)
|
||||
)
|
||||
|
||||
token = ScimTokenManager.create(
|
||||
name=f"scim-group-tests-{idp_style}",
|
||||
user_performing_action=admin,
|
||||
).raw_token
|
||||
assert token is not None
|
||||
return token
|
||||
|
||||
|
||||
def _make_group_resource(
|
||||
display_name: str,
|
||||
external_id: str | None = None,
|
||||
members: list[dict] | None = None,
|
||||
) -> dict:
|
||||
"""Build a minimal SCIM GroupResource payload."""
|
||||
resource: dict = {
|
||||
"schemas": [SCIM_GROUP_SCHEMA],
|
||||
"displayName": display_name,
|
||||
}
|
||||
if external_id is not None:
|
||||
resource["externalId"] = external_id
|
||||
if members is not None:
|
||||
resource["members"] = members
|
||||
return resource
|
||||
|
||||
|
||||
def _make_user_resource(email: str, external_id: str) -> dict:
|
||||
"""Build a minimal SCIM UserResource payload for member creation."""
|
||||
return {
|
||||
"schemas": [SCIM_USER_SCHEMA],
|
||||
"userName": email,
|
||||
"externalId": external_id,
|
||||
"name": {"givenName": "Test", "familyName": "User"},
|
||||
"active": True,
|
||||
}
|
||||
|
||||
|
||||
def _make_patch_request(operations: list[dict], idp_style: str = "okta") -> dict:
|
||||
"""Build a SCIM PatchOp payload, applying IdP-specific operation casing.
|
||||
|
||||
Entra sends capitalized operations (e.g. ``"Replace"`` instead of
|
||||
``"replace"``). The server's ``normalize_operation`` validator lowercases
|
||||
them — these tests verify that both casings are accepted.
|
||||
"""
|
||||
cased_operations = []
|
||||
for operation in operations:
|
||||
cased = dict(operation)
|
||||
if idp_style == "entra":
|
||||
cased["op"] = operation["op"].capitalize()
|
||||
cased_operations.append(cased)
|
||||
return {
|
||||
"schemas": [SCIM_PATCH_SCHEMA],
|
||||
"Operations": cased_operations,
|
||||
}
|
||||
|
||||
|
||||
def _create_scim_user(token: str, email: str, external_id: str) -> requests.Response:
|
||||
return ScimClient.post(
|
||||
"/Users", token, json=_make_user_resource(email, external_id)
|
||||
)
|
||||
|
||||
|
||||
def _create_scim_group(
|
||||
token: str,
|
||||
display_name: str,
|
||||
external_id: str | None = None,
|
||||
members: list[dict] | None = None,
|
||||
) -> requests.Response:
|
||||
return ScimClient.post(
|
||||
"/Groups",
|
||||
token,
|
||||
json=_make_group_resource(display_name, external_id, members),
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle: create → get → list → replace → patch → delete
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_create_group(scim_token: str, idp_style: str) -> None:
|
||||
"""POST /Groups creates a group and returns 201."""
|
||||
name = f"Engineering {idp_style}"
|
||||
resp = _create_scim_group(scim_token, name, external_id=f"ext-eng-{idp_style}")
|
||||
assert resp.status_code == 201
|
||||
|
||||
body = resp.json()
|
||||
assert body["displayName"] == name
|
||||
assert body["externalId"] == f"ext-eng-{idp_style}"
|
||||
assert body["id"] # integer ID assigned by server
|
||||
assert body["meta"]["resourceType"] == "Group"
|
||||
|
||||
|
||||
def test_create_group_with_members(scim_token: str, idp_style: str) -> None:
|
||||
"""POST /Groups with members populates the member list."""
|
||||
user = _create_scim_user(
|
||||
scim_token, f"grp_member1_{idp_style}@example.com", f"ext-gm-{idp_style}"
|
||||
).json()
|
||||
|
||||
resp = _create_scim_group(
|
||||
scim_token,
|
||||
f"Backend Team {idp_style}",
|
||||
external_id=f"ext-backend-{idp_style}",
|
||||
members=[{"value": user["id"]}],
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
|
||||
body = resp.json()
|
||||
member_ids = [m["value"] for m in body["members"]]
|
||||
assert user["id"] in member_ids
|
||||
|
||||
|
||||
def test_get_group(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Groups/{id} returns the group resource including members."""
|
||||
user = _create_scim_user(
|
||||
scim_token, f"grp_get_m_{idp_style}@example.com", f"ext-ggm-{idp_style}"
|
||||
).json()
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Frontend Team {idp_style}",
|
||||
external_id=f"ext-fe-{idp_style}",
|
||||
members=[{"value": user["id"]}],
|
||||
).json()
|
||||
|
||||
resp = ScimClient.get(f"/Groups/{created['id']}", scim_token)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["id"] == created["id"]
|
||||
assert body["displayName"] == f"Frontend Team {idp_style}"
|
||||
assert body["externalId"] == f"ext-fe-{idp_style}"
|
||||
member_ids = [m["value"] for m in body["members"]]
|
||||
assert user["id"] in member_ids
|
||||
|
||||
|
||||
def test_list_groups(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Groups returns a ListResponse containing provisioned groups."""
|
||||
name = f"DevOps Team {idp_style}"
|
||||
_create_scim_group(scim_token, name, external_id=f"ext-devops-{idp_style}")
|
||||
|
||||
resp = ScimClient.get("/Groups", scim_token)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["totalResults"] >= 1
|
||||
names = [r["displayName"] for r in body["Resources"]]
|
||||
assert name in names
|
||||
|
||||
|
||||
def test_list_groups_pagination(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Groups with startIndex and count returns correct pagination."""
|
||||
_create_scim_group(
|
||||
scim_token, f"Page Group A {idp_style}", external_id=f"ext-page-a-{idp_style}"
|
||||
)
|
||||
_create_scim_group(
|
||||
scim_token, f"Page Group B {idp_style}", external_id=f"ext-page-b-{idp_style}"
|
||||
)
|
||||
|
||||
resp = ScimClient.get("/Groups?startIndex=1&count=1", scim_token)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["startIndex"] == 1
|
||||
assert body["itemsPerPage"] == 1
|
||||
assert body["totalResults"] >= 2
|
||||
assert len(body["Resources"]) == 1
|
||||
|
||||
|
||||
def test_filter_groups_by_display_name(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Groups?filter=displayName eq '...' returns only matching groups."""
|
||||
name = f"Unique QA Team {idp_style}"
|
||||
_create_scim_group(scim_token, name, external_id=f"ext-qa-filter-{idp_style}")
|
||||
|
||||
resp = ScimClient.get(f'/Groups?filter=displayName eq "{name}"', scim_token)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["totalResults"] == 1
|
||||
assert body["Resources"][0]["displayName"] == name
|
||||
|
||||
|
||||
def test_filter_groups_by_external_id(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Groups?filter=externalId eq '...' returns the matching group."""
|
||||
ext_id = f"ext-unique-group-id-{idp_style}"
|
||||
_create_scim_group(
|
||||
scim_token, f"ExtId Filter Group {idp_style}", external_id=ext_id
|
||||
)
|
||||
|
||||
resp = ScimClient.get(f'/Groups?filter=externalId eq "{ext_id}"', scim_token)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["totalResults"] == 1
|
||||
assert body["Resources"][0]["externalId"] == ext_id
|
||||
|
||||
|
||||
def test_replace_group(scim_token: str, idp_style: str) -> None:
|
||||
"""PUT /Groups/{id} replaces the group resource."""
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Original Name {idp_style}",
|
||||
external_id=f"ext-replace-g-{idp_style}",
|
||||
).json()
|
||||
|
||||
user = _create_scim_user(
|
||||
scim_token, f"grp_replace_m_{idp_style}@example.com", f"ext-grm-{idp_style}"
|
||||
).json()
|
||||
|
||||
updated_resource = _make_group_resource(
|
||||
display_name=f"Renamed Group {idp_style}",
|
||||
external_id=f"ext-replace-g-{idp_style}",
|
||||
members=[{"value": user["id"]}],
|
||||
)
|
||||
resp = ScimClient.put(f"/Groups/{created['id']}", scim_token, json=updated_resource)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["displayName"] == f"Renamed Group {idp_style}"
|
||||
member_ids = [m["value"] for m in body["members"]]
|
||||
assert user["id"] in member_ids
|
||||
|
||||
|
||||
def test_replace_group_clears_members(scim_token: str, idp_style: str) -> None:
|
||||
"""PUT /Groups/{id} with empty members removes all memberships."""
|
||||
user = _create_scim_user(
|
||||
scim_token, f"grp_clear_m_{idp_style}@example.com", f"ext-gcm-{idp_style}"
|
||||
).json()
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Clear Members Group {idp_style}",
|
||||
external_id=f"ext-clear-g-{idp_style}",
|
||||
members=[{"value": user["id"]}],
|
||||
).json()
|
||||
|
||||
assert len(created["members"]) == 1
|
||||
|
||||
resp = ScimClient.put(
|
||||
f"/Groups/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_group_resource(
|
||||
f"Clear Members Group {idp_style}", f"ext-clear-g-{idp_style}", members=[]
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["members"] == []
|
||||
|
||||
|
||||
def test_patch_add_member(scim_token: str, idp_style: str) -> None:
|
||||
"""PATCH /Groups/{id} with op=add adds a member."""
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Patch Add Group {idp_style}",
|
||||
external_id=f"ext-patch-add-{idp_style}",
|
||||
).json()
|
||||
user = _create_scim_user(
|
||||
scim_token, f"grp_patch_add_{idp_style}@example.com", f"ext-gpa-{idp_style}"
|
||||
).json()
|
||||
|
||||
resp = ScimClient.patch(
|
||||
f"/Groups/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[{"op": "add", "path": "members", "value": [{"value": user["id"]}]}],
|
||||
idp_style,
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
member_ids = [m["value"] for m in resp.json()["members"]]
|
||||
assert user["id"] in member_ids
|
||||
|
||||
|
||||
def test_patch_remove_member(scim_token: str, idp_style: str) -> None:
|
||||
"""PATCH /Groups/{id} with op=remove removes a specific member."""
|
||||
user = _create_scim_user(
|
||||
scim_token, f"grp_patch_rm_{idp_style}@example.com", f"ext-gpr-{idp_style}"
|
||||
).json()
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Patch Remove Group {idp_style}",
|
||||
external_id=f"ext-patch-rm-{idp_style}",
|
||||
members=[{"value": user["id"]}],
|
||||
).json()
|
||||
assert len(created["members"]) == 1
|
||||
|
||||
resp = ScimClient.patch(
|
||||
f"/Groups/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[
|
||||
{
|
||||
"op": "remove",
|
||||
"path": f'members[value eq "{user["id"]}"]',
|
||||
}
|
||||
],
|
||||
idp_style,
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["members"] == []
|
||||
|
||||
|
||||
def test_patch_replace_members(scim_token: str, idp_style: str) -> None:
|
||||
"""PATCH /Groups/{id} with op=replace on members swaps the entire list."""
|
||||
user_a = _create_scim_user(
|
||||
scim_token, f"grp_repl_a_{idp_style}@example.com", f"ext-gra-{idp_style}"
|
||||
).json()
|
||||
user_b = _create_scim_user(
|
||||
scim_token, f"grp_repl_b_{idp_style}@example.com", f"ext-grb-{idp_style}"
|
||||
).json()
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Patch Replace Group {idp_style}",
|
||||
external_id=f"ext-patch-repl-{idp_style}",
|
||||
members=[{"value": user_a["id"]}],
|
||||
).json()
|
||||
|
||||
# Replace member list: swap A for B
|
||||
resp = ScimClient.patch(
|
||||
f"/Groups/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "members",
|
||||
"value": [{"value": user_b["id"]}],
|
||||
}
|
||||
],
|
||||
idp_style,
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
member_ids = [m["value"] for m in resp.json()["members"]]
|
||||
assert user_b["id"] in member_ids
|
||||
assert user_a["id"] not in member_ids
|
||||
|
||||
|
||||
def test_patch_rename_group(scim_token: str, idp_style: str) -> None:
|
||||
"""PATCH /Groups/{id} with op=replace on displayName renames the group."""
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Old Group Name {idp_style}",
|
||||
external_id=f"ext-rename-g-{idp_style}",
|
||||
).json()
|
||||
|
||||
new_name = f"New Group Name {idp_style}"
|
||||
resp = ScimClient.patch(
|
||||
f"/Groups/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[{"op": "replace", "path": "displayName", "value": new_name}],
|
||||
idp_style,
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["displayName"] == new_name
|
||||
|
||||
# Confirm via GET
|
||||
get_resp = ScimClient.get(f"/Groups/{created['id']}", scim_token)
|
||||
assert get_resp.json()["displayName"] == new_name
|
||||
|
||||
|
||||
def test_delete_group(scim_token: str, idp_style: str) -> None:
|
||||
"""DELETE /Groups/{id} removes the group."""
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Delete Me Group {idp_style}",
|
||||
external_id=f"ext-del-g-{idp_style}",
|
||||
).json()
|
||||
|
||||
resp = ScimClient.delete(f"/Groups/{created['id']}", scim_token)
|
||||
assert resp.status_code == 204
|
||||
|
||||
# Second DELETE returns 404 (group hard-deleted)
|
||||
resp2 = ScimClient.delete(f"/Groups/{created['id']}", scim_token)
|
||||
assert resp2.status_code == 404
|
||||
|
||||
|
||||
def test_delete_group_preserves_members(scim_token: str, idp_style: str) -> None:
|
||||
"""DELETE /Groups/{id} removes memberships but does not deactivate users."""
|
||||
user = _create_scim_user(
|
||||
scim_token, f"grp_del_member_{idp_style}@example.com", f"ext-gdm-{idp_style}"
|
||||
).json()
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Delete With Members {idp_style}",
|
||||
external_id=f"ext-del-wm-{idp_style}",
|
||||
members=[{"value": user["id"]}],
|
||||
).json()
|
||||
|
||||
resp = ScimClient.delete(f"/Groups/{created['id']}", scim_token)
|
||||
assert resp.status_code == 204
|
||||
|
||||
# User should still be active and retrievable
|
||||
user_resp = ScimClient.get(f"/Users/{user['id']}", scim_token)
|
||||
assert user_resp.status_code == 200
|
||||
assert user_resp.json()["active"] is True
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Error cases
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_create_group_duplicate_name(scim_token: str, idp_style: str) -> None:
|
||||
"""POST /Groups with an already-taken displayName returns 409."""
|
||||
name = f"Dup Name Group {idp_style}"
|
||||
resp1 = _create_scim_group(scim_token, name, external_id=f"ext-dup-g1-{idp_style}")
|
||||
assert resp1.status_code == 201
|
||||
|
||||
resp2 = _create_scim_group(scim_token, name, external_id=f"ext-dup-g2-{idp_style}")
|
||||
assert resp2.status_code == 409
|
||||
|
||||
|
||||
def test_get_nonexistent_group(scim_token: str) -> None:
|
||||
"""GET /Groups/{bad-id} returns 404."""
|
||||
resp = ScimClient.get("/Groups/999999999", scim_token)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_create_group_with_invalid_member(scim_token: str, idp_style: str) -> None:
|
||||
"""POST /Groups with a non-existent member UUID returns 400."""
|
||||
resp = _create_scim_group(
|
||||
scim_token,
|
||||
f"Bad Member Group {idp_style}",
|
||||
external_id=f"ext-bad-m-{idp_style}",
|
||||
members=[{"value": "00000000-0000-0000-0000-000000000000"}],
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert "not found" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_patch_add_nonexistent_member(scim_token: str, idp_style: str) -> None:
|
||||
"""PATCH /Groups/{id} adding a non-existent member returns 400."""
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Patch Bad Member Group {idp_style}",
|
||||
external_id=f"ext-pbm-{idp_style}",
|
||||
).json()
|
||||
|
||||
resp = ScimClient.patch(
|
||||
f"/Groups/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[
|
||||
{
|
||||
"op": "add",
|
||||
"path": "members",
|
||||
"value": [{"value": "00000000-0000-0000-0000-000000000000"}],
|
||||
}
|
||||
],
|
||||
idp_style,
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert "not found" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_patch_add_duplicate_member_is_idempotent(
|
||||
scim_token: str, idp_style: str
|
||||
) -> None:
|
||||
"""PATCH /Groups/{id} adding an already-present member succeeds silently."""
|
||||
user = _create_scim_user(
|
||||
scim_token, f"grp_dup_add_{idp_style}@example.com", f"ext-gda-{idp_style}"
|
||||
).json()
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Idempotent Add Group {idp_style}",
|
||||
external_id=f"ext-idem-g-{idp_style}",
|
||||
members=[{"value": user["id"]}],
|
||||
).json()
|
||||
assert len(created["members"]) == 1
|
||||
|
||||
# Add same member again
|
||||
resp = ScimClient.patch(
|
||||
f"/Groups/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[{"op": "add", "path": "members", "value": [{"value": user["id"]}]}],
|
||||
idp_style,
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()["members"]) == 1 # still just one member
|
||||
@@ -15,6 +15,7 @@ import time
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.scim_client import ScimClient
|
||||
from tests.integration.common_utils.managers.scim_token import ScimTokenManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
@@ -39,7 +40,7 @@ def test_scim_token_lifecycle(admin_user: DATestUser) -> None:
|
||||
assert active == token.model_copy(update={"raw_token": None})
|
||||
|
||||
# Token works for SCIM requests
|
||||
response = ScimTokenManager.scim_get("/Users", token.raw_token)
|
||||
response = ScimClient.get("/Users", token.raw_token)
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert "Resources" in body
|
||||
@@ -54,7 +55,7 @@ def test_scim_token_rotation_revokes_previous(admin_user: DATestUser) -> None:
|
||||
)
|
||||
assert first.raw_token is not None
|
||||
|
||||
response = ScimTokenManager.scim_get("/Users", first.raw_token)
|
||||
response = ScimClient.get("/Users", first.raw_token)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Create second token — should revoke first
|
||||
@@ -69,25 +70,22 @@ def test_scim_token_rotation_revokes_previous(admin_user: DATestUser) -> None:
|
||||
assert active == second.model_copy(update={"raw_token": None})
|
||||
|
||||
# First token rejected, second works
|
||||
assert ScimTokenManager.scim_get("/Users", first.raw_token).status_code == 401
|
||||
assert ScimTokenManager.scim_get("/Users", second.raw_token).status_code == 200
|
||||
assert ScimClient.get("/Users", first.raw_token).status_code == 401
|
||||
assert ScimClient.get("/Users", second.raw_token).status_code == 200
|
||||
|
||||
|
||||
def test_scim_request_without_token_rejected(
|
||||
admin_user: DATestUser, # noqa: ARG001
|
||||
) -> None:
|
||||
"""SCIM endpoints reject requests with no Authorization header."""
|
||||
assert ScimTokenManager.scim_get_no_auth("/Users").status_code == 401
|
||||
assert ScimClient.get_no_auth("/Users").status_code == 401
|
||||
|
||||
|
||||
def test_scim_request_with_bad_token_rejected(
|
||||
admin_user: DATestUser, # noqa: ARG001
|
||||
) -> None:
|
||||
"""SCIM endpoints reject requests with an invalid token."""
|
||||
assert (
|
||||
ScimTokenManager.scim_get("/Users", "onyx_scim_bogus_token_value").status_code
|
||||
== 401
|
||||
)
|
||||
assert ScimClient.get("/Users", "onyx_scim_bogus_token_value").status_code == 401
|
||||
|
||||
|
||||
def test_non_admin_cannot_create_token(
|
||||
@@ -139,7 +137,7 @@ def test_service_discovery_no_auth_required(
|
||||
) -> None:
|
||||
"""Service discovery endpoints work without any authentication."""
|
||||
for path in ["/ServiceProviderConfig", "/ResourceTypes", "/Schemas"]:
|
||||
response = ScimTokenManager.scim_get_no_auth(path)
|
||||
response = ScimClient.get_no_auth(path)
|
||||
assert response.status_code == 200, f"{path} returned {response.status_code}"
|
||||
|
||||
|
||||
@@ -158,7 +156,7 @@ def test_last_used_at_updated_after_scim_request(
|
||||
assert active.last_used_at is None
|
||||
|
||||
# Make a SCIM request, then verify last_used_at is set
|
||||
assert ScimTokenManager.scim_get("/Users", token.raw_token).status_code == 200
|
||||
assert ScimClient.get("/Users", token.raw_token).status_code == 200
|
||||
time.sleep(0.5)
|
||||
|
||||
active_after = ScimTokenManager.get_active(user_performing_action=admin_user)
|
||||
|
||||
520
backend/tests/integration/tests/scim/test_scim_users.py
Normal file
520
backend/tests/integration/tests/scim/test_scim_users.py
Normal file
@@ -0,0 +1,520 @@
|
||||
"""Integration tests for SCIM user provisioning endpoints.
|
||||
|
||||
Covers the full user lifecycle as driven by an IdP (Okta / Azure AD):
|
||||
1. Create a user via POST /Users
|
||||
2. Retrieve a user via GET /Users/{id}
|
||||
3. List, filter, and paginate users via GET /Users
|
||||
4. Replace a user via PUT /Users/{id}
|
||||
5. Patch a user (deactivate/reactivate) via PATCH /Users/{id}
|
||||
6. Delete a user via DELETE /Users/{id}
|
||||
7. Error cases: missing externalId, duplicate email, not-found, seat limit
|
||||
|
||||
All tests are parameterized across IdP request styles:
|
||||
- **Okta**: lowercase PATCH ops, minimal payloads (core schema only).
|
||||
- **Entra**: capitalized ops (``"Replace"``), enterprise extension data
|
||||
(department, manager), and structured email arrays.
|
||||
|
||||
The server normalizes both — these tests verify that all IdP-specific fields
|
||||
are accepted and round-tripped correctly.
|
||||
|
||||
Auth, revoked-token, and service-discovery tests live in test_scim_tokens.py.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
import pytest
|
||||
import redis
|
||||
import requests
|
||||
|
||||
from ee.onyx.server.license.models import LicenseMetadata
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
from ee.onyx.server.license.models import PlanType
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import REDIS_DB_NUMBER
|
||||
from onyx.configs.app_configs import REDIS_HOST
|
||||
from onyx.configs.app_configs import REDIS_PORT
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from tests.integration.common_utils.managers.scim_client import ScimClient
|
||||
from tests.integration.common_utils.managers.scim_token import ScimTokenManager
|
||||
|
||||
|
||||
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
|
||||
SCIM_ENTERPRISE_USER_SCHEMA = (
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
|
||||
)
|
||||
SCIM_PATCH_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:PatchOp"
|
||||
|
||||
_LICENSE_REDIS_KEY = "public:license:metadata"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=["okta", "entra"])
|
||||
def idp_style(request: pytest.FixtureRequest) -> str:
|
||||
"""Parameterized IdP style — runs every test with both Okta and Entra request formats."""
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def scim_token(idp_style: str) -> str:
|
||||
"""Create a single SCIM token shared across all tests in this module.
|
||||
|
||||
Creating a new token revokes the previous one, so we create exactly once
|
||||
per IdP-style run and reuse. Uses UserManager directly to avoid
|
||||
fixture-scope conflicts with the function-scoped admin_user fixture.
|
||||
"""
|
||||
from tests.integration.common_utils.constants import ADMIN_USER_NAME
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
try:
|
||||
admin = UserManager.create(name=ADMIN_USER_NAME)
|
||||
except Exception:
|
||||
admin = UserManager.login_as_user(
|
||||
DATestUser(
|
||||
id="",
|
||||
email=build_email(ADMIN_USER_NAME),
|
||||
password=DEFAULT_PASSWORD,
|
||||
headers=GENERAL_HEADERS,
|
||||
role=UserRole.ADMIN,
|
||||
is_active=True,
|
||||
)
|
||||
)
|
||||
|
||||
token = ScimTokenManager.create(
|
||||
name=f"scim-user-tests-{idp_style}",
|
||||
user_performing_action=admin,
|
||||
).raw_token
|
||||
assert token is not None
|
||||
return token
|
||||
|
||||
|
||||
def _make_user_resource(
|
||||
email: str,
|
||||
external_id: str,
|
||||
given_name: str = "Test",
|
||||
family_name: str = "User",
|
||||
active: bool = True,
|
||||
idp_style: str = "okta",
|
||||
department: str | None = None,
|
||||
manager_id: str | None = None,
|
||||
) -> dict:
|
||||
"""Build a SCIM UserResource payload appropriate for the IdP style.
|
||||
|
||||
Entra sends richer payloads including enterprise extension data (department,
|
||||
manager), structured email arrays, and the enterprise schema URN. Okta sends
|
||||
minimal payloads with just core user fields.
|
||||
"""
|
||||
resource: dict = {
|
||||
"schemas": [SCIM_USER_SCHEMA],
|
||||
"userName": email,
|
||||
"externalId": external_id,
|
||||
"name": {
|
||||
"givenName": given_name,
|
||||
"familyName": family_name,
|
||||
},
|
||||
"active": active,
|
||||
}
|
||||
if idp_style == "entra":
|
||||
dept = department or "Engineering"
|
||||
mgr = manager_id or "mgr-ext-001"
|
||||
resource["schemas"].append(SCIM_ENTERPRISE_USER_SCHEMA)
|
||||
resource[SCIM_ENTERPRISE_USER_SCHEMA] = {
|
||||
"department": dept,
|
||||
"manager": {"value": mgr},
|
||||
}
|
||||
resource["emails"] = [
|
||||
{"value": email, "type": "work", "primary": True},
|
||||
]
|
||||
return resource
|
||||
|
||||
|
||||
def _make_patch_request(operations: list[dict], idp_style: str = "okta") -> dict:
|
||||
"""Build a SCIM PatchOp payload, applying IdP-specific operation casing.
|
||||
|
||||
Entra sends capitalized operations (e.g. ``"Replace"`` instead of
|
||||
``"replace"``). The server's ``normalize_operation`` validator lowercases
|
||||
them — these tests verify that both casings are accepted.
|
||||
"""
|
||||
cased_operations = []
|
||||
for operation in operations:
|
||||
cased = dict(operation)
|
||||
if idp_style == "entra":
|
||||
cased["op"] = operation["op"].capitalize()
|
||||
cased_operations.append(cased)
|
||||
return {
|
||||
"schemas": [SCIM_PATCH_SCHEMA],
|
||||
"Operations": cased_operations,
|
||||
}
|
||||
|
||||
|
||||
def _create_scim_user(
|
||||
token: str,
|
||||
email: str,
|
||||
external_id: str,
|
||||
idp_style: str = "okta",
|
||||
) -> requests.Response:
|
||||
return ScimClient.post(
|
||||
"/Users",
|
||||
token,
|
||||
json=_make_user_resource(email, external_id, idp_style=idp_style),
|
||||
)
|
||||
|
||||
|
||||
def _assert_entra_extension(
|
||||
body: dict,
|
||||
expected_department: str = "Engineering",
|
||||
expected_manager: str = "mgr-ext-001",
|
||||
) -> None:
|
||||
"""Assert that Entra enterprise extension fields round-tripped correctly."""
|
||||
assert SCIM_ENTERPRISE_USER_SCHEMA in body["schemas"]
|
||||
ext = body[SCIM_ENTERPRISE_USER_SCHEMA]
|
||||
assert ext["department"] == expected_department
|
||||
assert ext["manager"]["value"] == expected_manager
|
||||
|
||||
|
||||
def _assert_entra_emails(body: dict, expected_email: str) -> None:
|
||||
"""Assert that structured email metadata round-tripped correctly."""
|
||||
emails = body["emails"]
|
||||
assert len(emails) >= 1
|
||||
work_email = next(e for e in emails if e.get("type") == "work")
|
||||
assert work_email["value"] == expected_email
|
||||
assert work_email["primary"] is True
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle: create -> get -> list -> replace -> patch -> delete
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_create_user(scim_token: str, idp_style: str) -> None:
|
||||
"""POST /Users creates a provisioned user and returns 201."""
|
||||
email = f"scim_create_{idp_style}@example.com"
|
||||
ext_id = f"ext-create-{idp_style}"
|
||||
resp = _create_scim_user(scim_token, email, ext_id, idp_style)
|
||||
assert resp.status_code == 201
|
||||
|
||||
body = resp.json()
|
||||
assert body["userName"] == email
|
||||
assert body["externalId"] == ext_id
|
||||
assert body["active"] is True
|
||||
assert body["id"] # UUID assigned by server
|
||||
assert body["meta"]["resourceType"] == "User"
|
||||
assert body["name"]["givenName"] == "Test"
|
||||
assert body["name"]["familyName"] == "User"
|
||||
|
||||
if idp_style == "entra":
|
||||
_assert_entra_extension(body)
|
||||
_assert_entra_emails(body, email)
|
||||
|
||||
|
||||
def test_get_user(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Users/{id} returns the user resource with all stored fields."""
|
||||
email = f"scim_get_{idp_style}@example.com"
|
||||
ext_id = f"ext-get-{idp_style}"
|
||||
created = _create_scim_user(scim_token, email, ext_id, idp_style).json()
|
||||
|
||||
resp = ScimClient.get(f"/Users/{created['id']}", scim_token)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["id"] == created["id"]
|
||||
assert body["userName"] == email
|
||||
assert body["externalId"] == ext_id
|
||||
assert body["name"]["givenName"] == "Test"
|
||||
assert body["name"]["familyName"] == "User"
|
||||
|
||||
if idp_style == "entra":
|
||||
_assert_entra_extension(body)
|
||||
_assert_entra_emails(body, email)
|
||||
|
||||
|
||||
def test_list_users(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Users returns a ListResponse containing provisioned users."""
|
||||
email = f"scim_list_{idp_style}@example.com"
|
||||
_create_scim_user(scim_token, email, f"ext-list-{idp_style}", idp_style)
|
||||
|
||||
resp = ScimClient.get("/Users", scim_token)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["totalResults"] >= 1
|
||||
emails = [r["userName"] for r in body["Resources"]]
|
||||
assert email in emails
|
||||
|
||||
|
||||
def test_list_users_pagination(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Users with startIndex and count returns correct pagination."""
|
||||
_create_scim_user(
|
||||
scim_token,
|
||||
f"scim_page1_{idp_style}@example.com",
|
||||
f"ext-page-1-{idp_style}",
|
||||
idp_style,
|
||||
)
|
||||
_create_scim_user(
|
||||
scim_token,
|
||||
f"scim_page2_{idp_style}@example.com",
|
||||
f"ext-page-2-{idp_style}",
|
||||
idp_style,
|
||||
)
|
||||
|
||||
resp = ScimClient.get("/Users?startIndex=1&count=1", scim_token)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["startIndex"] == 1
|
||||
assert body["itemsPerPage"] == 1
|
||||
assert body["totalResults"] >= 2
|
||||
assert len(body["Resources"]) == 1
|
||||
|
||||
|
||||
def test_filter_users_by_username(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Users?filter=userName eq '...' returns only matching users."""
|
||||
email = f"scim_filter_{idp_style}@example.com"
|
||||
_create_scim_user(scim_token, email, f"ext-filter-{idp_style}", idp_style)
|
||||
|
||||
resp = ScimClient.get(f'/Users?filter=userName eq "{email}"', scim_token)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["totalResults"] == 1
|
||||
assert body["Resources"][0]["userName"] == email
|
||||
|
||||
|
||||
def test_replace_user(scim_token: str, idp_style: str) -> None:
|
||||
"""PUT /Users/{id} replaces the user resource including enterprise fields."""
|
||||
email = f"scim_replace_{idp_style}@example.com"
|
||||
ext_id = f"ext-replace-{idp_style}"
|
||||
created = _create_scim_user(scim_token, email, ext_id, idp_style).json()
|
||||
|
||||
updated_resource = _make_user_resource(
|
||||
email=email,
|
||||
external_id=ext_id,
|
||||
given_name="Updated",
|
||||
family_name="Name",
|
||||
idp_style=idp_style,
|
||||
department="Product",
|
||||
)
|
||||
resp = ScimClient.put(f"/Users/{created['id']}", scim_token, json=updated_resource)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["name"]["givenName"] == "Updated"
|
||||
assert body["name"]["familyName"] == "Name"
|
||||
|
||||
if idp_style == "entra":
|
||||
_assert_entra_extension(body, expected_department="Product")
|
||||
_assert_entra_emails(body, email)
|
||||
|
||||
|
||||
def test_patch_deactivate_user(scim_token: str, idp_style: str) -> None:
|
||||
"""PATCH /Users/{id} with active=false deactivates the user."""
|
||||
created = _create_scim_user(
|
||||
scim_token,
|
||||
f"scim_deactivate_{idp_style}@example.com",
|
||||
f"ext-deactivate-{idp_style}",
|
||||
idp_style,
|
||||
).json()
|
||||
assert created["active"] is True
|
||||
|
||||
resp = ScimClient.patch(
|
||||
f"/Users/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[{"op": "replace", "path": "active", "value": False}], idp_style
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["active"] is False
|
||||
|
||||
# Confirm via GET
|
||||
get_resp = ScimClient.get(f"/Users/{created['id']}", scim_token)
|
||||
assert get_resp.json()["active"] is False
|
||||
|
||||
|
||||
def test_patch_reactivate_user(scim_token: str, idp_style: str) -> None:
|
||||
"""PATCH active=true reactivates a previously deactivated user."""
|
||||
created = _create_scim_user(
|
||||
scim_token,
|
||||
f"scim_reactivate_{idp_style}@example.com",
|
||||
f"ext-reactivate-{idp_style}",
|
||||
idp_style,
|
||||
).json()
|
||||
|
||||
# Deactivate
|
||||
deactivate_resp = ScimClient.patch(
|
||||
f"/Users/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[{"op": "replace", "path": "active", "value": False}], idp_style
|
||||
),
|
||||
)
|
||||
assert deactivate_resp.status_code == 200
|
||||
assert deactivate_resp.json()["active"] is False
|
||||
|
||||
# Reactivate
|
||||
resp = ScimClient.patch(
|
||||
f"/Users/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[{"op": "replace", "path": "active", "value": True}], idp_style
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["active"] is True
|
||||
|
||||
|
||||
def test_delete_user(scim_token: str, idp_style: str) -> None:
|
||||
"""DELETE /Users/{id} deactivates and removes the SCIM mapping."""
|
||||
created = _create_scim_user(
|
||||
scim_token,
|
||||
f"scim_delete_{idp_style}@example.com",
|
||||
f"ext-delete-{idp_style}",
|
||||
idp_style,
|
||||
).json()
|
||||
|
||||
resp = ScimClient.delete(f"/Users/{created['id']}", scim_token)
|
||||
assert resp.status_code == 204
|
||||
|
||||
# Second DELETE returns 404 per RFC 7644 §3.6 (mapping removed)
|
||||
resp2 = ScimClient.delete(f"/Users/{created['id']}", scim_token)
|
||||
assert resp2.status_code == 404
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Error cases
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_create_user_missing_external_id(scim_token: str, idp_style: str) -> None:
|
||||
"""POST /Users without externalId succeeds (RFC 7643: externalId is optional)."""
|
||||
email = f"scim_no_extid_{idp_style}@example.com"
|
||||
resp = ScimClient.post(
|
||||
"/Users",
|
||||
scim_token,
|
||||
json={
|
||||
"schemas": [SCIM_USER_SCHEMA],
|
||||
"userName": email,
|
||||
"active": True,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
body = resp.json()
|
||||
assert body["userName"] == email
|
||||
assert body.get("externalId") is None
|
||||
|
||||
|
||||
def test_create_user_duplicate_email(scim_token: str, idp_style: str) -> None:
|
||||
"""POST /Users with an already-taken email returns 409."""
|
||||
email = f"scim_dup_{idp_style}@example.com"
|
||||
resp1 = _create_scim_user(scim_token, email, f"ext-dup-1-{idp_style}", idp_style)
|
||||
assert resp1.status_code == 201
|
||||
|
||||
resp2 = _create_scim_user(scim_token, email, f"ext-dup-2-{idp_style}", idp_style)
|
||||
assert resp2.status_code == 409
|
||||
|
||||
|
||||
def test_get_nonexistent_user(scim_token: str) -> None:
|
||||
"""GET /Users/{bad-id} returns 404."""
|
||||
resp = ScimClient.get("/Users/00000000-0000-0000-0000-000000000000", scim_token)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_filter_users_by_external_id(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Users?filter=externalId eq '...' returns the matching user."""
|
||||
ext_id = f"ext-unique-filter-id-{idp_style}"
|
||||
_create_scim_user(
|
||||
scim_token, f"scim_extfilter_{idp_style}@example.com", ext_id, idp_style
|
||||
)
|
||||
|
||||
resp = ScimClient.get(f'/Users?filter=externalId eq "{ext_id}"', scim_token)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["totalResults"] == 1
|
||||
assert body["Resources"][0]["externalId"] == ext_id
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Seat-limit enforcement
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def _seed_license(r: redis.Redis, seats: int) -> None:
|
||||
"""Write a LicenseMetadata entry into Redis with the given seat cap."""
|
||||
now = datetime.now(timezone.utc)
|
||||
metadata = LicenseMetadata(
|
||||
tenant_id="public",
|
||||
organization_name="Test Org",
|
||||
seats=seats,
|
||||
used_seats=0, # check_seat_availability recalculates from DB
|
||||
plan_type=PlanType.ANNUAL,
|
||||
issued_at=now,
|
||||
expires_at=now + timedelta(days=365),
|
||||
status=ApplicationStatus.ACTIVE,
|
||||
source=LicenseSource.MANUAL_UPLOAD,
|
||||
)
|
||||
r.set(_LICENSE_REDIS_KEY, metadata.model_dump_json(), ex=300)
|
||||
|
||||
|
||||
def test_create_user_seat_limit(scim_token: str, idp_style: str) -> None:
|
||||
"""POST /Users returns 403 when the seat limit is reached."""
|
||||
r = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB_NUMBER)
|
||||
|
||||
# admin_user already occupies 1 seat; cap at 1 -> full
|
||||
_seed_license(r, seats=1)
|
||||
|
||||
try:
|
||||
resp = _create_scim_user(
|
||||
scim_token,
|
||||
f"scim_blocked_{idp_style}@example.com",
|
||||
f"ext-blocked-{idp_style}",
|
||||
idp_style,
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
assert "seat" in resp.json()["detail"].lower()
|
||||
finally:
|
||||
r.delete(_LICENSE_REDIS_KEY)
|
||||
|
||||
|
||||
def test_reactivate_user_seat_limit(scim_token: str, idp_style: str) -> None:
|
||||
"""PATCH active=true returns 403 when the seat limit is reached."""
|
||||
# Create and deactivate a user (before license is seeded)
|
||||
created = _create_scim_user(
|
||||
scim_token,
|
||||
f"scim_reactivate_blocked_{idp_style}@example.com",
|
||||
f"ext-reactivate-blocked-{idp_style}",
|
||||
idp_style,
|
||||
).json()
|
||||
assert created["active"] is True
|
||||
|
||||
deactivate_resp = ScimClient.patch(
|
||||
f"/Users/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[{"op": "replace", "path": "active", "value": False}], idp_style
|
||||
),
|
||||
)
|
||||
assert deactivate_resp.status_code == 200
|
||||
assert deactivate_resp.json()["active"] is False
|
||||
|
||||
r = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB_NUMBER)
|
||||
|
||||
# Seed license capped at current active users -> reactivation should fail
|
||||
_seed_license(r, seats=1)
|
||||
|
||||
try:
|
||||
resp = ScimClient.patch(
|
||||
f"/Users/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[{"op": "replace", "path": "active", "value": True}], idp_style
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
assert "seat" in resp.json()["detail"].lower()
|
||||
finally:
|
||||
r.delete(_LICENSE_REDIS_KEY)
|
||||
@@ -0,0 +1,121 @@
|
||||
"""Integration tests for Slack user deactivation and reactivation via admin endpoints.
|
||||
|
||||
Verifies that:
|
||||
- Slack users can be deactivated by admins
|
||||
- Deactivated Slack users can be reactivated by admins
|
||||
- Reactivation is blocked when the seat limit is reached
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
|
||||
import redis
|
||||
import requests
|
||||
|
||||
from ee.onyx.server.license.models import LicenseMetadata
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
from ee.onyx.server.license.models import PlanType
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import REDIS_DB_NUMBER
|
||||
from onyx.configs.app_configs import REDIS_HOST
|
||||
from onyx.configs.app_configs import REDIS_PORT
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
_LICENSE_REDIS_KEY = "public:license:metadata"
|
||||
|
||||
|
||||
def _seed_license(r: redis.Redis, seats: int) -> None:
|
||||
now = datetime.utcnow()
|
||||
metadata = LicenseMetadata(
|
||||
tenant_id="public",
|
||||
organization_name="Test Org",
|
||||
seats=seats,
|
||||
used_seats=0,
|
||||
plan_type=PlanType.ANNUAL,
|
||||
issued_at=now,
|
||||
expires_at=now + timedelta(days=365),
|
||||
status=ApplicationStatus.ACTIVE,
|
||||
source=LicenseSource.MANUAL_UPLOAD,
|
||||
)
|
||||
r.set(_LICENSE_REDIS_KEY, metadata.model_dump_json(), ex=300)
|
||||
|
||||
|
||||
def _clear_license(r: redis.Redis) -> None:
|
||||
r.delete(_LICENSE_REDIS_KEY)
|
||||
|
||||
|
||||
def _redis() -> redis.Redis:
|
||||
return redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB_NUMBER)
|
||||
|
||||
|
||||
def _get_user_is_active(email: str, admin_user: DATestUser) -> bool:
|
||||
"""Look up a user's is_active flag via the admin users list endpoint."""
|
||||
result = UserManager.get_user_page(
|
||||
user_performing_action=admin_user,
|
||||
search_query=email,
|
||||
)
|
||||
matching = [u for u in result.items if u.email == email]
|
||||
assert len(matching) == 1, f"Expected exactly 1 user with email {email}"
|
||||
return matching[0].is_active
|
||||
|
||||
|
||||
def test_slack_user_deactivate_and_reactivate(reset: None) -> None: # noqa: ARG001
|
||||
"""Admin can deactivate and then reactivate a Slack user."""
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
|
||||
slack_user = UserManager.create(name="slack_test_user")
|
||||
slack_user = UserManager.set_role(
|
||||
user_to_set=slack_user,
|
||||
target_role=UserRole.SLACK_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
|
||||
# Deactivate the Slack user
|
||||
UserManager.set_status(
|
||||
slack_user, target_status=False, user_performing_action=admin_user
|
||||
)
|
||||
assert _get_user_is_active(slack_user.email, admin_user) is False
|
||||
|
||||
# Reactivate the Slack user
|
||||
UserManager.set_status(
|
||||
slack_user, target_status=True, user_performing_action=admin_user
|
||||
)
|
||||
assert _get_user_is_active(slack_user.email, admin_user) is True
|
||||
|
||||
|
||||
def test_slack_user_reactivation_blocked_by_seat_limit(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Reactivating a deactivated Slack user returns 402 when seats are full."""
|
||||
r = _redis()
|
||||
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
|
||||
slack_user = UserManager.create(name="slack_test_user")
|
||||
slack_user = UserManager.set_role(
|
||||
user_to_set=slack_user,
|
||||
target_role=UserRole.SLACK_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
|
||||
UserManager.set_status(
|
||||
slack_user, target_status=False, user_performing_action=admin_user
|
||||
)
|
||||
|
||||
# License allows 1 seat — only admin counts
|
||||
_seed_license(r, seats=1)
|
||||
|
||||
try:
|
||||
response = requests.patch(
|
||||
url=f"{API_SERVER_URL}/manage/admin/activate-user",
|
||||
json={"user_email": slack_user.email},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 402
|
||||
finally:
|
||||
_clear_license(r)
|
||||
@@ -1,11 +1,20 @@
|
||||
"""Tests for license database CRUD operations."""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from ee.onyx.db.license import check_seat_availability
|
||||
from ee.onyx.db.license import delete_license
|
||||
from ee.onyx.db.license import get_license
|
||||
from ee.onyx.db.license import upsert_license
|
||||
from ee.onyx.server.license.models import LicenseMetadata
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
from ee.onyx.server.license.models import PlanType
|
||||
from onyx.db.models import License
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
|
||||
|
||||
class TestGetLicense:
|
||||
@@ -100,3 +109,108 @@ class TestDeleteLicense:
|
||||
assert result is False
|
||||
mock_session.delete.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
|
||||
def _make_license_metadata(seats: int = 10) -> LicenseMetadata:
|
||||
now = datetime.now(timezone.utc)
|
||||
return LicenseMetadata(
|
||||
tenant_id="public",
|
||||
seats=seats,
|
||||
used_seats=0,
|
||||
plan_type=PlanType.ANNUAL,
|
||||
issued_at=now,
|
||||
expires_at=now + timedelta(days=365),
|
||||
status=ApplicationStatus.ACTIVE,
|
||||
source=LicenseSource.MANUAL_UPLOAD,
|
||||
)
|
||||
|
||||
|
||||
class TestCheckSeatAvailabilitySelfHosted:
|
||||
"""Seat checks for self-hosted (MULTI_TENANT=False)."""
|
||||
|
||||
@patch("ee.onyx.db.license.get_license_metadata", return_value=None)
|
||||
def test_no_license_means_unlimited(self, _mock_meta: MagicMock) -> None:
|
||||
result = check_seat_availability(MagicMock(), seats_needed=1)
|
||||
assert result.available is True
|
||||
|
||||
@patch("ee.onyx.db.license.get_used_seats", return_value=5)
|
||||
@patch("ee.onyx.db.license.get_license_metadata")
|
||||
def test_seats_available(self, mock_meta: MagicMock, _mock_used: MagicMock) -> None:
|
||||
mock_meta.return_value = _make_license_metadata(seats=10)
|
||||
result = check_seat_availability(MagicMock(), seats_needed=1)
|
||||
assert result.available is True
|
||||
|
||||
@patch("ee.onyx.db.license.get_used_seats", return_value=10)
|
||||
@patch("ee.onyx.db.license.get_license_metadata")
|
||||
def test_seats_full_blocks_creation(
|
||||
self, mock_meta: MagicMock, _mock_used: MagicMock
|
||||
) -> None:
|
||||
mock_meta.return_value = _make_license_metadata(seats=10)
|
||||
result = check_seat_availability(MagicMock(), seats_needed=1)
|
||||
assert result.available is False
|
||||
assert result.error_message is not None
|
||||
assert "10 of 10" in result.error_message
|
||||
|
||||
@patch("ee.onyx.db.license.get_used_seats", return_value=10)
|
||||
@patch("ee.onyx.db.license.get_license_metadata")
|
||||
def test_exactly_at_capacity_allows_no_more(
|
||||
self, mock_meta: MagicMock, _mock_used: MagicMock
|
||||
) -> None:
|
||||
"""Filling to 100% is allowed; exceeding is not."""
|
||||
mock_meta.return_value = _make_license_metadata(seats=10)
|
||||
result = check_seat_availability(MagicMock(), seats_needed=1)
|
||||
assert result.available is False
|
||||
|
||||
@patch("ee.onyx.db.license.get_used_seats", return_value=9)
|
||||
@patch("ee.onyx.db.license.get_license_metadata")
|
||||
def test_filling_to_capacity_is_allowed(
|
||||
self, mock_meta: MagicMock, _mock_used: MagicMock
|
||||
) -> None:
|
||||
mock_meta.return_value = _make_license_metadata(seats=10)
|
||||
result = check_seat_availability(MagicMock(), seats_needed=1)
|
||||
assert result.available is True
|
||||
|
||||
|
||||
class TestCheckSeatAvailabilityMultiTenant:
|
||||
"""Seat checks for multi-tenant cloud (MULTI_TENANT=True).
|
||||
|
||||
Verifies that get_used_seats takes the MULTI_TENANT branch
|
||||
and delegates to get_tenant_count.
|
||||
"""
|
||||
|
||||
@patch("ee.onyx.db.license.MULTI_TENANT", True)
|
||||
@patch(
|
||||
"ee.onyx.server.tenants.user_mapping.get_tenant_count",
|
||||
return_value=5,
|
||||
)
|
||||
@patch("ee.onyx.db.license.get_license_metadata")
|
||||
def test_seats_available_multi_tenant(
|
||||
self,
|
||||
mock_meta: MagicMock,
|
||||
mock_tenant_count: MagicMock,
|
||||
) -> None:
|
||||
mock_meta.return_value = _make_license_metadata(seats=10)
|
||||
result = check_seat_availability(
|
||||
MagicMock(), seats_needed=1, tenant_id="tenant-abc"
|
||||
)
|
||||
assert result.available is True
|
||||
mock_tenant_count.assert_called_once_with("tenant-abc")
|
||||
|
||||
@patch("ee.onyx.db.license.MULTI_TENANT", True)
|
||||
@patch(
|
||||
"ee.onyx.server.tenants.user_mapping.get_tenant_count",
|
||||
return_value=10,
|
||||
)
|
||||
@patch("ee.onyx.db.license.get_license_metadata")
|
||||
def test_seats_full_multi_tenant(
|
||||
self,
|
||||
mock_meta: MagicMock,
|
||||
mock_tenant_count: MagicMock,
|
||||
) -> None:
|
||||
mock_meta.return_value = _make_license_metadata(seats=10)
|
||||
result = check_seat_availability(
|
||||
MagicMock(), seats_needed=1, tenant_id="tenant-abc"
|
||||
)
|
||||
assert result.available is False
|
||||
assert result.error_message is not None
|
||||
mock_tenant_count.assert_called_once_with("tenant-abc")
|
||||
|
||||
@@ -0,0 +1,291 @@
|
||||
"""Tests for the _impl functions' redis_locking parameter.
|
||||
|
||||
Verifies that:
|
||||
- redis_locking=True acquires/releases Redis locks and clears queued keys
|
||||
- redis_locking=False skips all Redis operations entirely
|
||||
- Both paths execute the same business logic (DB lookup, status check)
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
delete_user_file_impl,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
process_user_file_impl,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
project_sync_user_file_impl,
|
||||
)
|
||||
|
||||
TASKS_MODULE = "onyx.background.celery.tasks.user_file_processing.tasks"
|
||||
|
||||
|
||||
def _mock_session_returning_none() -> MagicMock:
|
||||
"""Return a mock session whose .get() returns None (file not found)."""
|
||||
session = MagicMock()
|
||||
session.get.return_value = None
|
||||
session.execute.return_value.scalar_one_or_none.return_value = None
|
||||
return session
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# process_user_file_impl
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProcessUserFileImpl:
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_true_acquires_and_releases_lock(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
redis_client = MagicMock()
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.owned.return_value = True
|
||||
redis_client.lock.return_value = lock
|
||||
mock_get_redis.return_value = redis_client
|
||||
|
||||
session = _mock_session_returning_none()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
user_file_id = str(uuid4())
|
||||
process_user_file_impl(
|
||||
user_file_id=user_file_id,
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=True,
|
||||
)
|
||||
|
||||
mock_get_redis.assert_called_once_with(tenant_id="test-tenant")
|
||||
redis_client.delete.assert_called_once()
|
||||
lock.acquire.assert_called_once_with(blocking=False)
|
||||
lock.release.assert_called_once()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_true_skips_when_lock_held(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
redis_client = MagicMock()
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = False
|
||||
redis_client.lock.return_value = lock
|
||||
mock_get_redis.return_value = redis_client
|
||||
|
||||
process_user_file_impl(
|
||||
user_file_id=str(uuid4()),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=True,
|
||||
)
|
||||
|
||||
lock.acquire.assert_called_once()
|
||||
mock_get_session.assert_not_called()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_false_skips_redis_entirely(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
session = _mock_session_returning_none()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
process_user_file_impl(
|
||||
user_file_id=str(uuid4()),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
mock_get_redis.assert_not_called()
|
||||
mock_get_session.assert_called_once()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_both_paths_call_db_get(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
"""Both redis_locking=True and False should call db_session.get(UserFile, ...)."""
|
||||
redis_client = MagicMock()
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.owned.return_value = True
|
||||
redis_client.lock.return_value = lock
|
||||
mock_get_redis.return_value = redis_client
|
||||
|
||||
session = _mock_session_returning_none()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
uid = str(uuid4())
|
||||
|
||||
process_user_file_impl(user_file_id=uid, tenant_id="t", redis_locking=True)
|
||||
call_count_true = session.get.call_count
|
||||
|
||||
session.reset_mock()
|
||||
mock_get_session.reset_mock()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
process_user_file_impl(user_file_id=uid, tenant_id="t", redis_locking=False)
|
||||
call_count_false = session.get.call_count
|
||||
|
||||
assert call_count_true == call_count_false == 1
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# delete_user_file_impl
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDeleteUserFileImpl:
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_true_acquires_and_releases_lock(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
redis_client = MagicMock()
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.owned.return_value = True
|
||||
redis_client.lock.return_value = lock
|
||||
mock_get_redis.return_value = redis_client
|
||||
|
||||
session = _mock_session_returning_none()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
delete_user_file_impl(
|
||||
user_file_id=str(uuid4()),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=True,
|
||||
)
|
||||
|
||||
mock_get_redis.assert_called_once()
|
||||
lock.acquire.assert_called_once_with(blocking=False)
|
||||
lock.release.assert_called_once()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_true_skips_when_lock_held(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
redis_client = MagicMock()
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = False
|
||||
redis_client.lock.return_value = lock
|
||||
mock_get_redis.return_value = redis_client
|
||||
|
||||
delete_user_file_impl(
|
||||
user_file_id=str(uuid4()),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=True,
|
||||
)
|
||||
|
||||
lock.acquire.assert_called_once()
|
||||
mock_get_session.assert_not_called()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_false_skips_redis_entirely(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
session = _mock_session_returning_none()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
delete_user_file_impl(
|
||||
user_file_id=str(uuid4()),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
mock_get_redis.assert_not_called()
|
||||
mock_get_session.assert_called_once()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# project_sync_user_file_impl
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProjectSyncUserFileImpl:
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_true_acquires_and_releases_lock(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
redis_client = MagicMock()
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.owned.return_value = True
|
||||
redis_client.lock.return_value = lock
|
||||
mock_get_redis.return_value = redis_client
|
||||
|
||||
session = _mock_session_returning_none()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=str(uuid4()),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=True,
|
||||
)
|
||||
|
||||
mock_get_redis.assert_called_once()
|
||||
redis_client.delete.assert_called_once()
|
||||
lock.acquire.assert_called_once_with(blocking=False)
|
||||
lock.release.assert_called_once()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_true_skips_when_lock_held(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
redis_client = MagicMock()
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = False
|
||||
redis_client.lock.return_value = lock
|
||||
mock_get_redis.return_value = redis_client
|
||||
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=str(uuid4()),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=True,
|
||||
)
|
||||
|
||||
lock.acquire.assert_called_once()
|
||||
mock_get_session.assert_not_called()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_false_skips_redis_entirely(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
session = _mock_session_returning_none()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=str(uuid4()),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
mock_get_redis.assert_not_called()
|
||||
mock_get_session.assert_called_once()
|
||||
@@ -0,0 +1,421 @@
|
||||
"""Tests for no-vector-DB user file processing paths.
|
||||
|
||||
Verifies that when DISABLE_VECTOR_DB is True:
|
||||
- process_user_file_impl calls _process_user_file_without_vector_db (not indexing)
|
||||
- _process_user_file_without_vector_db extracts text, counts tokens, stores plaintext,
|
||||
sets status=COMPLETED and chunk_count=0
|
||||
- delete_user_file_impl skips vector DB chunk deletion
|
||||
- project_sync_user_file_impl skips vector DB metadata update
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
_process_user_file_without_vector_db,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
delete_user_file_impl,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
process_user_file_impl,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
project_sync_user_file_impl,
|
||||
)
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.enums import UserFileStatus
|
||||
|
||||
TASKS_MODULE = "onyx.background.celery.tasks.user_file_processing.tasks"
|
||||
LLM_FACTORY_MODULE = "onyx.llm.factory"
|
||||
|
||||
|
||||
def _make_documents(texts: list[str]) -> list[Document]:
|
||||
"""Build a list of Document objects with the given section texts."""
|
||||
return [
|
||||
Document(
|
||||
id=str(uuid4()),
|
||||
source=DocumentSource.USER_FILE,
|
||||
sections=[TextSection(text=t)],
|
||||
semantic_identifier=f"test-doc-{i}",
|
||||
metadata={},
|
||||
)
|
||||
for i, t in enumerate(texts)
|
||||
]
|
||||
|
||||
|
||||
def _make_user_file(
|
||||
*,
|
||||
status: UserFileStatus = UserFileStatus.PROCESSING,
|
||||
file_id: str = "test-file-id",
|
||||
name: str = "test.txt",
|
||||
) -> MagicMock:
|
||||
"""Return a MagicMock mimicking a UserFile ORM instance."""
|
||||
uf = MagicMock()
|
||||
uf.id = uuid4()
|
||||
uf.file_id = file_id
|
||||
uf.name = name
|
||||
uf.status = status
|
||||
uf.token_count = None
|
||||
uf.chunk_count = None
|
||||
uf.last_project_sync_at = None
|
||||
uf.projects = []
|
||||
uf.assistants = []
|
||||
uf.needs_project_sync = True
|
||||
uf.needs_persona_sync = True
|
||||
return uf
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# _process_user_file_without_vector_db — direct tests
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProcessUserFileWithoutVectorDb:
|
||||
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_default_llm")
|
||||
def test_extracts_and_combines_text(
|
||||
self,
|
||||
mock_get_llm: MagicMock, # noqa: ARG002
|
||||
mock_get_encode: MagicMock,
|
||||
mock_store_plaintext: MagicMock,
|
||||
) -> None:
|
||||
mock_encode = MagicMock(return_value=[1, 2, 3, 4, 5])
|
||||
mock_get_encode.return_value = mock_encode
|
||||
|
||||
uf = _make_user_file()
|
||||
docs = _make_documents(["hello world", "foo bar"])
|
||||
db_session = MagicMock()
|
||||
|
||||
_process_user_file_without_vector_db(uf, docs, db_session)
|
||||
|
||||
stored_text = mock_store_plaintext.call_args.kwargs["plaintext_content"]
|
||||
assert "hello world" in stored_text
|
||||
assert "foo bar" in stored_text
|
||||
|
||||
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_default_llm")
|
||||
def test_computes_token_count(
|
||||
self,
|
||||
mock_get_llm: MagicMock, # noqa: ARG002
|
||||
mock_get_encode: MagicMock,
|
||||
mock_store_plaintext: MagicMock, # noqa: ARG002
|
||||
) -> None:
|
||||
mock_encode = MagicMock(return_value=list(range(42)))
|
||||
mock_get_encode.return_value = mock_encode
|
||||
|
||||
uf = _make_user_file()
|
||||
docs = _make_documents(["some text content"])
|
||||
db_session = MagicMock()
|
||||
|
||||
_process_user_file_without_vector_db(uf, docs, db_session)
|
||||
|
||||
assert uf.token_count == 42
|
||||
|
||||
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_default_llm")
|
||||
def test_token_count_falls_back_to_none_on_error(
|
||||
self,
|
||||
mock_get_llm: MagicMock,
|
||||
mock_get_encode: MagicMock, # noqa: ARG002
|
||||
mock_store_plaintext: MagicMock, # noqa: ARG002
|
||||
) -> None:
|
||||
mock_get_llm.side_effect = RuntimeError("No LLM configured")
|
||||
|
||||
uf = _make_user_file()
|
||||
docs = _make_documents(["text"])
|
||||
db_session = MagicMock()
|
||||
|
||||
_process_user_file_without_vector_db(uf, docs, db_session)
|
||||
|
||||
assert uf.token_count is None
|
||||
|
||||
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_default_llm")
|
||||
def test_stores_plaintext(
|
||||
self,
|
||||
mock_get_llm: MagicMock, # noqa: ARG002
|
||||
mock_get_encode: MagicMock,
|
||||
mock_store_plaintext: MagicMock,
|
||||
) -> None:
|
||||
mock_get_encode.return_value = MagicMock(return_value=[1])
|
||||
|
||||
uf = _make_user_file()
|
||||
docs = _make_documents(["content to store"])
|
||||
db_session = MagicMock()
|
||||
|
||||
_process_user_file_without_vector_db(uf, docs, db_session)
|
||||
|
||||
mock_store_plaintext.assert_called_once_with(
|
||||
user_file_id=uf.id,
|
||||
plaintext_content="content to store",
|
||||
)
|
||||
|
||||
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_default_llm")
|
||||
def test_sets_completed_status_and_zero_chunk_count(
|
||||
self,
|
||||
mock_get_llm: MagicMock, # noqa: ARG002
|
||||
mock_get_encode: MagicMock,
|
||||
mock_store_plaintext: MagicMock, # noqa: ARG002
|
||||
) -> None:
|
||||
mock_get_encode.return_value = MagicMock(return_value=[1])
|
||||
|
||||
uf = _make_user_file()
|
||||
docs = _make_documents(["text"])
|
||||
db_session = MagicMock()
|
||||
|
||||
_process_user_file_without_vector_db(uf, docs, db_session)
|
||||
|
||||
assert uf.status == UserFileStatus.COMPLETED
|
||||
assert uf.chunk_count == 0
|
||||
assert uf.last_project_sync_at is not None
|
||||
db_session.add.assert_called_once_with(uf)
|
||||
db_session.commit.assert_called_once()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_default_llm")
|
||||
def test_preserves_deleting_status(
|
||||
self,
|
||||
mock_get_llm: MagicMock, # noqa: ARG002
|
||||
mock_get_encode: MagicMock,
|
||||
mock_store_plaintext: MagicMock, # noqa: ARG002
|
||||
) -> None:
|
||||
mock_get_encode.return_value = MagicMock(return_value=[1])
|
||||
|
||||
uf = _make_user_file(status=UserFileStatus.DELETING)
|
||||
docs = _make_documents(["text"])
|
||||
db_session = MagicMock()
|
||||
|
||||
_process_user_file_without_vector_db(uf, docs, db_session)
|
||||
|
||||
assert uf.status == UserFileStatus.DELETING
|
||||
assert uf.chunk_count == 0
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# process_user_file_impl — branching on DISABLE_VECTOR_DB
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProcessImplBranching:
|
||||
@patch(f"{TASKS_MODULE}._process_user_file_without_vector_db")
|
||||
@patch(f"{TASKS_MODULE}._process_user_file_with_indexing")
|
||||
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
def test_calls_without_vector_db_when_disabled(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
mock_with_indexing: MagicMock,
|
||||
mock_without_vdb: MagicMock,
|
||||
) -> None:
|
||||
uf = _make_user_file()
|
||||
session = MagicMock()
|
||||
session.get.return_value = uf
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
connector_mock = MagicMock()
|
||||
connector_mock.load_from_state.return_value = [_make_documents(["hello"])]
|
||||
|
||||
with patch(f"{TASKS_MODULE}.LocalFileConnector", return_value=connector_mock):
|
||||
process_user_file_impl(
|
||||
user_file_id=str(uf.id),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
mock_without_vdb.assert_called_once()
|
||||
mock_with_indexing.assert_not_called()
|
||||
|
||||
@patch(f"{TASKS_MODULE}._process_user_file_without_vector_db")
|
||||
@patch(f"{TASKS_MODULE}._process_user_file_with_indexing")
|
||||
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", False)
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
def test_calls_with_indexing_when_vector_db_enabled(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
mock_with_indexing: MagicMock,
|
||||
mock_without_vdb: MagicMock,
|
||||
) -> None:
|
||||
uf = _make_user_file()
|
||||
session = MagicMock()
|
||||
session.get.return_value = uf
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
connector_mock = MagicMock()
|
||||
connector_mock.load_from_state.return_value = [_make_documents(["hello"])]
|
||||
|
||||
with patch(f"{TASKS_MODULE}.LocalFileConnector", return_value=connector_mock):
|
||||
process_user_file_impl(
|
||||
user_file_id=str(uf.id),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
mock_with_indexing.assert_called_once()
|
||||
mock_without_vdb.assert_not_called()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.run_indexing_pipeline")
|
||||
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
|
||||
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
def test_indexing_pipeline_not_called_when_disabled(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
mock_store_plaintext: MagicMock, # noqa: ARG002
|
||||
mock_run_pipeline: MagicMock,
|
||||
) -> None:
|
||||
"""End-to-end: verify run_indexing_pipeline is never invoked."""
|
||||
uf = _make_user_file()
|
||||
session = MagicMock()
|
||||
session.get.return_value = uf
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
connector_mock = MagicMock()
|
||||
connector_mock.load_from_state.return_value = [_make_documents(["content"])]
|
||||
|
||||
with (
|
||||
patch(f"{TASKS_MODULE}.LocalFileConnector", return_value=connector_mock),
|
||||
patch(f"{LLM_FACTORY_MODULE}.get_default_llm"),
|
||||
patch(
|
||||
f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func",
|
||||
return_value=MagicMock(return_value=[1, 2, 3]),
|
||||
),
|
||||
):
|
||||
process_user_file_impl(
|
||||
user_file_id=str(uf.id),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
mock_run_pipeline.assert_not_called()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# delete_user_file_impl — vector DB skip
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDeleteImplNoVectorDb:
|
||||
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
@patch(f"{TASKS_MODULE}.get_default_file_store")
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
def test_skips_vector_db_deletion(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
mock_get_file_store: MagicMock,
|
||||
) -> None:
|
||||
uf = _make_user_file(status=UserFileStatus.DELETING)
|
||||
session = MagicMock()
|
||||
session.get.return_value = uf
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
mock_get_file_store.return_value = MagicMock()
|
||||
|
||||
with (
|
||||
patch(f"{TASKS_MODULE}.get_all_document_indices") as mock_get_indices,
|
||||
patch(f"{TASKS_MODULE}.get_active_search_settings") as mock_get_ss,
|
||||
patch(f"{TASKS_MODULE}.httpx_init_vespa_pool") as mock_vespa_pool,
|
||||
):
|
||||
delete_user_file_impl(
|
||||
user_file_id=str(uf.id),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
mock_get_indices.assert_not_called()
|
||||
mock_get_ss.assert_not_called()
|
||||
mock_vespa_pool.assert_not_called()
|
||||
|
||||
session.delete.assert_called_once_with(uf)
|
||||
session.commit.assert_called_once()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
@patch(f"{TASKS_MODULE}.get_default_file_store")
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
def test_still_deletes_file_store_and_db_record(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
mock_get_file_store: MagicMock,
|
||||
) -> None:
|
||||
uf = _make_user_file(status=UserFileStatus.DELETING)
|
||||
session = MagicMock()
|
||||
session.get.return_value = uf
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
file_store = MagicMock()
|
||||
mock_get_file_store.return_value = file_store
|
||||
|
||||
delete_user_file_impl(
|
||||
user_file_id=str(uf.id),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
assert file_store.delete_file.call_count == 2
|
||||
session.delete.assert_called_once_with(uf)
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# project_sync_user_file_impl — vector DB skip
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProjectSyncImplNoVectorDb:
|
||||
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
def test_skips_vector_db_update(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
uf = _make_user_file(status=UserFileStatus.COMPLETED)
|
||||
session = MagicMock()
|
||||
session.execute.return_value.scalar_one_or_none.return_value = uf
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
with (
|
||||
patch(f"{TASKS_MODULE}.get_all_document_indices") as mock_get_indices,
|
||||
patch(f"{TASKS_MODULE}.get_active_search_settings") as mock_get_ss,
|
||||
patch(f"{TASKS_MODULE}.httpx_init_vespa_pool") as mock_vespa_pool,
|
||||
):
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=str(uf.id),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
mock_get_indices.assert_not_called()
|
||||
mock_get_ss.assert_not_called()
|
||||
mock_vespa_pool.assert_not_called()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
def test_still_clears_sync_flags(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
uf = _make_user_file(status=UserFileStatus.COMPLETED)
|
||||
session = MagicMock()
|
||||
session.execute.return_value.scalar_one_or_none.return_value = uf
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=str(uf.id),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
assert uf.needs_project_sync is False
|
||||
assert uf.needs_persona_sync is False
|
||||
assert uf.last_project_sync_at is not None
|
||||
session.add.assert_called_once_with(uf)
|
||||
session.commit.assert_called_once()
|
||||
166
backend/tests/unit/onyx/chat/test_stop_signal_checker.py
Normal file
166
backend/tests/unit/onyx/chat/test_stop_signal_checker.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""Unit tests for stop_signal_checker and chat_processing_checker.
|
||||
|
||||
These modules are safety-critical — they control whether a chat stream
|
||||
continues or stops. The tests use a simple in-memory CacheBackend stub
|
||||
so no external services are needed.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.interface import CacheLock
|
||||
from onyx.chat.chat_processing_checker import is_chat_session_processing
|
||||
from onyx.chat.chat_processing_checker import set_processing_status
|
||||
from onyx.chat.stop_signal_checker import FENCE_TTL
|
||||
from onyx.chat.stop_signal_checker import is_connected
|
||||
from onyx.chat.stop_signal_checker import reset_cancel_status
|
||||
from onyx.chat.stop_signal_checker import set_fence
|
||||
|
||||
|
||||
class _MemoryCacheBackend(CacheBackend):
|
||||
"""Minimal in-memory CacheBackend for unit tests."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._store: dict[str, bytes] = {}
|
||||
|
||||
def get(self, key: str) -> bytes | None:
|
||||
return self._store.get(key)
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str | bytes | int | float,
|
||||
ex: int | None = None, # noqa: ARG002
|
||||
) -> None:
|
||||
if isinstance(value, bytes):
|
||||
self._store[key] = value
|
||||
else:
|
||||
self._store[key] = str(value).encode()
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
self._store.pop(key, None)
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
return key in self._store
|
||||
|
||||
def expire(self, key: str, seconds: int) -> None:
|
||||
pass
|
||||
|
||||
def ttl(self, key: str) -> int:
|
||||
return -2 if key not in self._store else -1
|
||||
|
||||
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
|
||||
raise NotImplementedError
|
||||
|
||||
def rpush(self, key: str, value: str | bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# ── stop_signal_checker ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSetFence:
|
||||
def test_set_fence_true_creates_key(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
sid = uuid4()
|
||||
set_fence(sid, cache, True)
|
||||
assert not is_connected(sid, cache)
|
||||
|
||||
def test_set_fence_false_removes_key(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
sid = uuid4()
|
||||
set_fence(sid, cache, True)
|
||||
set_fence(sid, cache, False)
|
||||
assert is_connected(sid, cache)
|
||||
|
||||
def test_set_fence_false_noop_when_absent(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
sid = uuid4()
|
||||
set_fence(sid, cache, False)
|
||||
assert is_connected(sid, cache)
|
||||
|
||||
def test_set_fence_uses_ttl(self) -> None:
|
||||
"""Verify set_fence passes ex=FENCE_TTL to cache.set."""
|
||||
calls: list[dict[str, object]] = []
|
||||
cache = _MemoryCacheBackend()
|
||||
original_set = cache.set
|
||||
|
||||
def tracking_set(
|
||||
key: str,
|
||||
value: str | bytes | int | float,
|
||||
ex: int | None = None,
|
||||
) -> None:
|
||||
calls.append({"key": key, "ex": ex})
|
||||
original_set(key, value, ex=ex)
|
||||
|
||||
cache.set = tracking_set # type: ignore[method-assign]
|
||||
|
||||
set_fence(uuid4(), cache, True)
|
||||
assert len(calls) == 1
|
||||
assert calls[0]["ex"] == FENCE_TTL
|
||||
|
||||
|
||||
class TestIsConnected:
|
||||
def test_connected_when_no_fence(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
assert is_connected(uuid4(), cache)
|
||||
|
||||
def test_disconnected_when_fence_set(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
sid = uuid4()
|
||||
set_fence(sid, cache, True)
|
||||
assert not is_connected(sid, cache)
|
||||
|
||||
def test_sessions_are_isolated(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
sid1, sid2 = uuid4(), uuid4()
|
||||
set_fence(sid1, cache, True)
|
||||
assert not is_connected(sid1, cache)
|
||||
assert is_connected(sid2, cache)
|
||||
|
||||
|
||||
class TestResetCancelStatus:
|
||||
def test_clears_fence(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
sid = uuid4()
|
||||
set_fence(sid, cache, True)
|
||||
reset_cancel_status(sid, cache)
|
||||
assert is_connected(sid, cache)
|
||||
|
||||
def test_noop_when_no_fence(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
reset_cancel_status(uuid4(), cache)
|
||||
|
||||
|
||||
# ── chat_processing_checker ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSetProcessingStatus:
|
||||
def test_set_true_marks_processing(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
sid = uuid4()
|
||||
set_processing_status(sid, cache, True)
|
||||
assert is_chat_session_processing(sid, cache)
|
||||
|
||||
def test_set_false_clears_processing(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
sid = uuid4()
|
||||
set_processing_status(sid, cache, True)
|
||||
set_processing_status(sid, cache, False)
|
||||
assert not is_chat_session_processing(sid, cache)
|
||||
|
||||
|
||||
class TestIsChatSessionProcessing:
|
||||
def test_not_processing_by_default(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
assert not is_chat_session_processing(uuid4(), cache)
|
||||
|
||||
def test_sessions_are_isolated(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
sid1, sid2 = uuid4(), uuid4()
|
||||
set_processing_status(sid1, cache, True)
|
||||
assert is_chat_session_processing(sid1, cache)
|
||||
assert not is_chat_session_processing(sid2, cache)
|
||||
163
backend/tests/unit/onyx/federated_connectors/test_oauth_utils.py
Normal file
163
backend/tests/unit/onyx/federated_connectors/test_oauth_utils.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""Unit tests for federated OAuth state generation and verification.
|
||||
|
||||
Uses unittest.mock to patch get_cache_backend so no external services
|
||||
are needed. Verifies the generate -> verify round-trip, one-time-use
|
||||
semantics, TTL propagation, and error handling.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.interface import CacheLock
|
||||
from onyx.federated_connectors.oauth_utils import generate_oauth_state
|
||||
from onyx.federated_connectors.oauth_utils import OAUTH_STATE_TTL
|
||||
from onyx.federated_connectors.oauth_utils import OAuthSession
|
||||
from onyx.federated_connectors.oauth_utils import verify_oauth_state
|
||||
|
||||
|
||||
class _MemoryCacheBackend(CacheBackend):
|
||||
"""Minimal in-memory CacheBackend for unit tests."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._store: dict[str, bytes] = {}
|
||||
self.set_calls: list[dict[str, object]] = []
|
||||
|
||||
def get(self, key: str) -> bytes | None:
|
||||
return self._store.get(key)
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str | bytes | int | float,
|
||||
ex: int | None = None,
|
||||
) -> None:
|
||||
self.set_calls.append({"key": key, "ex": ex})
|
||||
if isinstance(value, bytes):
|
||||
self._store[key] = value
|
||||
else:
|
||||
self._store[key] = str(value).encode()
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
self._store.pop(key, None)
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
return key in self._store
|
||||
|
||||
def expire(self, key: str, seconds: int) -> None:
|
||||
pass
|
||||
|
||||
def ttl(self, key: str) -> int:
|
||||
return -2 if key not in self._store else -1
|
||||
|
||||
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
|
||||
raise NotImplementedError
|
||||
|
||||
def rpush(self, key: str, value: str | bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _patched(cache: _MemoryCacheBackend): # type: ignore[no-untyped-def]
|
||||
return patch(
|
||||
"onyx.federated_connectors.oauth_utils.get_cache_backend",
|
||||
return_value=cache,
|
||||
)
|
||||
|
||||
|
||||
class TestGenerateAndVerifyRoundTrip:
|
||||
def test_round_trip_basic(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
with _patched(cache):
|
||||
state = generate_oauth_state(
|
||||
federated_connector_id=42,
|
||||
user_id="user-abc",
|
||||
)
|
||||
session = verify_oauth_state(state)
|
||||
|
||||
assert session.federated_connector_id == 42
|
||||
assert session.user_id == "user-abc"
|
||||
assert session.redirect_uri is None
|
||||
assert session.additional_data == {}
|
||||
|
||||
def test_round_trip_with_all_fields(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
with _patched(cache):
|
||||
state = generate_oauth_state(
|
||||
federated_connector_id=7,
|
||||
user_id="user-xyz",
|
||||
redirect_uri="https://example.com/callback",
|
||||
additional_data={"scope": "read"},
|
||||
)
|
||||
session = verify_oauth_state(state)
|
||||
|
||||
assert session.federated_connector_id == 7
|
||||
assert session.user_id == "user-xyz"
|
||||
assert session.redirect_uri == "https://example.com/callback"
|
||||
assert session.additional_data == {"scope": "read"}
|
||||
|
||||
|
||||
class TestOneTimeUse:
|
||||
def test_verify_deletes_state(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
with _patched(cache):
|
||||
state = generate_oauth_state(federated_connector_id=1, user_id="u")
|
||||
verify_oauth_state(state)
|
||||
|
||||
with pytest.raises(ValueError, match="OAuth state not found"):
|
||||
verify_oauth_state(state)
|
||||
|
||||
|
||||
class TestTTLPropagation:
|
||||
def test_default_ttl(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
with _patched(cache):
|
||||
generate_oauth_state(federated_connector_id=1, user_id="u")
|
||||
|
||||
assert len(cache.set_calls) == 1
|
||||
assert cache.set_calls[0]["ex"] == OAUTH_STATE_TTL
|
||||
|
||||
def test_custom_ttl(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
with _patched(cache):
|
||||
generate_oauth_state(federated_connector_id=1, user_id="u", ttl=600)
|
||||
|
||||
assert cache.set_calls[0]["ex"] == 600
|
||||
|
||||
|
||||
class TestVerifyInvalidState:
|
||||
def test_missing_state_raises(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
with _patched(cache):
|
||||
state = generate_oauth_state(federated_connector_id=1, user_id="u")
|
||||
# Manually clear the cache to simulate expiration
|
||||
cache._store.clear()
|
||||
|
||||
with pytest.raises(ValueError, match="OAuth state not found"):
|
||||
verify_oauth_state(state)
|
||||
|
||||
|
||||
class TestOAuthSessionSerialization:
|
||||
def test_to_dict_from_dict_round_trip(self) -> None:
|
||||
session = OAuthSession(
|
||||
federated_connector_id=5,
|
||||
user_id="u-123",
|
||||
redirect_uri="https://redir.example.com",
|
||||
additional_data={"key": "val"},
|
||||
)
|
||||
d = session.to_dict()
|
||||
restored = OAuthSession.from_dict(d)
|
||||
|
||||
assert restored.federated_connector_id == 5
|
||||
assert restored.user_id == "u-123"
|
||||
assert restored.redirect_uri == "https://redir.example.com"
|
||||
assert restored.additional_data == {"key": "val"}
|
||||
|
||||
def test_from_dict_defaults(self) -> None:
|
||||
minimal = {"federated_connector_id": 1, "user_id": "u"}
|
||||
session = OAuthSession.from_dict(minimal)
|
||||
assert session.redirect_uri is None
|
||||
assert session.additional_data == {}
|
||||
@@ -19,6 +19,7 @@ from ee.onyx.server.scim.models import ScimListResponse
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
from ee.onyx.server.scim.providers.entra import EntraProvider
|
||||
from ee.onyx.server.scim.providers.okta import OktaProvider
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import ScimUserMapping
|
||||
@@ -26,6 +27,10 @@ from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
|
||||
# Every supported SCIM provider must appear here so that all endpoint tests
|
||||
# run against it. When adding a new provider, add its class to this list.
|
||||
SCIM_PROVIDERS: list[type[ScimProvider]] = [OktaProvider, EntraProvider]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session() -> MagicMock:
|
||||
@@ -41,10 +46,10 @@ def mock_token() -> MagicMock:
|
||||
return token
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider() -> ScimProvider:
|
||||
"""An OktaProvider instance for endpoint tests."""
|
||||
return OktaProvider()
|
||||
@pytest.fixture(params=SCIM_PROVIDERS, ids=[p.__name__ for p in SCIM_PROVIDERS])
|
||||
def provider(request: pytest.FixtureRequest) -> ScimProvider:
|
||||
"""Parameterized provider — runs each test with every provider in SCIM_PROVIDERS."""
|
||||
return request.param()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.scim.auth import _hash_scim_token
|
||||
from ee.onyx.server.scim.auth import generate_scim_token
|
||||
from ee.onyx.server.scim.auth import SCIM_TOKEN_PREFIX
|
||||
from ee.onyx.server.scim.auth import ScimAuthError
|
||||
from ee.onyx.server.scim.auth import verify_scim_token
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ class TestVerifyScimToken:
|
||||
def test_missing_header_raises_401(self) -> None:
|
||||
request = self._make_request(None)
|
||||
dal = self._make_dal()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
with pytest.raises(ScimAuthError) as exc_info:
|
||||
verify_scim_token(request, dal)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Missing" in str(exc_info.value.detail)
|
||||
@@ -68,7 +68,7 @@ class TestVerifyScimToken:
|
||||
def test_wrong_prefix_raises_401(self) -> None:
|
||||
request = self._make_request("Bearer on_some_api_key")
|
||||
dal = self._make_dal()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
with pytest.raises(ScimAuthError) as exc_info:
|
||||
verify_scim_token(request, dal)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
@@ -76,7 +76,7 @@ class TestVerifyScimToken:
|
||||
raw, _, _ = generate_scim_token()
|
||||
request = self._make_request(f"Bearer {raw}")
|
||||
dal = self._make_dal(token=None)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
with pytest.raises(ScimAuthError) as exc_info:
|
||||
verify_scim_token(request, dal)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Invalid" in str(exc_info.value.detail)
|
||||
@@ -87,7 +87,7 @@ class TestVerifyScimToken:
|
||||
mock_token = MagicMock()
|
||||
mock_token.is_active = False
|
||||
dal = self._make_dal(token=mock_token)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
with pytest.raises(ScimAuthError) as exc_info:
|
||||
verify_scim_token(request, dal)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "revoked" in str(exc_info.value.detail)
|
||||
|
||||
@@ -109,7 +109,7 @@ class TestOktaProvider:
|
||||
result = provider.build_user_resource(user, None)
|
||||
|
||||
assert result.name == ScimName(
|
||||
givenName="Madonna", familyName=None, formatted="Madonna"
|
||||
givenName="Madonna", familyName="", formatted="Madonna"
|
||||
)
|
||||
|
||||
def test_build_user_resource_no_name(self) -> None:
|
||||
@@ -117,7 +117,10 @@ class TestOktaProvider:
|
||||
user = _make_mock_user(personal_name=None)
|
||||
result = provider.build_user_resource(user, None)
|
||||
|
||||
assert result.name is None
|
||||
# Falls back to deriving name from email local part
|
||||
assert result.name == ScimName(
|
||||
givenName="test", familyName="", formatted="test"
|
||||
)
|
||||
assert result.displayName is None
|
||||
|
||||
def test_build_user_resource_scim_username_preserves_case(self) -> None:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user