mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-02 22:25:47 +00:00
Compare commits
2 Commits
main
...
content-re
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8f5d7e271a | ||
|
|
bb6e20614d |
@@ -54,7 +54,6 @@ runs:
|
||||
shell: bash
|
||||
env:
|
||||
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
|
||||
INTEGRATION_REPOSITORY: ${{ inputs.runs-on-ecr-cache }}
|
||||
TAG: nightly-llm-it-${{ inputs.run-id }}
|
||||
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
|
||||
HEAD_SHA: ${{ inputs.github-sha }}
|
||||
|
||||
26
.github/workflows/deployment.yml
vendored
26
.github/workflows/deployment.yml
vendored
@@ -426,9 +426,8 @@ jobs:
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64,mode=max
|
||||
@@ -500,9 +499,8 @@ jobs:
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64,mode=max
|
||||
@@ -648,8 +646,8 @@ jobs:
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64,mode=max
|
||||
@@ -730,8 +728,8 @@ jobs:
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64,mode=max
|
||||
@@ -864,9 +862,8 @@ jobs:
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64,mode=max
|
||||
@@ -937,9 +934,8 @@ jobs:
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64,mode=max
|
||||
@@ -1076,8 +1072,8 @@ jobs:
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
ENABLE_CRAFT=true
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64,mode=max
|
||||
@@ -1149,8 +1145,8 @@ jobs:
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
ENABLE_CRAFT=true
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64,mode=max
|
||||
@@ -1291,9 +1287,8 @@ jobs:
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64,mode=max
|
||||
@@ -1371,9 +1366,8 @@ jobs:
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64,mode=max
|
||||
|
||||
13
.github/workflows/nightly-llm-provider-chat.yml
vendored
13
.github/workflows/nightly-llm-provider-chat.yml
vendored
@@ -15,9 +15,6 @@ permissions:
|
||||
jobs:
|
||||
provider-chat-test:
|
||||
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
with:
|
||||
openai_models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
|
||||
anthropic_models: ${{ vars.NIGHTLY_LLM_ANTHROPIC_MODELS }}
|
||||
@@ -28,6 +25,16 @@ jobs:
|
||||
ollama_models: ${{ vars.NIGHTLY_LLM_OLLAMA_MODELS }}
|
||||
openrouter_models: ${{ vars.NIGHTLY_LLM_OPENROUTER_MODELS }}
|
||||
strict: true
|
||||
secrets:
|
||||
openai_api_key: ${{ secrets.OPENAI_API_KEY }}
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
bedrock_api_key: ${{ secrets.BEDROCK_API_KEY }}
|
||||
vertex_ai_custom_config_json: ${{ secrets.NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON }}
|
||||
azure_api_key: ${{ secrets.AZURE_API_KEY }}
|
||||
ollama_api_key: ${{ secrets.OLLAMA_API_KEY }}
|
||||
openrouter_api_key: ${{ secrets.OPENROUTER_API_KEY }}
|
||||
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs: [provider-chat-test]
|
||||
|
||||
@@ -160,7 +160,7 @@ jobs:
|
||||
cd deployment/docker_compose
|
||||
|
||||
# Get list of running containers
|
||||
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml ps -q)
|
||||
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml -f docker-compose.opensearch.yml ps -q)
|
||||
|
||||
# Collect logs from each container
|
||||
for container in $containers; do
|
||||
|
||||
21
.github/workflows/pr-python-checks.yml
vendored
21
.github/workflows/pr-python-checks.yml
vendored
@@ -8,7 +8,7 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- "release/**"
|
||||
- 'release/**'
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
@@ -21,13 +21,7 @@ jobs:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
# Note: Mypy seems quite optimized for x64 compared to arm64.
|
||||
# Similarly, mypy is single-threaded and incremental, so 2cpu is sufficient.
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=2cpu-linux-x64,
|
||||
"run-id=${{ github.run_id }}-mypy-check",
|
||||
"extras=s3-cache",
|
||||
]
|
||||
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}-mypy-check", "extras=s3-cache"]
|
||||
timeout-minutes: 45
|
||||
|
||||
steps:
|
||||
@@ -58,14 +52,21 @@ jobs:
|
||||
if: ${{ vars.DISABLE_MYPY_CACHE != 'true' }}
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: .mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'pyproject.toml') }}
|
||||
path: backend/.mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
restore-keys: |
|
||||
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
|
||||
mypy-${{ runner.os }}-
|
||||
|
||||
- name: Run MyPy
|
||||
working-directory: ./backend
|
||||
env:
|
||||
MYPY_FORCE_COLOR: 1
|
||||
TERM: xterm-256color
|
||||
run: mypy .
|
||||
|
||||
- name: Run MyPy (tools/)
|
||||
env:
|
||||
MYPY_FORCE_COLOR: 1
|
||||
TERM: xterm-256color
|
||||
run: mypy tools/
|
||||
|
||||
@@ -48,10 +48,28 @@ on:
|
||||
required: false
|
||||
default: true
|
||||
type: boolean
|
||||
secrets:
|
||||
openai_api_key:
|
||||
required: false
|
||||
anthropic_api_key:
|
||||
required: false
|
||||
bedrock_api_key:
|
||||
required: false
|
||||
vertex_ai_custom_config_json:
|
||||
required: false
|
||||
azure_api_key:
|
||||
required: false
|
||||
ollama_api_key:
|
||||
required: false
|
||||
openrouter_api_key:
|
||||
required: false
|
||||
DOCKER_USERNAME:
|
||||
required: true
|
||||
DOCKER_TOKEN:
|
||||
required: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
|
||||
jobs:
|
||||
build-backend-image:
|
||||
@@ -63,7 +81,6 @@ jobs:
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
environment: ci-protected
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
@@ -72,19 +89,6 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, test/docker-username
|
||||
DOCKER_TOKEN, test/docker-token
|
||||
|
||||
- name: Build backend image
|
||||
uses: ./.github/actions/build-backend-image
|
||||
with:
|
||||
@@ -93,8 +97,8 @@ jobs:
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ env.DOCKER_USERNAME }}
|
||||
docker-token: ${{ env.DOCKER_TOKEN }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
docker-no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' && 'true' || 'false' }}
|
||||
|
||||
build-model-server-image:
|
||||
@@ -106,7 +110,6 @@ jobs:
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
environment: ci-protected
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
@@ -115,19 +118,6 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, test/docker-username
|
||||
DOCKER_TOKEN, test/docker-token
|
||||
|
||||
- name: Build model server image
|
||||
uses: ./.github/actions/build-model-server-image
|
||||
with:
|
||||
@@ -136,8 +126,8 @@ jobs:
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ env.DOCKER_USERNAME }}
|
||||
docker-token: ${{ env.DOCKER_TOKEN }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
build-integration-image:
|
||||
runs-on:
|
||||
@@ -148,7 +138,6 @@ jobs:
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
environment: ci-protected
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
@@ -157,19 +146,6 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, test/docker-username
|
||||
DOCKER_TOKEN, test/docker-token
|
||||
|
||||
- name: Build integration image
|
||||
uses: ./.github/actions/build-integration-image
|
||||
with:
|
||||
@@ -178,8 +154,8 @@ jobs:
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ env.DOCKER_USERNAME }}
|
||||
docker-token: ${{ env.DOCKER_TOKEN }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
provider-chat-test:
|
||||
needs:
|
||||
@@ -194,56 +170,56 @@ jobs:
|
||||
include:
|
||||
- provider: openai
|
||||
models: ${{ inputs.openai_models }}
|
||||
api_key_env: OPENAI_API_KEY
|
||||
custom_config_env: ""
|
||||
api_key_secret: openai_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: true
|
||||
- provider: anthropic
|
||||
models: ${{ inputs.anthropic_models }}
|
||||
api_key_env: ANTHROPIC_API_KEY
|
||||
custom_config_env: ""
|
||||
api_key_secret: anthropic_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: true
|
||||
- provider: bedrock
|
||||
models: ${{ inputs.bedrock_models }}
|
||||
api_key_env: BEDROCK_API_KEY
|
||||
custom_config_env: ""
|
||||
api_key_secret: bedrock_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: vertex_ai
|
||||
models: ${{ inputs.vertex_ai_models }}
|
||||
api_key_env: ""
|
||||
custom_config_env: NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON
|
||||
api_key_secret: ""
|
||||
custom_config_secret: vertex_ai_custom_config_json
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: azure
|
||||
models: ${{ inputs.azure_models }}
|
||||
api_key_env: AZURE_API_KEY
|
||||
custom_config_env: ""
|
||||
api_key_secret: azure_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: ${{ inputs.azure_api_base }}
|
||||
api_version: "2025-04-01-preview"
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: ollama_chat
|
||||
models: ${{ inputs.ollama_models }}
|
||||
api_key_env: OLLAMA_API_KEY
|
||||
custom_config_env: ""
|
||||
api_key_secret: ollama_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: "https://ollama.com"
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: openrouter
|
||||
models: ${{ inputs.openrouter_models }}
|
||||
api_key_env: OPENROUTER_API_KEY
|
||||
custom_config_env: ""
|
||||
api_key_secret: openrouter_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: "https://openrouter.ai/api/v1"
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
@@ -254,7 +230,6 @@ jobs:
|
||||
- "run-id=${{ github.run_id }}-nightly-${{ matrix.provider }}-provider-chat-test"
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 45
|
||||
environment: ci-protected
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
@@ -263,43 +238,21 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
# Keep JSON values unparsed so vertex custom config is passed as raw JSON.
|
||||
parse-json-secrets: false
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, test/docker-username
|
||||
DOCKER_TOKEN, test/docker-token
|
||||
OPENAI_API_KEY, test/openai-api-key
|
||||
ANTHROPIC_API_KEY, test/anthropic-api-key
|
||||
BEDROCK_API_KEY, test/bedrock-api-key
|
||||
NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON, test/nightly-llm-vertex-ai-custom-config-json
|
||||
AZURE_API_KEY, test/azure-api-key
|
||||
OLLAMA_API_KEY, test/ollama-api-key
|
||||
OPENROUTER_API_KEY, test/openrouter-api-key
|
||||
|
||||
- name: Run nightly provider chat test
|
||||
uses: ./.github/actions/run-nightly-provider-chat-test
|
||||
with:
|
||||
provider: ${{ matrix.provider }}
|
||||
models: ${{ matrix.models }}
|
||||
provider-api-key: ${{ matrix.api_key_env && env[matrix.api_key_env] || '' }}
|
||||
provider-api-key: ${{ matrix.api_key_secret && secrets[matrix.api_key_secret] || '' }}
|
||||
strict: ${{ inputs.strict && 'true' || 'false' }}
|
||||
api-base: ${{ matrix.api_base }}
|
||||
api-version: ${{ matrix.api_version }}
|
||||
deployment-name: ${{ matrix.deployment_name }}
|
||||
custom-config-json: ${{ matrix.custom_config_env && env[matrix.custom_config_env] || '' }}
|
||||
custom-config-json: ${{ matrix.custom_config_secret && secrets[matrix.custom_config_secret] || '' }}
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ env.DOCKER_USERNAME }}
|
||||
docker-token: ${{ env.DOCKER_TOKEN }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Dump API server logs
|
||||
if: always()
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
"""Add INDEXING to UserFileStatus
|
||||
|
||||
Revision ID: 4a1e4b1c89d2
|
||||
Revises: 6b3b4083c5aa
|
||||
Create Date: 2026-02-28 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = "4a1e4b1c89d2"
|
||||
down_revision = "6b3b4083c5aa"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
TABLE = "user_file"
|
||||
COLUMN = "status"
|
||||
CONSTRAINT_NAME = "ck_user_file_status"
|
||||
|
||||
OLD_VALUES = ("PROCESSING", "COMPLETED", "FAILED", "CANCELED", "DELETING")
|
||||
NEW_VALUES = ("PROCESSING", "INDEXING", "COMPLETED", "FAILED", "CANCELED", "DELETING")
|
||||
|
||||
|
||||
def _drop_status_check_constraint() -> None:
|
||||
"""Drop the existing CHECK constraint on user_file.status.
|
||||
|
||||
The constraint name is auto-generated by SQLAlchemy and unknown,
|
||||
so we look it up via the inspector.
|
||||
"""
|
||||
inspector = sa.inspect(op.get_bind())
|
||||
for constraint in inspector.get_check_constraints(TABLE):
|
||||
if COLUMN in constraint.get("sqltext", ""):
|
||||
constraint_name = constraint["name"]
|
||||
if constraint_name is not None:
|
||||
op.drop_constraint(constraint_name, TABLE, type_="check")
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
_drop_status_check_constraint()
|
||||
in_clause = ", ".join(f"'{v}'" for v in NEW_VALUES)
|
||||
op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
f"UPDATE {TABLE} SET {COLUMN} = 'PROCESSING' WHERE {COLUMN} = 'INDEXING'"
|
||||
)
|
||||
op.drop_constraint(CONSTRAINT_NAME, TABLE, type_="check")
|
||||
in_clause = ", ".join(f"'{v}'" for v in OLD_VALUES)
|
||||
op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})")
|
||||
@@ -1,112 +0,0 @@
|
||||
"""persona cleanup and featured
|
||||
|
||||
Revision ID: 6b3b4083c5aa
|
||||
Revises: 57122d037335
|
||||
Create Date: 2026-02-26 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "6b3b4083c5aa"
|
||||
down_revision = "57122d037335"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add featured column with nullable=True first
|
||||
op.add_column("persona", sa.Column("featured", sa.Boolean(), nullable=True))
|
||||
|
||||
# Migrate data from is_default_persona to featured
|
||||
op.execute("UPDATE persona SET featured = is_default_persona")
|
||||
|
||||
# Make featured non-nullable with default=False
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"featured",
|
||||
existing_type=sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
)
|
||||
|
||||
# Drop is_default_persona column
|
||||
op.drop_column("persona", "is_default_persona")
|
||||
|
||||
# Drop unused columns
|
||||
op.drop_column("persona", "num_chunks")
|
||||
op.drop_column("persona", "chunks_above")
|
||||
op.drop_column("persona", "chunks_below")
|
||||
op.drop_column("persona", "llm_relevance_filter")
|
||||
op.drop_column("persona", "llm_filter_extraction")
|
||||
op.drop_column("persona", "recency_bias")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Add back recency_bias column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"recency_bias",
|
||||
sa.VARCHAR(),
|
||||
nullable=False,
|
||||
server_default="base_decay",
|
||||
),
|
||||
)
|
||||
|
||||
# Add back llm_filter_extraction column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"llm_filter_extraction",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
# Add back llm_relevance_filter column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"llm_relevance_filter",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
# Add back chunks_below column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("chunks_below", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
|
||||
# Add back chunks_above column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("chunks_above", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
|
||||
# Add back num_chunks column
|
||||
op.add_column("persona", sa.Column("num_chunks", sa.Float(), nullable=True))
|
||||
|
||||
# Add back is_default_persona column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"is_default_persona",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
# Migrate data from featured to is_default_persona
|
||||
op.execute("UPDATE persona SET is_default_persona = featured")
|
||||
|
||||
# Drop featured column
|
||||
op.drop_column("persona", "featured")
|
||||
@@ -18,6 +18,7 @@ from ee.onyx.server.enterprise_settings.store import (
|
||||
store_settings as store_ee_settings,
|
||||
)
|
||||
from ee.onyx.server.enterprise_settings.store import upload_logo
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
@@ -160,6 +161,12 @@ def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) ->
|
||||
user=None, # Seeding is done as admin
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=(
|
||||
persona.num_chunks if persona.num_chunks is not None else 0.0
|
||||
),
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
document_set_ids=persona.document_set_ids,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
@@ -171,7 +178,6 @@ def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) ->
|
||||
system_prompt=persona.system_prompt,
|
||||
task_prompt=persona.task_prompt,
|
||||
datetime_aware=persona.datetime_aware,
|
||||
featured=persona.featured,
|
||||
commit=False,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
@@ -543,7 +543,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
result = await db_session.execute(
|
||||
select(Persona.id)
|
||||
.where(
|
||||
Persona.featured.is_(True),
|
||||
Persona.is_default_persona.is_(True),
|
||||
Persona.is_public.is_(True),
|
||||
Persona.is_visible.is_(True),
|
||||
Persona.deleted.is_(False),
|
||||
@@ -725,19 +725,11 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
if user_by_session:
|
||||
user = user_by_session
|
||||
|
||||
# If the user is inactive, check seat availability before
|
||||
# upgrading role — otherwise they'd become an inactive BASIC
|
||||
# user who still can't log in.
|
||||
if not user.is_active:
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
enforce_seat_limit(sync_db)
|
||||
|
||||
await self.user_db.update(
|
||||
user,
|
||||
{
|
||||
"is_verified": is_verified_by_default,
|
||||
"role": UserRole.BASIC,
|
||||
**({"is_active": True} if not user.is_active else {}),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -241,7 +241,8 @@ _VECTOR_DB_BEAT_TASK_NAMES: set[str] = {
|
||||
"check-for-index-attempt-cleanup",
|
||||
"check-for-doc-permissions-sync",
|
||||
"check-for-external-group-sync",
|
||||
"migrate-chunks-from-vespa-to-opensearch",
|
||||
"check-for-documents-for-opensearch-migration",
|
||||
"migrate-documents-from-vespa-to-opensearch",
|
||||
}
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
|
||||
@@ -414,31 +414,34 @@ def _process_user_file_with_indexing(
|
||||
raise RuntimeError(f"Indexing pipeline failed for user file {user_file_id}")
|
||||
|
||||
|
||||
def process_user_file_impl(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
"""Core implementation for processing a single user file.
|
||||
|
||||
When redis_locking=True, acquires a per-file Redis lock and clears the
|
||||
queued-key guard (Celery path). When redis_locking=False, skips all Redis
|
||||
operations (BackgroundTask path).
|
||||
"""
|
||||
task_logger.info(f"process_user_file_impl - Starting id={user_file_id}")
|
||||
task_logger.info(f"process_single_user_file - Starting id={user_file_id}")
|
||||
start = time.monotonic()
|
||||
|
||||
file_lock: RedisLock | None = None
|
||||
if redis_locking:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client.delete(_user_file_queued_key(user_file_id))
|
||||
file_lock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Clear the "queued" guard set by the beat generator so that the next beat
|
||||
# cycle can re-enqueue this file if it is still in PROCESSING state after
|
||||
# this task completes or fails.
|
||||
redis_client.delete(_user_file_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
if not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"process_single_user_file - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
if file_lock is not None and not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"process_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
return
|
||||
return None
|
||||
|
||||
documents: list[Document] = []
|
||||
try:
|
||||
@@ -446,18 +449,15 @@ def process_user_file_impl(
|
||||
uf = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if not uf:
|
||||
task_logger.warning(
|
||||
f"process_user_file_impl - UserFile not found id={user_file_id}"
|
||||
f"process_single_user_file - UserFile not found id={user_file_id}"
|
||||
)
|
||||
return
|
||||
return None
|
||||
|
||||
if uf.status not in (
|
||||
UserFileStatus.PROCESSING,
|
||||
UserFileStatus.INDEXING,
|
||||
):
|
||||
if uf.status != UserFileStatus.PROCESSING:
|
||||
task_logger.info(
|
||||
f"process_user_file_impl - Skipping id={user_file_id} status={uf.status}"
|
||||
f"process_single_user_file - Skipping id={user_file_id} status={uf.status}"
|
||||
)
|
||||
return
|
||||
return None
|
||||
|
||||
connector = LocalFileConnector(
|
||||
file_locations=[uf.file_id],
|
||||
@@ -471,6 +471,7 @@ def process_user_file_impl(
|
||||
[doc for doc in batch if not isinstance(doc, HierarchyNode)]
|
||||
)
|
||||
|
||||
# update the document id to userfile id in the documents
|
||||
for document in documents:
|
||||
document.id = str(user_file_id)
|
||||
document.source = DocumentSource.USER_FILE
|
||||
@@ -492,8 +493,9 @@ def process_user_file_impl(
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"process_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
f"process_single_user_file - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
# don't update the status if the user file is being deleted
|
||||
current_user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if (
|
||||
current_user_file
|
||||
@@ -502,42 +504,33 @@ def process_user_file_impl(
|
||||
uf.status = UserFileStatus.FAILED
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
return
|
||||
return None
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"process_user_file_impl - Finished id={user_file_id} docs={len(documents)} elapsed={elapsed:.2f}s"
|
||||
f"process_single_user_file - Finished id={user_file_id} docs={len(documents)} elapsed={elapsed:.2f}s"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
# Attempt to mark the file as failed
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
uf = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if uf:
|
||||
# don't update the status if the user file is being deleted
|
||||
if uf.status != UserFileStatus.DELETING:
|
||||
uf.status = UserFileStatus.FAILED
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
|
||||
task_logger.exception(
|
||||
f"process_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
f"process_single_user_file - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
if file_lock is not None and file_lock.owned():
|
||||
if file_lock.owned():
|
||||
file_lock.release()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
process_user_file_impl(
|
||||
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
|
||||
)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_USER_FILE_DELETE,
|
||||
soft_time_limit=300,
|
||||
@@ -588,38 +581,36 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def delete_user_file_impl(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file_delete(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
"""Core implementation for deleting a single user file.
|
||||
|
||||
When redis_locking=True, acquires a per-file Redis lock (Celery path).
|
||||
When redis_locking=False, skips Redis operations (BackgroundTask path).
|
||||
"""
|
||||
task_logger.info(f"delete_user_file_impl - Starting id={user_file_id}")
|
||||
|
||||
file_lock: RedisLock | None = None
|
||||
if redis_locking:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
file_lock = redis_client.lock(
|
||||
_user_file_delete_lock_key(user_file_id),
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
"""Process a single user file delete."""
|
||||
task_logger.info(f"process_single_user_file_delete - Starting id={user_file_id}")
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_delete_lock_key(user_file_id),
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
if not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"process_single_user_file_delete - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
if file_lock is not None and not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"delete_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
return
|
||||
|
||||
return None
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if not user_file:
|
||||
task_logger.info(
|
||||
f"delete_user_file_impl - User file not found id={user_file_id}"
|
||||
f"process_single_user_file_delete - User file not found id={user_file_id}"
|
||||
)
|
||||
return
|
||||
return None
|
||||
|
||||
# 1) Delete vector DB chunks (skip when disabled)
|
||||
if not DISABLE_VECTOR_DB:
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
@@ -657,6 +648,7 @@ def delete_user_file_impl(
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
|
||||
# 2) Delete the user-uploaded file content from filestore (blob + metadata)
|
||||
file_store = get_default_file_store()
|
||||
try:
|
||||
file_store.delete_file(user_file.file_id)
|
||||
@@ -664,33 +656,26 @@ def delete_user_file_impl(
|
||||
user_file_id_to_plaintext_file_name(user_file.id)
|
||||
)
|
||||
except Exception as e:
|
||||
# This block executed only if the file is not found in the filestore
|
||||
task_logger.exception(
|
||||
f"delete_user_file_impl - Error deleting file id={user_file.id} - {e.__class__.__name__}"
|
||||
f"process_single_user_file_delete - Error deleting file id={user_file.id} - {e.__class__.__name__}"
|
||||
)
|
||||
|
||||
# 3) Finally, delete the UserFile row
|
||||
db_session.delete(user_file)
|
||||
db_session.commit()
|
||||
task_logger.info(f"delete_user_file_impl - Completed id={user_file_id}")
|
||||
task_logger.info(
|
||||
f"process_single_user_file_delete - Completed id={user_file_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"delete_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
f"process_single_user_file_delete - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
if file_lock is not None and file_lock.owned():
|
||||
if file_lock.owned():
|
||||
file_lock.release()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file_delete(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
delete_user_file_impl(
|
||||
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -762,30 +747,32 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def project_sync_user_file_impl(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file_project_sync(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
"""Core implementation for syncing a user file's project/persona metadata.
|
||||
"""Process a single user file project sync."""
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - Starting id={user_file_id}"
|
||||
)
|
||||
|
||||
When redis_locking=True, acquires a per-file Redis lock and clears the
|
||||
queued-key guard (Celery path). When redis_locking=False, skips Redis
|
||||
operations (BackgroundTask path).
|
||||
"""
|
||||
task_logger.info(f"project_sync_user_file_impl - Starting id={user_file_id}")
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock | None = None
|
||||
if redis_locking:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
|
||||
file_lock = redis_client.lock(
|
||||
user_file_project_sync_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
user_file_project_sync_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
if not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
if file_lock is not None and not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"project_sync_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
return
|
||||
return None
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -796,10 +783,11 @@ def project_sync_user_file_impl(
|
||||
).scalar_one_or_none()
|
||||
if not user_file:
|
||||
task_logger.info(
|
||||
f"project_sync_user_file_impl - User file not found id={user_file_id}"
|
||||
f"process_single_user_file_project_sync - User file not found id={user_file_id}"
|
||||
)
|
||||
return
|
||||
return None
|
||||
|
||||
# Sync project metadata to vector DB (skip when disabled)
|
||||
if not DISABLE_VECTOR_DB:
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
@@ -834,7 +822,7 @@ def project_sync_user_file_impl(
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"project_sync_user_file_impl - User file id={user_file_id}"
|
||||
f"process_single_user_file_project_sync - User file id={user_file_id}"
|
||||
)
|
||||
|
||||
user_file.needs_project_sync = False
|
||||
@@ -847,21 +835,11 @@ def project_sync_user_file_impl(
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"project_sync_user_file_impl - Error syncing project for file id={user_file_id} - {e.__class__.__name__}"
|
||||
f"process_single_user_file_project_sync - Error syncing project for file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
if file_lock is not None and file_lock.owned():
|
||||
if file_lock.owned():
|
||||
file_lock.release()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file_project_sync(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -1,287 +0,0 @@
|
||||
"""Periodic poller for NO_VECTOR_DB deployments.
|
||||
|
||||
Replaces Celery Beat and background workers with a lightweight daemon thread
|
||||
that runs from the API server process. Two responsibilities:
|
||||
|
||||
1. Recovery polling (every 30 s): re-processes user files stuck in
|
||||
PROCESSING / DELETING / needs_sync states via the drain loops defined
|
||||
in ``task_utils.py``.
|
||||
|
||||
2. Periodic task execution (configurable intervals): runs LLM model updates
|
||||
and scheduled evals at their configured cadences, with Postgres advisory
|
||||
lock deduplication across multiple API server instances.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
RECOVERY_INTERVAL_SECONDS = 30
|
||||
PERIODIC_TASK_LOCK_BASE = 20_000
|
||||
PERIODIC_TASK_KV_PREFIX = "periodic_poller:last_claimed:"
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Periodic task definitions
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
_NEVER_RAN: float = -1e18
|
||||
|
||||
|
||||
@dataclass
|
||||
class _PeriodicTaskDef:
|
||||
name: str
|
||||
interval_seconds: float
|
||||
lock_id: int
|
||||
run_fn: Callable[[], None]
|
||||
last_run_at: float = field(default=_NEVER_RAN)
|
||||
|
||||
|
||||
def _run_auto_llm_update() -> None:
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
|
||||
if not AUTO_LLM_CONFIG_URL:
|
||||
return
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.llm.well_known_providers.auto_update_service import (
|
||||
sync_llm_models_from_github,
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
sync_llm_models_from_github(db_session)
|
||||
|
||||
|
||||
def _run_scheduled_eval() -> None:
|
||||
from onyx.configs.app_configs import BRAINTRUST_API_KEY
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_PERMISSIONS_EMAIL
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_PROJECT
|
||||
|
||||
if not all(
|
||||
[
|
||||
BRAINTRUST_API_KEY,
|
||||
SCHEDULED_EVAL_PROJECT,
|
||||
SCHEDULED_EVAL_DATASET_NAMES,
|
||||
SCHEDULED_EVAL_PERMISSIONS_EMAIL,
|
||||
]
|
||||
):
|
||||
return
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from onyx.evals.eval import run_eval
|
||||
from onyx.evals.models import EvalConfigurationOptions
|
||||
|
||||
run_timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
for dataset_name in SCHEDULED_EVAL_DATASET_NAMES:
|
||||
try:
|
||||
run_eval(
|
||||
configuration=EvalConfigurationOptions(
|
||||
search_permissions_email=SCHEDULED_EVAL_PERMISSIONS_EMAIL,
|
||||
dataset_name=dataset_name,
|
||||
no_send_logs=False,
|
||||
braintrust_project=SCHEDULED_EVAL_PROJECT,
|
||||
experiment_name=f"{dataset_name} - {run_timestamp}",
|
||||
),
|
||||
remote_dataset_name=dataset_name,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Periodic poller - Failed scheduled eval for dataset {dataset_name}"
|
||||
)
|
||||
|
||||
|
||||
def _build_periodic_tasks() -> list[_PeriodicTaskDef]:
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
|
||||
tasks: list[_PeriodicTaskDef] = []
|
||||
if AUTO_LLM_CONFIG_URL:
|
||||
tasks.append(
|
||||
_PeriodicTaskDef(
|
||||
name="auto-llm-update",
|
||||
interval_seconds=AUTO_LLM_UPDATE_INTERVAL_SECONDS,
|
||||
lock_id=PERIODIC_TASK_LOCK_BASE,
|
||||
run_fn=_run_auto_llm_update,
|
||||
)
|
||||
)
|
||||
if SCHEDULED_EVAL_DATASET_NAMES:
|
||||
tasks.append(
|
||||
_PeriodicTaskDef(
|
||||
name="scheduled-eval",
|
||||
interval_seconds=7 * 24 * 3600,
|
||||
lock_id=PERIODIC_TASK_LOCK_BASE + 1,
|
||||
run_fn=_run_scheduled_eval,
|
||||
)
|
||||
)
|
||||
return tasks
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Periodic task runner with advisory-lock-guarded claim
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def _try_claim_task(task_def: _PeriodicTaskDef) -> bool:
|
||||
"""Atomically check whether *task_def* should run and record a claim.
|
||||
|
||||
Uses a transaction-scoped advisory lock for atomicity combined with a
|
||||
``KVStore`` timestamp for cross-instance dedup. The DB session is held
|
||||
only for this brief claim transaction, not during task execution.
|
||||
"""
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import KVStore
|
||||
|
||||
kv_key = PERIODIC_TASK_KV_PREFIX + task_def.name
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
acquired = db_session.execute(
|
||||
text("SELECT pg_try_advisory_xact_lock(:id)"),
|
||||
{"id": task_def.lock_id},
|
||||
).scalar()
|
||||
if not acquired:
|
||||
return False
|
||||
|
||||
row = db_session.query(KVStore).filter_by(key=kv_key).first()
|
||||
if row and row.value is not None:
|
||||
last_claimed = datetime.fromisoformat(str(row.value))
|
||||
elapsed = (datetime.now(timezone.utc) - last_claimed).total_seconds()
|
||||
if elapsed < task_def.interval_seconds:
|
||||
return False
|
||||
|
||||
now_ts = datetime.now(timezone.utc).isoformat()
|
||||
if row:
|
||||
row.value = now_ts
|
||||
else:
|
||||
db_session.add(KVStore(key=kv_key, value=now_ts))
|
||||
db_session.commit()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _try_run_periodic_task(task_def: _PeriodicTaskDef) -> None:
|
||||
"""Run *task_def* if its interval has elapsed and no peer holds the lock."""
|
||||
now = time.monotonic()
|
||||
if now - task_def.last_run_at < task_def.interval_seconds:
|
||||
return
|
||||
|
||||
if not _try_claim_task(task_def):
|
||||
return
|
||||
|
||||
try:
|
||||
task_def.run_fn()
|
||||
task_def.last_run_at = now
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Periodic poller - Error running periodic task {task_def.name}"
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Recovery / drain loop runner
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def _run_drain_loops(tenant_id: str) -> None:
|
||||
from onyx.background.task_utils import drain_delete_loop
|
||||
from onyx.background.task_utils import drain_processing_loop
|
||||
from onyx.background.task_utils import drain_project_sync_loop
|
||||
|
||||
drain_processing_loop(tenant_id)
|
||||
drain_delete_loop(tenant_id)
|
||||
drain_project_sync_loop(tenant_id)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Startup recovery (10g)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def recover_stuck_user_files(tenant_id: str) -> None:
|
||||
"""Run all drain loops once to re-process files left in intermediate states.
|
||||
|
||||
Called from ``lifespan()`` on startup when ``DISABLE_VECTOR_DB`` is set.
|
||||
"""
|
||||
logger.info("recover_stuck_user_files - Checking for stuck user files")
|
||||
try:
|
||||
_run_drain_loops(tenant_id)
|
||||
except Exception:
|
||||
logger.exception("recover_stuck_user_files - Error during recovery")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Daemon thread (10f)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
_shutdown_event = threading.Event()
|
||||
_poller_thread: threading.Thread | None = None
|
||||
|
||||
|
||||
def _poller_loop(tenant_id: str) -> None:
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
periodic_tasks = _build_periodic_tasks()
|
||||
logger.info(
|
||||
f"Periodic poller started with {len(periodic_tasks)} periodic task(s): "
|
||||
f"{[t.name for t in periodic_tasks]}"
|
||||
)
|
||||
|
||||
while not _shutdown_event.is_set():
|
||||
try:
|
||||
_run_drain_loops(tenant_id)
|
||||
except Exception:
|
||||
logger.exception("Periodic poller - Error in recovery polling")
|
||||
|
||||
for task_def in periodic_tasks:
|
||||
try:
|
||||
_try_run_periodic_task(task_def)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Periodic poller - Unhandled error checking task {task_def.name}"
|
||||
)
|
||||
|
||||
_shutdown_event.wait(RECOVERY_INTERVAL_SECONDS)
|
||||
|
||||
|
||||
def start_periodic_poller(tenant_id: str) -> None:
|
||||
"""Start the periodic poller daemon thread."""
|
||||
global _poller_thread # noqa: PLW0603
|
||||
_shutdown_event.clear()
|
||||
_poller_thread = threading.Thread(
|
||||
target=_poller_loop,
|
||||
args=(tenant_id,),
|
||||
daemon=True,
|
||||
name="no-vectordb-periodic-poller",
|
||||
)
|
||||
_poller_thread.start()
|
||||
logger.info("Periodic poller thread started")
|
||||
|
||||
|
||||
def stop_periodic_poller() -> None:
|
||||
"""Signal the periodic poller to stop and wait for it to exit."""
|
||||
global _poller_thread # noqa: PLW0603
|
||||
if _poller_thread is None:
|
||||
return
|
||||
_shutdown_event.set()
|
||||
_poller_thread.join(timeout=10)
|
||||
if _poller_thread.is_alive():
|
||||
logger.warning("Periodic poller thread did not stop within timeout")
|
||||
_poller_thread = None
|
||||
logger.info("Periodic poller thread stopped")
|
||||
@@ -1,33 +1,3 @@
|
||||
"""Background task utilities.
|
||||
|
||||
Contains query-history report helpers (used by all deployment modes) and
|
||||
in-process background task execution helpers for NO_VECTOR_DB mode:
|
||||
|
||||
- Atomic claim-and-mark helpers that prevent duplicate processing
|
||||
- Drain loops that process all pending user file work
|
||||
|
||||
Each claim function runs a short-lived transaction: SELECT ... FOR UPDATE
|
||||
SKIP LOCKED, UPDATE the row to remove it from future queries, COMMIT.
|
||||
After the commit the row lock is released, but the row is no longer
|
||||
eligible for re-claiming. No long-lived sessions or advisory locks.
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Query-history report helpers (pre-existing, used by all modes)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
QUERY_REPORT_NAME_PREFIX = "query-history"
|
||||
|
||||
|
||||
@@ -39,142 +9,3 @@ def construct_query_history_report_name(
|
||||
|
||||
def extract_task_id_from_query_history_report_name(name: str) -> str:
|
||||
return name.removeprefix(f"{QUERY_REPORT_NAME_PREFIX}-").removesuffix(".csv")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Atomic claim-and-mark helpers
|
||||
# ------------------------------------------------------------------
|
||||
# Each function runs inside a single short-lived session/transaction:
|
||||
# 1. SELECT ... FOR UPDATE SKIP LOCKED (locks one eligible row)
|
||||
# 2. UPDATE the row so it is no longer eligible
|
||||
# 3. COMMIT (releases the row lock)
|
||||
# After the commit, no other drain loop can claim the same row.
|
||||
|
||||
|
||||
def _claim_next_processing_file(db_session: Session) -> UUID | None:
|
||||
"""Claim the next PROCESSING file by transitioning it to INDEXING.
|
||||
|
||||
Returns the file id, or None when no eligible files remain.
|
||||
"""
|
||||
file_id = db_session.execute(
|
||||
select(UserFile.id)
|
||||
.where(UserFile.status == UserFileStatus.PROCESSING)
|
||||
.order_by(UserFile.created_at)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
).scalar_one_or_none()
|
||||
if file_id is None:
|
||||
return None
|
||||
|
||||
db_session.execute(
|
||||
sa.update(UserFile)
|
||||
.where(UserFile.id == file_id)
|
||||
.values(status=UserFileStatus.INDEXING)
|
||||
)
|
||||
db_session.commit()
|
||||
return file_id
|
||||
|
||||
|
||||
def _claim_next_deleting_file(db_session: Session) -> UUID | None:
|
||||
"""Claim the next DELETING file.
|
||||
|
||||
No status transition needed — the impl deletes the row on success.
|
||||
The short-lived FOR UPDATE lock prevents concurrent claims.
|
||||
"""
|
||||
file_id = db_session.execute(
|
||||
select(UserFile.id)
|
||||
.where(UserFile.status == UserFileStatus.DELETING)
|
||||
.order_by(UserFile.created_at)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
).scalar_one_or_none()
|
||||
# Commit to release the row lock promptly.
|
||||
db_session.commit()
|
||||
return file_id
|
||||
|
||||
|
||||
def _claim_next_sync_file(db_session: Session) -> UUID | None:
|
||||
"""Claim the next file needing project/persona sync.
|
||||
|
||||
No status transition needed — the impl clears the sync flags on
|
||||
success. The short-lived FOR UPDATE lock prevents concurrent claims.
|
||||
"""
|
||||
file_id = db_session.execute(
|
||||
select(UserFile.id)
|
||||
.where(
|
||||
sa.and_(
|
||||
sa.or_(
|
||||
UserFile.needs_project_sync.is_(True),
|
||||
UserFile.needs_persona_sync.is_(True),
|
||||
),
|
||||
UserFile.status == UserFileStatus.COMPLETED,
|
||||
)
|
||||
)
|
||||
.order_by(UserFile.created_at)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
).scalar_one_or_none()
|
||||
db_session.commit()
|
||||
return file_id
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Drain loops — process *all* pending work of each type
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def drain_processing_loop(tenant_id: str) -> None:
|
||||
"""Process all pending PROCESSING user files."""
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
process_user_file_impl,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
while True:
|
||||
with get_session_with_current_tenant() as session:
|
||||
file_id = _claim_next_processing_file(session)
|
||||
if file_id is None:
|
||||
break
|
||||
process_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
|
||||
def drain_delete_loop(tenant_id: str) -> None:
|
||||
"""Delete all pending DELETING user files."""
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
delete_user_file_impl,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
while True:
|
||||
with get_session_with_current_tenant() as session:
|
||||
file_id = _claim_next_deleting_file(session)
|
||||
if file_id is None:
|
||||
break
|
||||
delete_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
|
||||
def drain_project_sync_loop(tenant_id: str) -> None:
|
||||
"""Sync all pending project/persona metadata for user files."""
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
project_sync_user_file_impl,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
while True:
|
||||
with get_session_with_current_tenant() as session:
|
||||
file_id = _claim_next_sync_file(session)
|
||||
if file_id is None:
|
||||
break
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
45
backend/onyx/cache/factory.py
vendored
45
backend/onyx/cache/factory.py
vendored
@@ -1,45 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.interface import CacheBackendType
|
||||
from onyx.configs.app_configs import CACHE_BACKEND
|
||||
|
||||
|
||||
def _build_redis_backend(tenant_id: str) -> CacheBackend:
|
||||
from onyx.cache.redis_backend import RedisCacheBackend
|
||||
from onyx.redis.redis_pool import redis_pool
|
||||
|
||||
return RedisCacheBackend(redis_pool.get_client(tenant_id))
|
||||
|
||||
|
||||
_BACKEND_BUILDERS: dict[CacheBackendType, Callable[[str], CacheBackend]] = {
|
||||
CacheBackendType.REDIS: _build_redis_backend,
|
||||
# CacheBackendType.POSTGRES will be added in a follow-up PR.
|
||||
}
|
||||
|
||||
|
||||
def get_cache_backend(*, tenant_id: str | None = None) -> CacheBackend:
|
||||
"""Return a tenant-aware ``CacheBackend``.
|
||||
|
||||
If *tenant_id* is ``None``, the current tenant is read from the
|
||||
thread-local context variable (same behaviour as ``get_redis_client``).
|
||||
"""
|
||||
if tenant_id is None:
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
builder = _BACKEND_BUILDERS.get(CACHE_BACKEND)
|
||||
if builder is None:
|
||||
raise ValueError(
|
||||
f"Unsupported CACHE_BACKEND={CACHE_BACKEND!r}. "
|
||||
f"Supported values: {[t.value for t in CacheBackendType]}"
|
||||
)
|
||||
return builder(tenant_id)
|
||||
|
||||
|
||||
def get_shared_cache_backend() -> CacheBackend:
|
||||
"""Return a ``CacheBackend`` in the shared (cross-tenant) namespace."""
|
||||
from shared_configs.configs import DEFAULT_REDIS_PREFIX
|
||||
|
||||
return get_cache_backend(tenant_id=DEFAULT_REDIS_PREFIX)
|
||||
89
backend/onyx/cache/interface.py
vendored
89
backend/onyx/cache/interface.py
vendored
@@ -1,89 +0,0 @@
|
||||
import abc
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class CacheBackendType(str, Enum):
|
||||
REDIS = "redis"
|
||||
POSTGRES = "postgres"
|
||||
|
||||
|
||||
class CacheLock(abc.ABC):
|
||||
"""Abstract distributed lock returned by CacheBackend.lock()."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def acquire(
|
||||
self,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: float | None = None,
|
||||
) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def release(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def owned(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CacheBackend(abc.ABC):
|
||||
"""Thin abstraction over a key-value cache with TTL, locks, and blocking lists.
|
||||
|
||||
Covers the subset of Redis operations used outside of Celery. When
|
||||
CACHE_BACKEND=postgres, a PostgreSQL-backed implementation is used instead.
|
||||
"""
|
||||
|
||||
# -- basic key/value ---------------------------------------------------
|
||||
|
||||
@abc.abstractmethod
|
||||
def get(self, key: str) -> bytes | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str | bytes | int | float,
|
||||
ex: int | None = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete(self, key: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def exists(self, key: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
# -- TTL ---------------------------------------------------------------
|
||||
|
||||
@abc.abstractmethod
|
||||
def expire(self, key: str, seconds: int) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def ttl(self, key: str) -> int:
|
||||
"""Return remaining TTL in seconds. -1 if no expiry, -2 if key missing."""
|
||||
raise NotImplementedError
|
||||
|
||||
# -- distributed lock --------------------------------------------------
|
||||
|
||||
@abc.abstractmethod
|
||||
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
|
||||
raise NotImplementedError
|
||||
|
||||
# -- blocking list (used by MCP OAuth BLPOP pattern) -------------------
|
||||
|
||||
@abc.abstractmethod
|
||||
def rpush(self, key: str, value: str | bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
|
||||
"""Block until a value is available on one of *keys*, or *timeout* expires.
|
||||
|
||||
Returns ``(key, value)`` or ``None`` on timeout.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
92
backend/onyx/cache/redis_backend.py
vendored
92
backend/onyx/cache/redis_backend.py
vendored
@@ -1,92 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
from redis.client import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.interface import CacheLock
|
||||
|
||||
|
||||
class RedisCacheLock(CacheLock):
|
||||
"""Wraps ``redis.lock.Lock`` behind the ``CacheLock`` interface."""
|
||||
|
||||
def __init__(self, lock: RedisLock) -> None:
|
||||
self._lock = lock
|
||||
|
||||
def acquire(
|
||||
self,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: float | None = None,
|
||||
) -> bool:
|
||||
return bool(
|
||||
self._lock.acquire(
|
||||
blocking=blocking,
|
||||
blocking_timeout=blocking_timeout,
|
||||
)
|
||||
)
|
||||
|
||||
def release(self) -> None:
|
||||
self._lock.release()
|
||||
|
||||
def owned(self) -> bool:
|
||||
return bool(self._lock.owned())
|
||||
|
||||
|
||||
class RedisCacheBackend(CacheBackend):
|
||||
"""``CacheBackend`` implementation that delegates to a ``redis.Redis`` client.
|
||||
|
||||
This is a thin pass-through — every method maps 1-to-1 to the underlying
|
||||
Redis command. ``TenantRedis`` key-prefixing is handled by the client
|
||||
itself (provided by ``get_redis_client``).
|
||||
"""
|
||||
|
||||
def __init__(self, redis_client: Redis) -> None:
|
||||
self._r = redis_client
|
||||
|
||||
# -- basic key/value ---------------------------------------------------
|
||||
|
||||
def get(self, key: str) -> bytes | None:
|
||||
val = self._r.get(key)
|
||||
if val is None:
|
||||
return None
|
||||
if isinstance(val, bytes):
|
||||
return val
|
||||
return str(val).encode()
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str | bytes | int | float,
|
||||
ex: int | None = None,
|
||||
) -> None:
|
||||
self._r.set(key, value, ex=ex)
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
self._r.delete(key)
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
return bool(self._r.exists(key))
|
||||
|
||||
# -- TTL ---------------------------------------------------------------
|
||||
|
||||
def expire(self, key: str, seconds: int) -> None:
|
||||
self._r.expire(key, seconds)
|
||||
|
||||
def ttl(self, key: str) -> int:
|
||||
return cast(int, self._r.ttl(key))
|
||||
|
||||
# -- distributed lock --------------------------------------------------
|
||||
|
||||
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
|
||||
return RedisCacheLock(self._r.lock(name, timeout=timeout))
|
||||
|
||||
# -- blocking list (MCP OAuth BLPOP pattern) ---------------------------
|
||||
|
||||
def rpush(self, key: str, value: str | bytes) -> None:
|
||||
self._r.rpush(key, value)
|
||||
|
||||
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
|
||||
result = cast(list[bytes] | None, self._r.blpop(keys, timeout=timeout))
|
||||
if result is None:
|
||||
return None
|
||||
return (result[0], result[1])
|
||||
@@ -1,7 +1,6 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -531,13 +530,11 @@ def _create_file_tool_metadata_message(
|
||||
"""
|
||||
lines = [
|
||||
"You have access to the following files. Use the read_file tool to "
|
||||
"read sections of any file. You MUST pass the file_id UUID (not the "
|
||||
"filename) to read_file:"
|
||||
"read sections of any file:"
|
||||
]
|
||||
for meta in file_metadata:
|
||||
lines.append(
|
||||
f'- file_id="{meta.file_id}" filename="{meta.filename}" '
|
||||
f"(~{meta.approx_char_count:,} chars)"
|
||||
f'- {meta.file_id}: "{meta.filename}" (~{meta.approx_char_count:,} chars)'
|
||||
)
|
||||
|
||||
message_content = "\n".join(lines)
|
||||
@@ -561,16 +558,12 @@ def _create_context_files_message(
|
||||
# Format as documents JSON as described in README
|
||||
documents_list = []
|
||||
for idx, file_text in enumerate(context_files.file_texts, start=1):
|
||||
title = (
|
||||
context_files.file_metadata[idx - 1].filename
|
||||
if idx - 1 < len(context_files.file_metadata)
|
||||
else None
|
||||
documents_list.append(
|
||||
{
|
||||
"document": idx,
|
||||
"contents": file_text,
|
||||
}
|
||||
)
|
||||
entry: dict[str, Any] = {"document": idx}
|
||||
if title:
|
||||
entry["title"] = title
|
||||
entry["contents"] = file_text
|
||||
documents_list.append(entry)
|
||||
|
||||
documents_json = json.dumps({"documents": documents_list}, indent=2)
|
||||
message_content = f"Here are some documents provided for context, they may not all be relevant:\n{documents_json}"
|
||||
|
||||
@@ -6,7 +6,6 @@ from datetime import timezone
|
||||
from typing import cast
|
||||
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
from onyx.cache.interface import CacheBackendType
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
|
||||
@@ -55,12 +54,6 @@ DISABLE_USER_KNOWLEDGE = os.environ.get("DISABLE_USER_KNOWLEDGE", "").lower() ==
|
||||
# are disabled but core chat, tools, user file uploads, and Projects still work.
|
||||
DISABLE_VECTOR_DB = os.environ.get("DISABLE_VECTOR_DB", "").lower() == "true"
|
||||
|
||||
# Which backend to use for caching, locks, and ephemeral state.
|
||||
# "redis" (default) or "postgres" (only valid when DISABLE_VECTOR_DB=true).
|
||||
CACHE_BACKEND = CacheBackendType(
|
||||
os.environ.get("CACHE_BACKEND", CacheBackendType.REDIS)
|
||||
)
|
||||
|
||||
# Maximum token count for a single uploaded file. Files exceeding this are rejected.
|
||||
# Defaults to 100k tokens (or 10M when vector DB is disabled).
|
||||
_DEFAULT_FILE_TOKEN_LIMIT = 10_000_000 if DISABLE_VECTOR_DB else 100_000
|
||||
|
||||
@@ -98,7 +98,6 @@ def get_chat_sessions_by_user(
|
||||
db_session: Session,
|
||||
include_onyxbot_flows: bool = False,
|
||||
limit: int = 50,
|
||||
before: datetime | None = None,
|
||||
project_id: int | None = None,
|
||||
only_non_project_chats: bool = False,
|
||||
include_failed_chats: bool = False,
|
||||
@@ -113,9 +112,6 @@ def get_chat_sessions_by_user(
|
||||
if deleted is not None:
|
||||
stmt = stmt.where(ChatSession.deleted == deleted)
|
||||
|
||||
if before is not None:
|
||||
stmt = stmt.where(ChatSession.time_updated < before)
|
||||
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
|
||||
@@ -186,7 +186,6 @@ class EmbeddingPrecision(str, PyEnum):
|
||||
|
||||
class UserFileStatus(str, PyEnum):
|
||||
PROCESSING = "PROCESSING"
|
||||
INDEXING = "INDEXING"
|
||||
COMPLETED = "COMPLETED"
|
||||
FAILED = "FAILED"
|
||||
CANCELED = "CANCELED"
|
||||
|
||||
@@ -109,38 +109,45 @@ def can_user_access_llm_provider(
|
||||
is_admin: If True, bypass user group restrictions but still respect persona restrictions
|
||||
|
||||
Access logic:
|
||||
- is_public controls USER access (group bypass): when True, all users can access
|
||||
regardless of group membership. When False, user must be in a whitelisted group
|
||||
(or be admin).
|
||||
- Persona restrictions are ALWAYS enforced when set, regardless of is_public.
|
||||
This allows admins to make a provider available to all users while still
|
||||
restricting which personas (assistants) can use it.
|
||||
|
||||
Decision matrix:
|
||||
1. is_public=True, no personas set → everyone has access
|
||||
2. is_public=True, personas set → all users, but only whitelisted personas
|
||||
3. is_public=False, groups+personas set → must satisfy BOTH (admins bypass groups)
|
||||
4. is_public=False, only groups set → must be in group (admins bypass)
|
||||
5. is_public=False, only personas set → must use whitelisted persona
|
||||
6. is_public=False, neither set → admin-only (locked)
|
||||
1. If is_public=True → everyone has access (public override)
|
||||
2. If is_public=False:
|
||||
- Both groups AND personas set → must satisfy BOTH (AND logic, admins bypass group check)
|
||||
- Only groups set → must be in one of the groups (OR across groups, admins bypass)
|
||||
- Only personas set → must use one of the personas (OR across personas, applies to admins)
|
||||
- Neither set → NOBODY has access unless admin (locked, admin-only)
|
||||
"""
|
||||
provider_group_ids = {g.id for g in (provider.groups or [])}
|
||||
provider_persona_ids = {p.id for p in (provider.personas or [])}
|
||||
has_groups = bool(provider_group_ids)
|
||||
has_personas = bool(provider_persona_ids)
|
||||
|
||||
# Persona restrictions are always enforced when set, regardless of is_public
|
||||
if has_personas and not (persona and persona.id in provider_persona_ids):
|
||||
return False
|
||||
|
||||
# Public override - everyone has access
|
||||
if provider.is_public:
|
||||
return True
|
||||
|
||||
# Extract IDs once to avoid multiple iterations
|
||||
provider_group_ids = (
|
||||
{group.id for group in provider.groups} if provider.groups else set()
|
||||
)
|
||||
provider_persona_ids = (
|
||||
{p.id for p in provider.personas} if provider.personas else set()
|
||||
)
|
||||
|
||||
has_groups = bool(provider_group_ids)
|
||||
has_personas = bool(provider_persona_ids)
|
||||
|
||||
# Both groups AND personas set → AND logic (must satisfy both)
|
||||
if has_groups and has_personas:
|
||||
# Admins bypass group check but still must satisfy persona restrictions
|
||||
user_in_group = is_admin or bool(user_group_ids & provider_group_ids)
|
||||
persona_allowed = persona.id in provider_persona_ids if persona else False
|
||||
return user_in_group and persona_allowed
|
||||
|
||||
# Only groups set → user must be in one of the groups (admins bypass)
|
||||
if has_groups:
|
||||
return is_admin or bool(user_group_ids & provider_group_ids)
|
||||
|
||||
# No groups: either persona-whitelisted (already passed) or admin-only if locked
|
||||
return has_personas or is_admin
|
||||
# Only personas set → persona must be in allowed list (applies to admins too)
|
||||
if has_personas:
|
||||
return persona.id in provider_persona_ids if persona else False
|
||||
|
||||
# Neither groups nor personas set, and not public → admins can access
|
||||
return is_admin
|
||||
|
||||
|
||||
def validate_persona_ids_exist(
|
||||
|
||||
@@ -103,6 +103,7 @@ from onyx.utils.encryption import encrypt_string_to_bytes
|
||||
from onyx.utils.sensitive import SensitiveValue
|
||||
from onyx.utils.headers import HeaderItemDict
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
|
||||
# TODO: After anonymous user migration has been deployed, make user_id columns NOT NULL
|
||||
# and update Mapped[User | None] relationships to Mapped[User] where needed.
|
||||
@@ -3264,6 +3265,19 @@ class Persona(Base):
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String)
|
||||
description: Mapped[str] = mapped_column(String)
|
||||
# Number of chunks to pass to the LLM for generation.
|
||||
num_chunks: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
chunks_above: Mapped[int] = mapped_column(Integer)
|
||||
chunks_below: Mapped[int] = mapped_column(Integer)
|
||||
# Pass every chunk through LLM for evaluation, fairly expensive
|
||||
# Can be turned off globally by admin, in which case, this setting is ignored
|
||||
llm_relevance_filter: Mapped[bool] = mapped_column(Boolean)
|
||||
# Enables using LLM to extract time and source type filters
|
||||
# Can also be admin disabled globally
|
||||
llm_filter_extraction: Mapped[bool] = mapped_column(Boolean)
|
||||
recency_bias: Mapped[RecencyBiasSetting] = mapped_column(
|
||||
Enum(RecencyBiasSetting, native_enum=False)
|
||||
)
|
||||
|
||||
# Allows the persona to specify a specific default LLM model
|
||||
# NOTE: only is applied on the actual response generation - is not used for things like
|
||||
@@ -3290,8 +3304,11 @@ class Persona(Base):
|
||||
# Treated specially (cannot be user edited etc.)
|
||||
builtin_persona: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
# Featured personas are highlighted in the UI
|
||||
featured: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
# Default personas are personas created by admins and are automatically added
|
||||
# to all users' assistants list.
|
||||
is_default_persona: Mapped[bool] = mapped_column(
|
||||
Boolean, default=False, nullable=False
|
||||
)
|
||||
# controls whether the persona is available to be selected by users
|
||||
is_visible: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
# controls the ordering of personas in the UI
|
||||
|
||||
@@ -18,8 +18,11 @@ from sqlalchemy.orm import Session
|
||||
from onyx.access.hierarchy_access import get_user_external_group_ids
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS
|
||||
from onyx.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
|
||||
from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX
|
||||
from onyx.db.document_access import get_accessible_documents_by_ids
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
@@ -251,15 +254,13 @@ def create_update_persona(
|
||||
# Permission to actually use these is checked later
|
||||
|
||||
try:
|
||||
# Featured persona validation
|
||||
if create_persona_request.featured:
|
||||
|
||||
# Curators can edit featured personas, but not make them
|
||||
# TODO this will be reworked soon with RBAC permissions feature
|
||||
# Default persona validation
|
||||
if create_persona_request.is_default_persona:
|
||||
# Curators can edit default personas, but not make them
|
||||
if user.role == UserRole.CURATOR or user.role == UserRole.GLOBAL_CURATOR:
|
||||
pass
|
||||
elif user.role != UserRole.ADMIN:
|
||||
raise ValueError("Only admins can make a featured persona")
|
||||
raise ValueError("Only admins can make a default persona")
|
||||
|
||||
# Convert incoming string UUIDs to UUID objects for DB operations
|
||||
converted_user_file_ids = None
|
||||
@@ -280,6 +281,7 @@ def create_update_persona(
|
||||
document_set_ids=create_persona_request.document_set_ids,
|
||||
tool_ids=create_persona_request.tool_ids,
|
||||
is_public=create_persona_request.is_public,
|
||||
recency_bias=create_persona_request.recency_bias,
|
||||
llm_model_provider_override=create_persona_request.llm_model_provider_override,
|
||||
llm_model_version_override=create_persona_request.llm_model_version_override,
|
||||
starter_messages=create_persona_request.starter_messages,
|
||||
@@ -293,7 +295,10 @@ def create_update_persona(
|
||||
remove_image=create_persona_request.remove_image,
|
||||
search_start_date=create_persona_request.search_start_date,
|
||||
label_ids=create_persona_request.label_ids,
|
||||
featured=create_persona_request.featured,
|
||||
num_chunks=create_persona_request.num_chunks,
|
||||
llm_relevance_filter=create_persona_request.llm_relevance_filter,
|
||||
llm_filter_extraction=create_persona_request.llm_filter_extraction,
|
||||
is_default_persona=create_persona_request.is_default_persona,
|
||||
user_file_ids=converted_user_file_ids,
|
||||
commit=False,
|
||||
hierarchy_node_ids=create_persona_request.hierarchy_node_ids,
|
||||
@@ -869,6 +874,10 @@ def upsert_persona(
|
||||
user: User | None,
|
||||
name: str,
|
||||
description: str,
|
||||
num_chunks: float,
|
||||
llm_relevance_filter: bool,
|
||||
llm_filter_extraction: bool,
|
||||
recency_bias: RecencyBiasSetting,
|
||||
llm_model_provider_override: str | None,
|
||||
llm_model_version_override: str | None,
|
||||
starter_messages: list[StarterMessage] | None,
|
||||
@@ -889,11 +898,13 @@ def upsert_persona(
|
||||
remove_image: bool | None = None,
|
||||
search_start_date: datetime | None = None,
|
||||
builtin_persona: bool = False,
|
||||
featured: bool | None = None,
|
||||
is_default_persona: bool | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
user_file_ids: list[UUID] | None = None,
|
||||
hierarchy_node_ids: list[int] | None = None,
|
||||
document_ids: list[str] | None = None,
|
||||
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
|
||||
chunks_below: int = CONTEXT_CHUNKS_BELOW,
|
||||
replace_base_system_prompt: bool = False,
|
||||
) -> Persona:
|
||||
"""
|
||||
@@ -1004,6 +1015,12 @@ def upsert_persona(
|
||||
# `default` and `built-in` properties can only be set when creating a persona.
|
||||
existing_persona.name = name
|
||||
existing_persona.description = description
|
||||
existing_persona.num_chunks = num_chunks
|
||||
existing_persona.chunks_above = chunks_above
|
||||
existing_persona.chunks_below = chunks_below
|
||||
existing_persona.llm_relevance_filter = llm_relevance_filter
|
||||
existing_persona.llm_filter_extraction = llm_filter_extraction
|
||||
existing_persona.recency_bias = recency_bias
|
||||
existing_persona.llm_model_provider_override = llm_model_provider_override
|
||||
existing_persona.llm_model_version_override = llm_model_version_override
|
||||
existing_persona.starter_messages = starter_messages
|
||||
@@ -1017,8 +1034,10 @@ def upsert_persona(
|
||||
if label_ids is not None:
|
||||
existing_persona.labels.clear()
|
||||
existing_persona.labels = labels or []
|
||||
existing_persona.featured = (
|
||||
featured if featured is not None else existing_persona.featured
|
||||
existing_persona.is_default_persona = (
|
||||
is_default_persona
|
||||
if is_default_persona is not None
|
||||
else existing_persona.is_default_persona
|
||||
)
|
||||
# Update embedded prompt fields if provided
|
||||
if system_prompt is not None:
|
||||
@@ -1071,6 +1090,12 @@ def upsert_persona(
|
||||
is_public=is_public,
|
||||
name=name,
|
||||
description=description,
|
||||
num_chunks=num_chunks,
|
||||
chunks_above=chunks_above,
|
||||
chunks_below=chunks_below,
|
||||
llm_relevance_filter=llm_relevance_filter,
|
||||
llm_filter_extraction=llm_filter_extraction,
|
||||
recency_bias=recency_bias,
|
||||
builtin_persona=builtin_persona,
|
||||
system_prompt=system_prompt or "",
|
||||
task_prompt=task_prompt or "",
|
||||
@@ -1086,7 +1111,9 @@ def upsert_persona(
|
||||
display_priority=display_priority,
|
||||
is_visible=is_visible,
|
||||
search_start_date=search_start_date,
|
||||
featured=(featured if featured is not None else False),
|
||||
is_default_persona=(
|
||||
is_default_persona if is_default_persona is not None else False
|
||||
),
|
||||
user_files=user_files or [],
|
||||
labels=labels or [],
|
||||
hierarchy_nodes=hierarchy_nodes or [],
|
||||
@@ -1131,9 +1158,9 @@ def delete_old_default_personas(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_persona_featured(
|
||||
def update_persona_is_default(
|
||||
persona_id: int,
|
||||
featured: bool,
|
||||
is_default: bool,
|
||||
db_session: Session,
|
||||
user: User,
|
||||
) -> None:
|
||||
@@ -1141,7 +1168,7 @@ def update_persona_featured(
|
||||
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
|
||||
)
|
||||
|
||||
persona.featured = featured
|
||||
persona.is_default_persona = is_default
|
||||
db_session.commit()
|
||||
|
||||
|
||||
|
||||
@@ -9,9 +9,8 @@ from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.background import BackgroundTasks
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
@@ -106,8 +105,8 @@ def upload_files_to_user_files_with_indexing(
|
||||
user: User,
|
||||
temp_id_map: dict[str, str] | None,
|
||||
db_session: Session,
|
||||
background_tasks: BackgroundTasks | None = None,
|
||||
) -> CategorizedFilesResult:
|
||||
# Validate project ownership if a project_id is provided
|
||||
if project_id is not None and user is not None:
|
||||
if not check_project_ownership(project_id, user.id, db_session):
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
@@ -128,27 +127,16 @@ def upload_files_to_user_files_with_indexing(
|
||||
logger.warning(
|
||||
f"File {rejected_file.filename} rejected for {rejected_file.reason}"
|
||||
)
|
||||
|
||||
if DISABLE_VECTOR_DB and background_tasks is not None:
|
||||
from onyx.background.task_utils import drain_processing_loop
|
||||
|
||||
background_tasks.add_task(drain_processing_loop, tenant_id)
|
||||
for user_file in user_files:
|
||||
logger.info(f"Queued in-process processing for user_file_id={user_file.id}")
|
||||
else:
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
|
||||
for user_file in user_files:
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered indexing for user_file_id={user_file.id} "
|
||||
f"with task_id={task.id}"
|
||||
)
|
||||
for user_file in user_files:
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered indexing for user_file_id={user_file.id} with task_id={task.id}"
|
||||
)
|
||||
|
||||
return CategorizedFilesResult(
|
||||
user_files=user_files,
|
||||
|
||||
@@ -5,6 +5,8 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.constants import DEFAULT_PERSONA_SLACK_CHANNEL_NAME
|
||||
from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX
|
||||
from onyx.db.models import ChannelConfig
|
||||
@@ -43,6 +45,8 @@ def create_slack_channel_persona(
|
||||
channel_name: str | None,
|
||||
document_set_ids: list[int],
|
||||
existing_persona_id: int | None = None,
|
||||
num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
||||
enable_auto_filters: bool = False,
|
||||
) -> Persona:
|
||||
"""NOTE: does not commit changes"""
|
||||
|
||||
@@ -69,13 +73,17 @@ def create_slack_channel_persona(
|
||||
system_prompt="",
|
||||
task_prompt="",
|
||||
datetime_aware=True,
|
||||
num_chunks=num_chunks,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=enable_auto_filters,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
tool_ids=[search_tool.id],
|
||||
document_set_ids=document_set_ids,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
is_public=True,
|
||||
featured=False,
|
||||
is_default_persona=False,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ import httpx
|
||||
from opensearchpy import NotFoundError
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH
|
||||
from onyx.configs.app_configs import VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT
|
||||
from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
@@ -562,7 +563,12 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
)
|
||||
|
||||
if not self._client.index_exists():
|
||||
index_settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
if USING_AWS_MANAGED_OPENSEARCH:
|
||||
index_settings = (
|
||||
DocumentSchema.get_index_settings_for_aws_managed_opensearch()
|
||||
)
|
||||
else:
|
||||
index_settings = DocumentSchema.get_index_settings()
|
||||
self._client.create_index(
|
||||
mappings=expected_mappings,
|
||||
settings=index_settings,
|
||||
|
||||
@@ -12,7 +12,6 @@ from pydantic import model_validator
|
||||
from pydantic import SerializerFunctionWrapHandler
|
||||
|
||||
from onyx.configs.app_configs import OPENSEARCH_TEXT_ANALYZER
|
||||
from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
|
||||
from onyx.document_index.opensearch.constants import EF_CONSTRUCTION
|
||||
@@ -526,7 +525,7 @@ class DocumentSchema:
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_index_settings_for_aws_managed_opensearch_st_dev() -> dict[str, Any]:
|
||||
def get_index_settings_for_aws_managed_opensearch() -> dict[str, Any]:
|
||||
"""
|
||||
Settings for AWS-managed OpenSearch.
|
||||
|
||||
@@ -547,41 +546,3 @@ class DocumentSchema:
|
||||
"knn.algo_param.ef_search": EF_SEARCH,
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_index_settings_for_aws_managed_opensearch_mt_cloud() -> dict[str, Any]:
|
||||
"""
|
||||
Settings for AWS-managed OpenSearch in multi-tenant cloud.
|
||||
|
||||
324 shards very roughly targets a storage load of ~30Gb per shard, which
|
||||
according to AWS OpenSearch documentation is within a good target range.
|
||||
|
||||
As documented above we need 2 replicas for a total of 3 copies of the
|
||||
data because the cluster is configured with 3-AZ awareness.
|
||||
"""
|
||||
return {
|
||||
"index": {
|
||||
"number_of_shards": 324,
|
||||
"number_of_replicas": 2,
|
||||
# Required for vector search.
|
||||
"knn": True,
|
||||
"knn.algo_param.ef_search": EF_SEARCH,
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_index_settings_based_on_environment() -> dict[str, Any]:
|
||||
"""
|
||||
Returns the index settings based on the environment.
|
||||
"""
|
||||
if USING_AWS_MANAGED_OPENSEARCH:
|
||||
if MULTI_TENANT:
|
||||
return (
|
||||
DocumentSchema.get_index_settings_for_aws_managed_opensearch_mt_cloud()
|
||||
)
|
||||
else:
|
||||
return (
|
||||
DocumentSchema.get_index_settings_for_aws_managed_opensearch_st_dev()
|
||||
)
|
||||
else:
|
||||
return DocumentSchema.get_index_settings()
|
||||
|
||||
@@ -32,14 +32,11 @@ from onyx.auth.schemas import UserUpdate
|
||||
from onyx.auth.users import auth_backend
|
||||
from onyx.auth.users import create_onyx_oauth_router
|
||||
from onyx.auth.users import fastapi_users
|
||||
from onyx.cache.interface import CacheBackendType
|
||||
from onyx.configs.app_configs import APP_API_PREFIX
|
||||
from onyx.configs.app_configs import APP_HOST
|
||||
from onyx.configs.app_configs import APP_PORT
|
||||
from onyx.configs.app_configs import AUTH_RATE_LIMITING_ENABLED
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import CACHE_BACKEND
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import LOG_ENDPOINT_LATENCY
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_ID
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
@@ -257,53 +254,8 @@ def include_auth_router_with_prefix(
|
||||
)
|
||||
|
||||
|
||||
def validate_cache_backend_settings() -> None:
|
||||
"""Validate that CACHE_BACKEND=postgres is only used with DISABLE_VECTOR_DB.
|
||||
|
||||
The Postgres cache backend eliminates the Redis dependency, but only works
|
||||
when Celery is not running (which requires DISABLE_VECTOR_DB=true).
|
||||
"""
|
||||
if CACHE_BACKEND == CacheBackendType.POSTGRES and not DISABLE_VECTOR_DB:
|
||||
raise RuntimeError(
|
||||
"CACHE_BACKEND=postgres requires DISABLE_VECTOR_DB=true. "
|
||||
"The Postgres cache backend is only supported in no-vector-DB "
|
||||
"deployments where Celery is replaced by the in-process task runner."
|
||||
)
|
||||
|
||||
|
||||
def validate_no_vector_db_settings() -> None:
|
||||
"""Validate that DISABLE_VECTOR_DB is not combined with incompatible settings.
|
||||
|
||||
Raises RuntimeError if DISABLE_VECTOR_DB is set alongside MULTI_TENANT or ENABLE_CRAFT,
|
||||
since these modes require infrastructure that is removed in no-vector-DB deployments.
|
||||
"""
|
||||
if not DISABLE_VECTOR_DB:
|
||||
return
|
||||
|
||||
if MULTI_TENANT:
|
||||
raise RuntimeError(
|
||||
"DISABLE_VECTOR_DB cannot be used with MULTI_TENANT. "
|
||||
"Multi-tenant deployments require the vector database for "
|
||||
"per-tenant document indexing and search. Run in single-tenant "
|
||||
"mode when disabling the vector database."
|
||||
)
|
||||
|
||||
from onyx.server.features.build.configs import ENABLE_CRAFT
|
||||
|
||||
if ENABLE_CRAFT:
|
||||
raise RuntimeError(
|
||||
"DISABLE_VECTOR_DB cannot be used with ENABLE_CRAFT. "
|
||||
"Onyx Craft requires background workers for sandbox lifecycle "
|
||||
"management, which are removed in no-vector-DB deployments. "
|
||||
"Disable Craft (ENABLE_CRAFT=false) when disabling the vector database."
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
|
||||
validate_no_vector_db_settings()
|
||||
validate_cache_backend_settings()
|
||||
|
||||
# Set recursion limit
|
||||
if SYSTEM_RECURSION_LIMIT is not None:
|
||||
sys.setrecursionlimit(SYSTEM_RECURSION_LIMIT)
|
||||
@@ -372,20 +324,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
|
||||
if AUTH_RATE_LIMITING_ENABLED:
|
||||
await setup_auth_limiter()
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
from onyx.background.periodic_poller import recover_stuck_user_files
|
||||
from onyx.background.periodic_poller import start_periodic_poller
|
||||
|
||||
recover_stuck_user_files(POSTGRES_DEFAULT_SCHEMA)
|
||||
start_periodic_poller(POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
yield
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
from onyx.background.periodic_poller import stop_periodic_poller
|
||||
|
||||
stop_periodic_poller()
|
||||
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
if AUTH_RATE_LIMITING_ENABLED:
|
||||
|
||||
@@ -3,12 +3,10 @@ import datetime
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_FEEDBACK_REMINDER
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.user_preferences import activate_user
|
||||
from onyx.db.users import add_slack_user_if_not_exists
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.onyxbot.slack.blocks import get_feedback_reminder_blocks
|
||||
@@ -245,44 +243,6 @@ def handle_message(
|
||||
)
|
||||
return False
|
||||
|
||||
elif (
|
||||
not existing_user.is_active
|
||||
and existing_user.role == UserRole.SLACK_USER
|
||||
):
|
||||
check_seat_fn = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license",
|
||||
"check_seat_availability",
|
||||
None,
|
||||
)
|
||||
seat_result = check_seat_fn(db_session=db_session)
|
||||
if seat_result is not None and not seat_result.available:
|
||||
logger.info(
|
||||
f"Blocked inactive Slack user {message_info.email}: "
|
||||
f"{seat_result.error_message}"
|
||||
)
|
||||
respond_in_thread_or_channel(
|
||||
client=client,
|
||||
channel=channel,
|
||||
thread_ts=message_info.msg_to_respond,
|
||||
text=(
|
||||
"We weren't able to respond because your organization "
|
||||
"has reached its user seat limit. Your account is "
|
||||
"currently deactivated and cannot be reactivated "
|
||||
"until more seats are available. Please contact "
|
||||
"your Onyx administrator."
|
||||
),
|
||||
)
|
||||
return False
|
||||
|
||||
activate_user(existing_user, db_session)
|
||||
invalidate_license_cache_fn = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license",
|
||||
"invalidate_license_cache",
|
||||
None,
|
||||
)
|
||||
invalidate_license_cache_fn()
|
||||
logger.info(f"Reactivated inactive Slack user {message_info.email}")
|
||||
|
||||
add_slack_user_if_not_exists(db_session, message_info.email)
|
||||
|
||||
# first check if we need to respond with a standard answer
|
||||
|
||||
@@ -32,7 +32,7 @@ from onyx.db.persona import get_persona_snapshots_for_user
|
||||
from onyx.db.persona import get_persona_snapshots_paginated
|
||||
from onyx.db.persona import mark_persona_as_deleted
|
||||
from onyx.db.persona import mark_persona_as_not_deleted
|
||||
from onyx.db.persona import update_persona_featured
|
||||
from onyx.db.persona import update_persona_is_default
|
||||
from onyx.db.persona import update_persona_label
|
||||
from onyx.db.persona import update_persona_public_status
|
||||
from onyx.db.persona import update_persona_shared
|
||||
@@ -130,8 +130,8 @@ class IsPublicRequest(BaseModel):
|
||||
is_public: bool
|
||||
|
||||
|
||||
class IsFeaturedRequest(BaseModel):
|
||||
featured: bool
|
||||
class IsDefaultRequest(BaseModel):
|
||||
is_default_persona: bool
|
||||
|
||||
|
||||
@admin_router.patch("/{persona_id}/visible")
|
||||
@@ -168,22 +168,22 @@ def patch_user_persona_public_status(
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.patch("/{persona_id}/featured")
|
||||
def patch_persona_featured_status(
|
||||
@admin_router.patch("/{persona_id}/default")
|
||||
def patch_persona_default_status(
|
||||
persona_id: int,
|
||||
is_featured_request: IsFeaturedRequest,
|
||||
is_default_request: IsDefaultRequest,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
update_persona_featured(
|
||||
update_persona_is_default(
|
||||
persona_id=persona_id,
|
||||
featured=is_featured_request.featured,
|
||||
is_default=is_default_request.is_default_persona,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to update persona featured status")
|
||||
logger.exception("Failed to update persona default status")
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import HierarchyNodeType
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import HierarchyNode
|
||||
@@ -107,7 +108,11 @@ class PersonaUpsertRequest(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
document_set_ids: list[int]
|
||||
num_chunks: float
|
||||
is_public: bool
|
||||
recency_bias: RecencyBiasSetting
|
||||
llm_filter_extraction: bool
|
||||
llm_relevance_filter: bool
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
starter_messages: list[StarterMessage] | None = None
|
||||
@@ -123,7 +128,7 @@ class PersonaUpsertRequest(BaseModel):
|
||||
)
|
||||
search_start_date: datetime | None = None
|
||||
label_ids: list[int] | None = None
|
||||
featured: bool = False
|
||||
is_default_persona: bool = False
|
||||
display_priority: int | None = None
|
||||
# Accept string UUIDs from frontend
|
||||
user_file_ids: list[str] | None = None
|
||||
@@ -150,6 +155,9 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
tools: list[ToolSnapshot]
|
||||
starter_messages: list[StarterMessage] | None
|
||||
|
||||
llm_relevance_filter: bool
|
||||
llm_filter_extraction: bool
|
||||
|
||||
# only show document sets in the UI that the assistant has access to
|
||||
document_sets: list[DocumentSetSummary]
|
||||
# Counts for knowledge sources (used to determine if search tool should be enabled)
|
||||
@@ -167,7 +175,7 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
is_public: bool
|
||||
is_visible: bool
|
||||
display_priority: int | None
|
||||
featured: bool
|
||||
is_default_persona: bool
|
||||
builtin_persona: bool
|
||||
|
||||
# Used for filtering
|
||||
@@ -206,6 +214,8 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
if should_expose_tool_to_fe(tool)
|
||||
],
|
||||
starter_messages=persona.starter_messages,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
document_sets=[
|
||||
DocumentSetSummary.from_model(document_set)
|
||||
for document_set in persona.document_sets
|
||||
@@ -220,7 +230,7 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
is_public=persona.is_public,
|
||||
is_visible=persona.is_visible,
|
||||
display_priority=persona.display_priority,
|
||||
featured=persona.featured,
|
||||
is_default_persona=persona.is_default_persona,
|
||||
builtin_persona=persona.builtin_persona,
|
||||
labels=[PersonaLabelSnapshot.from_model(label) for label in persona.labels],
|
||||
owner=(
|
||||
@@ -242,9 +252,11 @@ class PersonaSnapshot(BaseModel):
|
||||
# Return string UUIDs to frontend for consistency
|
||||
user_file_ids: list[str]
|
||||
display_priority: int | None
|
||||
featured: bool
|
||||
is_default_persona: bool
|
||||
builtin_persona: bool
|
||||
starter_messages: list[StarterMessage] | None
|
||||
llm_relevance_filter: bool
|
||||
llm_filter_extraction: bool
|
||||
tools: list[ToolSnapshot]
|
||||
labels: list["PersonaLabelSnapshot"]
|
||||
owner: MinimalUserSnapshot | None
|
||||
@@ -253,6 +265,7 @@ class PersonaSnapshot(BaseModel):
|
||||
document_sets: list[DocumentSetSummary]
|
||||
llm_model_provider_override: str | None
|
||||
llm_model_version_override: str | None
|
||||
num_chunks: float | None
|
||||
# Hierarchy nodes attached for scoped search
|
||||
hierarchy_nodes: list[HierarchyNodeSnapshot] = Field(default_factory=list)
|
||||
# Individual documents attached for scoped search
|
||||
@@ -276,9 +289,11 @@ class PersonaSnapshot(BaseModel):
|
||||
icon_name=persona.icon_name,
|
||||
user_file_ids=[str(file.id) for file in persona.user_files],
|
||||
display_priority=persona.display_priority,
|
||||
featured=persona.featured,
|
||||
is_default_persona=persona.is_default_persona,
|
||||
builtin_persona=persona.builtin_persona,
|
||||
starter_messages=persona.starter_messages,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
tools=[
|
||||
ToolSnapshot.from_model(tool)
|
||||
for tool in persona.tools
|
||||
@@ -309,6 +324,7 @@ class PersonaSnapshot(BaseModel):
|
||||
],
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
num_chunks=persona.num_chunks,
|
||||
system_prompt=persona.system_prompt,
|
||||
replace_base_system_prompt=persona.replace_base_system_prompt,
|
||||
task_prompt=persona.task_prompt,
|
||||
@@ -316,10 +332,12 @@ class PersonaSnapshot(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
# Model with full context on persona's internal settings
|
||||
# Model with full context on perona's internal settings
|
||||
# This is used for flows which need to know all settings
|
||||
class FullPersonaSnapshot(PersonaSnapshot):
|
||||
search_start_date: datetime | None = None
|
||||
llm_relevance_filter: bool = False
|
||||
llm_filter_extraction: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
@@ -342,7 +360,7 @@ class FullPersonaSnapshot(PersonaSnapshot):
|
||||
icon_name=persona.icon_name,
|
||||
user_file_ids=[str(file.id) for file in persona.user_files],
|
||||
display_priority=persona.display_priority,
|
||||
featured=persona.featured,
|
||||
is_default_persona=persona.is_default_persona,
|
||||
builtin_persona=persona.builtin_persona,
|
||||
starter_messages=persona.starter_messages,
|
||||
users=[
|
||||
@@ -373,7 +391,10 @@ class FullPersonaSnapshot(PersonaSnapshot):
|
||||
DocumentSetSummary.from_model(document_set_model)
|
||||
for document_set_model in persona.document_sets
|
||||
],
|
||||
num_chunks=persona.num_chunks,
|
||||
search_start_date=persona.search_start_date,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
system_prompt=persona.system_prompt,
|
||||
|
||||
@@ -2,7 +2,6 @@ import json
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import BackgroundTasks
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import Form
|
||||
@@ -13,7 +12,13 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
enqueue_user_file_project_sync_task,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
get_user_file_project_sync_queue_depth,
|
||||
)
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -29,6 +34,7 @@ from onyx.db.models import UserProject
|
||||
from onyx.db.persona import get_personas_by_ids
|
||||
from onyx.db.projects import get_project_token_count
|
||||
from onyx.db.projects import upload_files_to_user_files_with_indexing
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.features.projects.models import CategorizedFilesSnapshot
|
||||
from onyx.server.features.projects.models import ChatSessionRequest
|
||||
from onyx.server.features.projects.models import TokenCountResponse
|
||||
@@ -49,27 +55,7 @@ class UserFileDeleteResult(BaseModel):
|
||||
assistant_names: list[str] = []
|
||||
|
||||
|
||||
def _trigger_user_file_project_sync(
|
||||
user_file_id: UUID,
|
||||
tenant_id: str,
|
||||
background_tasks: BackgroundTasks | None = None,
|
||||
) -> None:
|
||||
if DISABLE_VECTOR_DB and background_tasks is not None:
|
||||
from onyx.background.task_utils import drain_project_sync_loop
|
||||
|
||||
background_tasks.add_task(drain_project_sync_loop, tenant_id)
|
||||
logger.info(f"Queued in-process project sync for user_file_id={user_file_id}")
|
||||
return
|
||||
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
enqueue_user_file_project_sync_task,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
get_user_file_project_sync_queue_depth,
|
||||
)
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
|
||||
def _trigger_user_file_project_sync(user_file_id: UUID, tenant_id: str) -> None:
|
||||
queue_depth = get_user_file_project_sync_queue_depth(client_app)
|
||||
if queue_depth > USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH:
|
||||
logger.warning(
|
||||
@@ -125,7 +111,6 @@ def create_project(
|
||||
|
||||
@router.post("/file/upload", tags=PUBLIC_API_TAGS)
|
||||
def upload_user_files(
|
||||
bg_tasks: BackgroundTasks,
|
||||
files: list[UploadFile] = File(...),
|
||||
project_id: int | None = Form(None),
|
||||
temp_id_map: str | None = Form(None), # JSON string mapping hashed key -> temp_id
|
||||
@@ -152,12 +137,12 @@ def upload_user_files(
|
||||
user=user,
|
||||
temp_id_map=parsed_temp_id_map,
|
||||
db_session=db_session,
|
||||
background_tasks=bg_tasks if DISABLE_VECTOR_DB else None,
|
||||
)
|
||||
|
||||
return CategorizedFilesSnapshot.from_result(categorized_files_result)
|
||||
|
||||
except Exception as e:
|
||||
# Log error with type, message, and stack for easier debugging
|
||||
logger.exception(f"Error uploading files - {type(e).__name__}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
@@ -207,7 +192,6 @@ def get_files_in_project(
|
||||
def unlink_user_file_from_project(
|
||||
project_id: int,
|
||||
file_id: UUID,
|
||||
bg_tasks: BackgroundTasks,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
@@ -224,6 +208,7 @@ def unlink_user_file_from_project(
|
||||
if project is None:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
user_id = user.id
|
||||
user_file = (
|
||||
db_session.query(UserFile)
|
||||
.filter(UserFile.id == file_id, UserFile.user_id == user_id)
|
||||
@@ -239,7 +224,7 @@ def unlink_user_file_from_project(
|
||||
db_session.commit()
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
_trigger_user_file_project_sync(user_file.id, tenant_id, bg_tasks)
|
||||
_trigger_user_file_project_sync(user_file.id, tenant_id)
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
@@ -252,7 +237,6 @@ def unlink_user_file_from_project(
|
||||
def link_user_file_to_project(
|
||||
project_id: int,
|
||||
file_id: UUID,
|
||||
bg_tasks: BackgroundTasks,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserFileSnapshot:
|
||||
@@ -284,7 +268,7 @@ def link_user_file_to_project(
|
||||
db_session.commit()
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
_trigger_user_file_project_sync(user_file.id, tenant_id, bg_tasks)
|
||||
_trigger_user_file_project_sync(user_file.id, tenant_id)
|
||||
|
||||
return UserFileSnapshot.from_model(user_file)
|
||||
|
||||
@@ -351,7 +335,7 @@ def upsert_project_instructions(
|
||||
class ProjectPayload(BaseModel):
|
||||
project: UserProjectSnapshot
|
||||
files: list[UserFileSnapshot] | None = None
|
||||
persona_id_to_featured: dict[int, bool] | None = None
|
||||
persona_id_to_is_default: dict[int, bool] | None = None
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -370,11 +354,13 @@ def get_project_details(
|
||||
if session.persona_id is not None
|
||||
]
|
||||
personas = get_personas_by_ids(persona_ids, db_session)
|
||||
persona_id_to_featured = {persona.id: persona.featured for persona in personas}
|
||||
persona_id_to_is_default = {
|
||||
persona.id: persona.is_default_persona for persona in personas
|
||||
}
|
||||
return ProjectPayload(
|
||||
project=project,
|
||||
files=files,
|
||||
persona_id_to_featured=persona_id_to_featured,
|
||||
persona_id_to_is_default=persona_id_to_is_default,
|
||||
)
|
||||
|
||||
|
||||
@@ -440,7 +426,6 @@ def delete_project(
|
||||
@router.delete("/file/{file_id}", tags=PUBLIC_API_TAGS)
|
||||
def delete_user_file(
|
||||
file_id: UUID,
|
||||
bg_tasks: BackgroundTasks,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserFileDeleteResult:
|
||||
@@ -473,25 +458,15 @@ def delete_user_file(
|
||||
db_session.commit()
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
if DISABLE_VECTOR_DB:
|
||||
from onyx.background.task_utils import drain_delete_loop
|
||||
|
||||
bg_tasks.add_task(drain_delete_loop, tenant_id)
|
||||
logger.info(f"Queued in-process delete for user_file_id={user_file.id}")
|
||||
else:
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file.id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_DELETE,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered delete for user_file_id={user_file.id} "
|
||||
f"with task_id={task.id}"
|
||||
)
|
||||
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file.id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_DELETE,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered delete for user_file_id={user_file.id} with task_id={task.id}"
|
||||
)
|
||||
return UserFileDeleteResult(
|
||||
has_associations=False, project_names=[], assistant_names=[]
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
|
||||
from onyx.configs.kg_configs import KG_BETA_ASSISTANT_DESCRIPTION
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.entities import get_entity_stats_by_grounded_source_name
|
||||
from onyx.db.entity_type import get_configured_entity_types
|
||||
@@ -133,7 +134,11 @@ def enable_or_disable_kg(
|
||||
system_prompt=KG_BETA_ASSISTANT_SYSTEM_PROMPT,
|
||||
task_prompt=KG_BETA_ASSISTANT_TASK_PROMPT,
|
||||
datetime_aware=False,
|
||||
num_chunks=25,
|
||||
llm_relevance_filter=False,
|
||||
is_public=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
document_set_ids=[],
|
||||
tool_ids=[search_tool.id, kg_tool.id],
|
||||
llm_model_provider_override=None,
|
||||
@@ -142,7 +147,7 @@ def enable_or_disable_kg(
|
||||
users=[user.id],
|
||||
groups=[],
|
||||
label_ids=[],
|
||||
featured=False,
|
||||
is_default_persona=False,
|
||||
display_priority=0,
|
||||
user_file_ids=[],
|
||||
)
|
||||
|
||||
@@ -603,9 +603,9 @@ def list_llm_provider_basics(
|
||||
for provider in all_providers:
|
||||
# Use centralized access control logic with persona=None since we're
|
||||
# listing providers without a specific persona context. This correctly:
|
||||
# - Includes public providers WITHOUT persona restrictions
|
||||
# - Includes all public providers
|
||||
# - Includes providers user can access via group membership
|
||||
# - Excludes providers with persona restrictions (requires specific persona)
|
||||
# - Excludes persona-only restricted providers (requires specific persona)
|
||||
# - Excludes non-public providers with no restrictions (admin-only)
|
||||
if can_user_access_llm_provider(
|
||||
provider, user_group_ids, persona=None, is_admin=is_admin
|
||||
@@ -638,7 +638,7 @@ def get_valid_model_names_for_persona(
|
||||
|
||||
Returns a list of model names (e.g., ["gpt-4o", "claude-3-5-sonnet"]) that are
|
||||
available to the user when using this persona, respecting all RBAC restrictions.
|
||||
Public providers are included unless they have persona restrictions that exclude this persona.
|
||||
Public providers are always included.
|
||||
"""
|
||||
persona = fetch_persona_with_groups(db_session, persona_id)
|
||||
if not persona:
|
||||
@@ -652,7 +652,7 @@ def get_valid_model_names_for_persona(
|
||||
|
||||
valid_models = []
|
||||
for llm_provider_model in all_providers:
|
||||
# Check access with persona context — respects all RBAC restrictions
|
||||
# Public providers always included, restricted checked via RBAC
|
||||
if can_user_access_llm_provider(
|
||||
llm_provider_model, user_group_ids, persona, is_admin=is_admin
|
||||
):
|
||||
@@ -673,7 +673,7 @@ def list_llm_providers_for_persona(
|
||||
"""Get LLM providers for a specific persona.
|
||||
|
||||
Returns providers that the user can access when using this persona:
|
||||
- Public providers (respecting persona restrictions if set)
|
||||
- All public providers (is_public=True) - ALWAYS included
|
||||
- Restricted providers user can access via group/persona restrictions
|
||||
|
||||
This endpoint is used for background fetching of restricted providers
|
||||
@@ -702,7 +702,7 @@ def list_llm_providers_for_persona(
|
||||
llm_provider_list: list[LLMProviderDescriptor] = []
|
||||
|
||||
for llm_provider_model in all_providers:
|
||||
# Check access with persona context — respects persona restrictions
|
||||
# Use simplified access check - public providers always included
|
||||
if can_user_access_llm_provider(
|
||||
llm_provider_model, user_group_ids, persona, is_admin=is_admin
|
||||
):
|
||||
|
||||
@@ -198,6 +198,7 @@ def patch_slack_channel_config(
|
||||
channel_name=channel_config["channel_name"],
|
||||
document_set_ids=slack_channel_config_creation_request.document_sets,
|
||||
existing_persona_id=existing_persona_id,
|
||||
enable_auto_filters=slack_channel_config_creation_request.enable_auto_filters,
|
||||
).id
|
||||
|
||||
slack_channel_config_model = update_slack_channel_config(
|
||||
|
||||
@@ -152,20 +152,10 @@ def get_user_chat_sessions(
|
||||
project_id: int | None = None,
|
||||
only_non_project_chats: bool = True,
|
||||
include_failed_chats: bool = False,
|
||||
page_size: int = Query(default=50, ge=1, le=100),
|
||||
before: str | None = Query(default=None),
|
||||
) -> ChatSessionsResponse:
|
||||
user_id = user.id
|
||||
|
||||
try:
|
||||
before_dt = (
|
||||
datetime.datetime.fromisoformat(before) if before is not None else None
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=422, detail="Invalid 'before' timestamp format")
|
||||
|
||||
try:
|
||||
# Fetch one extra to determine if there are more results
|
||||
chat_sessions = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
@@ -173,16 +163,11 @@ def get_user_chat_sessions(
|
||||
project_id=project_id,
|
||||
only_non_project_chats=only_non_project_chats,
|
||||
include_failed_chats=include_failed_chats,
|
||||
limit=page_size + 1,
|
||||
before=before_dt,
|
||||
)
|
||||
|
||||
except ValueError:
|
||||
raise ValueError("Chat session does not exist or has been deleted")
|
||||
|
||||
has_more = len(chat_sessions) > page_size
|
||||
chat_sessions = chat_sessions[:page_size]
|
||||
|
||||
return ChatSessionsResponse(
|
||||
sessions=[
|
||||
ChatSessionDetails(
|
||||
@@ -196,8 +181,7 @@ def get_user_chat_sessions(
|
||||
current_temperature_override=chat.temperature_override,
|
||||
)
|
||||
for chat in chat_sessions
|
||||
],
|
||||
has_more=has_more,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -192,7 +192,6 @@ class ChatSessionDetails(BaseModel):
|
||||
|
||||
class ChatSessionsResponse(BaseModel):
|
||||
sessions: list[ChatSessionDetails]
|
||||
has_more: bool = False
|
||||
|
||||
|
||||
class ChatMessageDetail(BaseModel):
|
||||
|
||||
@@ -8,3 +8,37 @@ dependencies = [
|
||||
|
||||
[tool.uv.sources]
|
||||
onyx = { workspace = true }
|
||||
|
||||
[tool.mypy]
|
||||
plugins = "sqlalchemy.ext.mypy.plugin"
|
||||
mypy_path = "backend"
|
||||
explicit_package_bases = true
|
||||
disallow_untyped_defs = true
|
||||
warn_unused_ignores = true
|
||||
enable_error_code = ["possibly-undefined"]
|
||||
strict_equality = true
|
||||
# Patterns match paths whether mypy is run from backend/ (CI) or repo root (e.g. VS Code extension with target ./backend)
|
||||
exclude = [
|
||||
"(?:^|/)generated/",
|
||||
"(?:^|/)\\.venv/",
|
||||
"(?:^|/)onyx/server/features/build/sandbox/kubernetes/docker/skills/",
|
||||
"(?:^|/)onyx/server/features/build/sandbox/kubernetes/docker/templates/",
|
||||
]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "alembic.versions.*"
|
||||
disable_error_code = ["var-annotated"]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "alembic_tenants.versions.*"
|
||||
disable_error_code = ["var-annotated"]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "generated.*"
|
||||
follow_imports = "silent"
|
||||
ignore_errors = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "transformers.*"
|
||||
follow_imports = "skip"
|
||||
ignore_errors = true
|
||||
|
||||
@@ -528,7 +528,7 @@ lxml==5.3.0
|
||||
# unstructured
|
||||
# xmlsec
|
||||
# zeep
|
||||
lxml-html-clean==0.4.4
|
||||
lxml-html-clean==0.4.3
|
||||
# via lxml
|
||||
magika==0.6.3
|
||||
# via markitdown
|
||||
|
||||
@@ -9,6 +9,7 @@ from onyx.chat.models import AnswerStreamPart
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.process_message import handle_stream_message_objects
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.models import RecencyBiasSetting
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
@@ -73,6 +74,10 @@ def test_stream_chat_message_objects_without_web_search(
|
||||
user=None, # System persona
|
||||
name=f"Test Persona {uuid.uuid4()}",
|
||||
description="Test persona with no tools for web search test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=True,
|
||||
recency_bias=RecencyBiasSetting.BASE_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
|
||||
@@ -1,257 +0,0 @@
|
||||
"""External dependency unit tests for periodic task claiming.
|
||||
|
||||
Tests ``_try_claim_task`` and ``_try_run_periodic_task`` against real
|
||||
PostgreSQL, verifying happy-path behavior and concurrent-access safety.
|
||||
|
||||
The claim mechanism uses a transaction-scoped advisory lock + a KVStore
|
||||
timestamp for cross-instance dedup. The DB session is released before
|
||||
the task runs, so long-running tasks don't hold connections.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.background.periodic_poller import _PeriodicTaskDef
|
||||
from onyx.background.periodic_poller import _try_claim_task
|
||||
from onyx.background.periodic_poller import _try_run_periodic_task
|
||||
from onyx.background.periodic_poller import PERIODIC_TASK_KV_PREFIX
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.models import KVStore
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
_TEST_LOCK_BASE = 90_000
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def _init_engine() -> None:
|
||||
SqlEngine.init_engine(pool_size=10, max_overflow=5)
|
||||
|
||||
|
||||
def _make_task(
|
||||
*,
|
||||
name: str | None = None,
|
||||
interval: float = 3600,
|
||||
lock_id: int | None = None,
|
||||
run_fn: MagicMock | None = None,
|
||||
) -> _PeriodicTaskDef:
|
||||
return _PeriodicTaskDef(
|
||||
name=name if name is not None else f"test-{uuid4().hex[:8]}",
|
||||
interval_seconds=interval,
|
||||
lock_id=lock_id if lock_id is not None else _TEST_LOCK_BASE,
|
||||
run_fn=run_fn if run_fn is not None else MagicMock(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _cleanup_kv(
|
||||
tenant_context: None, # noqa: ARG001
|
||||
) -> Generator[None, None, None]:
|
||||
yield
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_session.query(KVStore).filter(
|
||||
KVStore.key.like(f"{PERIODIC_TASK_KV_PREFIX}test-%")
|
||||
).delete(synchronize_session=False)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Happy-path: _try_claim_task
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClaimHappyPath:
|
||||
def test_first_claim_succeeds(self) -> None:
|
||||
assert _try_claim_task(_make_task()) is True
|
||||
|
||||
def test_first_claim_creates_kv_row(self) -> None:
|
||||
task = _make_task()
|
||||
_try_claim_task(task)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
row = (
|
||||
db_session.query(KVStore)
|
||||
.filter_by(key=PERIODIC_TASK_KV_PREFIX + task.name)
|
||||
.first()
|
||||
)
|
||||
assert row is not None
|
||||
assert row.value is not None
|
||||
|
||||
def test_second_claim_within_interval_fails(self) -> None:
|
||||
task = _make_task(interval=3600)
|
||||
assert _try_claim_task(task) is True
|
||||
assert _try_claim_task(task) is False
|
||||
|
||||
def test_claim_after_interval_succeeds(self) -> None:
|
||||
task = _make_task(interval=1)
|
||||
assert _try_claim_task(task) is True
|
||||
|
||||
kv_key = PERIODIC_TASK_KV_PREFIX + task.name
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
row = db_session.query(KVStore).filter_by(key=kv_key).first()
|
||||
assert row is not None
|
||||
row.value = (datetime.now(timezone.utc) - timedelta(seconds=10)).isoformat()
|
||||
db_session.commit()
|
||||
|
||||
assert _try_claim_task(task) is True
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Happy-path: _try_run_periodic_task
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunHappyPath:
|
||||
def test_runs_task_and_updates_last_run_at(self) -> None:
|
||||
mock_fn = MagicMock()
|
||||
task = _make_task(run_fn=mock_fn)
|
||||
|
||||
_try_run_periodic_task(task)
|
||||
|
||||
mock_fn.assert_called_once()
|
||||
assert task.last_run_at > 0
|
||||
|
||||
def test_skips_when_in_memory_interval_not_elapsed(self) -> None:
|
||||
mock_fn = MagicMock()
|
||||
task = _make_task(run_fn=mock_fn, interval=3600)
|
||||
task.last_run_at = time.monotonic()
|
||||
|
||||
_try_run_periodic_task(task)
|
||||
|
||||
mock_fn.assert_not_called()
|
||||
|
||||
def test_skips_when_db_claim_blocked(self) -> None:
|
||||
name = f"test-{uuid4().hex[:8]}"
|
||||
lock_id = _TEST_LOCK_BASE + 10
|
||||
|
||||
_try_claim_task(_make_task(name=name, lock_id=lock_id, interval=3600))
|
||||
|
||||
mock_fn = MagicMock()
|
||||
task = _make_task(name=name, lock_id=lock_id, interval=3600, run_fn=mock_fn)
|
||||
_try_run_periodic_task(task)
|
||||
|
||||
mock_fn.assert_not_called()
|
||||
|
||||
def test_task_exception_does_not_propagate(self) -> None:
|
||||
task = _make_task(run_fn=MagicMock(side_effect=RuntimeError("boom")))
|
||||
_try_run_periodic_task(task)
|
||||
|
||||
def test_claim_committed_before_task_runs(self) -> None:
|
||||
"""The KV claim must be visible in the DB when run_fn executes."""
|
||||
task_name = f"test-order-{uuid4().hex[:8]}"
|
||||
kv_key = PERIODIC_TASK_KV_PREFIX + task_name
|
||||
claim_visible: list[bool] = []
|
||||
|
||||
def check_claim() -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
row = db_session.query(KVStore).filter_by(key=kv_key).first()
|
||||
claim_visible.append(row is not None and row.value is not None)
|
||||
|
||||
task = _PeriodicTaskDef(
|
||||
name=task_name,
|
||||
interval_seconds=3600,
|
||||
lock_id=_TEST_LOCK_BASE + 11,
|
||||
run_fn=check_claim,
|
||||
)
|
||||
|
||||
_try_run_periodic_task(task)
|
||||
|
||||
assert claim_visible == [True]
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Concurrency: only one claimer should win
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClaimConcurrency:
|
||||
def test_concurrent_claims_single_winner(self) -> None:
|
||||
"""Many threads claim the same task — exactly one should succeed."""
|
||||
num_threads = 20
|
||||
task_name = f"test-race-{uuid4().hex[:8]}"
|
||||
lock_id = _TEST_LOCK_BASE + 20
|
||||
|
||||
def claim() -> bool:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
|
||||
return _try_claim_task(
|
||||
_PeriodicTaskDef(
|
||||
name=task_name,
|
||||
interval_seconds=3600,
|
||||
lock_id=lock_id,
|
||||
run_fn=lambda: None,
|
||||
)
|
||||
)
|
||||
|
||||
results: list[bool] = []
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = [executor.submit(claim) for _ in range(num_threads)]
|
||||
for future in as_completed(futures):
|
||||
results.append(future.result())
|
||||
|
||||
winners = sum(1 for r in results if r)
|
||||
assert winners == 1, f"Expected 1 winner, got {winners}"
|
||||
|
||||
def test_concurrent_run_single_execution(self) -> None:
|
||||
"""Many threads run the same task — run_fn fires exactly once."""
|
||||
num_threads = 20
|
||||
task_name = f"test-run-race-{uuid4().hex[:8]}"
|
||||
lock_id = _TEST_LOCK_BASE + 21
|
||||
counter = MagicMock()
|
||||
|
||||
def run() -> None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
|
||||
_try_run_periodic_task(
|
||||
_PeriodicTaskDef(
|
||||
name=task_name,
|
||||
interval_seconds=3600,
|
||||
lock_id=lock_id,
|
||||
run_fn=counter,
|
||||
)
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = [executor.submit(run) for _ in range(num_threads)]
|
||||
for future in as_completed(futures):
|
||||
future.result()
|
||||
|
||||
assert (
|
||||
counter.call_count == 1
|
||||
), f"Expected run_fn called once, got {counter.call_count}"
|
||||
|
||||
def test_no_errors_under_contention(self) -> None:
|
||||
"""All threads complete without exceptions under high contention."""
|
||||
num_threads = 30
|
||||
task_name = f"test-err-{uuid4().hex[:8]}"
|
||||
lock_id = _TEST_LOCK_BASE + 22
|
||||
errors: list[Exception] = []
|
||||
|
||||
def claim() -> bool:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
|
||||
return _try_claim_task(
|
||||
_PeriodicTaskDef(
|
||||
name=task_name,
|
||||
interval_seconds=3600,
|
||||
lock_id=lock_id,
|
||||
run_fn=lambda: None,
|
||||
)
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = [executor.submit(claim) for _ in range(num_threads)]
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
future.result()
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
assert errors == [], f"Got {len(errors)} errors: {errors}"
|
||||
@@ -1,219 +0,0 @@
|
||||
"""External dependency unit tests for startup recovery (Step 10g).
|
||||
|
||||
Seeds ``UserFile`` records in stuck states (PROCESSING, DELETING,
|
||||
needs_project_sync) then calls ``recover_stuck_user_files`` and verifies
|
||||
the drain loops pick them up via ``FOR UPDATE SKIP LOCKED``.
|
||||
|
||||
Uses real PostgreSQL (via ``db_session`` / ``tenant_context`` fixtures).
|
||||
The per-file ``*_impl`` functions are mocked so no real file store or
|
||||
connector is needed — we only verify that recovery finds and dispatches
|
||||
the correct files.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.periodic_poller import recover_stuck_user_files
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import UserFile
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_IMPL_MODULE = "onyx.background.celery.tasks.user_file_processing.tasks"
|
||||
|
||||
|
||||
def _create_user_file(
|
||||
db_session: Session,
|
||||
user_id: object,
|
||||
*,
|
||||
status: UserFileStatus = UserFileStatus.PROCESSING,
|
||||
needs_project_sync: bool = False,
|
||||
needs_persona_sync: bool = False,
|
||||
) -> UserFile:
|
||||
uf = UserFile(
|
||||
id=uuid4(),
|
||||
user_id=user_id,
|
||||
file_id=f"test_file_{uuid4().hex[:8]}",
|
||||
name=f"test_{uuid4().hex[:8]}.txt",
|
||||
file_type="text/plain",
|
||||
status=status,
|
||||
needs_project_sync=needs_project_sync,
|
||||
needs_persona_sync=needs_persona_sync,
|
||||
)
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
db_session.refresh(uf)
|
||||
return uf
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _cleanup_user_files(db_session: Session) -> Generator[list[UserFile], None, None]:
|
||||
"""Track created UserFile rows and delete them after each test."""
|
||||
created: list[UserFile] = []
|
||||
yield created
|
||||
for uf in created:
|
||||
existing = db_session.get(UserFile, uf.id)
|
||||
if existing:
|
||||
db_session.delete(existing)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRecoverProcessingFiles:
|
||||
"""Files in PROCESSING status are re-processed via the processing drain loop."""
|
||||
|
||||
def test_processing_files_recovered(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_proc")
|
||||
uf = _create_user_file(db_session, user.id, status=UserFileStatus.PROCESSING)
|
||||
_cleanup_user_files.append(uf)
|
||||
|
||||
mock_impl = MagicMock()
|
||||
with patch(f"{_IMPL_MODULE}.process_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert (
|
||||
str(uf.id) in called_ids
|
||||
), f"Expected file {uf.id} to be recovered but got: {called_ids}"
|
||||
|
||||
def test_completed_files_not_recovered(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_comp")
|
||||
uf = _create_user_file(db_session, user.id, status=UserFileStatus.COMPLETED)
|
||||
_cleanup_user_files.append(uf)
|
||||
|
||||
mock_impl = MagicMock()
|
||||
with patch(f"{_IMPL_MODULE}.process_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert (
|
||||
str(uf.id) not in called_ids
|
||||
), f"COMPLETED file {uf.id} should not have been recovered"
|
||||
|
||||
|
||||
class TestRecoverDeletingFiles:
|
||||
"""Files in DELETING status are recovered via the delete drain loop."""
|
||||
|
||||
def test_deleting_files_recovered(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_del")
|
||||
uf = _create_user_file(db_session, user.id, status=UserFileStatus.DELETING)
|
||||
_cleanup_user_files.append(uf)
|
||||
|
||||
mock_impl = MagicMock()
|
||||
with patch(f"{_IMPL_MODULE}.delete_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert (
|
||||
str(uf.id) in called_ids
|
||||
), f"Expected file {uf.id} to be recovered for deletion but got: {called_ids}"
|
||||
|
||||
|
||||
class TestRecoverSyncFiles:
|
||||
"""Files needing project/persona sync are recovered via the sync drain loop."""
|
||||
|
||||
def test_needs_project_sync_recovered(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_sync")
|
||||
uf = _create_user_file(
|
||||
db_session,
|
||||
user.id,
|
||||
status=UserFileStatus.COMPLETED,
|
||||
needs_project_sync=True,
|
||||
)
|
||||
_cleanup_user_files.append(uf)
|
||||
|
||||
mock_impl = MagicMock()
|
||||
with patch(f"{_IMPL_MODULE}.project_sync_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert (
|
||||
str(uf.id) in called_ids
|
||||
), f"Expected file {uf.id} to be recovered for sync but got: {called_ids}"
|
||||
|
||||
def test_needs_persona_sync_recovered(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_psync")
|
||||
uf = _create_user_file(
|
||||
db_session,
|
||||
user.id,
|
||||
status=UserFileStatus.COMPLETED,
|
||||
needs_persona_sync=True,
|
||||
)
|
||||
_cleanup_user_files.append(uf)
|
||||
|
||||
mock_impl = MagicMock()
|
||||
with patch(f"{_IMPL_MODULE}.project_sync_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert (
|
||||
str(uf.id) in called_ids
|
||||
), f"Expected file {uf.id} to be recovered for persona sync but got: {called_ids}"
|
||||
|
||||
|
||||
class TestRecoveryMultipleFiles:
|
||||
"""Recovery processes all stuck files in one pass, not just the first."""
|
||||
|
||||
def test_multiple_processing_files(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_multi")
|
||||
files = []
|
||||
for _ in range(3):
|
||||
uf = _create_user_file(
|
||||
db_session, user.id, status=UserFileStatus.PROCESSING
|
||||
)
|
||||
_cleanup_user_files.append(uf)
|
||||
files.append(uf)
|
||||
|
||||
mock_impl = MagicMock()
|
||||
with patch(f"{_IMPL_MODULE}.process_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = {call.kwargs["user_file_id"] for call in mock_impl.call_args_list}
|
||||
expected_ids = {str(uf.id) for uf in files}
|
||||
assert expected_ids.issubset(called_ids), (
|
||||
f"Expected all {len(files)} files to be recovered. "
|
||||
f"Missing: {expected_ids - called_ids}"
|
||||
)
|
||||
@@ -36,6 +36,7 @@ from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
user_file_project_sync_lock_key,
|
||||
)
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserFile
|
||||
@@ -85,6 +86,12 @@ def _create_test_persona(
|
||||
persona = Persona(
|
||||
name=f"Test Persona {uuid4().hex[:8]}",
|
||||
description="Test persona",
|
||||
num_chunks=10.0,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
system_prompt="You are a test assistant",
|
||||
task_prompt="Answer the question",
|
||||
tools=[],
|
||||
@@ -403,6 +410,10 @@ class TestUpsertPersonaMarksSyncFlag:
|
||||
user=user,
|
||||
name=f"persona-{uuid4().hex[:8]}",
|
||||
description="test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
@@ -431,6 +442,10 @@ class TestUpsertPersonaMarksSyncFlag:
|
||||
user=user,
|
||||
name=f"persona-{uuid4().hex[:8]}",
|
||||
description="test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
@@ -446,11 +461,16 @@ class TestUpsertPersonaMarksSyncFlag:
|
||||
uf_old.needs_persona_sync = False
|
||||
db_session.commit()
|
||||
|
||||
assert persona.num_chunks is not None
|
||||
# Now update the persona to swap files
|
||||
upsert_persona(
|
||||
user=user,
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=persona.num_chunks,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=persona.recency_bias,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
@@ -481,6 +501,10 @@ class TestUpsertPersonaMarksSyncFlag:
|
||||
user=user,
|
||||
name=f"persona-{uuid4().hex[:8]}",
|
||||
description="test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
@@ -495,10 +519,15 @@ class TestUpsertPersonaMarksSyncFlag:
|
||||
uf.needs_persona_sync = False
|
||||
db_session.commit()
|
||||
|
||||
assert persona.num_chunks is not None
|
||||
upsert_persona(
|
||||
user=user,
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=persona.num_chunks,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=persona.recency_bias,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
|
||||
@@ -18,6 +18,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserFile
|
||||
@@ -57,6 +58,12 @@ def _create_persona(db_session: Session, user: User) -> Persona:
|
||||
persona = Persona(
|
||||
name=f"Test Persona {uuid4().hex[:8]}",
|
||||
description="Test persona",
|
||||
num_chunks=10.0,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
system_prompt="test",
|
||||
task_prompt="test",
|
||||
tools=[],
|
||||
|
||||
@@ -22,6 +22,7 @@ from sqlalchemy.orm import Session
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
from onyx.configs.constants import FederatedConnectorSource
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.context.search.federated.slack_search import fetch_and_cache_channel_metadata
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import FederatedConnector
|
||||
@@ -54,6 +55,11 @@ def _create_test_persona_with_slack_config(db_session: Session) -> Persona | Non
|
||||
persona = Persona(
|
||||
name=f"test_slack_persona_{unique_id}",
|
||||
description="Test persona for Slack federated search",
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=True,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
system_prompt="You are a helpful assistant.",
|
||||
task_prompt="Answer the user's question based on the provided context.",
|
||||
)
|
||||
@@ -812,6 +818,11 @@ def test_slack_channel_config_eager_loads_persona(db_session: Session) -> None:
|
||||
persona = Persona(
|
||||
name=f"test_eager_load_persona_{unique_id}",
|
||||
description="Test persona for eager loading test",
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=True,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
system_prompt="You are a helpful assistant.",
|
||||
task_prompt="Answer the user's question.",
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import MCPAuthenticationPerformer
|
||||
from onyx.db.enums import MCPAuthenticationType
|
||||
from onyx.db.enums import MCPTransport
|
||||
@@ -46,6 +47,12 @@ def _create_test_persona_with_mcp_tool(
|
||||
persona = Persona(
|
||||
name=f"Test MCP Persona {uuid4().hex[:8]}",
|
||||
description="Test persona with MCP tools",
|
||||
num_chunks=10.0,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
system_prompt="You are a helpful assistant",
|
||||
task_prompt="Answer the user's question",
|
||||
tools=tools,
|
||||
|
||||
@@ -17,6 +17,7 @@ import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import OAuthConfig
|
||||
from onyx.db.models import Persona
|
||||
@@ -56,6 +57,12 @@ def _create_test_persona(db_session: Session, user: User, tools: list[Tool]) ->
|
||||
persona = Persona(
|
||||
name=f"Test Persona {uuid4().hex[:8]}",
|
||||
description="Test persona",
|
||||
num_chunks=10.0,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
system_prompt="You are a helpful assistant",
|
||||
task_prompt="Answer the user's question",
|
||||
tools=tools,
|
||||
|
||||
@@ -933,7 +933,6 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi import UploadFile
|
||||
from fastapi.background import BackgroundTasks
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.datastructures import Headers
|
||||
|
||||
@@ -1140,7 +1139,6 @@ def test_code_interpreter_receives_chat_files(
|
||||
# Upload a test CSV
|
||||
csv_content = b"name,age,city\nAlice,30,NYC\nBob,25,SF\n"
|
||||
result = upload_user_files(
|
||||
bg_tasks=BackgroundTasks(),
|
||||
files=[
|
||||
UploadFile(
|
||||
file=io.BytesIO(csv_content),
|
||||
|
||||
@@ -3,6 +3,7 @@ from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.server.features.persona.models import FullPersonaSnapshot
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
@@ -19,7 +20,11 @@ class PersonaManager:
|
||||
description: str | None = None,
|
||||
system_prompt: str | None = None,
|
||||
task_prompt: str | None = None,
|
||||
num_chunks: float = 5,
|
||||
llm_relevance_filter: bool = True,
|
||||
is_public: bool = True,
|
||||
llm_filter_extraction: bool = True,
|
||||
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO,
|
||||
datetime_aware: bool = False,
|
||||
document_set_ids: list[int] | None = None,
|
||||
tool_ids: list[int] | None = None,
|
||||
@@ -30,7 +35,6 @@ class PersonaManager:
|
||||
label_ids: list[int] | None = None,
|
||||
user_file_ids: list[str] | None = None,
|
||||
display_priority: int | None = None,
|
||||
featured: bool = False,
|
||||
) -> DATestPersona:
|
||||
name = name or f"test-persona-{uuid4()}"
|
||||
description = description or f"Description for {name}"
|
||||
@@ -43,7 +47,11 @@ class PersonaManager:
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
datetime_aware=datetime_aware,
|
||||
num_chunks=num_chunks,
|
||||
llm_relevance_filter=llm_relevance_filter,
|
||||
is_public=is_public,
|
||||
llm_filter_extraction=llm_filter_extraction,
|
||||
recency_bias=recency_bias,
|
||||
document_set_ids=document_set_ids or [],
|
||||
tool_ids=tool_ids or [],
|
||||
llm_model_provider_override=llm_model_provider_override,
|
||||
@@ -53,7 +61,6 @@ class PersonaManager:
|
||||
label_ids=label_ids or [],
|
||||
user_file_ids=user_file_ids or [],
|
||||
display_priority=display_priority,
|
||||
featured=featured,
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
@@ -68,7 +75,11 @@ class PersonaManager:
|
||||
id=persona_data["id"],
|
||||
name=name,
|
||||
description=description,
|
||||
num_chunks=num_chunks,
|
||||
llm_relevance_filter=llm_relevance_filter,
|
||||
is_public=is_public,
|
||||
llm_filter_extraction=llm_filter_extraction,
|
||||
recency_bias=recency_bias,
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
datetime_aware=datetime_aware,
|
||||
@@ -79,7 +90,6 @@ class PersonaManager:
|
||||
users=users or [],
|
||||
groups=groups or [],
|
||||
label_ids=label_ids or [],
|
||||
featured=featured,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -90,7 +100,11 @@ class PersonaManager:
|
||||
description: str | None = None,
|
||||
system_prompt: str | None = None,
|
||||
task_prompt: str | None = None,
|
||||
num_chunks: float | None = None,
|
||||
llm_relevance_filter: bool | None = None,
|
||||
is_public: bool | None = None,
|
||||
llm_filter_extraction: bool | None = None,
|
||||
recency_bias: RecencyBiasSetting | None = None,
|
||||
datetime_aware: bool = False,
|
||||
document_set_ids: list[int] | None = None,
|
||||
tool_ids: list[int] | None = None,
|
||||
@@ -99,7 +113,6 @@ class PersonaManager:
|
||||
users: list[str] | None = None,
|
||||
groups: list[int] | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
featured: bool | None = None,
|
||||
) -> DATestPersona:
|
||||
system_prompt = system_prompt or f"System prompt for {persona.name}"
|
||||
task_prompt = task_prompt or f"Task prompt for {persona.name}"
|
||||
@@ -110,7 +123,13 @@ class PersonaManager:
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
datetime_aware=datetime_aware,
|
||||
num_chunks=num_chunks or persona.num_chunks,
|
||||
llm_relevance_filter=llm_relevance_filter or persona.llm_relevance_filter,
|
||||
is_public=persona.is_public if is_public is None else is_public,
|
||||
llm_filter_extraction=(
|
||||
llm_filter_extraction or persona.llm_filter_extraction
|
||||
),
|
||||
recency_bias=recency_bias or persona.recency_bias,
|
||||
document_set_ids=document_set_ids or persona.document_set_ids,
|
||||
tool_ids=tool_ids or persona.tool_ids,
|
||||
llm_model_provider_override=(
|
||||
@@ -122,7 +141,6 @@ class PersonaManager:
|
||||
users=[UUID(user) for user in (users or persona.users)],
|
||||
groups=groups or persona.groups,
|
||||
label_ids=label_ids or persona.label_ids,
|
||||
featured=featured if featured is not None else persona.featured,
|
||||
)
|
||||
|
||||
response = requests.patch(
|
||||
@@ -137,12 +155,16 @@ class PersonaManager:
|
||||
id=updated_persona_data["id"],
|
||||
name=updated_persona_data["name"],
|
||||
description=updated_persona_data["description"],
|
||||
num_chunks=updated_persona_data["num_chunks"],
|
||||
llm_relevance_filter=updated_persona_data["llm_relevance_filter"],
|
||||
is_public=updated_persona_data["is_public"],
|
||||
llm_filter_extraction=updated_persona_data["llm_filter_extraction"],
|
||||
recency_bias=recency_bias or persona.recency_bias,
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
datetime_aware=datetime_aware,
|
||||
document_set_ids=[ds["id"] for ds in updated_persona_data["document_sets"]],
|
||||
tool_ids=[t["id"] for t in updated_persona_data["tools"]],
|
||||
document_set_ids=updated_persona_data["document_sets"],
|
||||
tool_ids=updated_persona_data["tools"],
|
||||
llm_model_provider_override=updated_persona_data[
|
||||
"llm_model_provider_override"
|
||||
],
|
||||
@@ -151,8 +173,7 @@ class PersonaManager:
|
||||
],
|
||||
users=[user["email"] for user in updated_persona_data["users"]],
|
||||
groups=updated_persona_data["groups"],
|
||||
label_ids=[label["id"] for label in updated_persona_data["labels"]],
|
||||
featured=updated_persona_data["featured"],
|
||||
label_ids=updated_persona_data["labels"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -201,13 +222,32 @@ class PersonaManager:
|
||||
fetched_persona.description,
|
||||
)
|
||||
)
|
||||
if fetched_persona.num_chunks != persona.num_chunks:
|
||||
mismatches.append(
|
||||
("num_chunks", persona.num_chunks, fetched_persona.num_chunks)
|
||||
)
|
||||
if fetched_persona.llm_relevance_filter != persona.llm_relevance_filter:
|
||||
mismatches.append(
|
||||
(
|
||||
"llm_relevance_filter",
|
||||
persona.llm_relevance_filter,
|
||||
fetched_persona.llm_relevance_filter,
|
||||
)
|
||||
)
|
||||
if fetched_persona.is_public != persona.is_public:
|
||||
mismatches.append(
|
||||
("is_public", persona.is_public, fetched_persona.is_public)
|
||||
)
|
||||
if fetched_persona.featured != persona.featured:
|
||||
if (
|
||||
fetched_persona.llm_filter_extraction
|
||||
!= persona.llm_filter_extraction
|
||||
):
|
||||
mismatches.append(
|
||||
("featured", persona.featured, fetched_persona.featured)
|
||||
(
|
||||
"llm_filter_extraction",
|
||||
persona.llm_filter_extraction,
|
||||
fetched_persona.llm_filter_extraction,
|
||||
)
|
||||
)
|
||||
if (
|
||||
fetched_persona.llm_model_provider_override
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
|
||||
|
||||
class ScimClient:
|
||||
"""HTTP client for making authenticated SCIM v2 requests."""
|
||||
|
||||
@staticmethod
|
||||
def _headers(raw_token: str) -> dict[str, str]:
|
||||
return {
|
||||
**GENERAL_HEADERS,
|
||||
"Authorization": f"Bearer {raw_token}",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get(path: str, raw_token: str) -> requests.Response:
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=ScimClient._headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def post(path: str, raw_token: str, json: dict) -> requests.Response:
|
||||
return requests.post(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
json=json,
|
||||
headers=ScimClient._headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def put(path: str, raw_token: str, json: dict) -> requests.Response:
|
||||
return requests.put(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
json=json,
|
||||
headers=ScimClient._headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def patch(path: str, raw_token: str, json: dict) -> requests.Response:
|
||||
return requests.patch(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
json=json,
|
||||
headers=ScimClient._headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def delete(path: str, raw_token: str) -> requests.Response:
|
||||
return requests.delete(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=ScimClient._headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_no_auth(path: str) -> requests.Response:
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=GENERAL_HEADERS,
|
||||
timeout=60,
|
||||
)
|
||||
@@ -1,6 +1,7 @@
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestScimToken
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -50,3 +51,29 @@ class ScimTokenManager:
|
||||
created_at=data["created_at"],
|
||||
last_used_at=data.get("last_used_at"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_scim_headers(raw_token: str) -> dict[str, str]:
|
||||
return {
|
||||
**GENERAL_HEADERS,
|
||||
"Authorization": f"Bearer {raw_token}",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def scim_get(
|
||||
path: str,
|
||||
raw_token: str,
|
||||
) -> requests.Response:
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=ScimTokenManager.get_scim_headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def scim_get_no_auth(path: str) -> requests.Response:
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=GENERAL_HEADERS,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ from pydantic import Field
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.db.enums import AccessType
|
||||
@@ -161,7 +162,11 @@ class DATestPersona(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str
|
||||
num_chunks: float
|
||||
llm_relevance_filter: bool
|
||||
is_public: bool
|
||||
llm_filter_extraction: bool
|
||||
recency_bias: RecencyBiasSetting
|
||||
document_set_ids: list[int]
|
||||
tool_ids: list[int]
|
||||
llm_model_provider_override: str | None
|
||||
@@ -169,7 +174,6 @@ class DATestPersona(BaseModel):
|
||||
users: list[str]
|
||||
groups: list[int]
|
||||
label_ids: list[int]
|
||||
featured: bool = False
|
||||
|
||||
# Embedded prompt fields (no longer separate prompt_ids)
|
||||
system_prompt: str | None = None
|
||||
|
||||
@@ -8,6 +8,7 @@ from collections.abc import Generator
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.discord_bot import bulk_create_channel_configs
|
||||
from onyx.db.discord_bot import create_discord_bot_config
|
||||
from onyx.db.discord_bot import create_guild_config
|
||||
@@ -35,8 +36,14 @@ def _create_test_persona(db_session: Session, persona_id: int, name: str) -> Per
|
||||
id=persona_id,
|
||||
name=name,
|
||||
description="Test persona for Discord bot tests",
|
||||
num_chunks=5.0,
|
||||
chunks_above=1,
|
||||
chunks_below=1,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.FAVOR_RECENT,
|
||||
is_visible=True,
|
||||
featured=False,
|
||||
is_default_persona=False,
|
||||
deleted=False,
|
||||
builtin_persona=False,
|
||||
)
|
||||
|
||||
@@ -414,24 +414,6 @@ def test_mock_connector_checkpoint_recovery(
|
||||
)
|
||||
assert finished_index_attempt.status == IndexingStatus.FAILED
|
||||
|
||||
# Pause the connector immediately to prevent check_for_indexing from
|
||||
# creating automatic retry attempts while we reset the mock server.
|
||||
# Without this, the INITIAL_INDEXING status causes immediate retries
|
||||
# that would consume (or fail against) the mock server before we can
|
||||
# set up the recovery behavior.
|
||||
CCPairManager.pause_cc_pair(cc_pair, user_performing_action=admin_user)
|
||||
|
||||
# Collect all index attempt IDs created so far (the initial one plus
|
||||
# any automatic retries that may have started before the pause took effect).
|
||||
all_prior_attempt_ids: list[int] = []
|
||||
index_attempts_page = IndexAttemptManager.get_index_attempt_page(
|
||||
cc_pair_id=cc_pair.id,
|
||||
page=0,
|
||||
page_size=100,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
all_prior_attempt_ids = [ia.id for ia in index_attempts_page.items]
|
||||
|
||||
# Verify initial state: both docs should be indexed
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
documents = DocumentManager.fetch_documents_for_cc_pair(
|
||||
@@ -483,14 +465,17 @@ def test_mock_connector_checkpoint_recovery(
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Set the manual indexing trigger, then unpause to allow the recovery run.
|
||||
# After the failure, the connector is in repeated error state and paused.
|
||||
# Set the manual indexing trigger first (while paused), then unpause.
|
||||
# This ensures the trigger is set before CHECK_FOR_INDEXING runs, which will
|
||||
# prevent the connector from being re-paused when repeated error state is detected.
|
||||
CCPairManager.run_once(
|
||||
cc_pair, from_beginning=False, user_performing_action=admin_user
|
||||
)
|
||||
CCPairManager.unpause_cc_pair(cc_pair, user_performing_action=admin_user)
|
||||
recovery_index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempts_to_ignore=all_prior_attempt_ids,
|
||||
index_attempts_to_ignore=[initial_index_attempt.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
IndexAttemptManager.wait_for_index_attempt_completion(
|
||||
|
||||
@@ -130,8 +130,8 @@ def test_repeated_error_state_detection_and_recovery(
|
||||
# )
|
||||
break
|
||||
|
||||
if time.monotonic() - start_time > 90:
|
||||
assert False, "CC pair did not enter repeated error state within 90 seconds"
|
||||
if time.monotonic() - start_time > 30:
|
||||
assert False, "CC pair did not enter repeated error state within 30 seconds"
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import pytest
|
||||
import requests
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.llm import can_user_access_llm_provider
|
||||
from onyx.db.llm import fetch_user_group_ids
|
||||
@@ -77,6 +78,12 @@ def _create_persona(
|
||||
persona = Persona(
|
||||
name=name,
|
||||
description=f"{name} description",
|
||||
num_chunks=5,
|
||||
chunks_above=2,
|
||||
chunks_below=2,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=True,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
llm_model_provider_override=provider_name,
|
||||
llm_model_version_override="gpt-4o-mini",
|
||||
system_prompt="System prompt",
|
||||
@@ -243,116 +250,6 @@ def test_can_user_access_llm_provider_or_logic(
|
||||
)
|
||||
|
||||
|
||||
def test_public_provider_with_persona_restrictions(
|
||||
users: tuple[DATestUser, DATestUser],
|
||||
) -> None:
|
||||
"""Public providers should still enforce persona restrictions.
|
||||
|
||||
Regression test for the bug where is_public=True caused
|
||||
can_user_access_llm_provider() to return True immediately,
|
||||
bypassing persona whitelist checks entirely.
|
||||
"""
|
||||
admin_user, _basic_user = users
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Public provider with persona restrictions
|
||||
public_restricted = _create_llm_provider(
|
||||
db_session,
|
||||
name="public-persona-restricted",
|
||||
default_model_name="gpt-4o",
|
||||
is_public=True,
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
whitelisted_persona = _create_persona(
|
||||
db_session,
|
||||
name="whitelisted-persona",
|
||||
provider_name=public_restricted.name,
|
||||
)
|
||||
non_whitelisted_persona = _create_persona(
|
||||
db_session,
|
||||
name="non-whitelisted-persona",
|
||||
provider_name=public_restricted.name,
|
||||
)
|
||||
|
||||
# Only whitelist one persona
|
||||
db_session.add(
|
||||
LLMProvider__Persona(
|
||||
llm_provider_id=public_restricted.id,
|
||||
persona_id=whitelisted_persona.id,
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
db_session.refresh(public_restricted)
|
||||
|
||||
admin_model = db_session.get(User, admin_user.id)
|
||||
assert admin_model is not None
|
||||
admin_group_ids = fetch_user_group_ids(db_session, admin_model)
|
||||
|
||||
# Whitelisted persona — should be allowed
|
||||
assert can_user_access_llm_provider(
|
||||
public_restricted,
|
||||
admin_group_ids,
|
||||
whitelisted_persona,
|
||||
)
|
||||
|
||||
# Non-whitelisted persona — should be denied despite is_public=True
|
||||
assert not can_user_access_llm_provider(
|
||||
public_restricted,
|
||||
admin_group_ids,
|
||||
non_whitelisted_persona,
|
||||
)
|
||||
|
||||
# No persona context (e.g. global provider list) — should be denied
|
||||
# because provider has persona restrictions set
|
||||
assert not can_user_access_llm_provider(
|
||||
public_restricted,
|
||||
admin_group_ids,
|
||||
persona=None,
|
||||
)
|
||||
|
||||
|
||||
def test_public_provider_without_persona_restrictions(
|
||||
users: tuple[DATestUser, DATestUser],
|
||||
) -> None:
|
||||
"""Public providers with no persona restrictions remain accessible to all."""
|
||||
admin_user, basic_user = users
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
public_unrestricted = _create_llm_provider(
|
||||
db_session,
|
||||
name="public-unrestricted",
|
||||
default_model_name="gpt-4o",
|
||||
is_public=True,
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
any_persona = _create_persona(
|
||||
db_session,
|
||||
name="any-persona",
|
||||
provider_name=public_unrestricted.name,
|
||||
)
|
||||
|
||||
admin_model = db_session.get(User, admin_user.id)
|
||||
basic_model = db_session.get(User, basic_user.id)
|
||||
assert admin_model is not None
|
||||
assert basic_model is not None
|
||||
|
||||
admin_group_ids = fetch_user_group_ids(db_session, admin_model)
|
||||
basic_group_ids = fetch_user_group_ids(db_session, basic_model)
|
||||
|
||||
# Any user, any persona — all allowed
|
||||
assert can_user_access_llm_provider(
|
||||
public_unrestricted, admin_group_ids, any_persona
|
||||
)
|
||||
assert can_user_access_llm_provider(
|
||||
public_unrestricted, basic_group_ids, any_persona
|
||||
)
|
||||
assert can_user_access_llm_provider(
|
||||
public_unrestricted, admin_group_ids, persona=None
|
||||
)
|
||||
|
||||
|
||||
def test_get_llm_for_persona_falls_back_when_access_denied(
|
||||
users: tuple[DATestUser, DATestUser],
|
||||
) -> None:
|
||||
|
||||
@@ -25,7 +25,7 @@ def test_cold_startup_default_assistant() -> None:
|
||||
result = db_session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT id, name, builtin_persona, featured, deleted
|
||||
SELECT id, name, builtin_persona, is_default_persona, deleted
|
||||
FROM persona
|
||||
WHERE builtin_persona = true
|
||||
ORDER BY id
|
||||
@@ -40,7 +40,7 @@ def test_cold_startup_default_assistant() -> None:
|
||||
assert default[0] == 0, "Default assistant should have ID 0"
|
||||
assert default[1] == "Assistant", "Should be named 'Assistant'"
|
||||
assert default[2] is True, "Should be builtin"
|
||||
assert default[3] is True, "Should be featured"
|
||||
assert default[3] is True, "Should be default"
|
||||
assert default[4] is False, "Should not be deleted"
|
||||
|
||||
# Check tools are properly associated
|
||||
|
||||
@@ -195,7 +195,11 @@ def _base_persona_body(**overrides: object) -> dict:
|
||||
"description": "test",
|
||||
"system_prompt": "test",
|
||||
"task_prompt": "",
|
||||
"num_chunks": 5,
|
||||
"is_public": True,
|
||||
"recency_bias": "auto",
|
||||
"llm_filter_extraction": False,
|
||||
"llm_relevance_filter": False,
|
||||
"datetime_aware": False,
|
||||
"document_set_ids": [],
|
||||
"tool_ids": [],
|
||||
|
||||
@@ -1,160 +0,0 @@
|
||||
"""Integration test for the full user-file lifecycle in no-vector-DB mode.
|
||||
|
||||
Covers: upload → COMPLETED → unlink from project → delete → gone.
|
||||
|
||||
The entire lifecycle is handled by FastAPI BackgroundTasks (no Celery workers
|
||||
needed). The conftest-level ``pytestmark`` ensures these tests are skipped
|
||||
when the server is running with vector DB enabled.
|
||||
"""
|
||||
|
||||
import time
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.project import ProjectManager
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
POLL_INTERVAL_SECONDS = 1
|
||||
POLL_TIMEOUT_SECONDS = 30
|
||||
|
||||
|
||||
def _poll_file_status(
|
||||
file_id: UUID,
|
||||
user: DATestUser,
|
||||
target_status: UserFileStatus,
|
||||
timeout: int = POLL_TIMEOUT_SECONDS,
|
||||
) -> None:
|
||||
"""Poll GET /user/projects/file/{file_id} until the file reaches *target_status*."""
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
resp = requests.get(
|
||||
f"{API_SERVER_URL}/user/projects/file/{file_id}",
|
||||
headers=user.headers,
|
||||
)
|
||||
if resp.ok:
|
||||
status = resp.json().get("status")
|
||||
if status == target_status.value:
|
||||
return
|
||||
time.sleep(POLL_INTERVAL_SECONDS)
|
||||
raise TimeoutError(
|
||||
f"File {file_id} did not reach {target_status.value} within {timeout}s"
|
||||
)
|
||||
|
||||
|
||||
def _file_is_gone(file_id: UUID, user: DATestUser, timeout: int = 15) -> None:
|
||||
"""Poll until GET /user/projects/file/{file_id} returns 404."""
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
resp = requests.get(
|
||||
f"{API_SERVER_URL}/user/projects/file/{file_id}",
|
||||
headers=user.headers,
|
||||
)
|
||||
if resp.status_code == 404:
|
||||
return
|
||||
time.sleep(POLL_INTERVAL_SECONDS)
|
||||
raise TimeoutError(
|
||||
f"File {file_id} still accessible after {timeout}s (expected 404)"
|
||||
)
|
||||
|
||||
|
||||
def test_file_upload_process_delete_lifecycle(
|
||||
reset: None, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Full lifecycle: upload → COMPLETED → unlink → delete → 404.
|
||||
|
||||
Validates that the API server handles all background processing
|
||||
(via FastAPI BackgroundTasks) without any Celery workers running.
|
||||
"""
|
||||
project = ProjectManager.create(
|
||||
name="lifecycle-test", user_performing_action=admin_user
|
||||
)
|
||||
|
||||
file_content = b"Integration test file content for lifecycle verification."
|
||||
upload_result = ProjectManager.upload_files(
|
||||
project_id=project.id,
|
||||
files=[("lifecycle.txt", file_content)],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert upload_result.user_files, "Expected at least one file in upload response"
|
||||
|
||||
user_file = upload_result.user_files[0]
|
||||
file_id = user_file.id
|
||||
|
||||
_poll_file_status(file_id, admin_user, UserFileStatus.COMPLETED)
|
||||
|
||||
project_files = ProjectManager.get_project_files(project.id, admin_user)
|
||||
assert any(
|
||||
f.id == file_id for f in project_files
|
||||
), "File should be listed in project files after processing"
|
||||
|
||||
# Unlink the file from the project so the delete endpoint will proceed
|
||||
unlink_resp = requests.delete(
|
||||
f"{API_SERVER_URL}/user/projects/{project.id}/files/{file_id}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert (
|
||||
unlink_resp.status_code == 204
|
||||
), f"Expected 204 on unlink, got {unlink_resp.status_code}: {unlink_resp.text}"
|
||||
|
||||
delete_resp = requests.delete(
|
||||
f"{API_SERVER_URL}/user/projects/file/{file_id}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert (
|
||||
delete_resp.ok
|
||||
), f"Delete request failed: {delete_resp.status_code} {delete_resp.text}"
|
||||
body = delete_resp.json()
|
||||
assert (
|
||||
body["has_associations"] is False
|
||||
), f"File still has associations after unlink: {body}"
|
||||
|
||||
_file_is_gone(file_id, admin_user)
|
||||
|
||||
project_files_after = ProjectManager.get_project_files(project.id, admin_user)
|
||||
assert not any(
|
||||
f.id == file_id for f in project_files_after
|
||||
), "Deleted file should not appear in project files"
|
||||
|
||||
|
||||
def test_delete_blocked_while_associated(
|
||||
reset: None, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Deleting a file that still belongs to a project should return
|
||||
has_associations=True without actually deleting the file."""
|
||||
project = ProjectManager.create(
|
||||
name="assoc-test", user_performing_action=admin_user
|
||||
)
|
||||
|
||||
upload_result = ProjectManager.upload_files(
|
||||
project_id=project.id,
|
||||
files=[("assoc.txt", b"associated file content")],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
file_id = upload_result.user_files[0].id
|
||||
|
||||
_poll_file_status(file_id, admin_user, UserFileStatus.COMPLETED)
|
||||
|
||||
# Attempt to delete while still linked
|
||||
delete_resp = requests.delete(
|
||||
f"{API_SERVER_URL}/user/projects/file/{file_id}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert delete_resp.ok
|
||||
body = delete_resp.json()
|
||||
assert body["has_associations"] is True, "Should report existing associations"
|
||||
assert project.name in body["project_names"]
|
||||
|
||||
# File should still be accessible
|
||||
get_resp = requests.get(
|
||||
f"{API_SERVER_URL}/user/projects/file/{file_id}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert get_resp.status_code == 200, "File should still exist after blocked delete"
|
||||
@@ -40,6 +40,7 @@ def test_persona_create_update_share_delete(
|
||||
expected_persona,
|
||||
name=f"updated-{expected_persona.name}",
|
||||
description=f"updated-{expected_persona.description}",
|
||||
num_chunks=expected_persona.num_chunks + 1,
|
||||
is_public=False,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
@@ -31,7 +31,11 @@ def test_update_persona_with_null_label_ids_preserves_labels(
|
||||
task_prompt=persona.task_prompt or "",
|
||||
datetime_aware=persona.datetime_aware,
|
||||
document_set_ids=persona.document_set_ids,
|
||||
num_chunks=persona.num_chunks,
|
||||
is_public=persona.is_public,
|
||||
recency_bias=persona.recency_bias,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
tool_ids=persona.tool_ids,
|
||||
|
||||
@@ -31,8 +31,9 @@ def test_unified_assistant(reset: None, admin_user: DATestUser) -> None: # noqa
|
||||
"search, web browsing, and image generation"
|
||||
in unified_assistant.description.lower()
|
||||
)
|
||||
assert unified_assistant.featured is True
|
||||
assert unified_assistant.is_default_persona is True
|
||||
assert unified_assistant.is_visible is True
|
||||
assert unified_assistant.num_chunks == 25
|
||||
|
||||
# Verify tools
|
||||
tools = unified_assistant.tools
|
||||
|
||||
@@ -1,552 +0,0 @@
|
||||
"""Integration tests for SCIM group provisioning endpoints.
|
||||
|
||||
Covers the full group lifecycle as driven by an IdP (Okta / Azure AD):
|
||||
1. Create a group via POST /Groups
|
||||
2. Retrieve a group via GET /Groups/{id}
|
||||
3. List, filter, and paginate groups via GET /Groups
|
||||
4. Replace a group via PUT /Groups/{id}
|
||||
5. Patch a group (add/remove members, rename) via PATCH /Groups/{id}
|
||||
6. Delete a group via DELETE /Groups/{id}
|
||||
7. Error cases: duplicate name, not-found, invalid member IDs
|
||||
|
||||
All tests are parameterized across IdP request styles (Okta sends lowercase
|
||||
PATCH ops; Entra sends capitalized ops like ``"Replace"``). The server
|
||||
normalizes both — these tests verify that.
|
||||
|
||||
Auth tests live in test_scim_tokens.py.
|
||||
User lifecycle tests live in test_scim_users.py.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from tests.integration.common_utils.managers.scim_client import ScimClient
|
||||
from tests.integration.common_utils.managers.scim_token import ScimTokenManager
|
||||
|
||||
|
||||
SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group"
|
||||
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
|
||||
SCIM_PATCH_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:PatchOp"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=["okta", "entra"])
|
||||
def idp_style(request: pytest.FixtureRequest) -> str:
|
||||
"""Parameterized IdP style — runs every test with both Okta and Entra request formats."""
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def scim_token(idp_style: str) -> str:
|
||||
"""Create a single SCIM token shared across all tests in this module.
|
||||
|
||||
Creating a new token revokes the previous one, so we create exactly once
|
||||
per IdP-style run and reuse. Uses UserManager directly to avoid
|
||||
fixture-scope conflicts with the function-scoped admin_user fixture.
|
||||
"""
|
||||
from tests.integration.common_utils.constants import ADMIN_USER_NAME
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
try:
|
||||
admin = UserManager.create(name=ADMIN_USER_NAME)
|
||||
except Exception:
|
||||
admin = UserManager.login_as_user(
|
||||
DATestUser(
|
||||
id="",
|
||||
email=build_email(ADMIN_USER_NAME),
|
||||
password=DEFAULT_PASSWORD,
|
||||
headers=GENERAL_HEADERS,
|
||||
role=UserRole.ADMIN,
|
||||
is_active=True,
|
||||
)
|
||||
)
|
||||
|
||||
token = ScimTokenManager.create(
|
||||
name=f"scim-group-tests-{idp_style}",
|
||||
user_performing_action=admin,
|
||||
).raw_token
|
||||
assert token is not None
|
||||
return token
|
||||
|
||||
|
||||
def _make_group_resource(
|
||||
display_name: str,
|
||||
external_id: str | None = None,
|
||||
members: list[dict] | None = None,
|
||||
) -> dict:
|
||||
"""Build a minimal SCIM GroupResource payload."""
|
||||
resource: dict = {
|
||||
"schemas": [SCIM_GROUP_SCHEMA],
|
||||
"displayName": display_name,
|
||||
}
|
||||
if external_id is not None:
|
||||
resource["externalId"] = external_id
|
||||
if members is not None:
|
||||
resource["members"] = members
|
||||
return resource
|
||||
|
||||
|
||||
def _make_user_resource(email: str, external_id: str) -> dict:
|
||||
"""Build a minimal SCIM UserResource payload for member creation."""
|
||||
return {
|
||||
"schemas": [SCIM_USER_SCHEMA],
|
||||
"userName": email,
|
||||
"externalId": external_id,
|
||||
"name": {"givenName": "Test", "familyName": "User"},
|
||||
"active": True,
|
||||
}
|
||||
|
||||
|
||||
def _make_patch_request(operations: list[dict], idp_style: str = "okta") -> dict:
|
||||
"""Build a SCIM PatchOp payload, applying IdP-specific operation casing.
|
||||
|
||||
Entra sends capitalized operations (e.g. ``"Replace"`` instead of
|
||||
``"replace"``). The server's ``normalize_operation`` validator lowercases
|
||||
them — these tests verify that both casings are accepted.
|
||||
"""
|
||||
cased_operations = []
|
||||
for operation in operations:
|
||||
cased = dict(operation)
|
||||
if idp_style == "entra":
|
||||
cased["op"] = operation["op"].capitalize()
|
||||
cased_operations.append(cased)
|
||||
return {
|
||||
"schemas": [SCIM_PATCH_SCHEMA],
|
||||
"Operations": cased_operations,
|
||||
}
|
||||
|
||||
|
||||
def _create_scim_user(token: str, email: str, external_id: str) -> requests.Response:
|
||||
return ScimClient.post(
|
||||
"/Users", token, json=_make_user_resource(email, external_id)
|
||||
)
|
||||
|
||||
|
||||
def _create_scim_group(
|
||||
token: str,
|
||||
display_name: str,
|
||||
external_id: str | None = None,
|
||||
members: list[dict] | None = None,
|
||||
) -> requests.Response:
|
||||
return ScimClient.post(
|
||||
"/Groups",
|
||||
token,
|
||||
json=_make_group_resource(display_name, external_id, members),
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle: create → get → list → replace → patch → delete
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_create_group(scim_token: str, idp_style: str) -> None:
|
||||
"""POST /Groups creates a group and returns 201."""
|
||||
name = f"Engineering {idp_style}"
|
||||
resp = _create_scim_group(scim_token, name, external_id=f"ext-eng-{idp_style}")
|
||||
assert resp.status_code == 201
|
||||
|
||||
body = resp.json()
|
||||
assert body["displayName"] == name
|
||||
assert body["externalId"] == f"ext-eng-{idp_style}"
|
||||
assert body["id"] # integer ID assigned by server
|
||||
assert body["meta"]["resourceType"] == "Group"
|
||||
|
||||
|
||||
def test_create_group_with_members(scim_token: str, idp_style: str) -> None:
|
||||
"""POST /Groups with members populates the member list."""
|
||||
user = _create_scim_user(
|
||||
scim_token, f"grp_member1_{idp_style}@example.com", f"ext-gm-{idp_style}"
|
||||
).json()
|
||||
|
||||
resp = _create_scim_group(
|
||||
scim_token,
|
||||
f"Backend Team {idp_style}",
|
||||
external_id=f"ext-backend-{idp_style}",
|
||||
members=[{"value": user["id"]}],
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
|
||||
body = resp.json()
|
||||
member_ids = [m["value"] for m in body["members"]]
|
||||
assert user["id"] in member_ids
|
||||
|
||||
|
||||
def test_get_group(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Groups/{id} returns the group resource including members."""
|
||||
user = _create_scim_user(
|
||||
scim_token, f"grp_get_m_{idp_style}@example.com", f"ext-ggm-{idp_style}"
|
||||
).json()
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Frontend Team {idp_style}",
|
||||
external_id=f"ext-fe-{idp_style}",
|
||||
members=[{"value": user["id"]}],
|
||||
).json()
|
||||
|
||||
resp = ScimClient.get(f"/Groups/{created['id']}", scim_token)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["id"] == created["id"]
|
||||
assert body["displayName"] == f"Frontend Team {idp_style}"
|
||||
assert body["externalId"] == f"ext-fe-{idp_style}"
|
||||
member_ids = [m["value"] for m in body["members"]]
|
||||
assert user["id"] in member_ids
|
||||
|
||||
|
||||
def test_list_groups(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Groups returns a ListResponse containing provisioned groups."""
|
||||
name = f"DevOps Team {idp_style}"
|
||||
_create_scim_group(scim_token, name, external_id=f"ext-devops-{idp_style}")
|
||||
|
||||
resp = ScimClient.get("/Groups", scim_token)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["totalResults"] >= 1
|
||||
names = [r["displayName"] for r in body["Resources"]]
|
||||
assert name in names
|
||||
|
||||
|
||||
def test_list_groups_pagination(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Groups with startIndex and count returns correct pagination."""
|
||||
_create_scim_group(
|
||||
scim_token, f"Page Group A {idp_style}", external_id=f"ext-page-a-{idp_style}"
|
||||
)
|
||||
_create_scim_group(
|
||||
scim_token, f"Page Group B {idp_style}", external_id=f"ext-page-b-{idp_style}"
|
||||
)
|
||||
|
||||
resp = ScimClient.get("/Groups?startIndex=1&count=1", scim_token)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["startIndex"] == 1
|
||||
assert body["itemsPerPage"] == 1
|
||||
assert body["totalResults"] >= 2
|
||||
assert len(body["Resources"]) == 1
|
||||
|
||||
|
||||
def test_filter_groups_by_display_name(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Groups?filter=displayName eq '...' returns only matching groups."""
|
||||
name = f"Unique QA Team {idp_style}"
|
||||
_create_scim_group(scim_token, name, external_id=f"ext-qa-filter-{idp_style}")
|
||||
|
||||
resp = ScimClient.get(f'/Groups?filter=displayName eq "{name}"', scim_token)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["totalResults"] == 1
|
||||
assert body["Resources"][0]["displayName"] == name
|
||||
|
||||
|
||||
def test_filter_groups_by_external_id(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Groups?filter=externalId eq '...' returns the matching group."""
|
||||
ext_id = f"ext-unique-group-id-{idp_style}"
|
||||
_create_scim_group(
|
||||
scim_token, f"ExtId Filter Group {idp_style}", external_id=ext_id
|
||||
)
|
||||
|
||||
resp = ScimClient.get(f'/Groups?filter=externalId eq "{ext_id}"', scim_token)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["totalResults"] == 1
|
||||
assert body["Resources"][0]["externalId"] == ext_id
|
||||
|
||||
|
||||
def test_replace_group(scim_token: str, idp_style: str) -> None:
|
||||
"""PUT /Groups/{id} replaces the group resource."""
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Original Name {idp_style}",
|
||||
external_id=f"ext-replace-g-{idp_style}",
|
||||
).json()
|
||||
|
||||
user = _create_scim_user(
|
||||
scim_token, f"grp_replace_m_{idp_style}@example.com", f"ext-grm-{idp_style}"
|
||||
).json()
|
||||
|
||||
updated_resource = _make_group_resource(
|
||||
display_name=f"Renamed Group {idp_style}",
|
||||
external_id=f"ext-replace-g-{idp_style}",
|
||||
members=[{"value": user["id"]}],
|
||||
)
|
||||
resp = ScimClient.put(f"/Groups/{created['id']}", scim_token, json=updated_resource)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["displayName"] == f"Renamed Group {idp_style}"
|
||||
member_ids = [m["value"] for m in body["members"]]
|
||||
assert user["id"] in member_ids
|
||||
|
||||
|
||||
def test_replace_group_clears_members(scim_token: str, idp_style: str) -> None:
|
||||
"""PUT /Groups/{id} with empty members removes all memberships."""
|
||||
user = _create_scim_user(
|
||||
scim_token, f"grp_clear_m_{idp_style}@example.com", f"ext-gcm-{idp_style}"
|
||||
).json()
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Clear Members Group {idp_style}",
|
||||
external_id=f"ext-clear-g-{idp_style}",
|
||||
members=[{"value": user["id"]}],
|
||||
).json()
|
||||
|
||||
assert len(created["members"]) == 1
|
||||
|
||||
resp = ScimClient.put(
|
||||
f"/Groups/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_group_resource(
|
||||
f"Clear Members Group {idp_style}", f"ext-clear-g-{idp_style}", members=[]
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["members"] == []
|
||||
|
||||
|
||||
def test_patch_add_member(scim_token: str, idp_style: str) -> None:
|
||||
"""PATCH /Groups/{id} with op=add adds a member."""
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Patch Add Group {idp_style}",
|
||||
external_id=f"ext-patch-add-{idp_style}",
|
||||
).json()
|
||||
user = _create_scim_user(
|
||||
scim_token, f"grp_patch_add_{idp_style}@example.com", f"ext-gpa-{idp_style}"
|
||||
).json()
|
||||
|
||||
resp = ScimClient.patch(
|
||||
f"/Groups/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[{"op": "add", "path": "members", "value": [{"value": user["id"]}]}],
|
||||
idp_style,
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
member_ids = [m["value"] for m in resp.json()["members"]]
|
||||
assert user["id"] in member_ids
|
||||
|
||||
|
||||
def test_patch_remove_member(scim_token: str, idp_style: str) -> None:
|
||||
"""PATCH /Groups/{id} with op=remove removes a specific member."""
|
||||
user = _create_scim_user(
|
||||
scim_token, f"grp_patch_rm_{idp_style}@example.com", f"ext-gpr-{idp_style}"
|
||||
).json()
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Patch Remove Group {idp_style}",
|
||||
external_id=f"ext-patch-rm-{idp_style}",
|
||||
members=[{"value": user["id"]}],
|
||||
).json()
|
||||
assert len(created["members"]) == 1
|
||||
|
||||
resp = ScimClient.patch(
|
||||
f"/Groups/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[
|
||||
{
|
||||
"op": "remove",
|
||||
"path": f'members[value eq "{user["id"]}"]',
|
||||
}
|
||||
],
|
||||
idp_style,
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["members"] == []
|
||||
|
||||
|
||||
def test_patch_replace_members(scim_token: str, idp_style: str) -> None:
|
||||
"""PATCH /Groups/{id} with op=replace on members swaps the entire list."""
|
||||
user_a = _create_scim_user(
|
||||
scim_token, f"grp_repl_a_{idp_style}@example.com", f"ext-gra-{idp_style}"
|
||||
).json()
|
||||
user_b = _create_scim_user(
|
||||
scim_token, f"grp_repl_b_{idp_style}@example.com", f"ext-grb-{idp_style}"
|
||||
).json()
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Patch Replace Group {idp_style}",
|
||||
external_id=f"ext-patch-repl-{idp_style}",
|
||||
members=[{"value": user_a["id"]}],
|
||||
).json()
|
||||
|
||||
# Replace member list: swap A for B
|
||||
resp = ScimClient.patch(
|
||||
f"/Groups/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "members",
|
||||
"value": [{"value": user_b["id"]}],
|
||||
}
|
||||
],
|
||||
idp_style,
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
member_ids = [m["value"] for m in resp.json()["members"]]
|
||||
assert user_b["id"] in member_ids
|
||||
assert user_a["id"] not in member_ids
|
||||
|
||||
|
||||
def test_patch_rename_group(scim_token: str, idp_style: str) -> None:
|
||||
"""PATCH /Groups/{id} with op=replace on displayName renames the group."""
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Old Group Name {idp_style}",
|
||||
external_id=f"ext-rename-g-{idp_style}",
|
||||
).json()
|
||||
|
||||
new_name = f"New Group Name {idp_style}"
|
||||
resp = ScimClient.patch(
|
||||
f"/Groups/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[{"op": "replace", "path": "displayName", "value": new_name}],
|
||||
idp_style,
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["displayName"] == new_name
|
||||
|
||||
# Confirm via GET
|
||||
get_resp = ScimClient.get(f"/Groups/{created['id']}", scim_token)
|
||||
assert get_resp.json()["displayName"] == new_name
|
||||
|
||||
|
||||
def test_delete_group(scim_token: str, idp_style: str) -> None:
|
||||
"""DELETE /Groups/{id} removes the group."""
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Delete Me Group {idp_style}",
|
||||
external_id=f"ext-del-g-{idp_style}",
|
||||
).json()
|
||||
|
||||
resp = ScimClient.delete(f"/Groups/{created['id']}", scim_token)
|
||||
assert resp.status_code == 204
|
||||
|
||||
# Second DELETE returns 404 (group hard-deleted)
|
||||
resp2 = ScimClient.delete(f"/Groups/{created['id']}", scim_token)
|
||||
assert resp2.status_code == 404
|
||||
|
||||
|
||||
def test_delete_group_preserves_members(scim_token: str, idp_style: str) -> None:
|
||||
"""DELETE /Groups/{id} removes memberships but does not deactivate users."""
|
||||
user = _create_scim_user(
|
||||
scim_token, f"grp_del_member_{idp_style}@example.com", f"ext-gdm-{idp_style}"
|
||||
).json()
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Delete With Members {idp_style}",
|
||||
external_id=f"ext-del-wm-{idp_style}",
|
||||
members=[{"value": user["id"]}],
|
||||
).json()
|
||||
|
||||
resp = ScimClient.delete(f"/Groups/{created['id']}", scim_token)
|
||||
assert resp.status_code == 204
|
||||
|
||||
# User should still be active and retrievable
|
||||
user_resp = ScimClient.get(f"/Users/{user['id']}", scim_token)
|
||||
assert user_resp.status_code == 200
|
||||
assert user_resp.json()["active"] is True
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Error cases
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_create_group_duplicate_name(scim_token: str, idp_style: str) -> None:
|
||||
"""POST /Groups with an already-taken displayName returns 409."""
|
||||
name = f"Dup Name Group {idp_style}"
|
||||
resp1 = _create_scim_group(scim_token, name, external_id=f"ext-dup-g1-{idp_style}")
|
||||
assert resp1.status_code == 201
|
||||
|
||||
resp2 = _create_scim_group(scim_token, name, external_id=f"ext-dup-g2-{idp_style}")
|
||||
assert resp2.status_code == 409
|
||||
|
||||
|
||||
def test_get_nonexistent_group(scim_token: str) -> None:
|
||||
"""GET /Groups/{bad-id} returns 404."""
|
||||
resp = ScimClient.get("/Groups/999999999", scim_token)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_create_group_with_invalid_member(scim_token: str, idp_style: str) -> None:
|
||||
"""POST /Groups with a non-existent member UUID returns 400."""
|
||||
resp = _create_scim_group(
|
||||
scim_token,
|
||||
f"Bad Member Group {idp_style}",
|
||||
external_id=f"ext-bad-m-{idp_style}",
|
||||
members=[{"value": "00000000-0000-0000-0000-000000000000"}],
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert "not found" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_patch_add_nonexistent_member(scim_token: str, idp_style: str) -> None:
|
||||
"""PATCH /Groups/{id} adding a non-existent member returns 400."""
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Patch Bad Member Group {idp_style}",
|
||||
external_id=f"ext-pbm-{idp_style}",
|
||||
).json()
|
||||
|
||||
resp = ScimClient.patch(
|
||||
f"/Groups/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[
|
||||
{
|
||||
"op": "add",
|
||||
"path": "members",
|
||||
"value": [{"value": "00000000-0000-0000-0000-000000000000"}],
|
||||
}
|
||||
],
|
||||
idp_style,
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert "not found" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_patch_add_duplicate_member_is_idempotent(
|
||||
scim_token: str, idp_style: str
|
||||
) -> None:
|
||||
"""PATCH /Groups/{id} adding an already-present member succeeds silently."""
|
||||
user = _create_scim_user(
|
||||
scim_token, f"grp_dup_add_{idp_style}@example.com", f"ext-gda-{idp_style}"
|
||||
).json()
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Idempotent Add Group {idp_style}",
|
||||
external_id=f"ext-idem-g-{idp_style}",
|
||||
members=[{"value": user["id"]}],
|
||||
).json()
|
||||
assert len(created["members"]) == 1
|
||||
|
||||
# Add same member again
|
||||
resp = ScimClient.patch(
|
||||
f"/Groups/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[{"op": "add", "path": "members", "value": [{"value": user["id"]}]}],
|
||||
idp_style,
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()["members"]) == 1 # still just one member
|
||||
@@ -15,7 +15,6 @@ import time
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.scim_client import ScimClient
|
||||
from tests.integration.common_utils.managers.scim_token import ScimTokenManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
@@ -40,7 +39,7 @@ def test_scim_token_lifecycle(admin_user: DATestUser) -> None:
|
||||
assert active == token.model_copy(update={"raw_token": None})
|
||||
|
||||
# Token works for SCIM requests
|
||||
response = ScimClient.get("/Users", token.raw_token)
|
||||
response = ScimTokenManager.scim_get("/Users", token.raw_token)
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert "Resources" in body
|
||||
@@ -55,7 +54,7 @@ def test_scim_token_rotation_revokes_previous(admin_user: DATestUser) -> None:
|
||||
)
|
||||
assert first.raw_token is not None
|
||||
|
||||
response = ScimClient.get("/Users", first.raw_token)
|
||||
response = ScimTokenManager.scim_get("/Users", first.raw_token)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Create second token — should revoke first
|
||||
@@ -70,22 +69,25 @@ def test_scim_token_rotation_revokes_previous(admin_user: DATestUser) -> None:
|
||||
assert active == second.model_copy(update={"raw_token": None})
|
||||
|
||||
# First token rejected, second works
|
||||
assert ScimClient.get("/Users", first.raw_token).status_code == 401
|
||||
assert ScimClient.get("/Users", second.raw_token).status_code == 200
|
||||
assert ScimTokenManager.scim_get("/Users", first.raw_token).status_code == 401
|
||||
assert ScimTokenManager.scim_get("/Users", second.raw_token).status_code == 200
|
||||
|
||||
|
||||
def test_scim_request_without_token_rejected(
|
||||
admin_user: DATestUser, # noqa: ARG001
|
||||
) -> None:
|
||||
"""SCIM endpoints reject requests with no Authorization header."""
|
||||
assert ScimClient.get_no_auth("/Users").status_code == 401
|
||||
assert ScimTokenManager.scim_get_no_auth("/Users").status_code == 401
|
||||
|
||||
|
||||
def test_scim_request_with_bad_token_rejected(
|
||||
admin_user: DATestUser, # noqa: ARG001
|
||||
) -> None:
|
||||
"""SCIM endpoints reject requests with an invalid token."""
|
||||
assert ScimClient.get("/Users", "onyx_scim_bogus_token_value").status_code == 401
|
||||
assert (
|
||||
ScimTokenManager.scim_get("/Users", "onyx_scim_bogus_token_value").status_code
|
||||
== 401
|
||||
)
|
||||
|
||||
|
||||
def test_non_admin_cannot_create_token(
|
||||
@@ -137,7 +139,7 @@ def test_service_discovery_no_auth_required(
|
||||
) -> None:
|
||||
"""Service discovery endpoints work without any authentication."""
|
||||
for path in ["/ServiceProviderConfig", "/ResourceTypes", "/Schemas"]:
|
||||
response = ScimClient.get_no_auth(path)
|
||||
response = ScimTokenManager.scim_get_no_auth(path)
|
||||
assert response.status_code == 200, f"{path} returned {response.status_code}"
|
||||
|
||||
|
||||
@@ -156,7 +158,7 @@ def test_last_used_at_updated_after_scim_request(
|
||||
assert active.last_used_at is None
|
||||
|
||||
# Make a SCIM request, then verify last_used_at is set
|
||||
assert ScimClient.get("/Users", token.raw_token).status_code == 200
|
||||
assert ScimTokenManager.scim_get("/Users", token.raw_token).status_code == 200
|
||||
time.sleep(0.5)
|
||||
|
||||
active_after = ScimTokenManager.get_active(user_performing_action=admin_user)
|
||||
|
||||
@@ -1,517 +0,0 @@
|
||||
"""Integration tests for SCIM user provisioning endpoints.
|
||||
|
||||
Covers the full user lifecycle as driven by an IdP (Okta / Azure AD):
|
||||
1. Create a user via POST /Users
|
||||
2. Retrieve a user via GET /Users/{id}
|
||||
3. List, filter, and paginate users via GET /Users
|
||||
4. Replace a user via PUT /Users/{id}
|
||||
5. Patch a user (deactivate/reactivate) via PATCH /Users/{id}
|
||||
6. Delete a user via DELETE /Users/{id}
|
||||
7. Error cases: missing externalId, duplicate email, not-found, seat limit
|
||||
|
||||
All tests are parameterized across IdP request styles:
|
||||
- **Okta**: lowercase PATCH ops, minimal payloads (core schema only).
|
||||
- **Entra**: capitalized ops (``"Replace"``), enterprise extension data
|
||||
(department, manager), and structured email arrays.
|
||||
|
||||
The server normalizes both — these tests verify that all IdP-specific fields
|
||||
are accepted and round-tripped correctly.
|
||||
|
||||
Auth, revoked-token, and service-discovery tests live in test_scim_tokens.py.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
import pytest
|
||||
import redis
|
||||
import requests
|
||||
|
||||
from ee.onyx.server.license.models import LicenseMetadata
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
from ee.onyx.server.license.models import PlanType
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import REDIS_DB_NUMBER
|
||||
from onyx.configs.app_configs import REDIS_HOST
|
||||
from onyx.configs.app_configs import REDIS_PORT
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from tests.integration.common_utils.managers.scim_client import ScimClient
|
||||
from tests.integration.common_utils.managers.scim_token import ScimTokenManager
|
||||
|
||||
|
||||
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
|
||||
SCIM_ENTERPRISE_USER_SCHEMA = (
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
|
||||
)
|
||||
SCIM_PATCH_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:PatchOp"
|
||||
|
||||
_LICENSE_REDIS_KEY = "public:license:metadata"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=["okta", "entra"])
|
||||
def idp_style(request: pytest.FixtureRequest) -> str:
|
||||
"""Parameterized IdP style — runs every test with both Okta and Entra request formats."""
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def scim_token(idp_style: str) -> str:
|
||||
"""Create a single SCIM token shared across all tests in this module.
|
||||
|
||||
Creating a new token revokes the previous one, so we create exactly once
|
||||
per IdP-style run and reuse. Uses UserManager directly to avoid
|
||||
fixture-scope conflicts with the function-scoped admin_user fixture.
|
||||
"""
|
||||
from tests.integration.common_utils.constants import ADMIN_USER_NAME
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
try:
|
||||
admin = UserManager.create(name=ADMIN_USER_NAME)
|
||||
except Exception:
|
||||
admin = UserManager.login_as_user(
|
||||
DATestUser(
|
||||
id="",
|
||||
email=build_email(ADMIN_USER_NAME),
|
||||
password=DEFAULT_PASSWORD,
|
||||
headers=GENERAL_HEADERS,
|
||||
role=UserRole.ADMIN,
|
||||
is_active=True,
|
||||
)
|
||||
)
|
||||
|
||||
token = ScimTokenManager.create(
|
||||
name=f"scim-user-tests-{idp_style}",
|
||||
user_performing_action=admin,
|
||||
).raw_token
|
||||
assert token is not None
|
||||
return token
|
||||
|
||||
|
||||
def _make_user_resource(
|
||||
email: str,
|
||||
external_id: str,
|
||||
given_name: str = "Test",
|
||||
family_name: str = "User",
|
||||
active: bool = True,
|
||||
idp_style: str = "okta",
|
||||
department: str | None = None,
|
||||
manager_id: str | None = None,
|
||||
) -> dict:
|
||||
"""Build a SCIM UserResource payload appropriate for the IdP style.
|
||||
|
||||
Entra sends richer payloads including enterprise extension data (department,
|
||||
manager), structured email arrays, and the enterprise schema URN. Okta sends
|
||||
minimal payloads with just core user fields.
|
||||
"""
|
||||
resource: dict = {
|
||||
"schemas": [SCIM_USER_SCHEMA],
|
||||
"userName": email,
|
||||
"externalId": external_id,
|
||||
"name": {
|
||||
"givenName": given_name,
|
||||
"familyName": family_name,
|
||||
},
|
||||
"active": active,
|
||||
}
|
||||
if idp_style == "entra":
|
||||
dept = department or "Engineering"
|
||||
mgr = manager_id or "mgr-ext-001"
|
||||
resource["schemas"].append(SCIM_ENTERPRISE_USER_SCHEMA)
|
||||
resource[SCIM_ENTERPRISE_USER_SCHEMA] = {
|
||||
"department": dept,
|
||||
"manager": {"value": mgr},
|
||||
}
|
||||
resource["emails"] = [
|
||||
{"value": email, "type": "work", "primary": True},
|
||||
]
|
||||
return resource
|
||||
|
||||
|
||||
def _make_patch_request(operations: list[dict], idp_style: str = "okta") -> dict:
|
||||
"""Build a SCIM PatchOp payload, applying IdP-specific operation casing.
|
||||
|
||||
Entra sends capitalized operations (e.g. ``"Replace"`` instead of
|
||||
``"replace"``). The server's ``normalize_operation`` validator lowercases
|
||||
them — these tests verify that both casings are accepted.
|
||||
"""
|
||||
cased_operations = []
|
||||
for operation in operations:
|
||||
cased = dict(operation)
|
||||
if idp_style == "entra":
|
||||
cased["op"] = operation["op"].capitalize()
|
||||
cased_operations.append(cased)
|
||||
return {
|
||||
"schemas": [SCIM_PATCH_SCHEMA],
|
||||
"Operations": cased_operations,
|
||||
}
|
||||
|
||||
|
||||
def _create_scim_user(
|
||||
token: str,
|
||||
email: str,
|
||||
external_id: str,
|
||||
idp_style: str = "okta",
|
||||
) -> requests.Response:
|
||||
return ScimClient.post(
|
||||
"/Users",
|
||||
token,
|
||||
json=_make_user_resource(email, external_id, idp_style=idp_style),
|
||||
)
|
||||
|
||||
|
||||
def _assert_entra_extension(
|
||||
body: dict,
|
||||
expected_department: str = "Engineering",
|
||||
expected_manager: str = "mgr-ext-001",
|
||||
) -> None:
|
||||
"""Assert that Entra enterprise extension fields round-tripped correctly."""
|
||||
assert SCIM_ENTERPRISE_USER_SCHEMA in body["schemas"]
|
||||
ext = body[SCIM_ENTERPRISE_USER_SCHEMA]
|
||||
assert ext["department"] == expected_department
|
||||
assert ext["manager"]["value"] == expected_manager
|
||||
|
||||
|
||||
def _assert_entra_emails(body: dict, expected_email: str) -> None:
|
||||
"""Assert that structured email metadata round-tripped correctly."""
|
||||
emails = body["emails"]
|
||||
assert len(emails) >= 1
|
||||
work_email = next(e for e in emails if e.get("type") == "work")
|
||||
assert work_email["value"] == expected_email
|
||||
assert work_email["primary"] is True
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle: create -> get -> list -> replace -> patch -> delete
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_create_user(scim_token: str, idp_style: str) -> None:
|
||||
"""POST /Users creates a provisioned user and returns 201."""
|
||||
email = f"scim_create_{idp_style}@example.com"
|
||||
ext_id = f"ext-create-{idp_style}"
|
||||
resp = _create_scim_user(scim_token, email, ext_id, idp_style)
|
||||
assert resp.status_code == 201
|
||||
|
||||
body = resp.json()
|
||||
assert body["userName"] == email
|
||||
assert body["externalId"] == ext_id
|
||||
assert body["active"] is True
|
||||
assert body["id"] # UUID assigned by server
|
||||
assert body["meta"]["resourceType"] == "User"
|
||||
assert body["name"]["givenName"] == "Test"
|
||||
assert body["name"]["familyName"] == "User"
|
||||
|
||||
if idp_style == "entra":
|
||||
_assert_entra_extension(body)
|
||||
_assert_entra_emails(body, email)
|
||||
|
||||
|
||||
def test_get_user(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Users/{id} returns the user resource with all stored fields."""
|
||||
email = f"scim_get_{idp_style}@example.com"
|
||||
ext_id = f"ext-get-{idp_style}"
|
||||
created = _create_scim_user(scim_token, email, ext_id, idp_style).json()
|
||||
|
||||
resp = ScimClient.get(f"/Users/{created['id']}", scim_token)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["id"] == created["id"]
|
||||
assert body["userName"] == email
|
||||
assert body["externalId"] == ext_id
|
||||
assert body["name"]["givenName"] == "Test"
|
||||
assert body["name"]["familyName"] == "User"
|
||||
|
||||
if idp_style == "entra":
|
||||
_assert_entra_extension(body)
|
||||
_assert_entra_emails(body, email)
|
||||
|
||||
|
||||
def test_list_users(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Users returns a ListResponse containing provisioned users."""
|
||||
email = f"scim_list_{idp_style}@example.com"
|
||||
_create_scim_user(scim_token, email, f"ext-list-{idp_style}", idp_style)
|
||||
|
||||
resp = ScimClient.get("/Users", scim_token)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["totalResults"] >= 1
|
||||
emails = [r["userName"] for r in body["Resources"]]
|
||||
assert email in emails
|
||||
|
||||
|
||||
def test_list_users_pagination(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Users with startIndex and count returns correct pagination."""
|
||||
_create_scim_user(
|
||||
scim_token,
|
||||
f"scim_page1_{idp_style}@example.com",
|
||||
f"ext-page-1-{idp_style}",
|
||||
idp_style,
|
||||
)
|
||||
_create_scim_user(
|
||||
scim_token,
|
||||
f"scim_page2_{idp_style}@example.com",
|
||||
f"ext-page-2-{idp_style}",
|
||||
idp_style,
|
||||
)
|
||||
|
||||
resp = ScimClient.get("/Users?startIndex=1&count=1", scim_token)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["startIndex"] == 1
|
||||
assert body["itemsPerPage"] == 1
|
||||
assert body["totalResults"] >= 2
|
||||
assert len(body["Resources"]) == 1
|
||||
|
||||
|
||||
def test_filter_users_by_username(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Users?filter=userName eq '...' returns only matching users."""
|
||||
email = f"scim_filter_{idp_style}@example.com"
|
||||
_create_scim_user(scim_token, email, f"ext-filter-{idp_style}", idp_style)
|
||||
|
||||
resp = ScimClient.get(f'/Users?filter=userName eq "{email}"', scim_token)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["totalResults"] == 1
|
||||
assert body["Resources"][0]["userName"] == email
|
||||
|
||||
|
||||
def test_replace_user(scim_token: str, idp_style: str) -> None:
|
||||
"""PUT /Users/{id} replaces the user resource including enterprise fields."""
|
||||
email = f"scim_replace_{idp_style}@example.com"
|
||||
ext_id = f"ext-replace-{idp_style}"
|
||||
created = _create_scim_user(scim_token, email, ext_id, idp_style).json()
|
||||
|
||||
updated_resource = _make_user_resource(
|
||||
email=email,
|
||||
external_id=ext_id,
|
||||
given_name="Updated",
|
||||
family_name="Name",
|
||||
idp_style=idp_style,
|
||||
department="Product",
|
||||
)
|
||||
resp = ScimClient.put(f"/Users/{created['id']}", scim_token, json=updated_resource)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["name"]["givenName"] == "Updated"
|
||||
assert body["name"]["familyName"] == "Name"
|
||||
|
||||
if idp_style == "entra":
|
||||
_assert_entra_extension(body, expected_department="Product")
|
||||
_assert_entra_emails(body, email)
|
||||
|
||||
|
||||
def test_patch_deactivate_user(scim_token: str, idp_style: str) -> None:
|
||||
"""PATCH /Users/{id} with active=false deactivates the user."""
|
||||
created = _create_scim_user(
|
||||
scim_token,
|
||||
f"scim_deactivate_{idp_style}@example.com",
|
||||
f"ext-deactivate-{idp_style}",
|
||||
idp_style,
|
||||
).json()
|
||||
assert created["active"] is True
|
||||
|
||||
resp = ScimClient.patch(
|
||||
f"/Users/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[{"op": "replace", "path": "active", "value": False}], idp_style
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["active"] is False
|
||||
|
||||
# Confirm via GET
|
||||
get_resp = ScimClient.get(f"/Users/{created['id']}", scim_token)
|
||||
assert get_resp.json()["active"] is False
|
||||
|
||||
|
||||
def test_patch_reactivate_user(scim_token: str, idp_style: str) -> None:
|
||||
"""PATCH active=true reactivates a previously deactivated user."""
|
||||
created = _create_scim_user(
|
||||
scim_token,
|
||||
f"scim_reactivate_{idp_style}@example.com",
|
||||
f"ext-reactivate-{idp_style}",
|
||||
idp_style,
|
||||
).json()
|
||||
|
||||
# Deactivate
|
||||
deactivate_resp = ScimClient.patch(
|
||||
f"/Users/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[{"op": "replace", "path": "active", "value": False}], idp_style
|
||||
),
|
||||
)
|
||||
assert deactivate_resp.status_code == 200
|
||||
assert deactivate_resp.json()["active"] is False
|
||||
|
||||
# Reactivate
|
||||
resp = ScimClient.patch(
|
||||
f"/Users/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[{"op": "replace", "path": "active", "value": True}], idp_style
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["active"] is True
|
||||
|
||||
|
||||
def test_delete_user(scim_token: str, idp_style: str) -> None:
|
||||
"""DELETE /Users/{id} deactivates and removes the SCIM mapping."""
|
||||
created = _create_scim_user(
|
||||
scim_token,
|
||||
f"scim_delete_{idp_style}@example.com",
|
||||
f"ext-delete-{idp_style}",
|
||||
idp_style,
|
||||
).json()
|
||||
|
||||
resp = ScimClient.delete(f"/Users/{created['id']}", scim_token)
|
||||
assert resp.status_code == 204
|
||||
|
||||
# Second DELETE returns 404 per RFC 7644 §3.6 (mapping removed)
|
||||
resp2 = ScimClient.delete(f"/Users/{created['id']}", scim_token)
|
||||
assert resp2.status_code == 404
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Error cases
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_create_user_missing_external_id(scim_token: str) -> None:
|
||||
"""POST /Users without externalId returns 400."""
|
||||
resp = ScimClient.post(
|
||||
"/Users",
|
||||
scim_token,
|
||||
json={
|
||||
"schemas": [SCIM_USER_SCHEMA],
|
||||
"userName": "scim_no_extid@example.com",
|
||||
"active": True,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert "externalId" in resp.json()["detail"]
|
||||
|
||||
|
||||
def test_create_user_duplicate_email(scim_token: str, idp_style: str) -> None:
|
||||
"""POST /Users with an already-taken email returns 409."""
|
||||
email = f"scim_dup_{idp_style}@example.com"
|
||||
resp1 = _create_scim_user(scim_token, email, f"ext-dup-1-{idp_style}", idp_style)
|
||||
assert resp1.status_code == 201
|
||||
|
||||
resp2 = _create_scim_user(scim_token, email, f"ext-dup-2-{idp_style}", idp_style)
|
||||
assert resp2.status_code == 409
|
||||
|
||||
|
||||
def test_get_nonexistent_user(scim_token: str) -> None:
|
||||
"""GET /Users/{bad-id} returns 404."""
|
||||
resp = ScimClient.get("/Users/00000000-0000-0000-0000-000000000000", scim_token)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_filter_users_by_external_id(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Users?filter=externalId eq '...' returns the matching user."""
|
||||
ext_id = f"ext-unique-filter-id-{idp_style}"
|
||||
_create_scim_user(
|
||||
scim_token, f"scim_extfilter_{idp_style}@example.com", ext_id, idp_style
|
||||
)
|
||||
|
||||
resp = ScimClient.get(f'/Users?filter=externalId eq "{ext_id}"', scim_token)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["totalResults"] == 1
|
||||
assert body["Resources"][0]["externalId"] == ext_id
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Seat-limit enforcement
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def _seed_license(r: redis.Redis, seats: int) -> None:
|
||||
"""Write a LicenseMetadata entry into Redis with the given seat cap."""
|
||||
now = datetime.now(timezone.utc)
|
||||
metadata = LicenseMetadata(
|
||||
tenant_id="public",
|
||||
organization_name="Test Org",
|
||||
seats=seats,
|
||||
used_seats=0, # check_seat_availability recalculates from DB
|
||||
plan_type=PlanType.ANNUAL,
|
||||
issued_at=now,
|
||||
expires_at=now + timedelta(days=365),
|
||||
status=ApplicationStatus.ACTIVE,
|
||||
source=LicenseSource.MANUAL_UPLOAD,
|
||||
)
|
||||
r.set(_LICENSE_REDIS_KEY, metadata.model_dump_json(), ex=300)
|
||||
|
||||
|
||||
def test_create_user_seat_limit(scim_token: str, idp_style: str) -> None:
|
||||
"""POST /Users returns 403 when the seat limit is reached."""
|
||||
r = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB_NUMBER)
|
||||
|
||||
# admin_user already occupies 1 seat; cap at 1 -> full
|
||||
_seed_license(r, seats=1)
|
||||
|
||||
try:
|
||||
resp = _create_scim_user(
|
||||
scim_token,
|
||||
f"scim_blocked_{idp_style}@example.com",
|
||||
f"ext-blocked-{idp_style}",
|
||||
idp_style,
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
assert "seat" in resp.json()["detail"].lower()
|
||||
finally:
|
||||
r.delete(_LICENSE_REDIS_KEY)
|
||||
|
||||
|
||||
def test_reactivate_user_seat_limit(scim_token: str, idp_style: str) -> None:
|
||||
"""PATCH active=true returns 403 when the seat limit is reached."""
|
||||
# Create and deactivate a user (before license is seeded)
|
||||
created = _create_scim_user(
|
||||
scim_token,
|
||||
f"scim_reactivate_blocked_{idp_style}@example.com",
|
||||
f"ext-reactivate-blocked-{idp_style}",
|
||||
idp_style,
|
||||
).json()
|
||||
assert created["active"] is True
|
||||
|
||||
deactivate_resp = ScimClient.patch(
|
||||
f"/Users/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[{"op": "replace", "path": "active", "value": False}], idp_style
|
||||
),
|
||||
)
|
||||
assert deactivate_resp.status_code == 200
|
||||
assert deactivate_resp.json()["active"] is False
|
||||
|
||||
r = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB_NUMBER)
|
||||
|
||||
# Seed license capped at current active users -> reactivation should fail
|
||||
_seed_license(r, seats=1)
|
||||
|
||||
try:
|
||||
resp = ScimClient.patch(
|
||||
f"/Users/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[{"op": "replace", "path": "active", "value": True}], idp_style
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
assert "seat" in resp.json()["detail"].lower()
|
||||
finally:
|
||||
r.delete(_LICENSE_REDIS_KEY)
|
||||
@@ -1,121 +0,0 @@
|
||||
"""Integration tests for Slack user deactivation and reactivation via admin endpoints.
|
||||
|
||||
Verifies that:
|
||||
- Slack users can be deactivated by admins
|
||||
- Deactivated Slack users can be reactivated by admins
|
||||
- Reactivation is blocked when the seat limit is reached
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
|
||||
import redis
|
||||
import requests
|
||||
|
||||
from ee.onyx.server.license.models import LicenseMetadata
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
from ee.onyx.server.license.models import PlanType
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import REDIS_DB_NUMBER
|
||||
from onyx.configs.app_configs import REDIS_HOST
|
||||
from onyx.configs.app_configs import REDIS_PORT
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
_LICENSE_REDIS_KEY = "public:license:metadata"
|
||||
|
||||
|
||||
def _seed_license(r: redis.Redis, seats: int) -> None:
|
||||
now = datetime.utcnow()
|
||||
metadata = LicenseMetadata(
|
||||
tenant_id="public",
|
||||
organization_name="Test Org",
|
||||
seats=seats,
|
||||
used_seats=0,
|
||||
plan_type=PlanType.ANNUAL,
|
||||
issued_at=now,
|
||||
expires_at=now + timedelta(days=365),
|
||||
status=ApplicationStatus.ACTIVE,
|
||||
source=LicenseSource.MANUAL_UPLOAD,
|
||||
)
|
||||
r.set(_LICENSE_REDIS_KEY, metadata.model_dump_json(), ex=300)
|
||||
|
||||
|
||||
def _clear_license(r: redis.Redis) -> None:
|
||||
r.delete(_LICENSE_REDIS_KEY)
|
||||
|
||||
|
||||
def _redis() -> redis.Redis:
|
||||
return redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB_NUMBER)
|
||||
|
||||
|
||||
def _get_user_is_active(email: str, admin_user: DATestUser) -> bool:
|
||||
"""Look up a user's is_active flag via the admin users list endpoint."""
|
||||
result = UserManager.get_user_page(
|
||||
user_performing_action=admin_user,
|
||||
search_query=email,
|
||||
)
|
||||
matching = [u for u in result.items if u.email == email]
|
||||
assert len(matching) == 1, f"Expected exactly 1 user with email {email}"
|
||||
return matching[0].is_active
|
||||
|
||||
|
||||
def test_slack_user_deactivate_and_reactivate(reset: None) -> None: # noqa: ARG001
|
||||
"""Admin can deactivate and then reactivate a Slack user."""
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
|
||||
slack_user = UserManager.create(name="slack_test_user")
|
||||
slack_user = UserManager.set_role(
|
||||
user_to_set=slack_user,
|
||||
target_role=UserRole.SLACK_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
|
||||
# Deactivate the Slack user
|
||||
UserManager.set_status(
|
||||
slack_user, target_status=False, user_performing_action=admin_user
|
||||
)
|
||||
assert _get_user_is_active(slack_user.email, admin_user) is False
|
||||
|
||||
# Reactivate the Slack user
|
||||
UserManager.set_status(
|
||||
slack_user, target_status=True, user_performing_action=admin_user
|
||||
)
|
||||
assert _get_user_is_active(slack_user.email, admin_user) is True
|
||||
|
||||
|
||||
def test_slack_user_reactivation_blocked_by_seat_limit(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Reactivating a deactivated Slack user returns 402 when seats are full."""
|
||||
r = _redis()
|
||||
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
|
||||
slack_user = UserManager.create(name="slack_test_user")
|
||||
slack_user = UserManager.set_role(
|
||||
user_to_set=slack_user,
|
||||
target_role=UserRole.SLACK_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
|
||||
UserManager.set_status(
|
||||
slack_user, target_status=False, user_performing_action=admin_user
|
||||
)
|
||||
|
||||
# License allows 1 seat — only admin counts
|
||||
_seed_license(r, seats=1)
|
||||
|
||||
try:
|
||||
response = requests.patch(
|
||||
url=f"{API_SERVER_URL}/manage/admin/activate-user",
|
||||
json={"user_email": slack_user.email},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 402
|
||||
finally:
|
||||
_clear_license(r)
|
||||
@@ -1,20 +1,11 @@
|
||||
"""Tests for license database CRUD operations."""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from ee.onyx.db.license import check_seat_availability
|
||||
from ee.onyx.db.license import delete_license
|
||||
from ee.onyx.db.license import get_license
|
||||
from ee.onyx.db.license import upsert_license
|
||||
from ee.onyx.server.license.models import LicenseMetadata
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
from ee.onyx.server.license.models import PlanType
|
||||
from onyx.db.models import License
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
|
||||
|
||||
class TestGetLicense:
|
||||
@@ -109,108 +100,3 @@ class TestDeleteLicense:
|
||||
assert result is False
|
||||
mock_session.delete.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
|
||||
def _make_license_metadata(seats: int = 10) -> LicenseMetadata:
|
||||
now = datetime.now(timezone.utc)
|
||||
return LicenseMetadata(
|
||||
tenant_id="public",
|
||||
seats=seats,
|
||||
used_seats=0,
|
||||
plan_type=PlanType.ANNUAL,
|
||||
issued_at=now,
|
||||
expires_at=now + timedelta(days=365),
|
||||
status=ApplicationStatus.ACTIVE,
|
||||
source=LicenseSource.MANUAL_UPLOAD,
|
||||
)
|
||||
|
||||
|
||||
class TestCheckSeatAvailabilitySelfHosted:
|
||||
"""Seat checks for self-hosted (MULTI_TENANT=False)."""
|
||||
|
||||
@patch("ee.onyx.db.license.get_license_metadata", return_value=None)
|
||||
def test_no_license_means_unlimited(self, _mock_meta: MagicMock) -> None:
|
||||
result = check_seat_availability(MagicMock(), seats_needed=1)
|
||||
assert result.available is True
|
||||
|
||||
@patch("ee.onyx.db.license.get_used_seats", return_value=5)
|
||||
@patch("ee.onyx.db.license.get_license_metadata")
|
||||
def test_seats_available(self, mock_meta: MagicMock, _mock_used: MagicMock) -> None:
|
||||
mock_meta.return_value = _make_license_metadata(seats=10)
|
||||
result = check_seat_availability(MagicMock(), seats_needed=1)
|
||||
assert result.available is True
|
||||
|
||||
@patch("ee.onyx.db.license.get_used_seats", return_value=10)
|
||||
@patch("ee.onyx.db.license.get_license_metadata")
|
||||
def test_seats_full_blocks_creation(
|
||||
self, mock_meta: MagicMock, _mock_used: MagicMock
|
||||
) -> None:
|
||||
mock_meta.return_value = _make_license_metadata(seats=10)
|
||||
result = check_seat_availability(MagicMock(), seats_needed=1)
|
||||
assert result.available is False
|
||||
assert result.error_message is not None
|
||||
assert "10 of 10" in result.error_message
|
||||
|
||||
@patch("ee.onyx.db.license.get_used_seats", return_value=10)
|
||||
@patch("ee.onyx.db.license.get_license_metadata")
|
||||
def test_exactly_at_capacity_allows_no_more(
|
||||
self, mock_meta: MagicMock, _mock_used: MagicMock
|
||||
) -> None:
|
||||
"""Filling to 100% is allowed; exceeding is not."""
|
||||
mock_meta.return_value = _make_license_metadata(seats=10)
|
||||
result = check_seat_availability(MagicMock(), seats_needed=1)
|
||||
assert result.available is False
|
||||
|
||||
@patch("ee.onyx.db.license.get_used_seats", return_value=9)
|
||||
@patch("ee.onyx.db.license.get_license_metadata")
|
||||
def test_filling_to_capacity_is_allowed(
|
||||
self, mock_meta: MagicMock, _mock_used: MagicMock
|
||||
) -> None:
|
||||
mock_meta.return_value = _make_license_metadata(seats=10)
|
||||
result = check_seat_availability(MagicMock(), seats_needed=1)
|
||||
assert result.available is True
|
||||
|
||||
|
||||
class TestCheckSeatAvailabilityMultiTenant:
|
||||
"""Seat checks for multi-tenant cloud (MULTI_TENANT=True).
|
||||
|
||||
Verifies that get_used_seats takes the MULTI_TENANT branch
|
||||
and delegates to get_tenant_count.
|
||||
"""
|
||||
|
||||
@patch("ee.onyx.db.license.MULTI_TENANT", True)
|
||||
@patch(
|
||||
"ee.onyx.server.tenants.user_mapping.get_tenant_count",
|
||||
return_value=5,
|
||||
)
|
||||
@patch("ee.onyx.db.license.get_license_metadata")
|
||||
def test_seats_available_multi_tenant(
|
||||
self,
|
||||
mock_meta: MagicMock,
|
||||
mock_tenant_count: MagicMock,
|
||||
) -> None:
|
||||
mock_meta.return_value = _make_license_metadata(seats=10)
|
||||
result = check_seat_availability(
|
||||
MagicMock(), seats_needed=1, tenant_id="tenant-abc"
|
||||
)
|
||||
assert result.available is True
|
||||
mock_tenant_count.assert_called_once_with("tenant-abc")
|
||||
|
||||
@patch("ee.onyx.db.license.MULTI_TENANT", True)
|
||||
@patch(
|
||||
"ee.onyx.server.tenants.user_mapping.get_tenant_count",
|
||||
return_value=10,
|
||||
)
|
||||
@patch("ee.onyx.db.license.get_license_metadata")
|
||||
def test_seats_full_multi_tenant(
|
||||
self,
|
||||
mock_meta: MagicMock,
|
||||
mock_tenant_count: MagicMock,
|
||||
) -> None:
|
||||
mock_meta.return_value = _make_license_metadata(seats=10)
|
||||
result = check_seat_availability(
|
||||
MagicMock(), seats_needed=1, tenant_id="tenant-abc"
|
||||
)
|
||||
assert result.available is False
|
||||
assert result.error_message is not None
|
||||
mock_tenant_count.assert_called_once_with("tenant-abc")
|
||||
|
||||
@@ -1,291 +0,0 @@
|
||||
"""Tests for the _impl functions' redis_locking parameter.
|
||||
|
||||
Verifies that:
|
||||
- redis_locking=True acquires/releases Redis locks and clears queued keys
|
||||
- redis_locking=False skips all Redis operations entirely
|
||||
- Both paths execute the same business logic (DB lookup, status check)
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
delete_user_file_impl,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
process_user_file_impl,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
project_sync_user_file_impl,
|
||||
)
|
||||
|
||||
TASKS_MODULE = "onyx.background.celery.tasks.user_file_processing.tasks"
|
||||
|
||||
|
||||
def _mock_session_returning_none() -> MagicMock:
|
||||
"""Return a mock session whose .get() returns None (file not found)."""
|
||||
session = MagicMock()
|
||||
session.get.return_value = None
|
||||
session.execute.return_value.scalar_one_or_none.return_value = None
|
||||
return session
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# process_user_file_impl
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProcessUserFileImpl:
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_true_acquires_and_releases_lock(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
redis_client = MagicMock()
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.owned.return_value = True
|
||||
redis_client.lock.return_value = lock
|
||||
mock_get_redis.return_value = redis_client
|
||||
|
||||
session = _mock_session_returning_none()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
user_file_id = str(uuid4())
|
||||
process_user_file_impl(
|
||||
user_file_id=user_file_id,
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=True,
|
||||
)
|
||||
|
||||
mock_get_redis.assert_called_once_with(tenant_id="test-tenant")
|
||||
redis_client.delete.assert_called_once()
|
||||
lock.acquire.assert_called_once_with(blocking=False)
|
||||
lock.release.assert_called_once()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_true_skips_when_lock_held(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
redis_client = MagicMock()
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = False
|
||||
redis_client.lock.return_value = lock
|
||||
mock_get_redis.return_value = redis_client
|
||||
|
||||
process_user_file_impl(
|
||||
user_file_id=str(uuid4()),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=True,
|
||||
)
|
||||
|
||||
lock.acquire.assert_called_once()
|
||||
mock_get_session.assert_not_called()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_false_skips_redis_entirely(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
session = _mock_session_returning_none()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
process_user_file_impl(
|
||||
user_file_id=str(uuid4()),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
mock_get_redis.assert_not_called()
|
||||
mock_get_session.assert_called_once()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_both_paths_call_db_get(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
"""Both redis_locking=True and False should call db_session.get(UserFile, ...)."""
|
||||
redis_client = MagicMock()
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.owned.return_value = True
|
||||
redis_client.lock.return_value = lock
|
||||
mock_get_redis.return_value = redis_client
|
||||
|
||||
session = _mock_session_returning_none()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
uid = str(uuid4())
|
||||
|
||||
process_user_file_impl(user_file_id=uid, tenant_id="t", redis_locking=True)
|
||||
call_count_true = session.get.call_count
|
||||
|
||||
session.reset_mock()
|
||||
mock_get_session.reset_mock()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
process_user_file_impl(user_file_id=uid, tenant_id="t", redis_locking=False)
|
||||
call_count_false = session.get.call_count
|
||||
|
||||
assert call_count_true == call_count_false == 1
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# delete_user_file_impl
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDeleteUserFileImpl:
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_true_acquires_and_releases_lock(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
redis_client = MagicMock()
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.owned.return_value = True
|
||||
redis_client.lock.return_value = lock
|
||||
mock_get_redis.return_value = redis_client
|
||||
|
||||
session = _mock_session_returning_none()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
delete_user_file_impl(
|
||||
user_file_id=str(uuid4()),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=True,
|
||||
)
|
||||
|
||||
mock_get_redis.assert_called_once()
|
||||
lock.acquire.assert_called_once_with(blocking=False)
|
||||
lock.release.assert_called_once()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_true_skips_when_lock_held(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
redis_client = MagicMock()
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = False
|
||||
redis_client.lock.return_value = lock
|
||||
mock_get_redis.return_value = redis_client
|
||||
|
||||
delete_user_file_impl(
|
||||
user_file_id=str(uuid4()),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=True,
|
||||
)
|
||||
|
||||
lock.acquire.assert_called_once()
|
||||
mock_get_session.assert_not_called()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_false_skips_redis_entirely(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
session = _mock_session_returning_none()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
delete_user_file_impl(
|
||||
user_file_id=str(uuid4()),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
mock_get_redis.assert_not_called()
|
||||
mock_get_session.assert_called_once()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# project_sync_user_file_impl
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProjectSyncUserFileImpl:
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_true_acquires_and_releases_lock(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
redis_client = MagicMock()
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.owned.return_value = True
|
||||
redis_client.lock.return_value = lock
|
||||
mock_get_redis.return_value = redis_client
|
||||
|
||||
session = _mock_session_returning_none()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=str(uuid4()),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=True,
|
||||
)
|
||||
|
||||
mock_get_redis.assert_called_once()
|
||||
redis_client.delete.assert_called_once()
|
||||
lock.acquire.assert_called_once_with(blocking=False)
|
||||
lock.release.assert_called_once()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_true_skips_when_lock_held(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
redis_client = MagicMock()
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = False
|
||||
redis_client.lock.return_value = lock
|
||||
mock_get_redis.return_value = redis_client
|
||||
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=str(uuid4()),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=True,
|
||||
)
|
||||
|
||||
lock.acquire.assert_called_once()
|
||||
mock_get_session.assert_not_called()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_false_skips_redis_entirely(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
session = _mock_session_returning_none()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=str(uuid4()),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
mock_get_redis.assert_not_called()
|
||||
mock_get_session.assert_called_once()
|
||||
@@ -1,421 +0,0 @@
|
||||
"""Tests for no-vector-DB user file processing paths.
|
||||
|
||||
Verifies that when DISABLE_VECTOR_DB is True:
|
||||
- process_user_file_impl calls _process_user_file_without_vector_db (not indexing)
|
||||
- _process_user_file_without_vector_db extracts text, counts tokens, stores plaintext,
|
||||
sets status=COMPLETED and chunk_count=0
|
||||
- delete_user_file_impl skips vector DB chunk deletion
|
||||
- project_sync_user_file_impl skips vector DB metadata update
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
_process_user_file_without_vector_db,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
delete_user_file_impl,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
process_user_file_impl,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
project_sync_user_file_impl,
|
||||
)
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.enums import UserFileStatus
|
||||
|
||||
TASKS_MODULE = "onyx.background.celery.tasks.user_file_processing.tasks"
|
||||
LLM_FACTORY_MODULE = "onyx.llm.factory"
|
||||
|
||||
|
||||
def _make_documents(texts: list[str]) -> list[Document]:
|
||||
"""Build a list of Document objects with the given section texts."""
|
||||
return [
|
||||
Document(
|
||||
id=str(uuid4()),
|
||||
source=DocumentSource.USER_FILE,
|
||||
sections=[TextSection(text=t)],
|
||||
semantic_identifier=f"test-doc-{i}",
|
||||
metadata={},
|
||||
)
|
||||
for i, t in enumerate(texts)
|
||||
]
|
||||
|
||||
|
||||
def _make_user_file(
|
||||
*,
|
||||
status: UserFileStatus = UserFileStatus.PROCESSING,
|
||||
file_id: str = "test-file-id",
|
||||
name: str = "test.txt",
|
||||
) -> MagicMock:
|
||||
"""Return a MagicMock mimicking a UserFile ORM instance."""
|
||||
uf = MagicMock()
|
||||
uf.id = uuid4()
|
||||
uf.file_id = file_id
|
||||
uf.name = name
|
||||
uf.status = status
|
||||
uf.token_count = None
|
||||
uf.chunk_count = None
|
||||
uf.last_project_sync_at = None
|
||||
uf.projects = []
|
||||
uf.assistants = []
|
||||
uf.needs_project_sync = True
|
||||
uf.needs_persona_sync = True
|
||||
return uf
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# _process_user_file_without_vector_db — direct tests
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProcessUserFileWithoutVectorDb:
|
||||
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_default_llm")
|
||||
def test_extracts_and_combines_text(
|
||||
self,
|
||||
mock_get_llm: MagicMock, # noqa: ARG002
|
||||
mock_get_encode: MagicMock,
|
||||
mock_store_plaintext: MagicMock,
|
||||
) -> None:
|
||||
mock_encode = MagicMock(return_value=[1, 2, 3, 4, 5])
|
||||
mock_get_encode.return_value = mock_encode
|
||||
|
||||
uf = _make_user_file()
|
||||
docs = _make_documents(["hello world", "foo bar"])
|
||||
db_session = MagicMock()
|
||||
|
||||
_process_user_file_without_vector_db(uf, docs, db_session)
|
||||
|
||||
stored_text = mock_store_plaintext.call_args.kwargs["plaintext_content"]
|
||||
assert "hello world" in stored_text
|
||||
assert "foo bar" in stored_text
|
||||
|
||||
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_default_llm")
|
||||
def test_computes_token_count(
|
||||
self,
|
||||
mock_get_llm: MagicMock, # noqa: ARG002
|
||||
mock_get_encode: MagicMock,
|
||||
mock_store_plaintext: MagicMock, # noqa: ARG002
|
||||
) -> None:
|
||||
mock_encode = MagicMock(return_value=list(range(42)))
|
||||
mock_get_encode.return_value = mock_encode
|
||||
|
||||
uf = _make_user_file()
|
||||
docs = _make_documents(["some text content"])
|
||||
db_session = MagicMock()
|
||||
|
||||
_process_user_file_without_vector_db(uf, docs, db_session)
|
||||
|
||||
assert uf.token_count == 42
|
||||
|
||||
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_default_llm")
|
||||
def test_token_count_falls_back_to_none_on_error(
|
||||
self,
|
||||
mock_get_llm: MagicMock,
|
||||
mock_get_encode: MagicMock, # noqa: ARG002
|
||||
mock_store_plaintext: MagicMock, # noqa: ARG002
|
||||
) -> None:
|
||||
mock_get_llm.side_effect = RuntimeError("No LLM configured")
|
||||
|
||||
uf = _make_user_file()
|
||||
docs = _make_documents(["text"])
|
||||
db_session = MagicMock()
|
||||
|
||||
_process_user_file_without_vector_db(uf, docs, db_session)
|
||||
|
||||
assert uf.token_count is None
|
||||
|
||||
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_default_llm")
|
||||
def test_stores_plaintext(
|
||||
self,
|
||||
mock_get_llm: MagicMock, # noqa: ARG002
|
||||
mock_get_encode: MagicMock,
|
||||
mock_store_plaintext: MagicMock,
|
||||
) -> None:
|
||||
mock_get_encode.return_value = MagicMock(return_value=[1])
|
||||
|
||||
uf = _make_user_file()
|
||||
docs = _make_documents(["content to store"])
|
||||
db_session = MagicMock()
|
||||
|
||||
_process_user_file_without_vector_db(uf, docs, db_session)
|
||||
|
||||
mock_store_plaintext.assert_called_once_with(
|
||||
user_file_id=uf.id,
|
||||
plaintext_content="content to store",
|
||||
)
|
||||
|
||||
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_default_llm")
|
||||
def test_sets_completed_status_and_zero_chunk_count(
|
||||
self,
|
||||
mock_get_llm: MagicMock, # noqa: ARG002
|
||||
mock_get_encode: MagicMock,
|
||||
mock_store_plaintext: MagicMock, # noqa: ARG002
|
||||
) -> None:
|
||||
mock_get_encode.return_value = MagicMock(return_value=[1])
|
||||
|
||||
uf = _make_user_file()
|
||||
docs = _make_documents(["text"])
|
||||
db_session = MagicMock()
|
||||
|
||||
_process_user_file_without_vector_db(uf, docs, db_session)
|
||||
|
||||
assert uf.status == UserFileStatus.COMPLETED
|
||||
assert uf.chunk_count == 0
|
||||
assert uf.last_project_sync_at is not None
|
||||
db_session.add.assert_called_once_with(uf)
|
||||
db_session.commit.assert_called_once()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_default_llm")
|
||||
def test_preserves_deleting_status(
|
||||
self,
|
||||
mock_get_llm: MagicMock, # noqa: ARG002
|
||||
mock_get_encode: MagicMock,
|
||||
mock_store_plaintext: MagicMock, # noqa: ARG002
|
||||
) -> None:
|
||||
mock_get_encode.return_value = MagicMock(return_value=[1])
|
||||
|
||||
uf = _make_user_file(status=UserFileStatus.DELETING)
|
||||
docs = _make_documents(["text"])
|
||||
db_session = MagicMock()
|
||||
|
||||
_process_user_file_without_vector_db(uf, docs, db_session)
|
||||
|
||||
assert uf.status == UserFileStatus.DELETING
|
||||
assert uf.chunk_count == 0
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# process_user_file_impl — branching on DISABLE_VECTOR_DB
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProcessImplBranching:
|
||||
@patch(f"{TASKS_MODULE}._process_user_file_without_vector_db")
|
||||
@patch(f"{TASKS_MODULE}._process_user_file_with_indexing")
|
||||
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
def test_calls_without_vector_db_when_disabled(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
mock_with_indexing: MagicMock,
|
||||
mock_without_vdb: MagicMock,
|
||||
) -> None:
|
||||
uf = _make_user_file()
|
||||
session = MagicMock()
|
||||
session.get.return_value = uf
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
connector_mock = MagicMock()
|
||||
connector_mock.load_from_state.return_value = [_make_documents(["hello"])]
|
||||
|
||||
with patch(f"{TASKS_MODULE}.LocalFileConnector", return_value=connector_mock):
|
||||
process_user_file_impl(
|
||||
user_file_id=str(uf.id),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
mock_without_vdb.assert_called_once()
|
||||
mock_with_indexing.assert_not_called()
|
||||
|
||||
@patch(f"{TASKS_MODULE}._process_user_file_without_vector_db")
|
||||
@patch(f"{TASKS_MODULE}._process_user_file_with_indexing")
|
||||
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", False)
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
def test_calls_with_indexing_when_vector_db_enabled(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
mock_with_indexing: MagicMock,
|
||||
mock_without_vdb: MagicMock,
|
||||
) -> None:
|
||||
uf = _make_user_file()
|
||||
session = MagicMock()
|
||||
session.get.return_value = uf
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
connector_mock = MagicMock()
|
||||
connector_mock.load_from_state.return_value = [_make_documents(["hello"])]
|
||||
|
||||
with patch(f"{TASKS_MODULE}.LocalFileConnector", return_value=connector_mock):
|
||||
process_user_file_impl(
|
||||
user_file_id=str(uf.id),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
mock_with_indexing.assert_called_once()
|
||||
mock_without_vdb.assert_not_called()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.run_indexing_pipeline")
|
||||
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
|
||||
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
def test_indexing_pipeline_not_called_when_disabled(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
mock_store_plaintext: MagicMock, # noqa: ARG002
|
||||
mock_run_pipeline: MagicMock,
|
||||
) -> None:
|
||||
"""End-to-end: verify run_indexing_pipeline is never invoked."""
|
||||
uf = _make_user_file()
|
||||
session = MagicMock()
|
||||
session.get.return_value = uf
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
connector_mock = MagicMock()
|
||||
connector_mock.load_from_state.return_value = [_make_documents(["content"])]
|
||||
|
||||
with (
|
||||
patch(f"{TASKS_MODULE}.LocalFileConnector", return_value=connector_mock),
|
||||
patch(f"{LLM_FACTORY_MODULE}.get_default_llm"),
|
||||
patch(
|
||||
f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func",
|
||||
return_value=MagicMock(return_value=[1, 2, 3]),
|
||||
),
|
||||
):
|
||||
process_user_file_impl(
|
||||
user_file_id=str(uf.id),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
mock_run_pipeline.assert_not_called()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# delete_user_file_impl — vector DB skip
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDeleteImplNoVectorDb:
|
||||
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
@patch(f"{TASKS_MODULE}.get_default_file_store")
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
def test_skips_vector_db_deletion(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
mock_get_file_store: MagicMock,
|
||||
) -> None:
|
||||
uf = _make_user_file(status=UserFileStatus.DELETING)
|
||||
session = MagicMock()
|
||||
session.get.return_value = uf
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
mock_get_file_store.return_value = MagicMock()
|
||||
|
||||
with (
|
||||
patch(f"{TASKS_MODULE}.get_all_document_indices") as mock_get_indices,
|
||||
patch(f"{TASKS_MODULE}.get_active_search_settings") as mock_get_ss,
|
||||
patch(f"{TASKS_MODULE}.httpx_init_vespa_pool") as mock_vespa_pool,
|
||||
):
|
||||
delete_user_file_impl(
|
||||
user_file_id=str(uf.id),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
mock_get_indices.assert_not_called()
|
||||
mock_get_ss.assert_not_called()
|
||||
mock_vespa_pool.assert_not_called()
|
||||
|
||||
session.delete.assert_called_once_with(uf)
|
||||
session.commit.assert_called_once()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
@patch(f"{TASKS_MODULE}.get_default_file_store")
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
def test_still_deletes_file_store_and_db_record(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
mock_get_file_store: MagicMock,
|
||||
) -> None:
|
||||
uf = _make_user_file(status=UserFileStatus.DELETING)
|
||||
session = MagicMock()
|
||||
session.get.return_value = uf
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
file_store = MagicMock()
|
||||
mock_get_file_store.return_value = file_store
|
||||
|
||||
delete_user_file_impl(
|
||||
user_file_id=str(uf.id),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
assert file_store.delete_file.call_count == 2
|
||||
session.delete.assert_called_once_with(uf)
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# project_sync_user_file_impl — vector DB skip
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProjectSyncImplNoVectorDb:
|
||||
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
def test_skips_vector_db_update(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
uf = _make_user_file(status=UserFileStatus.COMPLETED)
|
||||
session = MagicMock()
|
||||
session.execute.return_value.scalar_one_or_none.return_value = uf
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
with (
|
||||
patch(f"{TASKS_MODULE}.get_all_document_indices") as mock_get_indices,
|
||||
patch(f"{TASKS_MODULE}.get_active_search_settings") as mock_get_ss,
|
||||
patch(f"{TASKS_MODULE}.httpx_init_vespa_pool") as mock_vespa_pool,
|
||||
):
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=str(uf.id),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
mock_get_indices.assert_not_called()
|
||||
mock_get_ss.assert_not_called()
|
||||
mock_vespa_pool.assert_not_called()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
def test_still_clears_sync_flags(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
uf = _make_user_file(status=UserFileStatus.COMPLETED)
|
||||
session = MagicMock()
|
||||
session.execute.return_value.scalar_one_or_none.return_value = uf
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=str(uf.id),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
assert uf.needs_project_sync is False
|
||||
assert uf.needs_persona_sync is False
|
||||
assert uf.last_project_sync_at is not None
|
||||
session.add.assert_called_once_with(uf)
|
||||
session.commit.assert_called_once()
|
||||
@@ -19,7 +19,6 @@ from ee.onyx.server.scim.models import ScimListResponse
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
from ee.onyx.server.scim.providers.entra import EntraProvider
|
||||
from ee.onyx.server.scim.providers.okta import OktaProvider
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import ScimUserMapping
|
||||
@@ -27,10 +26,6 @@ from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
|
||||
# Every supported SCIM provider must appear here so that all endpoint tests
|
||||
# run against it. When adding a new provider, add its class to this list.
|
||||
SCIM_PROVIDERS: list[type[ScimProvider]] = [OktaProvider, EntraProvider]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session() -> MagicMock:
|
||||
@@ -46,10 +41,10 @@ def mock_token() -> MagicMock:
|
||||
return token
|
||||
|
||||
|
||||
@pytest.fixture(params=SCIM_PROVIDERS, ids=[p.__name__ for p in SCIM_PROVIDERS])
|
||||
def provider(request: pytest.FixtureRequest) -> ScimProvider:
|
||||
"""Parameterized provider — runs each test with every provider in SCIM_PROVIDERS."""
|
||||
return request.param()
|
||||
@pytest.fixture
|
||||
def provider() -> ScimProvider:
|
||||
"""An OktaProvider instance for endpoint tests."""
|
||||
return OktaProvider()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
"""Tests for startup validation in no-vector-DB mode.
|
||||
|
||||
Verifies that DISABLE_VECTOR_DB raises RuntimeError when combined with
|
||||
incompatible settings (MULTI_TENANT, ENABLE_CRAFT).
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestValidateNoVectorDbSettings:
|
||||
@patch("onyx.main.DISABLE_VECTOR_DB", False)
|
||||
def test_no_error_when_vector_db_enabled(self) -> None:
|
||||
from onyx.main import validate_no_vector_db_settings
|
||||
|
||||
validate_no_vector_db_settings()
|
||||
|
||||
@patch("onyx.main.DISABLE_VECTOR_DB", True)
|
||||
@patch("onyx.main.MULTI_TENANT", False)
|
||||
@patch("onyx.server.features.build.configs.ENABLE_CRAFT", False)
|
||||
def test_no_error_when_no_conflicts(self) -> None:
|
||||
from onyx.main import validate_no_vector_db_settings
|
||||
|
||||
validate_no_vector_db_settings()
|
||||
|
||||
@patch("onyx.main.DISABLE_VECTOR_DB", True)
|
||||
@patch("onyx.main.MULTI_TENANT", True)
|
||||
def test_raises_on_multi_tenant(self) -> None:
|
||||
from onyx.main import validate_no_vector_db_settings
|
||||
|
||||
with pytest.raises(RuntimeError, match="MULTI_TENANT"):
|
||||
validate_no_vector_db_settings()
|
||||
|
||||
@patch("onyx.main.DISABLE_VECTOR_DB", True)
|
||||
@patch("onyx.main.MULTI_TENANT", False)
|
||||
@patch("onyx.server.features.build.configs.ENABLE_CRAFT", True)
|
||||
def test_raises_on_enable_craft(self) -> None:
|
||||
from onyx.main import validate_no_vector_db_settings
|
||||
|
||||
with pytest.raises(RuntimeError, match="ENABLE_CRAFT"):
|
||||
validate_no_vector_db_settings()
|
||||
|
||||
@patch("onyx.main.DISABLE_VECTOR_DB", True)
|
||||
@patch("onyx.main.MULTI_TENANT", True)
|
||||
@patch("onyx.server.features.build.configs.ENABLE_CRAFT", True)
|
||||
def test_multi_tenant_checked_before_craft(self) -> None:
|
||||
"""MULTI_TENANT is checked first, so it should be the error raised."""
|
||||
from onyx.main import validate_no_vector_db_settings
|
||||
|
||||
with pytest.raises(RuntimeError, match="MULTI_TENANT"):
|
||||
validate_no_vector_db_settings()
|
||||
@@ -1,196 +0,0 @@
|
||||
"""Tests for tool construction when DISABLE_VECTOR_DB is True.
|
||||
|
||||
Verifies that:
|
||||
- SearchTool.is_available() returns False when vector DB is disabled
|
||||
- OpenURLTool.is_available() returns False when vector DB is disabled
|
||||
- The force-add SearchTool block is suppressed when DISABLE_VECTOR_DB
|
||||
- FileReaderTool.is_available() returns True when vector DB is disabled
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.tools.tool_implementations.file_reader.file_reader_tool import FileReaderTool
|
||||
|
||||
APP_CONFIGS_MODULE = "onyx.configs.app_configs"
|
||||
FILE_READER_MODULE = "onyx.tools.tool_implementations.file_reader.file_reader_tool"
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# SearchTool.is_available()
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSearchToolAvailability:
|
||||
@patch(f"{APP_CONFIGS_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
def test_unavailable_when_vector_db_disabled(self) -> None:
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
assert SearchTool.is_available(MagicMock()) is False
|
||||
|
||||
@patch("onyx.db.connector.check_user_files_exist", return_value=True)
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.search.search_tool.check_federated_connectors_exist",
|
||||
return_value=False,
|
||||
)
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.search.search_tool.check_connectors_exist",
|
||||
return_value=False,
|
||||
)
|
||||
@patch(f"{APP_CONFIGS_MODULE}.DISABLE_VECTOR_DB", False)
|
||||
def test_available_when_vector_db_enabled_and_files_exist(
|
||||
self,
|
||||
mock_connectors: MagicMock, # noqa: ARG002
|
||||
mock_federated: MagicMock, # noqa: ARG002
|
||||
mock_user_files: MagicMock, # noqa: ARG002
|
||||
) -> None:
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
assert SearchTool.is_available(MagicMock()) is True
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# OpenURLTool.is_available()
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOpenURLToolAvailability:
|
||||
@patch(f"{APP_CONFIGS_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
def test_unavailable_when_vector_db_disabled(self) -> None:
|
||||
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
|
||||
|
||||
assert OpenURLTool.is_available(MagicMock()) is False
|
||||
|
||||
@patch(f"{APP_CONFIGS_MODULE}.DISABLE_VECTOR_DB", False)
|
||||
def test_available_when_vector_db_enabled(self) -> None:
|
||||
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
|
||||
|
||||
assert OpenURLTool.is_available(MagicMock()) is True
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# FileReaderTool.is_available()
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFileReaderToolAvailability:
|
||||
@patch(f"{FILE_READER_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
def test_available_when_vector_db_disabled(self) -> None:
|
||||
assert FileReaderTool.is_available(MagicMock()) is True
|
||||
|
||||
@patch(f"{FILE_READER_MODULE}.DISABLE_VECTOR_DB", False)
|
||||
def test_unavailable_when_vector_db_enabled(self) -> None:
|
||||
assert FileReaderTool.is_available(MagicMock()) is False
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Force-add SearchTool suppression
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestForceAddSearchToolGuard:
|
||||
def test_force_add_block_checks_disable_vector_db(self) -> None:
|
||||
"""The force-add SearchTool block in construct_tools should include
|
||||
`not DISABLE_VECTOR_DB` so that forced search is also suppressed
|
||||
without a vector DB."""
|
||||
import inspect
|
||||
|
||||
from onyx.tools.tool_constructor import construct_tools
|
||||
|
||||
source = inspect.getsource(construct_tools)
|
||||
assert "DISABLE_VECTOR_DB" in source, (
|
||||
"construct_tools should reference DISABLE_VECTOR_DB "
|
||||
"to suppress force-adding SearchTool"
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Persona API — _validate_vector_db_knowledge
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValidateVectorDbKnowledge:
|
||||
@patch(
|
||||
"onyx.server.features.persona.api.DISABLE_VECTOR_DB",
|
||||
True,
|
||||
)
|
||||
def test_rejects_document_set_ids(self) -> None:
|
||||
from fastapi import HTTPException
|
||||
|
||||
from onyx.server.features.persona.api import _validate_vector_db_knowledge
|
||||
|
||||
request = MagicMock()
|
||||
request.document_set_ids = [1]
|
||||
request.hierarchy_node_ids = []
|
||||
request.document_ids = []
|
||||
|
||||
with __import__("pytest").raises(HTTPException) as exc_info:
|
||||
_validate_vector_db_knowledge(request)
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "document sets" in exc_info.value.detail
|
||||
|
||||
@patch(
|
||||
"onyx.server.features.persona.api.DISABLE_VECTOR_DB",
|
||||
True,
|
||||
)
|
||||
def test_rejects_hierarchy_node_ids(self) -> None:
|
||||
from fastapi import HTTPException
|
||||
|
||||
from onyx.server.features.persona.api import _validate_vector_db_knowledge
|
||||
|
||||
request = MagicMock()
|
||||
request.document_set_ids = []
|
||||
request.hierarchy_node_ids = [1]
|
||||
request.document_ids = []
|
||||
|
||||
with __import__("pytest").raises(HTTPException) as exc_info:
|
||||
_validate_vector_db_knowledge(request)
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "hierarchy nodes" in exc_info.value.detail
|
||||
|
||||
@patch(
|
||||
"onyx.server.features.persona.api.DISABLE_VECTOR_DB",
|
||||
True,
|
||||
)
|
||||
def test_rejects_document_ids(self) -> None:
|
||||
from fastapi import HTTPException
|
||||
|
||||
from onyx.server.features.persona.api import _validate_vector_db_knowledge
|
||||
|
||||
request = MagicMock()
|
||||
request.document_set_ids = []
|
||||
request.hierarchy_node_ids = []
|
||||
request.document_ids = ["doc-abc"]
|
||||
|
||||
with __import__("pytest").raises(HTTPException) as exc_info:
|
||||
_validate_vector_db_knowledge(request)
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "documents" in exc_info.value.detail
|
||||
|
||||
@patch(
|
||||
"onyx.server.features.persona.api.DISABLE_VECTOR_DB",
|
||||
True,
|
||||
)
|
||||
def test_allows_user_files_only(self) -> None:
|
||||
from onyx.server.features.persona.api import _validate_vector_db_knowledge
|
||||
|
||||
request = MagicMock()
|
||||
request.document_set_ids = []
|
||||
request.hierarchy_node_ids = []
|
||||
request.document_ids = []
|
||||
|
||||
_validate_vector_db_knowledge(request)
|
||||
|
||||
@patch(
|
||||
"onyx.server.features.persona.api.DISABLE_VECTOR_DB",
|
||||
False,
|
||||
)
|
||||
def test_allows_everything_when_vector_db_enabled(self) -> None:
|
||||
from onyx.server.features.persona.api import _validate_vector_db_knowledge
|
||||
|
||||
request = MagicMock()
|
||||
request.document_set_ids = [1, 2]
|
||||
request.hierarchy_node_ids = [3]
|
||||
request.document_ids = ["doc-x"]
|
||||
|
||||
_validate_vector_db_knowledge(request)
|
||||
@@ -1,237 +0,0 @@
|
||||
"""Tests for the FileReaderTool.
|
||||
|
||||
Verifies:
|
||||
- Tool definition schema is well-formed
|
||||
- File ID validation (allowlist, UUID format)
|
||||
- Character range extraction and clamping
|
||||
- Error handling for missing parameters and non-text files
|
||||
- is_available() reflects DISABLE_VECTOR_DB
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.tools.models import ToolCallException
|
||||
from onyx.tools.tool_implementations.file_reader.file_reader_tool import FILE_ID_FIELD
|
||||
from onyx.tools.tool_implementations.file_reader.file_reader_tool import FileReaderTool
|
||||
from onyx.tools.tool_implementations.file_reader.file_reader_tool import MAX_NUM_CHARS
|
||||
from onyx.tools.tool_implementations.file_reader.file_reader_tool import NUM_CHARS_FIELD
|
||||
from onyx.tools.tool_implementations.file_reader.file_reader_tool import (
|
||||
START_CHAR_FIELD,
|
||||
)
|
||||
|
||||
TOOL_MODULE = "onyx.tools.tool_implementations.file_reader.file_reader_tool"
|
||||
_PLACEMENT = Placement(turn_index=0)
|
||||
|
||||
|
||||
def _make_tool(
|
||||
user_file_ids: list | None = None,
|
||||
chat_file_ids: list | None = None,
|
||||
) -> FileReaderTool:
|
||||
emitter = MagicMock()
|
||||
return FileReaderTool(
|
||||
tool_id=99,
|
||||
emitter=emitter,
|
||||
user_file_ids=user_file_ids or [],
|
||||
chat_file_ids=chat_file_ids or [],
|
||||
)
|
||||
|
||||
|
||||
def _text_file(content: str, filename: str = "test.txt") -> InMemoryChatFile:
|
||||
return InMemoryChatFile(
|
||||
file_id="some-file-id",
|
||||
content=content.encode("utf-8"),
|
||||
file_type=ChatFileType.PLAIN_TEXT,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool metadata
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestToolMetadata:
|
||||
def test_tool_name(self) -> None:
|
||||
tool = _make_tool()
|
||||
assert tool.name == "read_file"
|
||||
|
||||
def test_tool_definition_schema(self) -> None:
|
||||
tool = _make_tool()
|
||||
defn = tool.tool_definition()
|
||||
assert defn["type"] == "function"
|
||||
func = defn["function"]
|
||||
assert func["name"] == "read_file"
|
||||
props = func["parameters"]["properties"]
|
||||
assert FILE_ID_FIELD in props
|
||||
assert START_CHAR_FIELD in props
|
||||
assert NUM_CHARS_FIELD in props
|
||||
assert func["parameters"]["required"] == [FILE_ID_FIELD]
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# File ID validation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFileIdValidation:
|
||||
def test_rejects_invalid_uuid(self) -> None:
|
||||
tool = _make_tool()
|
||||
with pytest.raises(ToolCallException, match="Invalid file_id"):
|
||||
tool._validate_file_id("not-a-uuid")
|
||||
|
||||
def test_rejects_file_not_in_allowlist(self) -> None:
|
||||
tool = _make_tool(user_file_ids=[uuid4()])
|
||||
other_id = uuid4()
|
||||
with pytest.raises(ToolCallException, match="not in available files"):
|
||||
tool._validate_file_id(str(other_id))
|
||||
|
||||
def test_accepts_user_file_id(self) -> None:
|
||||
uid = uuid4()
|
||||
tool = _make_tool(user_file_ids=[uid])
|
||||
assert tool._validate_file_id(str(uid)) == uid
|
||||
|
||||
def test_accepts_chat_file_id(self) -> None:
|
||||
cid = uuid4()
|
||||
tool = _make_tool(chat_file_ids=[cid])
|
||||
assert tool._validate_file_id(str(cid)) == cid
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# run() — character range extraction
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRun:
|
||||
@patch(f"{TOOL_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TOOL_MODULE}.load_user_file")
|
||||
def test_returns_full_content_by_default(
|
||||
self,
|
||||
mock_load_user_file: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
uid = uuid4()
|
||||
content = "Hello, world!"
|
||||
mock_load_user_file.return_value = _text_file(content)
|
||||
mock_get_session.return_value.__enter__.return_value = MagicMock()
|
||||
|
||||
tool = _make_tool(user_file_ids=[uid])
|
||||
resp = tool.run(
|
||||
placement=_PLACEMENT,
|
||||
override_kwargs=MagicMock(),
|
||||
**{FILE_ID_FIELD: str(uid)},
|
||||
)
|
||||
assert content in resp.llm_facing_response
|
||||
|
||||
@patch(f"{TOOL_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TOOL_MODULE}.load_user_file")
|
||||
def test_respects_start_char_and_num_chars(
|
||||
self,
|
||||
mock_load_user_file: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
uid = uuid4()
|
||||
content = "abcdefghijklmnop"
|
||||
mock_load_user_file.return_value = _text_file(content)
|
||||
mock_get_session.return_value.__enter__.return_value = MagicMock()
|
||||
|
||||
tool = _make_tool(user_file_ids=[uid])
|
||||
resp = tool.run(
|
||||
placement=_PLACEMENT,
|
||||
override_kwargs=MagicMock(),
|
||||
**{FILE_ID_FIELD: str(uid), START_CHAR_FIELD: 4, NUM_CHARS_FIELD: 6},
|
||||
)
|
||||
assert "efghij" in resp.llm_facing_response
|
||||
|
||||
@patch(f"{TOOL_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TOOL_MODULE}.load_user_file")
|
||||
def test_clamps_num_chars_to_max(
|
||||
self,
|
||||
mock_load_user_file: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
uid = uuid4()
|
||||
content = "x" * (MAX_NUM_CHARS + 500)
|
||||
mock_load_user_file.return_value = _text_file(content)
|
||||
mock_get_session.return_value.__enter__.return_value = MagicMock()
|
||||
|
||||
tool = _make_tool(user_file_ids=[uid])
|
||||
resp = tool.run(
|
||||
placement=_PLACEMENT,
|
||||
override_kwargs=MagicMock(),
|
||||
**{FILE_ID_FIELD: str(uid), NUM_CHARS_FIELD: MAX_NUM_CHARS + 9999},
|
||||
)
|
||||
assert f"Characters 0-{MAX_NUM_CHARS}" in resp.llm_facing_response
|
||||
|
||||
@patch(f"{TOOL_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TOOL_MODULE}.load_user_file")
|
||||
def test_includes_continuation_hint(
|
||||
self,
|
||||
mock_load_user_file: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
uid = uuid4()
|
||||
content = "x" * 100
|
||||
mock_load_user_file.return_value = _text_file(content)
|
||||
mock_get_session.return_value.__enter__.return_value = MagicMock()
|
||||
|
||||
tool = _make_tool(user_file_ids=[uid])
|
||||
resp = tool.run(
|
||||
placement=_PLACEMENT,
|
||||
override_kwargs=MagicMock(),
|
||||
**{FILE_ID_FIELD: str(uid), NUM_CHARS_FIELD: 10},
|
||||
)
|
||||
assert "use start_char=10 to continue reading" in resp.llm_facing_response
|
||||
|
||||
def test_raises_on_missing_file_id(self) -> None:
|
||||
tool = _make_tool()
|
||||
with pytest.raises(ToolCallException, match="Missing required"):
|
||||
tool.run(
|
||||
placement=_PLACEMENT,
|
||||
override_kwargs=MagicMock(),
|
||||
)
|
||||
|
||||
@patch(f"{TOOL_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TOOL_MODULE}.load_user_file")
|
||||
def test_raises_on_non_text_file(
|
||||
self,
|
||||
mock_load_user_file: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
uid = uuid4()
|
||||
mock_load_user_file.return_value = InMemoryChatFile(
|
||||
file_id="img",
|
||||
content=b"\x89PNG",
|
||||
file_type=ChatFileType.IMAGE,
|
||||
filename="photo.png",
|
||||
)
|
||||
mock_get_session.return_value.__enter__.return_value = MagicMock()
|
||||
|
||||
tool = _make_tool(user_file_ids=[uid])
|
||||
with pytest.raises(ToolCallException, match="not a text file"):
|
||||
tool.run(
|
||||
placement=_PLACEMENT,
|
||||
override_kwargs=MagicMock(),
|
||||
**{FILE_ID_FIELD: str(uid)},
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# is_available()
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsAvailable:
|
||||
@patch(f"{TOOL_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
def test_available_when_vector_db_disabled(self) -> None:
|
||||
assert FileReaderTool.is_available(MagicMock()) is True
|
||||
|
||||
@patch(f"{TOOL_MODULE}.DISABLE_VECTOR_DB", False)
|
||||
def test_unavailable_when_vector_db_enabled(self) -> None:
|
||||
assert FileReaderTool.is_available(MagicMock()) is False
|
||||
@@ -16,15 +16,12 @@
|
||||
# This overlay:
|
||||
# - Moves Vespa (index), both model servers, and code-interpreter to profiles
|
||||
# so they do not start by default
|
||||
# - Moves the background worker to the "background" profile (the API server
|
||||
# handles all background work via FastAPI BackgroundTasks)
|
||||
# - Makes the depends_on references to removed services optional
|
||||
# - Sets DISABLE_VECTOR_DB=true on the api_server
|
||||
# - Makes the depends_on references to those services optional
|
||||
# - Sets DISABLE_VECTOR_DB=true on backend services
|
||||
#
|
||||
# To selectively bring services back:
|
||||
# --profile vectordb Vespa + indexing model server
|
||||
# --profile inference Inference model server
|
||||
# --profile background Background worker (Celery)
|
||||
# --profile code-interpreter Code interpreter
|
||||
# =============================================================================
|
||||
|
||||
@@ -46,20 +43,20 @@ services:
|
||||
- DISABLE_VECTOR_DB=true
|
||||
- FILE_STORE_BACKEND=postgres
|
||||
|
||||
# Move the background worker to a profile so it does not start by default.
|
||||
# The API server handles all background work in NO_VECTOR_DB mode.
|
||||
background:
|
||||
profiles: ["background"]
|
||||
depends_on:
|
||||
index:
|
||||
condition: service_started
|
||||
required: false
|
||||
inference_model_server:
|
||||
condition: service_started
|
||||
required: false
|
||||
indexing_model_server:
|
||||
condition: service_started
|
||||
required: false
|
||||
inference_model_server:
|
||||
condition: service_started
|
||||
required: false
|
||||
environment:
|
||||
- DISABLE_VECTOR_DB=true
|
||||
- FILE_STORE_BACKEND=postgres
|
||||
|
||||
# Move Vespa and indexing model server to a profile so they do not start.
|
||||
index:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_beat.replicaCount) 0) }}
|
||||
{{- if gt (int .Values.celery_beat.replicaCount) 0 }}
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_heavy.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
{{- if and (.Values.celery_worker_heavy.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
apiVersion: autoscaling/v2
|
||||
kind: HorizontalPodAutoscaler
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_heavy.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
{{- if and (.Values.celery_worker_heavy.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
apiVersion: keda.sh/v1alpha1
|
||||
kind: ScaledObject
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_heavy.replicaCount) 0) }}
|
||||
{{- if gt (int .Values.celery_worker_heavy.replicaCount) 0 }}
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_light.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
{{- if and (.Values.celery_worker_light.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
apiVersion: autoscaling/v2
|
||||
kind: HorizontalPodAutoscaler
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_light.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
{{- if and (.Values.celery_worker_light.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
apiVersion: keda.sh/v1alpha1
|
||||
kind: ScaledObject
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_light.replicaCount) 0) }}
|
||||
{{- if gt (int .Values.celery_worker_light.replicaCount) 0 }}
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_monitoring.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
{{- if and (.Values.celery_worker_monitoring.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
apiVersion: autoscaling/v2
|
||||
kind: HorizontalPodAutoscaler
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_monitoring.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
{{- if and (.Values.celery_worker_monitoring.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
apiVersion: keda.sh/v1alpha1
|
||||
kind: ScaledObject
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_monitoring.replicaCount) 0) }}
|
||||
{{- if gt (int .Values.celery_worker_monitoring.replicaCount) 0 }}
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_primary.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
{{- if and (.Values.celery_worker_primary.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
apiVersion: autoscaling/v2
|
||||
kind: HorizontalPodAutoscaler
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_primary.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
{{- if and (.Values.celery_worker_primary.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
apiVersion: keda.sh/v1alpha1
|
||||
kind: ScaledObject
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_primary.replicaCount) 0) }}
|
||||
{{- if gt (int .Values.celery_worker_primary.replicaCount) 0 }}
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_user_file_processing.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
{{- if and (.Values.celery_worker_user_file_processing.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
apiVersion: autoscaling/v2
|
||||
kind: HorizontalPodAutoscaler
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_user_file_processing.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
{{- if and (.Values.celery_worker_user_file_processing.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
apiVersion: keda.sh/v1alpha1
|
||||
kind: ScaledObject
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_user_file_processing.replicaCount) 0) }}
|
||||
{{- if gt (int .Values.celery_worker_user_file_processing.replicaCount) 0 }}
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
|
||||
@@ -28,9 +28,7 @@ postgresql:
|
||||
# -- Master toggle for vector database support. When false:
|
||||
# - Sets DISABLE_VECTOR_DB=true on all backend pods
|
||||
# - Skips the indexing model server deployment (embeddings not needed)
|
||||
# - Skips ALL celery worker deployments (beat, primary, light, heavy,
|
||||
# monitoring, user-file-processing, docprocessing, docfetching) — the
|
||||
# API server handles background work via FastAPI BackgroundTasks
|
||||
# - Skips docprocessing and docfetching celery workers
|
||||
# - You should also set vespa.enabled=false and opensearch.enabled=false
|
||||
# to prevent those subcharts from deploying
|
||||
vectorDB:
|
||||
|
||||
@@ -40,8 +40,6 @@ const TRAY_MENU_OPEN_APP_ID: &str = "tray_open_app";
|
||||
const TRAY_MENU_OPEN_CHAT_ID: &str = "tray_open_chat";
|
||||
const TRAY_MENU_SHOW_IN_BAR_ID: &str = "tray_show_in_menu_bar";
|
||||
const TRAY_MENU_QUIT_ID: &str = "tray_quit";
|
||||
const MENU_SHOW_MENU_BAR_ID: &str = "show_menu_bar";
|
||||
const MENU_HIDE_DECORATIONS_ID: &str = "hide_window_decorations";
|
||||
const CHAT_LINK_INTERCEPT_SCRIPT: &str = r##"
|
||||
(() => {
|
||||
if (window.__ONYX_CHAT_LINK_INTERCEPT_INSTALLED__) {
|
||||
@@ -173,92 +171,25 @@ const CHAT_LINK_INTERCEPT_SCRIPT: &str = r##"
|
||||
})();
|
||||
"##;
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
const MENU_KEY_HANDLER_SCRIPT: &str = r#"
|
||||
(() => {
|
||||
if (window.__ONYX_MENU_KEY_HANDLER__) return;
|
||||
window.__ONYX_MENU_KEY_HANDLER__ = true;
|
||||
|
||||
let altHeld = false;
|
||||
|
||||
function invoke(cmd) {
|
||||
const fn_ =
|
||||
window.__TAURI__?.core?.invoke || window.__TAURI_INTERNALS__?.invoke;
|
||||
if (typeof fn_ === 'function') fn_(cmd);
|
||||
}
|
||||
|
||||
function releaseAltAndHideMenu() {
|
||||
if (!altHeld) {
|
||||
return;
|
||||
}
|
||||
altHeld = false;
|
||||
invoke('hide_menu_bar_temporary');
|
||||
}
|
||||
|
||||
document.addEventListener('keydown', (e) => {
|
||||
if (e.key === 'Alt') {
|
||||
if (!altHeld) {
|
||||
altHeld = true;
|
||||
invoke('show_menu_bar_temporarily');
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (e.altKey && e.key === 'F1') {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
altHeld = false;
|
||||
invoke('toggle_menu_bar');
|
||||
return;
|
||||
}
|
||||
}, true);
|
||||
|
||||
document.addEventListener('keyup', (e) => {
|
||||
if (e.key === 'Alt' && altHeld) {
|
||||
releaseAltAndHideMenu();
|
||||
}
|
||||
}, true);
|
||||
|
||||
window.addEventListener('blur', () => {
|
||||
releaseAltAndHideMenu();
|
||||
});
|
||||
|
||||
document.addEventListener('visibilitychange', () => {
|
||||
if (document.hidden) {
|
||||
releaseAltAndHideMenu();
|
||||
}
|
||||
});
|
||||
})();
|
||||
"#;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AppConfig {
|
||||
/// The Onyx server URL (default: https://cloud.onyx.app)
|
||||
pub server_url: String,
|
||||
|
||||
/// Optional: Custom window title
|
||||
#[serde(default = "default_window_title")]
|
||||
pub window_title: String,
|
||||
|
||||
#[serde(default = "default_show_menu_bar")]
|
||||
pub show_menu_bar: bool,
|
||||
|
||||
#[serde(default)]
|
||||
pub hide_window_decorations: bool,
|
||||
}
|
||||
|
||||
fn default_window_title() -> String {
|
||||
"Onyx".to_string()
|
||||
}
|
||||
|
||||
fn default_show_menu_bar() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
impl Default for AppConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
server_url: DEFAULT_SERVER_URL.to_string(),
|
||||
window_title: default_window_title(),
|
||||
show_menu_bar: true,
|
||||
hide_window_decorations: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -316,7 +247,6 @@ struct ConfigState {
|
||||
config: RwLock<AppConfig>,
|
||||
config_initialized: RwLock<bool>,
|
||||
app_base_url: RwLock<Option<Url>>,
|
||||
menu_temporarily_visible: RwLock<bool>,
|
||||
}
|
||||
|
||||
fn focus_main_window(app: &AppHandle) {
|
||||
@@ -371,7 +301,6 @@ fn trigger_new_window(app: &AppHandle) {
|
||||
inject_titlebar(window.clone());
|
||||
}
|
||||
|
||||
apply_settings_to_window(&handle, &window);
|
||||
let _ = window.set_focus();
|
||||
}
|
||||
});
|
||||
@@ -648,15 +577,18 @@ async fn new_window(app: AppHandle, state: tauri::State<'_, ConfigState>) -> Res
|
||||
#[cfg(target_os = "linux")]
|
||||
let builder = builder.background_color(tauri::window::Color(0x1a, 0x1a, 0x2e, 0xff));
|
||||
|
||||
let window = builder.build().map_err(|e| e.to_string())?;
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
{
|
||||
let window = builder.build().map_err(|e| e.to_string())?;
|
||||
// Apply vibrancy effect and inject titlebar
|
||||
let _ = apply_vibrancy(&window, NSVisualEffectMaterial::Sidebar, None, None);
|
||||
inject_titlebar(window.clone());
|
||||
}
|
||||
|
||||
apply_settings_to_window(&app, &window);
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
{
|
||||
let _window = builder.build().map_err(|e| e.to_string())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -692,142 +624,6 @@ async fn start_drag_window(window: tauri::Window) -> Result<(), String> {
|
||||
window.start_dragging().map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Window Settings
|
||||
// ============================================================================
|
||||
|
||||
fn find_check_menu_item(
|
||||
app: &AppHandle,
|
||||
id: &str,
|
||||
) -> Option<CheckMenuItem<tauri::Wry>> {
|
||||
let menu = app.menu()?;
|
||||
for item in menu.items().ok()? {
|
||||
if let Some(submenu) = item.as_submenu() {
|
||||
for sub_item in submenu.items().ok()? {
|
||||
if let Some(check) = sub_item.as_check_menuitem() {
|
||||
if check.id().as_ref() == id {
|
||||
return Some(check.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn apply_settings_to_window(app: &AppHandle, window: &tauri::WebviewWindow) {
|
||||
if cfg!(target_os = "macos") {
|
||||
return;
|
||||
}
|
||||
let state = app.state::<ConfigState>();
|
||||
let config = state.config.read().unwrap();
|
||||
let temp_visible = *state.menu_temporarily_visible.read().unwrap();
|
||||
if !config.show_menu_bar && !temp_visible {
|
||||
let _ = window.hide_menu();
|
||||
}
|
||||
if config.hide_window_decorations {
|
||||
let _ = window.set_decorations(false);
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_menu_bar_toggle(app: &AppHandle) {
|
||||
if cfg!(target_os = "macos") {
|
||||
return;
|
||||
}
|
||||
let state = app.state::<ConfigState>();
|
||||
let show = {
|
||||
let mut config = state.config.write().unwrap();
|
||||
config.show_menu_bar = !config.show_menu_bar;
|
||||
let _ = save_config(&config);
|
||||
config.show_menu_bar
|
||||
};
|
||||
|
||||
*state.menu_temporarily_visible.write().unwrap() = false;
|
||||
|
||||
for (_, window) in app.webview_windows() {
|
||||
if show {
|
||||
let _ = window.show_menu();
|
||||
} else {
|
||||
let _ = window.hide_menu();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_decorations_toggle(app: &AppHandle) {
|
||||
if cfg!(target_os = "macos") {
|
||||
return;
|
||||
}
|
||||
let state = app.state::<ConfigState>();
|
||||
let hide = {
|
||||
let mut config = state.config.write().unwrap();
|
||||
config.hide_window_decorations = !config.hide_window_decorations;
|
||||
let _ = save_config(&config);
|
||||
config.hide_window_decorations
|
||||
};
|
||||
|
||||
for (_, window) in app.webview_windows() {
|
||||
let _ = window.set_decorations(!hide);
|
||||
}
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
fn toggle_menu_bar(app: AppHandle) {
|
||||
if cfg!(target_os = "macos") {
|
||||
return;
|
||||
}
|
||||
handle_menu_bar_toggle(&app);
|
||||
|
||||
let state = app.state::<ConfigState>();
|
||||
let checked = state.config.read().unwrap().show_menu_bar;
|
||||
if let Some(check) = find_check_menu_item(&app, MENU_SHOW_MENU_BAR_ID) {
|
||||
let _ = check.set_checked(checked);
|
||||
}
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
fn show_menu_bar_temporarily(app: AppHandle) {
|
||||
if cfg!(target_os = "macos") {
|
||||
return;
|
||||
}
|
||||
let state = app.state::<ConfigState>();
|
||||
if state.config.read().unwrap().show_menu_bar {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut temp = state.menu_temporarily_visible.write().unwrap();
|
||||
if *temp {
|
||||
return;
|
||||
}
|
||||
*temp = true;
|
||||
drop(temp);
|
||||
|
||||
for (_, window) in app.webview_windows() {
|
||||
let _ = window.show_menu();
|
||||
}
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
fn hide_menu_bar_temporary(app: AppHandle) {
|
||||
if cfg!(target_os = "macos") {
|
||||
return;
|
||||
}
|
||||
let state = app.state::<ConfigState>();
|
||||
let mut temp = state.menu_temporarily_visible.write().unwrap();
|
||||
if !*temp {
|
||||
return;
|
||||
}
|
||||
*temp = false;
|
||||
drop(temp);
|
||||
|
||||
if state.config.read().unwrap().show_menu_bar {
|
||||
return;
|
||||
}
|
||||
|
||||
for (_, window) in app.webview_windows() {
|
||||
let _ = window.hide_menu();
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Menu Setup
|
||||
// ============================================================================
|
||||
@@ -871,59 +667,6 @@ fn setup_app_menu(app: &AppHandle) -> tauri::Result<()> {
|
||||
menu.prepend(&file_menu)?;
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
{
|
||||
let config = app.state::<ConfigState>();
|
||||
let config_guard = config.config.read().unwrap();
|
||||
|
||||
let show_menu_bar_item = CheckMenuItem::with_id(
|
||||
app,
|
||||
MENU_SHOW_MENU_BAR_ID,
|
||||
"Show Menu Bar",
|
||||
true,
|
||||
config_guard.show_menu_bar,
|
||||
None::<&str>,
|
||||
)?;
|
||||
|
||||
let hide_decorations_item = CheckMenuItem::with_id(
|
||||
app,
|
||||
MENU_HIDE_DECORATIONS_ID,
|
||||
"Hide Window Decorations",
|
||||
true,
|
||||
config_guard.hide_window_decorations,
|
||||
None::<&str>,
|
||||
)?;
|
||||
|
||||
drop(config_guard);
|
||||
|
||||
if let Some(window_menu) = menu
|
||||
.items()?
|
||||
.into_iter()
|
||||
.filter_map(|item| item.as_submenu().cloned())
|
||||
.find(|submenu| submenu.text().ok().as_deref() == Some("Window"))
|
||||
{
|
||||
window_menu.append(&show_menu_bar_item)?;
|
||||
window_menu.append(&hide_decorations_item)?;
|
||||
} else {
|
||||
let window_menu = SubmenuBuilder::new(app, "Window")
|
||||
.item(&show_menu_bar_item)
|
||||
.item(&hide_decorations_item)
|
||||
.build()?;
|
||||
|
||||
let items = menu.items()?;
|
||||
let help_idx = items
|
||||
.iter()
|
||||
.position(|item| {
|
||||
item.as_submenu()
|
||||
.and_then(|s| s.text().ok())
|
||||
.as_deref()
|
||||
== Some("Help")
|
||||
})
|
||||
.unwrap_or(items.len());
|
||||
menu.insert(&window_menu, help_idx)?;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(help_menu) = menu
|
||||
.get(HELP_SUBMENU_ID)
|
||||
.and_then(|item| item.as_submenu().cloned())
|
||||
@@ -1058,7 +801,6 @@ fn main() {
|
||||
config: RwLock::new(config),
|
||||
config_initialized: RwLock::new(config_initialized),
|
||||
app_base_url: RwLock::new(None),
|
||||
menu_temporarily_visible: RwLock::new(false),
|
||||
})
|
||||
.invoke_handler(tauri::generate_handler![
|
||||
get_server_url,
|
||||
@@ -1074,18 +816,13 @@ fn main() {
|
||||
go_forward,
|
||||
new_window,
|
||||
reset_config,
|
||||
start_drag_window,
|
||||
toggle_menu_bar,
|
||||
show_menu_bar_temporarily,
|
||||
hide_menu_bar_temporary
|
||||
start_drag_window
|
||||
])
|
||||
.on_menu_event(|app, event| match event.id().as_ref() {
|
||||
"open_docs" => open_docs(),
|
||||
"new_chat" => trigger_new_chat(app),
|
||||
"new_window" => trigger_new_window(app),
|
||||
"open_settings" => open_settings(app),
|
||||
"show_menu_bar" => handle_menu_bar_toggle(app),
|
||||
"hide_window_decorations" => handle_decorations_toggle(app),
|
||||
_ => {}
|
||||
})
|
||||
.setup(move |app| {
|
||||
@@ -1118,8 +855,6 @@ fn main() {
|
||||
#[cfg(target_os = "macos")]
|
||||
inject_titlebar(window.clone());
|
||||
|
||||
apply_settings_to_window(&app_handle, &window);
|
||||
|
||||
let _ = window.set_focus();
|
||||
}
|
||||
|
||||
@@ -1128,27 +863,7 @@ fn main() {
|
||||
.on_page_load(|webview: &Webview, _payload: &PageLoadPayload| {
|
||||
inject_chat_link_intercept(webview);
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
{
|
||||
let _ = webview.eval(MENU_KEY_HANDLER_SCRIPT);
|
||||
|
||||
let app = webview.app_handle();
|
||||
let state = app.state::<ConfigState>();
|
||||
let config = state.config.read().unwrap();
|
||||
let temp_visible = *state.menu_temporarily_visible.read().unwrap();
|
||||
let label = webview.label().to_string();
|
||||
if !config.show_menu_bar && !temp_visible {
|
||||
if let Some(win) = app.get_webview_window(&label) {
|
||||
let _ = win.hide_menu();
|
||||
}
|
||||
}
|
||||
if config.hide_window_decorations {
|
||||
if let Some(win) = app.get_webview_window(&label) {
|
||||
let _ = win.set_decorations(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Re-inject titlebar after every navigation/page load (macOS only)
|
||||
#[cfg(target_os = "macos")]
|
||||
let _ = webview.eval(TITLEBAR_SCRIPT);
|
||||
})
|
||||
|
||||
@@ -191,40 +191,6 @@ model_server = [
|
||||
"sentry-sdk[fastapi,celery,starlette]==2.14.0",
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
plugins = "sqlalchemy.ext.mypy.plugin"
|
||||
mypy_path = "backend"
|
||||
explicit_package_bases = true
|
||||
disallow_untyped_defs = true
|
||||
warn_unused_ignores = true
|
||||
enable_error_code = ["possibly-undefined"]
|
||||
strict_equality = true
|
||||
# Patterns match paths whether mypy is run from backend/ (CI) or repo root (e.g. VS Code extension with target ./backend)
|
||||
exclude = [
|
||||
"(?:^|/)generated/",
|
||||
"(?:^|/)\\.venv/",
|
||||
"(?:^|/)onyx/server/features/build/sandbox/kubernetes/docker/skills/",
|
||||
"(?:^|/)onyx/server/features/build/sandbox/kubernetes/docker/templates/",
|
||||
]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "alembic.versions.*"
|
||||
disable_error_code = ["var-annotated"]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "alembic_tenants.versions.*"
|
||||
disable_error_code = ["var-annotated"]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "generated.*"
|
||||
follow_imports = "silent"
|
||||
ignore_errors = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "transformers.*"
|
||||
follow_imports = "skip"
|
||||
ignore_errors = true
|
||||
|
||||
[tool.uv.workspace]
|
||||
members = ["backend", "tools/ods"]
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ def generate_schema(output_path: str, tagged_for_docs: str | None = None) -> boo
|
||||
try:
|
||||
# Import here to avoid requiring backend dependencies when not generating schema
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from onyx.main import app as app_fn
|
||||
from onyx.main import app as app_fn # type: ignore
|
||||
except ImportError as e:
|
||||
print(f"Error: Failed to import required modules: {e}", file=sys.stderr)
|
||||
print(
|
||||
|
||||
6
uv.lock
generated
6
uv.lock
generated
@@ -3426,14 +3426,14 @@ html-clean = [
|
||||
|
||||
[[package]]
|
||||
name = "lxml-html-clean"
|
||||
version = "0.4.4"
|
||||
version = "0.4.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "lxml" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/9a/a4/5c62acfacd69ff4f5db395100f5cfb9b54e7ac8c69a235e4e939fd13f021/lxml_html_clean-0.4.4.tar.gz", hash = "sha256:58f39a9d632711202ed1d6d0b9b47a904e306c85de5761543b90e3e3f736acfb", size = 23899, upload-time = "2026-02-27T09:35:52.911Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d9/cb/c9c5bb2a9c47292e236a808dd233a03531f53b626f36259dcd32b49c76da/lxml_html_clean-0.4.3.tar.gz", hash = "sha256:c9df91925b00f836c807beab127aac82575110eacff54d0a75187914f1bd9d8c", size = 21498, upload-time = "2025-10-02T20:49:24.895Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d9/76/7ffc1d3005cf7749123bc47cb3ea343cd97b0ac2211bab40f57283577d0e/lxml_html_clean-0.4.4-py3-none-any.whl", hash = "sha256:ce2ef506614ecb85ee1c5fe0a2aa45b06a19514ec7949e9c8f34f06925cfabcb", size = 14565, upload-time = "2026-02-27T09:35:51.86Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/10/4a/63a9540e3ca73709f4200564a737d63a4c8c9c4dd032bab8535f507c190a/lxml_html_clean-0.4.3-py3-none-any.whl", hash = "sha256:63fd7b0b9c3a2e4176611c2ca5d61c4c07ffca2de76c14059a81a2825833731e", size = 14177, upload-time = "2025-10-02T20:49:23.749Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgCreditCard = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M14.6667 6V4.00008C14.6667 3.26675 14.0667 2.66675 13.3333 2.66675H2.66668C1.93334 2.66675 1.33334 3.26675 1.33334 4.00008V6M14.6667 6V12.0001C14.6667 12.7334 14.0667 13.3334 13.3333 13.3334H2.66668C1.93334 13.3334 1.33334 12.7334 1.33334 12.0001V6M14.6667 6H1.33334"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgCreditCard;
|
||||
@@ -51,7 +51,6 @@ export { default as SvgCode } from "@opal/icons/code";
|
||||
export { default as SvgCopy } from "@opal/icons/copy";
|
||||
export { default as SvgCornerRightUpDot } from "@opal/icons/corner-right-up-dot";
|
||||
export { default as SvgCpu } from "@opal/icons/cpu";
|
||||
export { default as SvgCreditCard } from "@opal/icons/credit-card";
|
||||
export { default as SvgDashboard } from "@opal/icons/dashboard";
|
||||
export { default as SvgDevKit } from "@opal/icons/dev-kit";
|
||||
export { default as SvgDownload } from "@opal/icons/download";
|
||||
@@ -107,7 +106,6 @@ export { default as SvgMinusCircle } from "@opal/icons/minus-circle";
|
||||
export { default as SvgMoon } from "@opal/icons/moon";
|
||||
export { default as SvgMoreHorizontal } from "@opal/icons/more-horizontal";
|
||||
export { default as SvgMusicSmall } from "@opal/icons/music-small";
|
||||
export { default as SvgNetworkGraph } from "@opal/icons/network-graph";
|
||||
export { default as SvgNotificationBubble } from "@opal/icons/notification-bubble";
|
||||
export { default as SvgOllama } from "@opal/icons/ollama";
|
||||
export { default as SvgOnyxLogo } from "@opal/icons/onyx-logo";
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgNetworkGraph = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<g clipPath="url(#clip0_2828_22555)">
|
||||
<path
|
||||
d="M9.23744 4.48744C9.92086 3.80402 9.92086 2.69598 9.23744 2.01256C8.55402 1.32915 7.44598 1.32915 6.76256 2.01256C6.07915 2.69598 6.07915 3.80402 6.76256 4.48744M9.23744 4.48744C8.89573 4.82915 8.44787 5 8 5M9.23744 4.48744L11.7626 8.01256M6.76256 4.48744C7.10427 4.82915 7.55214 5 8 5M6.76256 4.48744L4.23744 8.01256M8 11C7.0335 11 6.25001 11.7835 6.25001 12.75C6.25001 13.7165 7.03351 14.5 8.00001 14.5C8.9665 14.5 9.75 13.7165 9.75 12.75C9.75 11.7835 8.9665 11 8 11ZM8 11V5M4.23744 8.01256C4.92085 8.69598 4.92422 9.81658 4.2408 10.5C3.55739 11.1834 2.44598 11.1709 1.76256 10.4874C1.07915 9.80402 1.07915 8.69598 1.76256 8.01256C2.44598 7.32915 3.55402 7.32915 4.23744 8.01256ZM11.7626 8.01256C11.0791 8.69598 11.0791 9.80402 11.7626 10.4874C12.446 11.1709 13.554 11.1709 14.2374 10.4874C14.9209 9.80402 14.9209 8.69598 14.2374 8.01256C13.554 7.32915 12.446 7.32915 11.7626 8.01256Z"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</g>
|
||||
<defs>
|
||||
<clipPath id="clip0_2828_22555">
|
||||
<rect width="16" height="16" fill="white" />
|
||||
</clipPath>
|
||||
</defs>
|
||||
</svg>
|
||||
);
|
||||
export default SvgNetworkGraph;
|
||||
@@ -5,7 +5,7 @@ import type { SizeVariant } from "@opal/shared";
|
||||
import SvgEdit from "@opal/icons/edit";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import { cn } from "@opal/utils";
|
||||
import { useState } from "react";
|
||||
import { useRef, useState } from "react";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
@@ -89,6 +89,7 @@ function ContentLg({
|
||||
}: ContentLgProps) {
|
||||
const [editing, setEditing] = useState(false);
|
||||
const [editValue, setEditValue] = useState(title);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const config = CONTENT_LG_PRESETS[sizePreset];
|
||||
|
||||
@@ -130,6 +131,7 @@ function ContentLg({
|
||||
{editValue || "\u00A0"}
|
||||
</span>
|
||||
<input
|
||||
ref={inputRef}
|
||||
className={cn(
|
||||
"opal-content-lg-input",
|
||||
config.titleFont,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user