Compare commits

..

53 Commits

Author SHA1 Message Date
Dane Urban
15c828d0e2 Add default code interpreter base url 2026-03-02 23:51:55 -08:00
Evan Lohn
6aa56821d6 feat: use new cache backend where appropriate (#8889) 2026-03-03 07:14:39 +00:00
Danelegend
eda436de01 fix: Block deleting default provider (#8962) 2026-03-03 06:36:29 +00:00
Danelegend
07915a6c01 fix: Update frontend route calls to use new endpoints (#8968) 2026-03-03 06:19:47 +00:00
Nikolas Garza
2c3e9aecd1 fix(scim): only list SCIM-managed users and link pre-existing users (#8959) 2026-03-03 05:36:43 +00:00
Evan Lohn
fa29cc3849 feat: postgres cache backend (#8879) 2026-03-03 04:33:47 +00:00
Raunak Bhagat
24ac8b37d3 refactor(fe): define settings layout width presets as CSS variables (#8936) 2026-03-03 03:11:18 +00:00
Jessica Singh
be8b108ae4 chore(auth): ecs fargate deployment cleanup (#8589) 2026-03-03 02:34:04 +00:00
Danelegend
f380a75df3 fix: Non-intuitive llm auth exceptions (#8960) 2026-03-03 01:58:45 +00:00
Wenxi
21ec93663b chore: proxy cloud ph (#8961) 2026-03-03 01:43:15 +00:00
Raunak Bhagat
d789c74024 chore(icons): add SvgBookmark to @opal/icons (#8933) 2026-03-02 17:07:24 -08:00
Danelegend
fe014776f7 feat: embed code interpreter images in chat (#8875) 2026-03-03 00:57:56 +00:00
Danelegend
700ca0e0fc fix: Sticky background in CSV Preview Variant (#8939) 2026-03-03 00:08:17 +00:00
Jamison Lahman
a84f8238ec chore(fe): space between Manage All connectors button (#8938) 2026-03-02 23:56:08 +00:00
dependabot[bot]
4fc802e19d chore(deps): bump pypdf from 6.7.4 to 6.7.5 (#8932)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-02 23:55:34 +00:00
Danelegend
6cfd49439a chore: Bump code interpreter to 0.3.1 (#8937) 2026-03-02 23:49:58 +00:00
Jamison Lahman
71a1faa47e fix(fe): break long words in human messages (#8929)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-03-02 15:47:35 -08:00
Nikolas Garza
1a65217baf fix(scim): pass Okta Runscope spec test for OIN submission (#8925) 2026-03-02 23:03:38 +00:00
dependabot[bot]
30fa43b5fc chore(deps): bump pypdf from 6.7.3 to 6.7.4 (#8905)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-02 22:48:54 +00:00
Justin Tahara
28332fa24b fix(ui): InputComboBox search for users/groups (#8928) 2026-03-02 22:44:09 +00:00
Raunak Bhagat
1f5050f9f6 refactor(admin): update admin-page HealthCheckBanner (#8922) 2026-03-02 22:21:47 +00:00
acaprau
3c1d29d3cf chore(opensearch): Configure index settings for multitenant cloud (#8921) 2026-03-02 22:16:05 +00:00
Raunak Bhagat
709e3f4ca7 chore(icons): add SvgCreditCard and SvgNetworkGraph to @opal/icons (#8927) 2026-03-02 22:04:36 +00:00
Jamison Lahman
dfa27c08ef chore(deployment): optimize layer caching (#8924) 2026-03-02 20:58:46 +00:00
Nikolas Garza
13d60dcb0e test(scim): add integration tests for SCIM group CRUD (#8830) 2026-03-02 20:51:59 +00:00
Evan Lohn
30704f427f refactor: add abstraction for cache backend (#8870) 2026-03-02 20:50:13 +00:00
Jamison Lahman
4f3c54f282 chore(playwright): hide actions toolbar buttons in screenshots (#8914) 2026-03-02 20:10:21 +00:00
Jamison Lahman
580d41dc23 chore(mypy): run from repro root in CI (#6995) 2026-03-02 20:04:54 +00:00
Raunak Bhagat
897e181d67 refactor(opal): update ModalHeader to use Content (#8885) 2026-03-02 20:04:35 +00:00
dependabot[bot]
fd322a8a10 chore(deps): bump lxml-html-clean from 0.4.3 to 0.4.4 (#8919)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-02 20:03:36 +00:00
Evan Lohn
11c54bafb5 chore: no vector db deployment (#8867) 2026-03-02 20:01:26 +00:00
Nikolas Garza
c93617df5d test(scim): add integration tests for SCIM user CRUD (#8825) 2026-03-02 19:38:33 +00:00
Justin Tahara
0cdd438f46 chore(ui): Update the Share Agent Modal (#8915) 2026-03-02 19:28:49 +00:00
Justin Tahara
31aef36f78 chore(llm): Use AWS Secrrets Manager (#8913) 2026-03-02 19:28:43 +00:00
Jamison Lahman
0c35dfc0e4 fix(search): re-sync search/chat preference on user data load (#8868) 2026-03-02 18:59:02 +00:00
Nikolas Garza
a9769757fe fix(llm): enforce persona restrictions on public LLM providers (#8846)
Co-authored-by: Dane <dane@onyx.app>
2026-03-02 18:20:03 +00:00
Nikolas Garza
15d8946f40 refactor(fe): rename assistant → agent identifiers (#8869) 2026-03-02 18:19:23 +00:00
Nikolas Garza
ba79539d6d feat(slack): add Slack user deactivation and seat-aware reactivation (#8887) 2026-03-02 18:10:13 +00:00
Jamison Lahman
59d3725fc6 chore(gha): rm docker-compose.opensearch.yml ref (#8912) 2026-03-02 17:34:22 +00:00
Jamison Lahman
9c05bd215d fix(a11y): settings popover buttons prefer href (#8880)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-03-02 17:22:01 +00:00
Wenxi
4d2aa09654 feat: infinite chat session sidebar scroll (#8874) 2026-03-02 17:20:23 +00:00
Jamison Lahman
16c07c8756 feat(desktop): option to hide alt-menu (#8882)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-03-02 17:03:24 +00:00
Raunak Bhagat
3fb4f5d6e6 refactor(opal): split ContentLg into ContentXl + ContentLg (#8904) 2026-03-02 15:11:32 +00:00
Evan Lohn
14fab7fcdf feat: no vector db beat tasks (#8865) 2026-03-02 03:51:18 +00:00
Evan Lohn
22a335fffa feat: bg tasks via fastapi (#8861) 2026-03-02 02:35:31 +00:00
Justin Tahara
b0f7466eba chore(llm): Fixing test image selection (#8902) 2026-03-01 18:32:20 +00:00
Evan Lohn
b1d42726b1 test: file reader tool (#8856) 2026-02-28 23:09:43 +00:00
Yuhong Sun
7d922bffc1 chore: Persona cleanup (#8810) 2026-02-28 21:34:45 +00:00
Evan Lohn
de7fc36fc5 test: no vector db user file processing (#8854) 2026-02-28 04:19:59 +00:00
Evan Lohn
7f9e37450d fix: non vector db tasks (#8849) 2026-02-28 03:51:57 +00:00
Evan Lohn
c7ef85b733 chore: narrow no_vector_db supported scope (#8847) 2026-02-28 02:54:15 +00:00
Danelegend
bd9319e592 feat: LLM Provider Rework (#8761)
Co-authored-by: Raunak Bhagat <r@rabh.io>
2026-02-28 01:29:49 +00:00
Nikolas Garza
db5955d6f2 fix(ee): show Access Restricted page when seat limit exceeded (#8877) 2026-02-28 01:26:00 +00:00
420 changed files with 13110 additions and 8221 deletions

View File

@@ -54,6 +54,7 @@ runs:
shell: bash
env:
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
INTEGRATION_REPOSITORY: ${{ inputs.runs-on-ecr-cache }}
TAG: nightly-llm-it-${{ inputs.run-id }}
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
HEAD_SHA: ${{ inputs.github-sha }}

View File

@@ -426,8 +426,9 @@ jobs:
ONYX_VERSION=${{ github.ref_name }}
NODE_OPTIONS=--max-old-space-size=8192
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64,mode=max
@@ -499,8 +500,9 @@ jobs:
ONYX_VERSION=${{ github.ref_name }}
NODE_OPTIONS=--max-old-space-size=8192
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64,mode=max
@@ -646,8 +648,8 @@ jobs:
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
NODE_OPTIONS=--max-old-space-size=8192
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64,mode=max
@@ -728,8 +730,8 @@ jobs:
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
NODE_OPTIONS=--max-old-space-size=8192
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64,mode=max
@@ -862,8 +864,9 @@ jobs:
build-args: |
ONYX_VERSION=${{ github.ref_name }}
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64,mode=max
@@ -934,8 +937,9 @@ jobs:
build-args: |
ONYX_VERSION=${{ github.ref_name }}
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64,mode=max
@@ -1072,8 +1076,8 @@ jobs:
ONYX_VERSION=${{ github.ref_name }}
ENABLE_CRAFT=true
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64,mode=max
@@ -1145,8 +1149,8 @@ jobs:
ONYX_VERSION=${{ github.ref_name }}
ENABLE_CRAFT=true
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64,mode=max
@@ -1287,8 +1291,9 @@ jobs:
build-args: |
ONYX_VERSION=${{ github.ref_name }}
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64,mode=max
@@ -1366,8 +1371,9 @@ jobs:
build-args: |
ONYX_VERSION=${{ github.ref_name }}
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64,mode=max

View File

@@ -15,6 +15,9 @@ permissions:
jobs:
provider-chat-test:
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
permissions:
contents: read
id-token: write
with:
openai_models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
anthropic_models: ${{ vars.NIGHTLY_LLM_ANTHROPIC_MODELS }}
@@ -25,16 +28,6 @@ jobs:
ollama_models: ${{ vars.NIGHTLY_LLM_OLLAMA_MODELS }}
openrouter_models: ${{ vars.NIGHTLY_LLM_OPENROUTER_MODELS }}
strict: true
secrets:
openai_api_key: ${{ secrets.OPENAI_API_KEY }}
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
bedrock_api_key: ${{ secrets.BEDROCK_API_KEY }}
vertex_ai_custom_config_json: ${{ secrets.NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON }}
azure_api_key: ${{ secrets.AZURE_API_KEY }}
ollama_api_key: ${{ secrets.OLLAMA_API_KEY }}
openrouter_api_key: ${{ secrets.OPENROUTER_API_KEY }}
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
notify-slack-on-failure:
needs: [provider-chat-test]

View File

@@ -114,8 +114,10 @@ jobs:
- name: Mark workflow as failed if cherry-pick failed
if: steps.gate.outputs.should_cherrypick == 'true' && steps.run_cherry_pick.outputs.status == 'failure'
env:
CHERRY_PICK_REASON: ${{ steps.run_cherry_pick.outputs.reason }}
run: |
echo "::error::Automated cherry-pick failed (${{ steps.run_cherry_pick.outputs.reason }})."
echo "::error::Automated cherry-pick failed (${CHERRY_PICK_REASON})."
exit 1
notify-slack-on-cherry-pick-failure:

View File

@@ -160,7 +160,7 @@ jobs:
cd deployment/docker_compose
# Get list of running containers
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml -f docker-compose.opensearch.yml ps -q)
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml ps -q)
# Collect logs from each container
for container in $containers; do

View File

@@ -8,7 +8,7 @@ on:
pull_request:
branches:
- main
- 'release/**'
- "release/**"
push:
tags:
- "v*.*.*"
@@ -21,7 +21,13 @@ jobs:
# See https://runs-on.com/runners/linux/
# Note: Mypy seems quite optimized for x64 compared to arm64.
# Similarly, mypy is single-threaded and incremental, so 2cpu is sufficient.
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}-mypy-check", "extras=s3-cache"]
runs-on:
[
runs-on,
runner=2cpu-linux-x64,
"run-id=${{ github.run_id }}-mypy-check",
"extras=s3-cache",
]
timeout-minutes: 45
steps:
@@ -52,21 +58,14 @@ jobs:
if: ${{ vars.DISABLE_MYPY_CACHE != 'true' }}
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
with:
path: backend/.mypy_cache
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
path: .mypy_cache
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'pyproject.toml') }}
restore-keys: |
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
mypy-${{ runner.os }}-
- name: Run MyPy
working-directory: ./backend
env:
MYPY_FORCE_COLOR: 1
TERM: xterm-256color
run: mypy .
- name: Run MyPy (tools/)
env:
MYPY_FORCE_COLOR: 1
TERM: xterm-256color
run: mypy tools/

View File

@@ -48,28 +48,10 @@ on:
required: false
default: true
type: boolean
secrets:
openai_api_key:
required: false
anthropic_api_key:
required: false
bedrock_api_key:
required: false
vertex_ai_custom_config_json:
required: false
azure_api_key:
required: false
ollama_api_key:
required: false
openrouter_api_key:
required: false
DOCKER_USERNAME:
required: true
DOCKER_TOKEN:
required: true
permissions:
contents: read
id-token: write
jobs:
build-backend-image:
@@ -81,6 +63,7 @@ jobs:
"extras=ecr-cache",
]
timeout-minutes: 45
environment: ci-protected
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
@@ -89,6 +72,19 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, test/docker-username
DOCKER_TOKEN, test/docker-token
- name: Build backend image
uses: ./.github/actions/build-backend-image
with:
@@ -97,8 +93,8 @@ jobs:
pr-number: ${{ github.event.pull_request.number }}
github-sha: ${{ github.sha }}
run-id: ${{ github.run_id }}
docker-username: ${{ secrets.DOCKER_USERNAME }}
docker-token: ${{ secrets.DOCKER_TOKEN }}
docker-username: ${{ env.DOCKER_USERNAME }}
docker-token: ${{ env.DOCKER_TOKEN }}
docker-no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' && 'true' || 'false' }}
build-model-server-image:
@@ -110,6 +106,7 @@ jobs:
"extras=ecr-cache",
]
timeout-minutes: 45
environment: ci-protected
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
@@ -118,6 +115,19 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, test/docker-username
DOCKER_TOKEN, test/docker-token
- name: Build model server image
uses: ./.github/actions/build-model-server-image
with:
@@ -126,8 +136,8 @@ jobs:
pr-number: ${{ github.event.pull_request.number }}
github-sha: ${{ github.sha }}
run-id: ${{ github.run_id }}
docker-username: ${{ secrets.DOCKER_USERNAME }}
docker-token: ${{ secrets.DOCKER_TOKEN }}
docker-username: ${{ env.DOCKER_USERNAME }}
docker-token: ${{ env.DOCKER_TOKEN }}
build-integration-image:
runs-on:
@@ -138,6 +148,7 @@ jobs:
"extras=ecr-cache",
]
timeout-minutes: 45
environment: ci-protected
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
@@ -146,6 +157,19 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, test/docker-username
DOCKER_TOKEN, test/docker-token
- name: Build integration image
uses: ./.github/actions/build-integration-image
with:
@@ -154,8 +178,8 @@ jobs:
pr-number: ${{ github.event.pull_request.number }}
github-sha: ${{ github.sha }}
run-id: ${{ github.run_id }}
docker-username: ${{ secrets.DOCKER_USERNAME }}
docker-token: ${{ secrets.DOCKER_TOKEN }}
docker-username: ${{ env.DOCKER_USERNAME }}
docker-token: ${{ env.DOCKER_TOKEN }}
provider-chat-test:
needs:
@@ -170,56 +194,56 @@ jobs:
include:
- provider: openai
models: ${{ inputs.openai_models }}
api_key_secret: openai_api_key
custom_config_secret: ""
api_key_env: OPENAI_API_KEY
custom_config_env: ""
api_base: ""
api_version: ""
deployment_name: ""
required: true
- provider: anthropic
models: ${{ inputs.anthropic_models }}
api_key_secret: anthropic_api_key
custom_config_secret: ""
api_key_env: ANTHROPIC_API_KEY
custom_config_env: ""
api_base: ""
api_version: ""
deployment_name: ""
required: true
- provider: bedrock
models: ${{ inputs.bedrock_models }}
api_key_secret: bedrock_api_key
custom_config_secret: ""
api_key_env: BEDROCK_API_KEY
custom_config_env: ""
api_base: ""
api_version: ""
deployment_name: ""
required: false
- provider: vertex_ai
models: ${{ inputs.vertex_ai_models }}
api_key_secret: ""
custom_config_secret: vertex_ai_custom_config_json
api_key_env: ""
custom_config_env: NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON
api_base: ""
api_version: ""
deployment_name: ""
required: false
- provider: azure
models: ${{ inputs.azure_models }}
api_key_secret: azure_api_key
custom_config_secret: ""
api_key_env: AZURE_API_KEY
custom_config_env: ""
api_base: ${{ inputs.azure_api_base }}
api_version: "2025-04-01-preview"
deployment_name: ""
required: false
- provider: ollama_chat
models: ${{ inputs.ollama_models }}
api_key_secret: ollama_api_key
custom_config_secret: ""
api_key_env: OLLAMA_API_KEY
custom_config_env: ""
api_base: "https://ollama.com"
api_version: ""
deployment_name: ""
required: false
- provider: openrouter
models: ${{ inputs.openrouter_models }}
api_key_secret: openrouter_api_key
custom_config_secret: ""
api_key_env: OPENROUTER_API_KEY
custom_config_env: ""
api_base: "https://openrouter.ai/api/v1"
api_version: ""
deployment_name: ""
@@ -230,6 +254,7 @@ jobs:
- "run-id=${{ github.run_id }}-nightly-${{ matrix.provider }}-provider-chat-test"
- extras=ecr-cache
timeout-minutes: 45
environment: ci-protected
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
@@ -238,21 +263,43 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
# Keep JSON values unparsed so vertex custom config is passed as raw JSON.
parse-json-secrets: false
secret-ids: |
DOCKER_USERNAME, test/docker-username
DOCKER_TOKEN, test/docker-token
OPENAI_API_KEY, test/openai-api-key
ANTHROPIC_API_KEY, test/anthropic-api-key
BEDROCK_API_KEY, test/bedrock-api-key
NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON, test/nightly-llm-vertex-ai-custom-config-json
AZURE_API_KEY, test/azure-api-key
OLLAMA_API_KEY, test/ollama-api-key
OPENROUTER_API_KEY, test/openrouter-api-key
- name: Run nightly provider chat test
uses: ./.github/actions/run-nightly-provider-chat-test
with:
provider: ${{ matrix.provider }}
models: ${{ matrix.models }}
provider-api-key: ${{ matrix.api_key_secret && secrets[matrix.api_key_secret] || '' }}
provider-api-key: ${{ matrix.api_key_env && env[matrix.api_key_env] || '' }}
strict: ${{ inputs.strict && 'true' || 'false' }}
api-base: ${{ matrix.api_base }}
api-version: ${{ matrix.api_version }}
deployment-name: ${{ matrix.deployment_name }}
custom-config-json: ${{ matrix.custom_config_secret && secrets[matrix.custom_config_secret] || '' }}
custom-config-json: ${{ matrix.custom_config_env && env[matrix.custom_config_env] || '' }}
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
run-id: ${{ github.run_id }}
docker-username: ${{ secrets.DOCKER_USERNAME }}
docker-token: ${{ secrets.DOCKER_TOKEN }}
docker-username: ${{ env.DOCKER_USERNAME }}
docker-token: ${{ env.DOCKER_TOKEN }}
- name: Dump API server logs
if: always()

View File

@@ -0,0 +1,37 @@
"""add cache_store table
Revision ID: 2664261bfaab
Revises: 4a1e4b1c89d2
Create Date: 2026-02-27 00:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "2664261bfaab"
down_revision = "4a1e4b1c89d2"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"cache_store",
sa.Column("key", sa.String(), nullable=False),
sa.Column("value", sa.LargeBinary(), nullable=True),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint("key"),
)
op.create_index(
"ix_cache_store_expires",
"cache_store",
["expires_at"],
postgresql_where=sa.text("expires_at IS NOT NULL"),
)
def downgrade() -> None:
op.drop_index("ix_cache_store_expires", table_name="cache_store")
op.drop_table("cache_store")

View File

@@ -0,0 +1,51 @@
"""Add INDEXING to UserFileStatus
Revision ID: 4a1e4b1c89d2
Revises: 6b3b4083c5aa
Create Date: 2026-02-28 00:00:00.000000
"""
import sqlalchemy as sa
from alembic import op
revision = "4a1e4b1c89d2"
down_revision = "6b3b4083c5aa"
branch_labels = None
depends_on = None
TABLE = "user_file"
COLUMN = "status"
CONSTRAINT_NAME = "ck_user_file_status"
OLD_VALUES = ("PROCESSING", "COMPLETED", "FAILED", "CANCELED", "DELETING")
NEW_VALUES = ("PROCESSING", "INDEXING", "COMPLETED", "FAILED", "CANCELED", "DELETING")
def _drop_status_check_constraint() -> None:
"""Drop the existing CHECK constraint on user_file.status.
The constraint name is auto-generated by SQLAlchemy and unknown,
so we look it up via the inspector.
"""
inspector = sa.inspect(op.get_bind())
for constraint in inspector.get_check_constraints(TABLE):
if COLUMN in constraint.get("sqltext", ""):
constraint_name = constraint["name"]
if constraint_name is not None:
op.drop_constraint(constraint_name, TABLE, type_="check")
def upgrade() -> None:
_drop_status_check_constraint()
in_clause = ", ".join(f"'{v}'" for v in NEW_VALUES)
op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})")
def downgrade() -> None:
op.execute(
f"UPDATE {TABLE} SET {COLUMN} = 'PROCESSING' WHERE {COLUMN} = 'INDEXING'"
)
op.drop_constraint(CONSTRAINT_NAME, TABLE, type_="check")
in_clause = ", ".join(f"'{v}'" for v in OLD_VALUES)
op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})")

View File

@@ -0,0 +1,112 @@
"""persona cleanup and featured
Revision ID: 6b3b4083c5aa
Revises: 57122d037335
Create Date: 2026-02-26 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "6b3b4083c5aa"
down_revision = "57122d037335"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add featured column with nullable=True first
op.add_column("persona", sa.Column("featured", sa.Boolean(), nullable=True))
# Migrate data from is_default_persona to featured
op.execute("UPDATE persona SET featured = is_default_persona")
# Make featured non-nullable with default=False
op.alter_column(
"persona",
"featured",
existing_type=sa.Boolean(),
nullable=False,
server_default=sa.false(),
)
# Drop is_default_persona column
op.drop_column("persona", "is_default_persona")
# Drop unused columns
op.drop_column("persona", "num_chunks")
op.drop_column("persona", "chunks_above")
op.drop_column("persona", "chunks_below")
op.drop_column("persona", "llm_relevance_filter")
op.drop_column("persona", "llm_filter_extraction")
op.drop_column("persona", "recency_bias")
def downgrade() -> None:
# Add back recency_bias column
op.add_column(
"persona",
sa.Column(
"recency_bias",
sa.VARCHAR(),
nullable=False,
server_default="base_decay",
),
)
# Add back llm_filter_extraction column
op.add_column(
"persona",
sa.Column(
"llm_filter_extraction",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
)
# Add back llm_relevance_filter column
op.add_column(
"persona",
sa.Column(
"llm_relevance_filter",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
)
# Add back chunks_below column
op.add_column(
"persona",
sa.Column("chunks_below", sa.Integer(), nullable=False, server_default="0"),
)
# Add back chunks_above column
op.add_column(
"persona",
sa.Column("chunks_above", sa.Integer(), nullable=False, server_default="0"),
)
# Add back num_chunks column
op.add_column("persona", sa.Column("num_chunks", sa.Float(), nullable=True))
# Add back is_default_persona column
op.add_column(
"persona",
sa.Column(
"is_default_persona",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
)
# Migrate data from featured to is_default_persona
op.execute("UPDATE persona SET is_default_persona = featured")
# Drop featured column
op.drop_column("persona", "featured")

View File

@@ -0,0 +1,34 @@
"""make scim_user_mapping.external_id nullable
Revision ID: a3b8d9e2f1c4
Revises: 2664261bfaab
Create Date: 2026-03-02
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "a3b8d9e2f1c4"
down_revision = "2664261bfaab"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column(
"scim_user_mapping",
"external_id",
nullable=True,
)
def downgrade() -> None:
# Delete any rows where external_id is NULL before re-applying NOT NULL
op.execute("DELETE FROM scim_user_mapping WHERE external_id IS NULL")
op.alter_column(
"scim_user_mapping",
"external_id",
nullable=False,
)

View File

@@ -126,12 +126,16 @@ class ScimDAL(DAL):
def create_user_mapping(
self,
external_id: str,
external_id: str | None,
user_id: UUID,
scim_username: str | None = None,
fields: ScimMappingFields | None = None,
) -> ScimUserMapping:
"""Create a mapping between a SCIM externalId and an Onyx user."""
"""Create a SCIM mapping for a user.
``external_id`` may be ``None`` when the IdP omits it (RFC 7643
allows this). The mapping still marks the user as SCIM-managed.
"""
f = fields or ScimMappingFields()
mapping = ScimUserMapping(
external_id=external_id,
@@ -270,8 +274,13 @@ class ScimDAL(DAL):
Raises:
ValueError: If the filter uses an unsupported attribute.
"""
query = select(User).where(
User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER])
# Inner-join with ScimUserMapping so only SCIM-managed users appear.
# Pre-existing system accounts (anonymous, admin, etc.) are excluded
# unless they were explicitly linked via SCIM provisioning.
query = (
select(User)
.join(ScimUserMapping, ScimUserMapping.user_id == User.id)
.where(User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER]))
)
if scim_filter:
@@ -321,34 +330,37 @@ class ScimDAL(DAL):
scim_username: str | None = None,
fields: ScimMappingFields | None = None,
) -> None:
"""Create, update, or delete the external ID mapping for a user.
"""Sync the SCIM mapping for a user.
If a mapping already exists, its fields are updated (including
setting ``external_id`` to ``None`` when the IdP omits it).
If no mapping exists and ``new_external_id`` is provided, a new
mapping is created. A mapping is never deleted here — SCIM-managed
users must retain their mapping to remain visible in ``GET /Users``.
When *fields* is provided, all mapping fields are written
unconditionally — including ``None`` values — so that a caller can
clear a previously-set field (e.g. removing a department).
"""
mapping = self.get_user_mapping_by_user_id(user_id)
if new_external_id:
if mapping:
if mapping.external_id != new_external_id:
mapping.external_id = new_external_id
if scim_username is not None:
mapping.scim_username = scim_username
if fields is not None:
mapping.department = fields.department
mapping.manager = fields.manager
mapping.given_name = fields.given_name
mapping.family_name = fields.family_name
mapping.scim_emails_json = fields.scim_emails_json
else:
self.create_user_mapping(
external_id=new_external_id,
user_id=user_id,
scim_username=scim_username,
fields=fields,
)
elif mapping:
self.delete_user_mapping(mapping.id)
if mapping:
if mapping.external_id != new_external_id:
mapping.external_id = new_external_id
if scim_username is not None:
mapping.scim_username = scim_username
if fields is not None:
mapping.department = fields.department
mapping.manager = fields.manager
mapping.given_name = fields.given_name
mapping.family_name = fields.family_name
mapping.scim_emails_json = fields.scim_emails_json
elif new_external_id:
self.create_user_mapping(
external_id=new_external_id,
user_id=user_id,
scim_username=scim_username,
fields=fields,
)
def _get_user_mappings_batch(
self, user_ids: list[UUID]

View File

@@ -31,6 +31,7 @@ from ee.onyx.server.query_and_chat.query_backend import (
from ee.onyx.server.query_and_chat.search_backend import router as search_router
from ee.onyx.server.query_history.api import router as query_history_router
from ee.onyx.server.reporting.usage_export_api import router as usage_export_router
from ee.onyx.server.scim.api import register_scim_exception_handlers
from ee.onyx.server.scim.api import scim_router
from ee.onyx.server.seeding import seed_db
from ee.onyx.server.tenants.api import router as tenants_router
@@ -167,6 +168,7 @@ def get_application() -> FastAPI:
# they use their own SCIM bearer token auth).
# Not behind APP_API_PREFIX because IdPs expect /scim/v2/... directly.
application.include_router(scim_router)
register_scim_exception_handlers(application)
# Ensure all routes have auth enabled or are explicitly marked as public
check_ee_router_auth(application)

View File

@@ -15,7 +15,9 @@ from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
from fastapi import FastAPI
from fastapi import Query
from fastapi import Request
from fastapi import Response
from fastapi.responses import JSONResponse
from fastapi_users.password import PasswordHelper
@@ -24,6 +26,7 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from ee.onyx.db.scim import ScimDAL
from ee.onyx.server.scim.auth import ScimAuthError
from ee.onyx.server.scim.auth import verify_scim_token
from ee.onyx.server.scim.filtering import parse_scim_filter
from ee.onyx.server.scim.models import SCIM_LIST_RESPONSE_SCHEMA
@@ -77,6 +80,22 @@ scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
_pw_helper = PasswordHelper()
def register_scim_exception_handlers(app: FastAPI) -> None:
"""Register SCIM-specific exception handlers on the FastAPI app.
Call this after ``app.include_router(scim_router)`` so that auth
failures from ``verify_scim_token`` return RFC 7644 §3.12 error
envelopes (with ``schemas`` and ``status`` fields) instead of
FastAPI's default ``{"detail": "..."}`` format.
"""
@app.exception_handler(ScimAuthError)
async def _handle_scim_auth_error(
_request: Request, exc: ScimAuthError
) -> ScimJSONResponse:
return _scim_error_response(exc.status_code, exc.detail)
def _get_provider(
_token: ScimToken = Depends(verify_scim_token),
) -> ScimProvider:
@@ -404,21 +423,63 @@ def create_user(
email = user_resource.userName.strip()
# externalId is how the IdP correlates this user on subsequent requests.
# Without it, the IdP can't find the user and will try to re-create,
# hitting a 409 conflict — so we require it up front.
if not user_resource.externalId:
return _scim_error_response(400, "externalId is required")
# Check for existing user — if they exist but aren't SCIM-managed yet,
# link them to the IdP rather than rejecting with 409.
external_id: str | None = user_resource.externalId
scim_username: str = user_resource.userName.strip()
fields: ScimMappingFields = _fields_from_resource(user_resource)
# Enforce seat limit
existing_user = dal.get_user_by_email(email)
if existing_user:
existing_mapping = dal.get_user_mapping_by_user_id(existing_user.id)
if existing_mapping:
return _scim_error_response(409, f"User with email {email} already exists")
# Adopt pre-existing user into SCIM management.
# Reactivating a deactivated user consumes a seat, so enforce the
# seat limit the same way replace_user does.
if user_resource.active and not existing_user.is_active:
seat_error = _check_seat_availability(dal)
if seat_error:
return _scim_error_response(403, seat_error)
personal_name = _scim_name_to_str(user_resource.name)
dal.update_user(
existing_user,
is_active=user_resource.active,
**({"personal_name": personal_name} if personal_name else {}),
)
try:
dal.create_user_mapping(
external_id=external_id,
user_id=existing_user.id,
scim_username=scim_username,
fields=fields,
)
dal.commit()
except IntegrityError:
dal.rollback()
return _scim_error_response(
409, f"User with email {email} already has a SCIM mapping"
)
return _scim_resource_response(
provider.build_user_resource(
existing_user,
external_id,
scim_username=scim_username,
fields=fields,
),
status_code=201,
)
# Only enforce seat limit for net-new users — adopting a pre-existing
# user doesn't consume a new seat.
seat_error = _check_seat_availability(dal)
if seat_error:
return _scim_error_response(403, seat_error)
# Check for existing user
if dal.get_user_by_email(email):
return _scim_error_response(409, f"User with email {email} already exists")
# Create user with a random password (SCIM users authenticate via IdP)
personal_name = _scim_name_to_str(user_resource.name)
user = User(
@@ -436,18 +497,21 @@ def create_user(
dal.rollback()
return _scim_error_response(409, f"User with email {email} already exists")
# Create SCIM mapping (externalId is validated above, always present)
external_id = user_resource.externalId
scim_username = user_resource.userName.strip()
fields = _fields_from_resource(user_resource)
dal.create_user_mapping(
external_id=external_id,
user_id=user.id,
scim_username=scim_username,
fields=fields,
)
dal.commit()
# Always create a SCIM mapping so that the user is marked as
# SCIM-managed. externalId may be None (RFC 7643 says it's optional).
try:
dal.create_user_mapping(
external_id=external_id,
user_id=user.id,
scim_username=scim_username,
fields=fields,
)
dal.commit()
except IntegrityError:
dal.rollback()
return _scim_error_response(
409, f"User with email {email} already has a SCIM mapping"
)
return _scim_resource_response(
provider.build_user_resource(

View File

@@ -19,7 +19,6 @@ import hashlib
import secrets
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from sqlalchemy.orm import Session
@@ -28,6 +27,21 @@ from onyx.auth.utils import get_hashed_bearer_token_from_request
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import ScimToken
class ScimAuthError(Exception):
"""Raised when SCIM bearer token authentication fails.
Unlike HTTPException, this carries the status and detail so the SCIM
exception handler can wrap them in an RFC 7644 §3.12 error envelope
with ``schemas`` and ``status`` fields.
"""
def __init__(self, status_code: int, detail: str) -> None:
self.status_code = status_code
self.detail = detail
super().__init__(detail)
SCIM_TOKEN_PREFIX = "onyx_scim_"
SCIM_TOKEN_LENGTH = 48
@@ -82,23 +96,14 @@ def verify_scim_token(
"""
hashed = _get_hashed_scim_token_from_request(request)
if not hashed:
raise HTTPException(
status_code=401,
detail="Missing or invalid SCIM bearer token",
)
raise ScimAuthError(401, "Missing or invalid SCIM bearer token")
token = dal.get_token_by_hash(hashed)
if not token:
raise HTTPException(
status_code=401,
detail="Invalid SCIM bearer token",
)
raise ScimAuthError(401, "Invalid SCIM bearer token")
if not token.is_active:
raise HTTPException(
status_code=401,
detail="SCIM token has been revoked",
)
raise ScimAuthError(401, "SCIM token has been revoked")
return token

View File

@@ -153,26 +153,31 @@ class ScimProvider(ABC):
self,
user: User,
fields: ScimMappingFields,
) -> ScimName | None:
) -> ScimName:
"""Build SCIM name components for the response.
Round-trips stored ``given_name``/``family_name`` when available (so
the IdP gets back what it sent). Falls back to splitting
``personal_name`` for users provisioned before we stored components.
Always returns a ScimName — Okta's spec tests expect ``name``
(with ``givenName``/``familyName``) on every user resource.
Providers may override for custom behavior.
"""
if fields.given_name is not None or fields.family_name is not None:
return ScimName(
givenName=fields.given_name,
familyName=fields.family_name,
formatted=user.personal_name,
givenName=fields.given_name or "",
familyName=fields.family_name or "",
formatted=user.personal_name or "",
)
if not user.personal_name:
return None
# Derive a reasonable name from the email so that SCIM spec tests
# see non-empty givenName / familyName for every user resource.
local = user.email.split("@")[0] if user.email else ""
return ScimName(givenName=local, familyName="", formatted=local)
parts = user.personal_name.split(" ", 1)
return ScimName(
givenName=parts[0],
familyName=parts[1] if len(parts) > 1 else None,
familyName=parts[1] if len(parts) > 1 else "",
formatted=user.personal_name,
)

View File

@@ -18,8 +18,8 @@ from ee.onyx.server.enterprise_settings.store import (
store_settings as store_ee_settings,
)
from ee.onyx.server.enterprise_settings.store import upload_logo
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import Tool
@@ -117,15 +117,38 @@ def _seed_custom_tools(db_session: Session, tools: List[CustomToolSeed]) -> None
def _seed_llms(
db_session: Session, llm_upsert_requests: list[LLMProviderUpsertRequest]
) -> None:
if llm_upsert_requests:
logger.notice("Seeding LLMs")
seeded_providers = [
upsert_llm_provider(llm_upsert_request, db_session)
for llm_upsert_request in llm_upsert_requests
]
update_default_provider(
provider_id=seeded_providers[0].id, db_session=db_session
)
if not llm_upsert_requests:
return
logger.notice("Seeding LLMs")
for request in llm_upsert_requests:
existing = fetch_existing_llm_provider(name=request.name, db_session=db_session)
if existing:
request.id = existing.id
seeded_providers = [
upsert_llm_provider(llm_upsert_request, db_session)
for llm_upsert_request in llm_upsert_requests
]
default_provider = next(
(p for p in seeded_providers if p.model_configurations), None
)
if not default_provider:
return
visible_configs = [
mc for mc in default_provider.model_configurations if mc.is_visible
]
default_config = (
visible_configs[0]
if visible_configs
else default_provider.model_configurations[0]
)
update_default_provider(
provider_id=default_provider.id,
model_name=default_config.name,
db_session=db_session,
)
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:
@@ -137,12 +160,6 @@ def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) ->
user=None, # Seeding is done as admin
name=persona.name,
description=persona.description,
num_chunks=(
persona.num_chunks if persona.num_chunks is not None else 0.0
),
llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction,
recency_bias=RecencyBiasSetting.AUTO,
document_set_ids=persona.document_set_ids,
llm_model_provider_override=persona.llm_model_provider_override,
llm_model_version_override=persona.llm_model_version_override,
@@ -154,6 +171,7 @@ def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) ->
system_prompt=persona.system_prompt,
task_prompt=persona.task_prompt,
datetime_aware=persona.datetime_aware,
featured=persona.featured,
commit=False,
)
db_session.commit()

View File

@@ -109,6 +109,12 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
if metadata.status == _BLOCKING_STATUS:
settings.application_status = metadata.status
settings.ee_features_enabled = False
elif metadata.used_seats > metadata.seats:
# License is valid but seat limit exceeded
settings.application_status = ApplicationStatus.SEAT_LIMIT_EXCEEDED
settings.seat_count = metadata.seats
settings.used_seats = metadata.used_seats
settings.ee_features_enabled = True
else:
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
settings.ee_features_enabled = True

View File

@@ -33,6 +33,7 @@ from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine.sql_engine import get_session_with_shared_schema
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.image_generation import create_default_image_gen_config_from_api_key
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_cloud_embedding_provider
from onyx.db.llm import upsert_llm_provider
@@ -302,12 +303,17 @@ def configure_default_api_keys(db_session: Session) -> None:
has_set_default_provider = False
def _upsert(request: LLMProviderUpsertRequest) -> None:
def _upsert(request: LLMProviderUpsertRequest, default_model: str) -> None:
nonlocal has_set_default_provider
try:
existing = fetch_existing_llm_provider(
name=request.name, db_session=db_session
)
if existing:
request.id = existing.id
provider = upsert_llm_provider(request, db_session)
if not has_set_default_provider:
update_default_provider(provider.id, db_session)
update_default_provider(provider.id, default_model, db_session)
has_set_default_provider = True
except Exception as e:
logger.error(f"Failed to configure {request.provider} provider: {e}")
@@ -325,14 +331,13 @@ def configure_default_api_keys(db_session: Session) -> None:
name="OpenAI",
provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY,
default_model_name=default_model_name,
model_configurations=_build_model_configuration_upsert_requests(
OPENAI_PROVIDER_NAME, recommendations
),
api_key_changed=True,
is_auto_mode=True,
)
_upsert(openai_provider)
_upsert(openai_provider, default_model_name)
# Create default image generation config using the OpenAI API key
try:
@@ -361,14 +366,13 @@ def configure_default_api_keys(db_session: Session) -> None:
name="Anthropic",
provider=ANTHROPIC_PROVIDER_NAME,
api_key=ANTHROPIC_DEFAULT_API_KEY,
default_model_name=default_model_name,
model_configurations=_build_model_configuration_upsert_requests(
ANTHROPIC_PROVIDER_NAME, recommendations
),
api_key_changed=True,
is_auto_mode=True,
)
_upsert(anthropic_provider)
_upsert(anthropic_provider, default_model_name)
else:
logger.info(
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
@@ -393,14 +397,13 @@ def configure_default_api_keys(db_session: Session) -> None:
name="Google Vertex AI",
provider=VERTEXAI_PROVIDER_NAME,
custom_config=custom_config,
default_model_name=default_model_name,
model_configurations=_build_model_configuration_upsert_requests(
VERTEXAI_PROVIDER_NAME, recommendations
),
api_key_changed=True,
is_auto_mode=True,
)
_upsert(vertexai_provider)
_upsert(vertexai_provider, default_model_name)
else:
logger.info(
"VERTEXAI_DEFAULT_CREDENTIALS not set, skipping Vertex AI provider configuration"
@@ -432,12 +435,11 @@ def configure_default_api_keys(db_session: Session) -> None:
name="OpenRouter",
provider=OPENROUTER_PROVIDER_NAME,
api_key=OPENROUTER_DEFAULT_API_KEY,
default_model_name=default_model_name,
model_configurations=model_configurations,
api_key_changed=True,
is_auto_mode=True,
)
_upsert(openrouter_provider)
_upsert(openrouter_provider, default_model_name)
else:
logger.info(
"OPENROUTER_DEFAULT_API_KEY not set, skipping OpenRouter provider configuration"

View File

@@ -543,7 +543,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
result = await db_session.execute(
select(Persona.id)
.where(
Persona.is_default_persona.is_(True),
Persona.featured.is_(True),
Persona.is_public.is_(True),
Persona.is_visible.is_(True),
Persona.deleted.is_(False),
@@ -725,11 +725,19 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
if user_by_session:
user = user_by_session
# If the user is inactive, check seat availability before
# upgrading role — otherwise they'd become an inactive BASIC
# user who still can't log in.
if not user.is_active:
with get_session_with_current_tenant() as sync_db:
enforce_seat_limit(sync_db)
await self.user_db.update(
user,
{
"is_verified": is_verified_by_default,
"role": UserRole.BASIC,
**({"is_active": True} if not user.is_active else {}),
},
)

View File

@@ -241,8 +241,7 @@ _VECTOR_DB_BEAT_TASK_NAMES: set[str] = {
"check-for-index-attempt-cleanup",
"check-for-doc-permissions-sync",
"check-for-external-group-sync",
"check-for-documents-for-opensearch-migration",
"migrate-documents-from-vespa-to-opensearch",
"migrate-chunks-from-vespa-to-opensearch",
}
if DISABLE_VECTOR_DB:

View File

@@ -414,34 +414,31 @@ def _process_user_file_with_indexing(
raise RuntimeError(f"Indexing pipeline failed for user file {user_file_id}")
@shared_task(
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
bind=True,
ignore_result=True,
)
def process_single_user_file(
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
def process_user_file_impl(
*, user_file_id: str, tenant_id: str, redis_locking: bool
) -> None:
task_logger.info(f"process_single_user_file - Starting id={user_file_id}")
"""Core implementation for processing a single user file.
When redis_locking=True, acquires a per-file Redis lock and clears the
queued-key guard (Celery path). When redis_locking=False, skips all Redis
operations (BackgroundTask path).
"""
task_logger.info(f"process_user_file_impl - Starting id={user_file_id}")
start = time.monotonic()
redis_client = get_redis_client(tenant_id=tenant_id)
# Clear the "queued" guard set by the beat generator so that the next beat
# cycle can re-enqueue this file if it is still in PROCESSING state after
# this task completes or fails.
redis_client.delete(_user_file_queued_key(user_file_id))
file_lock: RedisLock = redis_client.lock(
_user_file_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
)
if not file_lock.acquire(blocking=False):
task_logger.info(
f"process_single_user_file - Lock held, skipping user_file_id={user_file_id}"
file_lock: RedisLock | None = None
if redis_locking:
redis_client = get_redis_client(tenant_id=tenant_id)
redis_client.delete(_user_file_queued_key(user_file_id))
file_lock = redis_client.lock(
_user_file_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
)
return None
if file_lock is not None and not file_lock.acquire(blocking=False):
task_logger.info(
f"process_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
)
return
documents: list[Document] = []
try:
@@ -449,15 +446,18 @@ def process_single_user_file(
uf = db_session.get(UserFile, _as_uuid(user_file_id))
if not uf:
task_logger.warning(
f"process_single_user_file - UserFile not found id={user_file_id}"
f"process_user_file_impl - UserFile not found id={user_file_id}"
)
return None
return
if uf.status != UserFileStatus.PROCESSING:
if uf.status not in (
UserFileStatus.PROCESSING,
UserFileStatus.INDEXING,
):
task_logger.info(
f"process_single_user_file - Skipping id={user_file_id} status={uf.status}"
f"process_user_file_impl - Skipping id={user_file_id} status={uf.status}"
)
return None
return
connector = LocalFileConnector(
file_locations=[uf.file_id],
@@ -471,7 +471,6 @@ def process_single_user_file(
[doc for doc in batch if not isinstance(doc, HierarchyNode)]
)
# update the document id to userfile id in the documents
for document in documents:
document.id = str(user_file_id)
document.source = DocumentSource.USER_FILE
@@ -493,9 +492,8 @@ def process_single_user_file(
except Exception as e:
task_logger.exception(
f"process_single_user_file - Error processing file id={user_file_id} - {e.__class__.__name__}"
f"process_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
)
# don't update the status if the user file is being deleted
current_user_file = db_session.get(UserFile, _as_uuid(user_file_id))
if (
current_user_file
@@ -504,33 +502,43 @@ def process_single_user_file(
uf.status = UserFileStatus.FAILED
db_session.add(uf)
db_session.commit()
return None
return
elapsed = time.monotonic() - start
task_logger.info(
f"process_single_user_file - Finished id={user_file_id} docs={len(documents)} elapsed={elapsed:.2f}s"
f"process_user_file_impl - Finished id={user_file_id} docs={len(documents)} elapsed={elapsed:.2f}s"
)
return None
except Exception as e:
# Attempt to mark the file as failed
with get_session_with_current_tenant() as db_session:
uf = db_session.get(UserFile, _as_uuid(user_file_id))
if uf:
# don't update the status if the user file is being deleted
if uf.status != UserFileStatus.DELETING:
uf.status = UserFileStatus.FAILED
db_session.add(uf)
db_session.commit()
task_logger.exception(
f"process_single_user_file - Error processing file id={user_file_id} - {e.__class__.__name__}"
f"process_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
)
return None
raise
finally:
if file_lock.owned():
if file_lock is not None and file_lock.owned():
file_lock.release()
@shared_task(
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
bind=True,
ignore_result=True,
)
def process_single_user_file(
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
) -> None:
process_user_file_impl(
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
)
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_USER_FILE_DELETE,
soft_time_limit=300,
@@ -581,36 +589,38 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
return None
@shared_task(
name=OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
bind=True,
ignore_result=True,
)
def process_single_user_file_delete(
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
def delete_user_file_impl(
*, user_file_id: str, tenant_id: str, redis_locking: bool
) -> None:
"""Process a single user file delete."""
task_logger.info(f"process_single_user_file_delete - Starting id={user_file_id}")
redis_client = get_redis_client(tenant_id=tenant_id)
file_lock: RedisLock = redis_client.lock(
_user_file_delete_lock_key(user_file_id),
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
if not file_lock.acquire(blocking=False):
task_logger.info(
f"process_single_user_file_delete - Lock held, skipping user_file_id={user_file_id}"
"""Core implementation for deleting a single user file.
When redis_locking=True, acquires a per-file Redis lock (Celery path).
When redis_locking=False, skips Redis operations (BackgroundTask path).
"""
task_logger.info(f"delete_user_file_impl - Starting id={user_file_id}")
file_lock: RedisLock | None = None
if redis_locking:
redis_client = get_redis_client(tenant_id=tenant_id)
file_lock = redis_client.lock(
_user_file_delete_lock_key(user_file_id),
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
return None
if file_lock is not None and not file_lock.acquire(blocking=False):
task_logger.info(
f"delete_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
)
return
try:
with get_session_with_current_tenant() as db_session:
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
if not user_file:
task_logger.info(
f"process_single_user_file_delete - User file not found id={user_file_id}"
f"delete_user_file_impl - User file not found id={user_file_id}"
)
return None
return
# 1) Delete vector DB chunks (skip when disabled)
if not DISABLE_VECTOR_DB:
if MANAGED_VESPA:
httpx_init_vespa_pool(
@@ -648,7 +658,6 @@ def process_single_user_file_delete(
chunk_count=chunk_count,
)
# 2) Delete the user-uploaded file content from filestore (blob + metadata)
file_store = get_default_file_store()
try:
file_store.delete_file(user_file.file_id)
@@ -656,26 +665,34 @@ def process_single_user_file_delete(
user_file_id_to_plaintext_file_name(user_file.id)
)
except Exception as e:
# This block executed only if the file is not found in the filestore
task_logger.exception(
f"process_single_user_file_delete - Error deleting file id={user_file.id} - {e.__class__.__name__}"
f"delete_user_file_impl - Error deleting file id={user_file.id} - {e.__class__.__name__}"
)
# 3) Finally, delete the UserFile row
db_session.delete(user_file)
db_session.commit()
task_logger.info(
f"process_single_user_file_delete - Completed id={user_file_id}"
)
task_logger.info(f"delete_user_file_impl - Completed id={user_file_id}")
except Exception as e:
task_logger.exception(
f"process_single_user_file_delete - Error processing file id={user_file_id} - {e.__class__.__name__}"
f"delete_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
)
return None
raise
finally:
if file_lock.owned():
if file_lock is not None and file_lock.owned():
file_lock.release()
return None
@shared_task(
name=OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
bind=True,
ignore_result=True,
)
def process_single_user_file_delete(
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
) -> None:
delete_user_file_impl(
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
)
@shared_task(
@@ -747,32 +764,30 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
return None
@shared_task(
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
bind=True,
ignore_result=True,
)
def process_single_user_file_project_sync(
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
def project_sync_user_file_impl(
*, user_file_id: str, tenant_id: str, redis_locking: bool
) -> None:
"""Process a single user file project sync."""
task_logger.info(
f"process_single_user_file_project_sync - Starting id={user_file_id}"
)
"""Core implementation for syncing a user file's project/persona metadata.
redis_client = get_redis_client(tenant_id=tenant_id)
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
When redis_locking=True, acquires a per-file Redis lock and clears the
queued-key guard (Celery path). When redis_locking=False, skips Redis
operations (BackgroundTask path).
"""
task_logger.info(f"project_sync_user_file_impl - Starting id={user_file_id}")
file_lock: RedisLock = redis_client.lock(
user_file_project_sync_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
)
if not file_lock.acquire(blocking=False):
task_logger.info(
f"process_single_user_file_project_sync - Lock held, skipping user_file_id={user_file_id}"
file_lock: RedisLock | None = None
if redis_locking:
redis_client = get_redis_client(tenant_id=tenant_id)
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
file_lock = redis_client.lock(
user_file_project_sync_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
)
return None
if file_lock is not None and not file_lock.acquire(blocking=False):
task_logger.info(
f"project_sync_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
)
return
try:
with get_session_with_current_tenant() as db_session:
@@ -783,11 +798,10 @@ def process_single_user_file_project_sync(
).scalar_one_or_none()
if not user_file:
task_logger.info(
f"process_single_user_file_project_sync - User file not found id={user_file_id}"
f"project_sync_user_file_impl - User file not found id={user_file_id}"
)
return None
return
# Sync project metadata to vector DB (skip when disabled)
if not DISABLE_VECTOR_DB:
if MANAGED_VESPA:
httpx_init_vespa_pool(
@@ -822,7 +836,7 @@ def process_single_user_file_project_sync(
)
task_logger.info(
f"process_single_user_file_project_sync - User file id={user_file_id}"
f"project_sync_user_file_impl - User file id={user_file_id}"
)
user_file.needs_project_sync = False
@@ -835,11 +849,22 @@ def process_single_user_file_project_sync(
except Exception as e:
task_logger.exception(
f"process_single_user_file_project_sync - Error syncing project for file id={user_file_id} - {e.__class__.__name__}"
f"project_sync_user_file_impl - Error syncing project for file id={user_file_id} - {e.__class__.__name__}"
)
return None
raise
finally:
if file_lock.owned():
if file_lock is not None and file_lock.owned():
file_lock.release()
return None
@shared_task(
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
bind=True,
ignore_result=True,
)
def process_single_user_file_project_sync(
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
) -> None:
project_sync_user_file_impl(
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
)

View File

@@ -0,0 +1,307 @@
"""Periodic poller for NO_VECTOR_DB deployments.
Replaces Celery Beat and background workers with a lightweight daemon thread
that runs from the API server process. Two responsibilities:
1. Recovery polling (every 30 s): re-processes user files stuck in
PROCESSING / DELETING / needs_sync states via the drain loops defined
in ``task_utils.py``.
2. Periodic task execution (configurable intervals): runs LLM model updates
and scheduled evals at their configured cadences, with Postgres advisory
lock deduplication across multiple API server instances.
"""
import threading
import time
from collections.abc import Callable
from dataclasses import dataclass
from dataclasses import field
from onyx.utils.logger import setup_logger
logger = setup_logger()
RECOVERY_INTERVAL_SECONDS = 30
PERIODIC_TASK_LOCK_BASE = 20_000
PERIODIC_TASK_KV_PREFIX = "periodic_poller:last_claimed:"
# ------------------------------------------------------------------
# Periodic task definitions
# ------------------------------------------------------------------
_NEVER_RAN: float = -1e18
@dataclass
class _PeriodicTaskDef:
name: str
interval_seconds: float
lock_id: int
run_fn: Callable[[], None]
last_run_at: float = field(default=_NEVER_RAN)
def _run_auto_llm_update() -> None:
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
if not AUTO_LLM_CONFIG_URL:
return
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.llm.well_known_providers.auto_update_service import (
sync_llm_models_from_github,
)
with get_session_with_current_tenant() as db_session:
sync_llm_models_from_github(db_session)
def _run_cache_cleanup() -> None:
from onyx.cache.postgres_backend import cleanup_expired_cache_entries
cleanup_expired_cache_entries()
def _run_scheduled_eval() -> None:
from onyx.configs.app_configs import BRAINTRUST_API_KEY
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
from onyx.configs.app_configs import SCHEDULED_EVAL_PERMISSIONS_EMAIL
from onyx.configs.app_configs import SCHEDULED_EVAL_PROJECT
if not all(
[
BRAINTRUST_API_KEY,
SCHEDULED_EVAL_PROJECT,
SCHEDULED_EVAL_DATASET_NAMES,
SCHEDULED_EVAL_PERMISSIONS_EMAIL,
]
):
return
from datetime import datetime
from datetime import timezone
from onyx.evals.eval import run_eval
from onyx.evals.models import EvalConfigurationOptions
run_timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d")
for dataset_name in SCHEDULED_EVAL_DATASET_NAMES:
try:
run_eval(
configuration=EvalConfigurationOptions(
search_permissions_email=SCHEDULED_EVAL_PERMISSIONS_EMAIL,
dataset_name=dataset_name,
no_send_logs=False,
braintrust_project=SCHEDULED_EVAL_PROJECT,
experiment_name=f"{dataset_name} - {run_timestamp}",
),
remote_dataset_name=dataset_name,
)
except Exception:
logger.exception(
f"Periodic poller - Failed scheduled eval for dataset {dataset_name}"
)
_CACHE_CLEANUP_INTERVAL_SECONDS = 300
def _build_periodic_tasks() -> list[_PeriodicTaskDef]:
from onyx.cache.interface import CacheBackendType
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
from onyx.configs.app_configs import CACHE_BACKEND
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
tasks: list[_PeriodicTaskDef] = []
if CACHE_BACKEND == CacheBackendType.POSTGRES:
tasks.append(
_PeriodicTaskDef(
name="cache-cleanup",
interval_seconds=_CACHE_CLEANUP_INTERVAL_SECONDS,
lock_id=PERIODIC_TASK_LOCK_BASE + 2,
run_fn=_run_cache_cleanup,
)
)
if AUTO_LLM_CONFIG_URL:
tasks.append(
_PeriodicTaskDef(
name="auto-llm-update",
interval_seconds=AUTO_LLM_UPDATE_INTERVAL_SECONDS,
lock_id=PERIODIC_TASK_LOCK_BASE,
run_fn=_run_auto_llm_update,
)
)
if SCHEDULED_EVAL_DATASET_NAMES:
tasks.append(
_PeriodicTaskDef(
name="scheduled-eval",
interval_seconds=7 * 24 * 3600,
lock_id=PERIODIC_TASK_LOCK_BASE + 1,
run_fn=_run_scheduled_eval,
)
)
return tasks
# ------------------------------------------------------------------
# Periodic task runner with advisory-lock-guarded claim
# ------------------------------------------------------------------
def _try_claim_task(task_def: _PeriodicTaskDef) -> bool:
"""Atomically check whether *task_def* should run and record a claim.
Uses a transaction-scoped advisory lock for atomicity combined with a
``KVStore`` timestamp for cross-instance dedup. The DB session is held
only for this brief claim transaction, not during task execution.
"""
from datetime import datetime
from datetime import timezone
from sqlalchemy import text
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import KVStore
kv_key = PERIODIC_TASK_KV_PREFIX + task_def.name
with get_session_with_current_tenant() as db_session:
acquired = db_session.execute(
text("SELECT pg_try_advisory_xact_lock(:id)"),
{"id": task_def.lock_id},
).scalar()
if not acquired:
return False
row = db_session.query(KVStore).filter_by(key=kv_key).first()
if row and row.value is not None:
last_claimed = datetime.fromisoformat(str(row.value))
elapsed = (datetime.now(timezone.utc) - last_claimed).total_seconds()
if elapsed < task_def.interval_seconds:
return False
now_ts = datetime.now(timezone.utc).isoformat()
if row:
row.value = now_ts
else:
db_session.add(KVStore(key=kv_key, value=now_ts))
db_session.commit()
return True
def _try_run_periodic_task(task_def: _PeriodicTaskDef) -> None:
"""Run *task_def* if its interval has elapsed and no peer holds the lock."""
now = time.monotonic()
if now - task_def.last_run_at < task_def.interval_seconds:
return
if not _try_claim_task(task_def):
return
try:
task_def.run_fn()
task_def.last_run_at = now
except Exception:
logger.exception(
f"Periodic poller - Error running periodic task {task_def.name}"
)
# ------------------------------------------------------------------
# Recovery / drain loop runner
# ------------------------------------------------------------------
def _run_drain_loops(tenant_id: str) -> None:
from onyx.background.task_utils import drain_delete_loop
from onyx.background.task_utils import drain_processing_loop
from onyx.background.task_utils import drain_project_sync_loop
drain_processing_loop(tenant_id)
drain_delete_loop(tenant_id)
drain_project_sync_loop(tenant_id)
# ------------------------------------------------------------------
# Startup recovery (10g)
# ------------------------------------------------------------------
def recover_stuck_user_files(tenant_id: str) -> None:
"""Run all drain loops once to re-process files left in intermediate states.
Called from ``lifespan()`` on startup when ``DISABLE_VECTOR_DB`` is set.
"""
logger.info("recover_stuck_user_files - Checking for stuck user files")
try:
_run_drain_loops(tenant_id)
except Exception:
logger.exception("recover_stuck_user_files - Error during recovery")
# ------------------------------------------------------------------
# Daemon thread (10f)
# ------------------------------------------------------------------
_shutdown_event = threading.Event()
_poller_thread: threading.Thread | None = None
def _poller_loop(tenant_id: str) -> None:
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
periodic_tasks = _build_periodic_tasks()
logger.info(
f"Periodic poller started with {len(periodic_tasks)} periodic task(s): "
f"{[t.name for t in periodic_tasks]}"
)
while not _shutdown_event.is_set():
try:
_run_drain_loops(tenant_id)
except Exception:
logger.exception("Periodic poller - Error in recovery polling")
for task_def in periodic_tasks:
try:
_try_run_periodic_task(task_def)
except Exception:
logger.exception(
f"Periodic poller - Unhandled error checking task {task_def.name}"
)
_shutdown_event.wait(RECOVERY_INTERVAL_SECONDS)
def start_periodic_poller(tenant_id: str) -> None:
"""Start the periodic poller daemon thread."""
global _poller_thread # noqa: PLW0603
_shutdown_event.clear()
_poller_thread = threading.Thread(
target=_poller_loop,
args=(tenant_id,),
daemon=True,
name="no-vectordb-periodic-poller",
)
_poller_thread.start()
logger.info("Periodic poller thread started")
def stop_periodic_poller() -> None:
"""Signal the periodic poller to stop and wait for it to exit."""
global _poller_thread # noqa: PLW0603
if _poller_thread is None:
return
_shutdown_event.set()
_poller_thread.join(timeout=10)
if _poller_thread.is_alive():
logger.warning("Periodic poller thread did not stop within timeout")
_poller_thread = None
logger.info("Periodic poller thread stopped")

View File

@@ -1,3 +1,33 @@
"""Background task utilities.
Contains query-history report helpers (used by all deployment modes) and
in-process background task execution helpers for NO_VECTOR_DB mode:
- Atomic claim-and-mark helpers that prevent duplicate processing
- Drain loops that process all pending user file work
Each claim function runs a short-lived transaction: SELECT ... FOR UPDATE
SKIP LOCKED, UPDATE the row to remove it from future queries, COMMIT.
After the commit the row lock is released, but the row is no longer
eligible for re-claiming. No long-lived sessions or advisory locks.
"""
from uuid import UUID
import sqlalchemy as sa
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.db.enums import UserFileStatus
from onyx.db.models import UserFile
from onyx.utils.logger import setup_logger
logger = setup_logger()
# ------------------------------------------------------------------
# Query-history report helpers (pre-existing, used by all modes)
# ------------------------------------------------------------------
QUERY_REPORT_NAME_PREFIX = "query-history"
@@ -9,3 +39,168 @@ def construct_query_history_report_name(
def extract_task_id_from_query_history_report_name(name: str) -> str:
return name.removeprefix(f"{QUERY_REPORT_NAME_PREFIX}-").removesuffix(".csv")
# ------------------------------------------------------------------
# Atomic claim-and-mark helpers
# ------------------------------------------------------------------
# Each function runs inside a single short-lived session/transaction:
# 1. SELECT ... FOR UPDATE SKIP LOCKED (locks one eligible row)
# 2. UPDATE the row so it is no longer eligible
# 3. COMMIT (releases the row lock)
# After the commit, no other drain loop can claim the same row.
def _claim_next_processing_file(db_session: Session) -> UUID | None:
"""Claim the next PROCESSING file by transitioning it to INDEXING.
Returns the file id, or None when no eligible files remain.
"""
file_id = db_session.execute(
select(UserFile.id)
.where(UserFile.status == UserFileStatus.PROCESSING)
.order_by(UserFile.created_at)
.limit(1)
.with_for_update(skip_locked=True)
).scalar_one_or_none()
if file_id is None:
return None
db_session.execute(
sa.update(UserFile)
.where(UserFile.id == file_id)
.values(status=UserFileStatus.INDEXING)
)
db_session.commit()
return file_id
def _claim_next_deleting_file(
db_session: Session,
exclude_ids: set[UUID] | None = None,
) -> UUID | None:
"""Claim the next DELETING file.
No status transition needed — the impl deletes the row on success.
The short-lived FOR UPDATE lock prevents concurrent claims.
*exclude_ids* prevents re-processing the same file if the impl fails.
"""
stmt = (
select(UserFile.id)
.where(UserFile.status == UserFileStatus.DELETING)
.order_by(UserFile.created_at)
.limit(1)
.with_for_update(skip_locked=True)
)
if exclude_ids:
stmt = stmt.where(UserFile.id.notin_(exclude_ids))
file_id = db_session.execute(stmt).scalar_one_or_none()
db_session.commit()
return file_id
def _claim_next_sync_file(
db_session: Session,
exclude_ids: set[UUID] | None = None,
) -> UUID | None:
"""Claim the next file needing project/persona sync.
No status transition needed — the impl clears the sync flags on
success. The short-lived FOR UPDATE lock prevents concurrent claims.
*exclude_ids* prevents re-processing the same file if the impl fails.
"""
stmt = (
select(UserFile.id)
.where(
sa.and_(
sa.or_(
UserFile.needs_project_sync.is_(True),
UserFile.needs_persona_sync.is_(True),
),
UserFile.status == UserFileStatus.COMPLETED,
)
)
.order_by(UserFile.created_at)
.limit(1)
.with_for_update(skip_locked=True)
)
if exclude_ids:
stmt = stmt.where(UserFile.id.notin_(exclude_ids))
file_id = db_session.execute(stmt).scalar_one_or_none()
db_session.commit()
return file_id
# ------------------------------------------------------------------
# Drain loops — process *all* pending work of each type
# ------------------------------------------------------------------
def drain_processing_loop(tenant_id: str) -> None:
"""Process all pending PROCESSING user files."""
from onyx.background.celery.tasks.user_file_processing.tasks import (
process_user_file_impl,
)
from onyx.db.engine.sql_engine import get_session_with_current_tenant
while True:
with get_session_with_current_tenant() as session:
file_id = _claim_next_processing_file(session)
if file_id is None:
break
try:
process_user_file_impl(
user_file_id=str(file_id),
tenant_id=tenant_id,
redis_locking=False,
)
except Exception:
logger.exception(f"Failed to process user file {file_id}")
def drain_delete_loop(tenant_id: str) -> None:
"""Delete all pending DELETING user files."""
from onyx.background.celery.tasks.user_file_processing.tasks import (
delete_user_file_impl,
)
from onyx.db.engine.sql_engine import get_session_with_current_tenant
failed: set[UUID] = set()
while True:
with get_session_with_current_tenant() as session:
file_id = _claim_next_deleting_file(session, exclude_ids=failed)
if file_id is None:
break
try:
delete_user_file_impl(
user_file_id=str(file_id),
tenant_id=tenant_id,
redis_locking=False,
)
except Exception:
logger.exception(f"Failed to delete user file {file_id}")
failed.add(file_id)
def drain_project_sync_loop(tenant_id: str) -> None:
"""Sync all pending project/persona metadata for user files."""
from onyx.background.celery.tasks.user_file_processing.tasks import (
project_sync_user_file_impl,
)
from onyx.db.engine.sql_engine import get_session_with_current_tenant
failed: set[UUID] = set()
while True:
with get_session_with_current_tenant() as session:
file_id = _claim_next_sync_file(session, exclude_ids=failed)
if file_id is None:
break
try:
project_sync_user_file_impl(
user_file_id=str(file_id),
tenant_id=tenant_id,
redis_locking=False,
)
except Exception:
logger.exception(f"Failed to sync user file {file_id}")
failed.add(file_id)

51
backend/onyx/cache/factory.py vendored Normal file
View File

@@ -0,0 +1,51 @@
from collections.abc import Callable
from onyx.cache.interface import CacheBackend
from onyx.cache.interface import CacheBackendType
from onyx.configs.app_configs import CACHE_BACKEND
def _build_redis_backend(tenant_id: str) -> CacheBackend:
from onyx.cache.redis_backend import RedisCacheBackend
from onyx.redis.redis_pool import redis_pool
return RedisCacheBackend(redis_pool.get_client(tenant_id))
def _build_postgres_backend(tenant_id: str) -> CacheBackend:
from onyx.cache.postgres_backend import PostgresCacheBackend
return PostgresCacheBackend(tenant_id)
_BACKEND_BUILDERS: dict[CacheBackendType, Callable[[str], CacheBackend]] = {
CacheBackendType.REDIS: _build_redis_backend,
CacheBackendType.POSTGRES: _build_postgres_backend,
}
def get_cache_backend(*, tenant_id: str | None = None) -> CacheBackend:
"""Return a tenant-aware ``CacheBackend``.
If *tenant_id* is ``None``, the current tenant is read from the
thread-local context variable (same behaviour as ``get_redis_client``).
"""
if tenant_id is None:
from shared_configs.contextvars import get_current_tenant_id
tenant_id = get_current_tenant_id()
builder = _BACKEND_BUILDERS.get(CACHE_BACKEND)
if builder is None:
raise ValueError(
f"Unsupported CACHE_BACKEND={CACHE_BACKEND!r}. "
f"Supported values: {[t.value for t in CacheBackendType]}"
)
return builder(tenant_id)
def get_shared_cache_backend() -> CacheBackend:
"""Return a ``CacheBackend`` in the shared (cross-tenant) namespace."""
from shared_configs.configs import DEFAULT_REDIS_PREFIX
return get_cache_backend(tenant_id=DEFAULT_REDIS_PREFIX)

104
backend/onyx/cache/interface.py vendored Normal file
View File

@@ -0,0 +1,104 @@
import abc
from enum import Enum
TTL_KEY_NOT_FOUND = -2
TTL_NO_EXPIRY = -1
class CacheBackendType(str, Enum):
REDIS = "redis"
POSTGRES = "postgres"
class CacheLock(abc.ABC):
"""Abstract distributed lock returned by CacheBackend.lock()."""
@abc.abstractmethod
def acquire(
self,
blocking: bool = True,
blocking_timeout: float | None = None,
) -> bool:
raise NotImplementedError
@abc.abstractmethod
def release(self) -> None:
raise NotImplementedError
@abc.abstractmethod
def owned(self) -> bool:
raise NotImplementedError
def __enter__(self) -> "CacheLock":
if not self.acquire():
raise RuntimeError("Failed to acquire lock")
return self
def __exit__(self, *args: object) -> None:
self.release()
class CacheBackend(abc.ABC):
"""Thin abstraction over a key-value cache with TTL, locks, and blocking lists.
Covers the subset of Redis operations used outside of Celery. When
CACHE_BACKEND=postgres, a PostgreSQL-backed implementation is used instead.
"""
# -- basic key/value ---------------------------------------------------
@abc.abstractmethod
def get(self, key: str) -> bytes | None:
raise NotImplementedError
@abc.abstractmethod
def set(
self,
key: str,
value: str | bytes | int | float,
ex: int | None = None,
) -> None:
raise NotImplementedError
@abc.abstractmethod
def delete(self, key: str) -> None:
raise NotImplementedError
@abc.abstractmethod
def exists(self, key: str) -> bool:
raise NotImplementedError
# -- TTL ---------------------------------------------------------------
@abc.abstractmethod
def expire(self, key: str, seconds: int) -> None:
raise NotImplementedError
@abc.abstractmethod
def ttl(self, key: str) -> int:
"""Return remaining TTL in seconds.
Returns ``TTL_NO_EXPIRY`` (-1) if key exists without expiry,
``TTL_KEY_NOT_FOUND`` (-2) if key is missing or expired.
"""
raise NotImplementedError
# -- distributed lock --------------------------------------------------
@abc.abstractmethod
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
raise NotImplementedError
# -- blocking list (used by MCP OAuth BLPOP pattern) -------------------
@abc.abstractmethod
def rpush(self, key: str, value: str | bytes) -> None:
raise NotImplementedError
@abc.abstractmethod
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
"""Block until a value is available on one of *keys*, or *timeout* expires.
Returns ``(key, value)`` or ``None`` on timeout.
"""
raise NotImplementedError

323
backend/onyx/cache/postgres_backend.py vendored Normal file
View File

@@ -0,0 +1,323 @@
"""PostgreSQL-backed ``CacheBackend`` for NO_VECTOR_DB deployments.
Uses the ``cache_store`` table for key-value storage, PostgreSQL advisory locks
for distributed locking, and a polling loop for the BLPOP pattern.
"""
import hashlib
import struct
import time
import uuid
from contextlib import AbstractContextManager
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.orm import Session
from onyx.cache.interface import CacheBackend
from onyx.cache.interface import CacheLock
from onyx.cache.interface import TTL_KEY_NOT_FOUND
from onyx.cache.interface import TTL_NO_EXPIRY
from onyx.db.models import CacheStore
_LIST_KEY_PREFIX = "_q:"
# ASCII: ':' (0x3A) < ';' (0x3B). Upper bound for range queries so [prefix+, prefix;)
# captures all list-item keys (e.g. _q:mylist:123:uuid) without including other
# lists whose names share a prefix (e.g. _q:mylist2:...).
_LIST_KEY_RANGE_TERMINATOR = ";"
_LIST_ITEM_TTL_SECONDS = 3600
_LOCK_POLL_INTERVAL = 0.1
_BLPOP_POLL_INTERVAL = 0.25
def _list_item_key(key: str) -> str:
"""Unique key for a list item. Timestamp for FIFO ordering; UUID prevents
collision when concurrent rpush calls occur within the same nanosecond.
"""
return f"{_LIST_KEY_PREFIX}{key}:{time.time_ns()}:{uuid.uuid4().hex}"
def _to_bytes(value: str | bytes | int | float) -> bytes:
if isinstance(value, bytes):
return value
return str(value).encode()
# ------------------------------------------------------------------
# Lock
# ------------------------------------------------------------------
class PostgresCacheLock(CacheLock):
"""Advisory-lock-based distributed lock.
Uses ``get_session_with_tenant`` for connection lifecycle. The lock is tied
to the session's connection; releasing or closing the session frees it.
NOTE: Unlike Redis locks, advisory locks do not auto-expire after
``timeout`` seconds. They are released when ``release()`` is
called or when the session is closed.
"""
def __init__(self, lock_id: int, timeout: float | None, tenant_id: str) -> None:
self._lock_id = lock_id
self._timeout = timeout
self._tenant_id = tenant_id
self._session_cm: AbstractContextManager[Session] | None = None
self._session: Session | None = None
self._acquired = False
def acquire(
self,
blocking: bool = True,
blocking_timeout: float | None = None,
) -> bool:
from onyx.db.engine.sql_engine import get_session_with_tenant
self._session_cm = get_session_with_tenant(tenant_id=self._tenant_id)
self._session = self._session_cm.__enter__()
try:
if not blocking:
return self._try_lock()
effective_timeout = blocking_timeout or self._timeout
deadline = (
(time.monotonic() + effective_timeout) if effective_timeout else None
)
while True:
if self._try_lock():
return True
if deadline is not None and time.monotonic() >= deadline:
return False
time.sleep(_LOCK_POLL_INTERVAL)
finally:
if not self._acquired:
self._close_session()
def release(self) -> None:
if not self._acquired or self._session is None:
return
try:
self._session.execute(select(func.pg_advisory_unlock(self._lock_id)))
finally:
self._acquired = False
self._close_session()
def owned(self) -> bool:
return self._acquired
def _close_session(self) -> None:
if self._session_cm is not None:
try:
self._session_cm.__exit__(None, None, None)
finally:
self._session_cm = None
self._session = None
def _try_lock(self) -> bool:
assert self._session is not None
result = self._session.execute(
select(func.pg_try_advisory_lock(self._lock_id))
).scalar()
if result:
self._acquired = True
return True
return False
# ------------------------------------------------------------------
# Backend
# ------------------------------------------------------------------
class PostgresCacheBackend(CacheBackend):
"""``CacheBackend`` backed by the ``cache_store`` table in PostgreSQL.
Each operation opens and closes its own database session so the backend
is safe to share across threads. Tenant isolation is handled by
SQLAlchemy's ``schema_translate_map`` (set by ``get_session_with_tenant``).
"""
def __init__(self, tenant_id: str) -> None:
self._tenant_id = tenant_id
# -- basic key/value ---------------------------------------------------
def get(self, key: str) -> bytes | None:
from onyx.db.engine.sql_engine import get_session_with_tenant
stmt = select(CacheStore.value).where(
CacheStore.key == key,
or_(CacheStore.expires_at.is_(None), CacheStore.expires_at > func.now()),
)
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
value = session.execute(stmt).scalar_one_or_none()
if value is None:
return None
return bytes(value)
def set(
self,
key: str,
value: str | bytes | int | float,
ex: int | None = None,
) -> None:
from onyx.db.engine.sql_engine import get_session_with_tenant
value_bytes = _to_bytes(value)
expires_at = (
datetime.now(timezone.utc) + timedelta(seconds=ex)
if ex is not None
else None
)
stmt = (
pg_insert(CacheStore)
.values(key=key, value=value_bytes, expires_at=expires_at)
.on_conflict_do_update(
index_elements=[CacheStore.key],
set_={"value": value_bytes, "expires_at": expires_at},
)
)
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
session.execute(stmt)
session.commit()
def delete(self, key: str) -> None:
from onyx.db.engine.sql_engine import get_session_with_tenant
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
session.execute(delete(CacheStore).where(CacheStore.key == key))
session.commit()
def exists(self, key: str) -> bool:
from onyx.db.engine.sql_engine import get_session_with_tenant
stmt = (
select(CacheStore.key)
.where(
CacheStore.key == key,
or_(
CacheStore.expires_at.is_(None),
CacheStore.expires_at > func.now(),
),
)
.limit(1)
)
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
return session.execute(stmt).first() is not None
# -- TTL ---------------------------------------------------------------
def expire(self, key: str, seconds: int) -> None:
from onyx.db.engine.sql_engine import get_session_with_tenant
new_exp = datetime.now(timezone.utc) + timedelta(seconds=seconds)
stmt = (
update(CacheStore).where(CacheStore.key == key).values(expires_at=new_exp)
)
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
session.execute(stmt)
session.commit()
def ttl(self, key: str) -> int:
from onyx.db.engine.sql_engine import get_session_with_tenant
stmt = select(CacheStore.expires_at).where(CacheStore.key == key)
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
result = session.execute(stmt).first()
if result is None:
return TTL_KEY_NOT_FOUND
expires_at: datetime | None = result[0]
if expires_at is None:
return TTL_NO_EXPIRY
remaining = (expires_at - datetime.now(timezone.utc)).total_seconds()
if remaining <= 0:
return TTL_KEY_NOT_FOUND
return int(remaining)
# -- distributed lock --------------------------------------------------
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
return PostgresCacheLock(
self._lock_id_for(name), timeout, tenant_id=self._tenant_id
)
# -- blocking list (MCP OAuth BLPOP pattern) ---------------------------
def rpush(self, key: str, value: str | bytes) -> None:
self.set(_list_item_key(key), value, ex=_LIST_ITEM_TTL_SECONDS)
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
if timeout <= 0:
raise ValueError(
"PostgresCacheBackend.blpop requires timeout > 0. "
"timeout=0 would block the calling thread indefinitely "
"with no way to interrupt short of process termination."
)
from onyx.db.engine.sql_engine import get_session_with_tenant
deadline = time.monotonic() + timeout
while True:
for key in keys:
lower = f"{_LIST_KEY_PREFIX}{key}:"
upper = f"{_LIST_KEY_PREFIX}{key}{_LIST_KEY_RANGE_TERMINATOR}"
stmt = (
select(CacheStore)
.where(
CacheStore.key >= lower,
CacheStore.key < upper,
or_(
CacheStore.expires_at.is_(None),
CacheStore.expires_at > func.now(),
),
)
.order_by(CacheStore.key)
.limit(1)
.with_for_update(skip_locked=True)
)
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
row = session.execute(stmt).scalars().first()
if row is not None:
value = bytes(row.value) if row.value else b""
session.delete(row)
session.commit()
return (key.encode(), value)
if time.monotonic() >= deadline:
return None
time.sleep(_BLPOP_POLL_INTERVAL)
# -- helpers -----------------------------------------------------------
def _lock_id_for(self, name: str) -> int:
"""Map *name* to a 64-bit signed int for ``pg_advisory_lock``."""
h = hashlib.md5(f"{self._tenant_id}:{name}".encode()).digest()
return struct.unpack("q", h[:8])[0]
# ------------------------------------------------------------------
# Periodic cleanup
# ------------------------------------------------------------------
def cleanup_expired_cache_entries() -> None:
"""Delete rows whose ``expires_at`` is in the past.
Called by the periodic poller every 5 minutes.
"""
from onyx.db.engine.sql_engine import get_session_with_current_tenant
with get_session_with_current_tenant() as session:
session.execute(
delete(CacheStore).where(
CacheStore.expires_at.is_not(None),
CacheStore.expires_at < func.now(),
)
)
session.commit()

92
backend/onyx/cache/redis_backend.py vendored Normal file
View File

@@ -0,0 +1,92 @@
from typing import cast
from redis.client import Redis
from redis.lock import Lock as RedisLock
from onyx.cache.interface import CacheBackend
from onyx.cache.interface import CacheLock
class RedisCacheLock(CacheLock):
"""Wraps ``redis.lock.Lock`` behind the ``CacheLock`` interface."""
def __init__(self, lock: RedisLock) -> None:
self._lock = lock
def acquire(
self,
blocking: bool = True,
blocking_timeout: float | None = None,
) -> bool:
return bool(
self._lock.acquire(
blocking=blocking,
blocking_timeout=blocking_timeout,
)
)
def release(self) -> None:
self._lock.release()
def owned(self) -> bool:
return bool(self._lock.owned())
class RedisCacheBackend(CacheBackend):
"""``CacheBackend`` implementation that delegates to a ``redis.Redis`` client.
This is a thin pass-through — every method maps 1-to-1 to the underlying
Redis command. ``TenantRedis`` key-prefixing is handled by the client
itself (provided by ``get_redis_client``).
"""
def __init__(self, redis_client: Redis) -> None:
self._r = redis_client
# -- basic key/value ---------------------------------------------------
def get(self, key: str) -> bytes | None:
val = self._r.get(key)
if val is None:
return None
if isinstance(val, bytes):
return val
return str(val).encode()
def set(
self,
key: str,
value: str | bytes | int | float,
ex: int | None = None,
) -> None:
self._r.set(key, value, ex=ex)
def delete(self, key: str) -> None:
self._r.delete(key)
def exists(self, key: str) -> bool:
return bool(self._r.exists(key))
# -- TTL ---------------------------------------------------------------
def expire(self, key: str, seconds: int) -> None:
self._r.expire(key, seconds)
def ttl(self, key: str) -> int:
return cast(int, self._r.ttl(key))
# -- distributed lock --------------------------------------------------
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
return RedisCacheLock(self._r.lock(name, timeout=timeout))
# -- blocking list (MCP OAuth BLPOP pattern) ---------------------------
def rpush(self, key: str, value: str | bytes) -> None:
self._r.rpush(key, value)
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
result = cast(list[bytes] | None, self._r.blpop(keys, timeout=timeout))
if result is None:
return None
return (result[0], result[1])

View File

@@ -1,57 +1,52 @@
from uuid import UUID
from redis.client import Redis
from onyx.cache.interface import CacheBackend
# Redis key prefixes for chat message processing
PREFIX = "chatprocessing"
FENCE_PREFIX = f"{PREFIX}_fence"
FENCE_TTL = 30 * 60 # 30 minutes
def _get_fence_key(chat_session_id: UUID) -> str:
"""
Generate the Redis key for a chat session processing a message.
"""Generate the cache key for a chat session processing fence.
Args:
chat_session_id: The UUID of the chat session
Returns:
The fence key string (tenant_id is automatically added by the Redis client)
The fence key string. Tenant isolation is handled automatically
by the cache backend (Redis key-prefixing or Postgres schema routing).
"""
return f"{FENCE_PREFIX}_{chat_session_id}"
def set_processing_status(
chat_session_id: UUID, redis_client: Redis, value: bool
chat_session_id: UUID, cache: CacheBackend, value: bool
) -> None:
"""
Set or clear the fence for a chat session processing a message.
"""Set or clear the fence for a chat session processing a message.
If the key exists, we are processing a message. If the key does not exist, we are not processing a message.
If the key exists, a message is being processed.
Args:
chat_session_id: The UUID of the chat session
redis_client: The Redis client to use
cache: Tenant-aware cache backend
value: True to set the fence, False to clear it
"""
fence_key = _get_fence_key(chat_session_id)
if value:
redis_client.set(fence_key, 0, ex=FENCE_TTL)
cache.set(fence_key, 0, ex=FENCE_TTL)
else:
redis_client.delete(fence_key)
cache.delete(fence_key)
def is_chat_session_processing(chat_session_id: UUID, redis_client: Redis) -> bool:
"""
Check if the chat session is processing a message.
def is_chat_session_processing(chat_session_id: UUID, cache: CacheBackend) -> bool:
"""Check if the chat session is processing a message.
Args:
chat_session_id: The UUID of the chat session
redis_client: The Redis client to use
cache: Tenant-aware cache backend
Returns:
True if the chat session is processing a message, False otherwise
"""
fence_key = _get_fence_key(chat_session_id)
return bool(redis_client.exists(fence_key))
return cache.exists(_get_fence_key(chat_session_id))

View File

@@ -1,6 +1,7 @@
import json
import time
from collections.abc import Callable
from typing import Any
from typing import Literal
from sqlalchemy.orm import Session
@@ -530,11 +531,13 @@ def _create_file_tool_metadata_message(
"""
lines = [
"You have access to the following files. Use the read_file tool to "
"read sections of any file:"
"read sections of any file. You MUST pass the file_id UUID (not the "
"filename) to read_file:"
]
for meta in file_metadata:
lines.append(
f'- {meta.file_id}: "{meta.filename}" (~{meta.approx_char_count:,} chars)'
f'- file_id="{meta.file_id}" filename="{meta.filename}" '
f"(~{meta.approx_char_count:,} chars)"
)
message_content = "\n".join(lines)
@@ -558,12 +561,16 @@ def _create_context_files_message(
# Format as documents JSON as described in README
documents_list = []
for idx, file_text in enumerate(context_files.file_texts, start=1):
documents_list.append(
{
"document": idx,
"contents": file_text,
}
title = (
context_files.file_metadata[idx - 1].filename
if idx - 1 < len(context_files.file_metadata)
else None
)
entry: dict[str, Any] = {"document": idx}
if title:
entry["title"] = title
entry["contents"] = file_text
documents_list.append(entry)
documents_json = json.dumps({"documents": documents_list}, indent=2)
message_content = f"Here are some documents provided for context, they may not all be relevant:\n{documents_json}"

View File

@@ -11,9 +11,10 @@ from contextvars import Token
from uuid import UUID
from pydantic import BaseModel
from redis.client import Redis
from sqlalchemy.orm import Session
from onyx.cache.factory import get_cache_backend
from onyx.cache.interface import CacheBackend
from onyx.chat.chat_processing_checker import set_processing_status
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.chat_state import run_chat_loop_with_state_containers
@@ -79,7 +80,6 @@ from onyx.llm.request_context import reset_llm_mock_response
from onyx.llm.request_context import set_llm_mock_response
from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.onyxbot.slack.models import SlackContext
from onyx.redis.redis_pool import get_redis_client
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
from onyx.server.query_and_chat.models import MessageResponseIDInfo
from onyx.server.query_and_chat.models import SendMessageRequest
@@ -448,7 +448,7 @@ def handle_stream_message_objects(
llm: LLM | None = None
chat_session: ChatSession | None = None
redis_client: Redis | None = None
cache: CacheBackend | None = None
user_id = user.id
if user.is_anonymous:
@@ -809,19 +809,19 @@ def handle_stream_message_objects(
)
simple_chat_history.insert(0, summary_simple)
redis_client = get_redis_client()
cache = get_cache_backend()
reset_cancel_status(
chat_session.id,
redis_client,
cache,
)
def check_is_connected() -> bool:
return check_stop_signal(chat_session.id, redis_client)
return check_stop_signal(chat_session.id, cache)
set_processing_status(
chat_session_id=chat_session.id,
redis_client=redis_client,
cache=cache,
value=True,
)
@@ -968,10 +968,10 @@ def handle_stream_message_objects(
reset_llm_mock_response(mock_response_token)
try:
if redis_client is not None and chat_session is not None:
if cache is not None and chat_session is not None:
set_processing_status(
chat_session_id=chat_session.id,
redis_client=redis_client,
cache=cache,
value=False,
)
except Exception:

View File

@@ -1,65 +1,58 @@
from uuid import UUID
from redis.client import Redis
from onyx.cache.interface import CacheBackend
# Redis key prefixes for chat session stop signals
PREFIX = "chatsessionstop"
FENCE_PREFIX = f"{PREFIX}_fence"
FENCE_TTL = 10 * 60 # 10 minutes - defensive TTL to prevent memory leaks
FENCE_TTL = 10 * 60 # 10 minutes
def _get_fence_key(chat_session_id: UUID) -> str:
"""
Generate the Redis key for a chat session stop signal fence.
"""Generate the cache key for a chat session stop signal fence.
Args:
chat_session_id: The UUID of the chat session
Returns:
The fence key string (tenant_id is automatically added by the Redis client)
The fence key string. Tenant isolation is handled automatically
by the cache backend (Redis key-prefixing or Postgres schema routing).
"""
return f"{FENCE_PREFIX}_{chat_session_id}"
def set_fence(chat_session_id: UUID, redis_client: Redis, value: bool) -> None:
"""
Set or clear the stop signal fence for a chat session.
def set_fence(chat_session_id: UUID, cache: CacheBackend, value: bool) -> None:
"""Set or clear the stop signal fence for a chat session.
Args:
chat_session_id: The UUID of the chat session
redis_client: Redis client to use (tenant-aware client that auto-prefixes keys)
cache: Tenant-aware cache backend
value: True to set the fence (stop signal), False to clear it
"""
fence_key = _get_fence_key(chat_session_id)
if not value:
redis_client.delete(fence_key)
cache.delete(fence_key)
return
redis_client.set(fence_key, 0, ex=FENCE_TTL)
cache.set(fence_key, 0, ex=FENCE_TTL)
def is_connected(chat_session_id: UUID, redis_client: Redis) -> bool:
"""
Check if the chat session should continue (not stopped).
def is_connected(chat_session_id: UUID, cache: CacheBackend) -> bool:
"""Check if the chat session should continue (not stopped).
Args:
chat_session_id: The UUID of the chat session to check
redis_client: Redis client to use for checking the stop signal (tenant-aware client that auto-prefixes keys)
cache: Tenant-aware cache backend
Returns:
True if the session should continue, False if it should stop
"""
fence_key = _get_fence_key(chat_session_id)
return not bool(redis_client.exists(fence_key))
return not cache.exists(_get_fence_key(chat_session_id))
def reset_cancel_status(chat_session_id: UUID, redis_client: Redis) -> None:
"""
Clear the stop signal for a chat session.
def reset_cancel_status(chat_session_id: UUID, cache: CacheBackend) -> None:
"""Clear the stop signal for a chat session.
Args:
chat_session_id: The UUID of the chat session
redis_client: Redis client to use (tenant-aware client that auto-prefixes keys)
cache: Tenant-aware cache backend
"""
fence_key = _get_fence_key(chat_session_id)
redis_client.delete(fence_key)
cache.delete(_get_fence_key(chat_session_id))

View File

@@ -6,6 +6,7 @@ from datetime import timezone
from typing import cast
from onyx.auth.schemas import AuthBackend
from onyx.cache.interface import CacheBackendType
from onyx.configs.constants import AuthType
from onyx.configs.constants import QueryHistoryType
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
@@ -54,6 +55,12 @@ DISABLE_USER_KNOWLEDGE = os.environ.get("DISABLE_USER_KNOWLEDGE", "").lower() ==
# are disabled but core chat, tools, user file uploads, and Projects still work.
DISABLE_VECTOR_DB = os.environ.get("DISABLE_VECTOR_DB", "").lower() == "true"
# Which backend to use for caching, locks, and ephemeral state.
# "redis" (default) or "postgres" (only valid when DISABLE_VECTOR_DB=true).
CACHE_BACKEND = CacheBackendType(
os.environ.get("CACHE_BACKEND", CacheBackendType.REDIS)
)
# Maximum token count for a single uploaded file. Files exceeding this are rejected.
# Defaults to 100k tokens (or 10M when vector DB is disabled).
_DEFAULT_FILE_TOKEN_LIMIT = 10_000_000 if DISABLE_VECTOR_DB else 100_000
@@ -812,7 +819,9 @@ RERANK_COUNT = int(os.environ.get("RERANK_COUNT") or 1000)
# Tool Configs
#####
# Code Interpreter Service Configuration
CODE_INTERPRETER_BASE_URL = os.environ.get("CODE_INTERPRETER_BASE_URL")
CODE_INTERPRETER_BASE_URL = os.environ.get(
"CODE_INTERPRETER_BASE_URL", "http://localhost:8000"
)
CODE_INTERPRETER_DEFAULT_TIMEOUT_MS = int(
os.environ.get("CODE_INTERPRETER_DEFAULT_TIMEOUT_MS") or 60_000

View File

@@ -98,6 +98,7 @@ def get_chat_sessions_by_user(
db_session: Session,
include_onyxbot_flows: bool = False,
limit: int = 50,
before: datetime | None = None,
project_id: int | None = None,
only_non_project_chats: bool = False,
include_failed_chats: bool = False,
@@ -112,6 +113,9 @@ def get_chat_sessions_by_user(
if deleted is not None:
stmt = stmt.where(ChatSession.deleted == deleted)
if before is not None:
stmt = stmt.where(ChatSession.time_updated < before)
if limit:
stmt = stmt.limit(limit)

View File

@@ -186,6 +186,7 @@ class EmbeddingPrecision(str, PyEnum):
class UserFileStatus(str, PyEnum):
PROCESSING = "PROCESSING"
INDEXING = "INDEXING"
COMPLETED = "COMPLETED"
FAILED = "FAILED"
CANCELED = "CANCELED"

View File

@@ -202,7 +202,6 @@ def create_default_image_gen_config_from_api_key(
api_key=api_key,
api_base=None,
api_version=None,
default_model_name=model_name,
deployment_name=None,
is_public=True,
)

View File

@@ -109,45 +109,38 @@ def can_user_access_llm_provider(
is_admin: If True, bypass user group restrictions but still respect persona restrictions
Access logic:
1. If is_public=True → everyone has access (public override)
2. If is_public=False:
- Both groups AND personas set → must satisfy BOTH (AND logic, admins bypass group check)
- Only groups set → must be in one of the groups (OR across groups, admins bypass)
- Only personas set → must use one of the personas (OR across personas, applies to admins)
- Neither set → NOBODY has access unless admin (locked, admin-only)
- is_public controls USER access (group bypass): when True, all users can access
regardless of group membership. When False, user must be in a whitelisted group
(or be admin).
- Persona restrictions are ALWAYS enforced when set, regardless of is_public.
This allows admins to make a provider available to all users while still
restricting which personas (assistants) can use it.
Decision matrix:
1. is_public=True, no personas set → everyone has access
2. is_public=True, personas set → all users, but only whitelisted personas
3. is_public=False, groups+personas set → must satisfy BOTH (admins bypass groups)
4. is_public=False, only groups set → must be in group (admins bypass)
5. is_public=False, only personas set → must use whitelisted persona
6. is_public=False, neither set → admin-only (locked)
"""
# Public override - everyone has access
if provider.is_public:
return True
# Extract IDs once to avoid multiple iterations
provider_group_ids = (
{group.id for group in provider.groups} if provider.groups else set()
)
provider_persona_ids = (
{p.id for p in provider.personas} if provider.personas else set()
)
provider_group_ids = {g.id for g in (provider.groups or [])}
provider_persona_ids = {p.id for p in (provider.personas or [])}
has_groups = bool(provider_group_ids)
has_personas = bool(provider_persona_ids)
# Both groups AND personas set → AND logic (must satisfy both)
if has_groups and has_personas:
# Admins bypass group check but still must satisfy persona restrictions
user_in_group = is_admin or bool(user_group_ids & provider_group_ids)
persona_allowed = persona.id in provider_persona_ids if persona else False
return user_in_group and persona_allowed
# Persona restrictions are always enforced when set, regardless of is_public
if has_personas and not (persona and persona.id in provider_persona_ids):
return False
if provider.is_public:
return True
# Only groups set → user must be in one of the groups (admins bypass)
if has_groups:
return is_admin or bool(user_group_ids & provider_group_ids)
# Only personas set → persona must be in allowed list (applies to admins too)
if has_personas:
return persona.id in provider_persona_ids if persona else False
# Neither groups nor personas set, and not public → admins can access
return is_admin
# No groups: either persona-whitelisted (already passed) or admin-only if locked
return has_personas or is_admin
def validate_persona_ids_exist(
@@ -213,11 +206,29 @@ def upsert_llm_provider(
llm_provider_upsert_request: LLMProviderUpsertRequest,
db_session: Session,
) -> LLMProviderView:
existing_llm_provider = fetch_existing_llm_provider(
name=llm_provider_upsert_request.name, db_session=db_session
)
existing_llm_provider: LLMProviderModel | None = None
if llm_provider_upsert_request.id:
existing_llm_provider = fetch_existing_llm_provider_by_id(
id=llm_provider_upsert_request.id, db_session=db_session
)
if not existing_llm_provider:
raise ValueError(
f"LLM provider with id {llm_provider_upsert_request.id} not found"
)
if not existing_llm_provider:
if existing_llm_provider.name != llm_provider_upsert_request.name:
raise ValueError(
f"LLM provider with id {llm_provider_upsert_request.id} name change not allowed"
)
else:
existing_llm_provider = fetch_existing_llm_provider(
name=llm_provider_upsert_request.name, db_session=db_session
)
if existing_llm_provider:
raise ValueError(
f"LLM provider with name '{llm_provider_upsert_request.name}'"
" already exists"
)
existing_llm_provider = LLMProviderModel(name=llm_provider_upsert_request.name)
db_session.add(existing_llm_provider)
@@ -238,11 +249,7 @@ def upsert_llm_provider(
existing_llm_provider.api_base = api_base
existing_llm_provider.api_version = llm_provider_upsert_request.api_version
existing_llm_provider.custom_config = custom_config
# TODO: Remove default model name on api change
# Needed due to /provider/{id}/default endpoint not disclosing the default model name
existing_llm_provider.default_model_name = (
llm_provider_upsert_request.default_model_name
)
existing_llm_provider.is_public = llm_provider_upsert_request.is_public
existing_llm_provider.is_auto_mode = llm_provider_upsert_request.is_auto_mode
existing_llm_provider.deployment_name = llm_provider_upsert_request.deployment_name
@@ -306,15 +313,6 @@ def upsert_llm_provider(
display_name=model_config.display_name,
)
default_model = fetch_default_model(db_session, LLMModelFlowType.CHAT)
if default_model and default_model.llm_provider_id == existing_llm_provider.id:
_update_default_model(
db_session=db_session,
provider_id=existing_llm_provider.id,
model=existing_llm_provider.default_model_name,
flow_type=LLMModelFlowType.CHAT,
)
# Make sure the relationship table stays up to date
update_group_llm_provider_relationships__no_commit(
llm_provider_id=existing_llm_provider.id,
@@ -488,6 +486,22 @@ def fetch_existing_llm_provider(
return provider_model
def fetch_existing_llm_provider_by_id(
id: int, db_session: Session
) -> LLMProviderModel | None:
provider_model = db_session.scalar(
select(LLMProviderModel)
.where(LLMProviderModel.id == id)
.options(
selectinload(LLMProviderModel.model_configurations),
selectinload(LLMProviderModel.groups),
selectinload(LLMProviderModel.personas),
)
)
return provider_model
def fetch_embedding_provider(
db_session: Session, provider_type: EmbeddingProvider
) -> CloudEmbeddingProviderModel | None:
@@ -604,22 +618,13 @@ def remove_llm_provider__no_commit(db_session: Session, provider_id: int) -> Non
db_session.flush()
def update_default_provider(provider_id: int, db_session: Session) -> None:
# Attempt to get the default_model_name from the provider first
# TODO: Remove default_model_name check
provider = db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.id == provider_id,
)
)
if provider is None:
raise ValueError(f"LLM Provider with id={provider_id} does not exist")
def update_default_provider(
provider_id: int, model_name: str, db_session: Session
) -> None:
_update_default_model(
db_session,
provider_id,
provider.default_model_name, # type: ignore[arg-type]
model_name,
LLMModelFlowType.CHAT,
)
@@ -805,12 +810,6 @@ def sync_auto_mode_models(
)
changes += 1
# In Auto mode, default model is always set from GitHub config
default_model = llm_recommendations.get_default_model(provider.provider)
if default_model and provider.default_model_name != default_model.name:
provider.default_model_name = default_model.name
changes += 1
db_session.commit()
return changes

View File

@@ -103,7 +103,6 @@ from onyx.utils.encryption import encrypt_string_to_bytes
from onyx.utils.sensitive import SensitiveValue
from onyx.utils.headers import HeaderItemDict
from shared_configs.enums import EmbeddingProvider
from onyx.context.search.enums import RecencyBiasSetting
# TODO: After anonymous user migration has been deployed, make user_id columns NOT NULL
# and update Mapped[User | None] relationships to Mapped[User] where needed.
@@ -3265,19 +3264,6 @@ class Persona(Base):
)
name: Mapped[str] = mapped_column(String)
description: Mapped[str] = mapped_column(String)
# Number of chunks to pass to the LLM for generation.
num_chunks: Mapped[float | None] = mapped_column(Float, nullable=True)
chunks_above: Mapped[int] = mapped_column(Integer)
chunks_below: Mapped[int] = mapped_column(Integer)
# Pass every chunk through LLM for evaluation, fairly expensive
# Can be turned off globally by admin, in which case, this setting is ignored
llm_relevance_filter: Mapped[bool] = mapped_column(Boolean)
# Enables using LLM to extract time and source type filters
# Can also be admin disabled globally
llm_filter_extraction: Mapped[bool] = mapped_column(Boolean)
recency_bias: Mapped[RecencyBiasSetting] = mapped_column(
Enum(RecencyBiasSetting, native_enum=False)
)
# Allows the persona to specify a specific default LLM model
# NOTE: only is applied on the actual response generation - is not used for things like
@@ -3304,11 +3290,8 @@ class Persona(Base):
# Treated specially (cannot be user edited etc.)
builtin_persona: Mapped[bool] = mapped_column(Boolean, default=False)
# Default personas are personas created by admins and are automatically added
# to all users' assistants list.
is_default_persona: Mapped[bool] = mapped_column(
Boolean, default=False, nullable=False
)
# Featured personas are highlighted in the UI
featured: Mapped[bool] = mapped_column(Boolean, default=False)
# controls whether the persona is available to be selected by users
is_visible: Mapped[bool] = mapped_column(Boolean, default=True)
# controls the ordering of personas in the UI
@@ -4943,7 +4926,9 @@ class ScimUserMapping(Base):
__tablename__ = "scim_user_mapping"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
external_id: Mapped[str] = mapped_column(String, unique=True, index=True)
external_id: Mapped[str | None] = mapped_column(
String, unique=True, index=True, nullable=True
)
user_id: Mapped[UUID] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), unique=True, nullable=False
)
@@ -5000,3 +4985,25 @@ class CodeInterpreterServer(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True)
server_enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
class CacheStore(Base):
"""Key-value cache table used by ``PostgresCacheBackend``.
Replaces Redis for simple KV caching, locks, and list operations
when ``CACHE_BACKEND=postgres`` (NO_VECTOR_DB deployments).
Intentionally separate from ``KVStore``:
- Stores raw bytes (LargeBinary) vs JSONB, matching Redis semantics.
- Has ``expires_at`` for TTL; rows are periodically garbage-collected.
- Holds ephemeral data (tokens, stop signals, lock state) not
persistent application config, so cleanup can be aggressive.
"""
__tablename__ = "cache_store"
key: Mapped[str] = mapped_column(String, primary_key=True)
value: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
expires_at: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)

View File

@@ -18,11 +18,8 @@ from sqlalchemy.orm import Session
from onyx.access.hierarchy_access import get_user_external_group_ids
from onyx.auth.schemas import UserRole
from onyx.configs.app_configs import CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS
from onyx.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW
from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.configs.constants import NotificationType
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX
from onyx.db.document_access import get_accessible_documents_by_ids
from onyx.db.models import ConnectorCredentialPair
@@ -254,13 +251,15 @@ def create_update_persona(
# Permission to actually use these is checked later
try:
# Default persona validation
if create_persona_request.is_default_persona:
# Curators can edit default personas, but not make them
# Featured persona validation
if create_persona_request.featured:
# Curators can edit featured personas, but not make them
# TODO this will be reworked soon with RBAC permissions feature
if user.role == UserRole.CURATOR or user.role == UserRole.GLOBAL_CURATOR:
pass
elif user.role != UserRole.ADMIN:
raise ValueError("Only admins can make a default persona")
raise ValueError("Only admins can make a featured persona")
# Convert incoming string UUIDs to UUID objects for DB operations
converted_user_file_ids = None
@@ -281,7 +280,6 @@ def create_update_persona(
document_set_ids=create_persona_request.document_set_ids,
tool_ids=create_persona_request.tool_ids,
is_public=create_persona_request.is_public,
recency_bias=create_persona_request.recency_bias,
llm_model_provider_override=create_persona_request.llm_model_provider_override,
llm_model_version_override=create_persona_request.llm_model_version_override,
starter_messages=create_persona_request.starter_messages,
@@ -295,10 +293,7 @@ def create_update_persona(
remove_image=create_persona_request.remove_image,
search_start_date=create_persona_request.search_start_date,
label_ids=create_persona_request.label_ids,
num_chunks=create_persona_request.num_chunks,
llm_relevance_filter=create_persona_request.llm_relevance_filter,
llm_filter_extraction=create_persona_request.llm_filter_extraction,
is_default_persona=create_persona_request.is_default_persona,
featured=create_persona_request.featured,
user_file_ids=converted_user_file_ids,
commit=False,
hierarchy_node_ids=create_persona_request.hierarchy_node_ids,
@@ -874,10 +869,6 @@ def upsert_persona(
user: User | None,
name: str,
description: str,
num_chunks: float,
llm_relevance_filter: bool,
llm_filter_extraction: bool,
recency_bias: RecencyBiasSetting,
llm_model_provider_override: str | None,
llm_model_version_override: str | None,
starter_messages: list[StarterMessage] | None,
@@ -898,13 +889,11 @@ def upsert_persona(
remove_image: bool | None = None,
search_start_date: datetime | None = None,
builtin_persona: bool = False,
is_default_persona: bool | None = None,
featured: bool | None = None,
label_ids: list[int] | None = None,
user_file_ids: list[UUID] | None = None,
hierarchy_node_ids: list[int] | None = None,
document_ids: list[str] | None = None,
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
chunks_below: int = CONTEXT_CHUNKS_BELOW,
replace_base_system_prompt: bool = False,
) -> Persona:
"""
@@ -1015,12 +1004,6 @@ def upsert_persona(
# `default` and `built-in` properties can only be set when creating a persona.
existing_persona.name = name
existing_persona.description = description
existing_persona.num_chunks = num_chunks
existing_persona.chunks_above = chunks_above
existing_persona.chunks_below = chunks_below
existing_persona.llm_relevance_filter = llm_relevance_filter
existing_persona.llm_filter_extraction = llm_filter_extraction
existing_persona.recency_bias = recency_bias
existing_persona.llm_model_provider_override = llm_model_provider_override
existing_persona.llm_model_version_override = llm_model_version_override
existing_persona.starter_messages = starter_messages
@@ -1034,10 +1017,8 @@ def upsert_persona(
if label_ids is not None:
existing_persona.labels.clear()
existing_persona.labels = labels or []
existing_persona.is_default_persona = (
is_default_persona
if is_default_persona is not None
else existing_persona.is_default_persona
existing_persona.featured = (
featured if featured is not None else existing_persona.featured
)
# Update embedded prompt fields if provided
if system_prompt is not None:
@@ -1090,12 +1071,6 @@ def upsert_persona(
is_public=is_public,
name=name,
description=description,
num_chunks=num_chunks,
chunks_above=chunks_above,
chunks_below=chunks_below,
llm_relevance_filter=llm_relevance_filter,
llm_filter_extraction=llm_filter_extraction,
recency_bias=recency_bias,
builtin_persona=builtin_persona,
system_prompt=system_prompt or "",
task_prompt=task_prompt or "",
@@ -1111,9 +1086,7 @@ def upsert_persona(
display_priority=display_priority,
is_visible=is_visible,
search_start_date=search_start_date,
is_default_persona=(
is_default_persona if is_default_persona is not None else False
),
featured=(featured if featured is not None else False),
user_files=user_files or [],
labels=labels or [],
hierarchy_nodes=hierarchy_nodes or [],
@@ -1158,9 +1131,9 @@ def delete_old_default_personas(
db_session.commit()
def update_persona_is_default(
def update_persona_featured(
persona_id: int,
is_default: bool,
featured: bool,
db_session: Session,
user: User,
) -> None:
@@ -1168,7 +1141,7 @@ def update_persona_is_default(
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
)
persona.is_default_persona = is_default
persona.featured = featured
db_session.commit()

View File

@@ -9,8 +9,9 @@ from pydantic import BaseModel
from pydantic import ConfigDict
from sqlalchemy import func
from sqlalchemy.orm import Session
from starlette.background import BackgroundTasks
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
@@ -105,8 +106,8 @@ def upload_files_to_user_files_with_indexing(
user: User,
temp_id_map: dict[str, str] | None,
db_session: Session,
background_tasks: BackgroundTasks | None = None,
) -> CategorizedFilesResult:
# Validate project ownership if a project_id is provided
if project_id is not None and user is not None:
if not check_project_ownership(project_id, user.id, db_session):
raise HTTPException(status_code=404, detail="Project not found")
@@ -127,16 +128,27 @@ def upload_files_to_user_files_with_indexing(
logger.warning(
f"File {rejected_file.filename} rejected for {rejected_file.reason}"
)
for user_file in user_files:
task = client_app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
)
logger.info(
f"Triggered indexing for user_file_id={user_file.id} with task_id={task.id}"
)
if DISABLE_VECTOR_DB and background_tasks is not None:
from onyx.background.task_utils import drain_processing_loop
background_tasks.add_task(drain_processing_loop, tenant_id)
for user_file in user_files:
logger.info(f"Queued in-process processing for user_file_id={user_file.id}")
else:
from onyx.background.celery.versioned_apps.client import app as client_app
for user_file in user_files:
task = client_app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
)
logger.info(
f"Triggered indexing for user_file_id={user_file.id} "
f"with task_id={task.id}"
)
return CategorizedFilesResult(
user_files=user_files,

View File

@@ -5,8 +5,6 @@ from sqlalchemy import select
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.constants import DEFAULT_PERSONA_SLACK_CHANNEL_NAME
from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX
from onyx.db.models import ChannelConfig
@@ -45,8 +43,6 @@ def create_slack_channel_persona(
channel_name: str | None,
document_set_ids: list[int],
existing_persona_id: int | None = None,
num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
enable_auto_filters: bool = False,
) -> Persona:
"""NOTE: does not commit changes"""
@@ -73,17 +69,13 @@ def create_slack_channel_persona(
system_prompt="",
task_prompt="",
datetime_aware=True,
num_chunks=num_chunks,
llm_relevance_filter=True,
llm_filter_extraction=enable_auto_filters,
recency_bias=RecencyBiasSetting.AUTO,
tool_ids=[search_tool.id],
document_set_ids=document_set_ids,
llm_model_provider_override=None,
llm_model_version_override=None,
starter_messages=None,
is_public=True,
is_default_persona=False,
featured=False,
db_session=db_session,
commit=False,
)

View File

@@ -6,7 +6,6 @@ import httpx
from opensearchpy import NotFoundError
from onyx.access.models import DocumentAccess
from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH
from onyx.configs.app_configs import VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT
from onyx.configs.chat_configs import NUM_RETURNED_HITS
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
@@ -563,12 +562,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
)
if not self._client.index_exists():
if USING_AWS_MANAGED_OPENSEARCH:
index_settings = (
DocumentSchema.get_index_settings_for_aws_managed_opensearch()
)
else:
index_settings = DocumentSchema.get_index_settings()
index_settings = DocumentSchema.get_index_settings_based_on_environment()
self._client.create_index(
mappings=expected_mappings,
settings=index_settings,

View File

@@ -12,6 +12,7 @@ from pydantic import model_validator
from pydantic import SerializerFunctionWrapHandler
from onyx.configs.app_configs import OPENSEARCH_TEXT_ANALYZER
from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
from onyx.document_index.opensearch.constants import EF_CONSTRUCTION
@@ -525,7 +526,7 @@ class DocumentSchema:
}
@staticmethod
def get_index_settings_for_aws_managed_opensearch() -> dict[str, Any]:
def get_index_settings_for_aws_managed_opensearch_st_dev() -> dict[str, Any]:
"""
Settings for AWS-managed OpenSearch.
@@ -546,3 +547,41 @@ class DocumentSchema:
"knn.algo_param.ef_search": EF_SEARCH,
}
}
@staticmethod
def get_index_settings_for_aws_managed_opensearch_mt_cloud() -> dict[str, Any]:
"""
Settings for AWS-managed OpenSearch in multi-tenant cloud.
324 shards very roughly targets a storage load of ~30Gb per shard, which
according to AWS OpenSearch documentation is within a good target range.
As documented above we need 2 replicas for a total of 3 copies of the
data because the cluster is configured with 3-AZ awareness.
"""
return {
"index": {
"number_of_shards": 324,
"number_of_replicas": 2,
# Required for vector search.
"knn": True,
"knn.algo_param.ef_search": EF_SEARCH,
}
}
@staticmethod
def get_index_settings_based_on_environment() -> dict[str, Any]:
"""
Returns the index settings based on the environment.
"""
if USING_AWS_MANAGED_OPENSEARCH:
if MULTI_TENANT:
return (
DocumentSchema.get_index_settings_for_aws_managed_opensearch_mt_cloud()
)
else:
return (
DocumentSchema.get_index_settings_for_aws_managed_opensearch_st_dev()
)
else:
return DocumentSchema.get_index_settings()

View File

@@ -4,39 +4,33 @@ import base64
import json
import uuid
from typing import Any
from typing import cast
from typing import Dict
from typing import Optional
from onyx.cache.factory import get_cache_backend
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Redis key prefix for OAuth state
OAUTH_STATE_PREFIX = "federated_oauth"
# Default TTL for OAuth state (5 minutes)
OAUTH_STATE_TTL = 300
OAUTH_STATE_TTL = 300 # 5 minutes
class OAuthSession:
"""Represents an OAuth session stored in Redis."""
"""Represents an OAuth session stored in the cache backend."""
def __init__(
self,
federated_connector_id: int,
user_id: str,
redirect_uri: Optional[str] = None,
additional_data: Optional[Dict[str, Any]] = None,
redirect_uri: str | None = None,
additional_data: dict[str, Any] | None = None,
):
self.federated_connector_id = federated_connector_id
self.user_id = user_id
self.redirect_uri = redirect_uri
self.additional_data = additional_data or {}
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for Redis storage."""
def to_dict(self) -> dict[str, Any]:
return {
"federated_connector_id": self.federated_connector_id,
"user_id": self.user_id,
@@ -45,8 +39,7 @@ class OAuthSession:
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "OAuthSession":
"""Create from dictionary retrieved from Redis."""
def from_dict(cls, data: dict[str, Any]) -> "OAuthSession":
return cls(
federated_connector_id=data["federated_connector_id"],
user_id=data["user_id"],
@@ -58,31 +51,27 @@ class OAuthSession:
def generate_oauth_state(
federated_connector_id: int,
user_id: str,
redirect_uri: Optional[str] = None,
additional_data: Optional[Dict[str, Any]] = None,
redirect_uri: str | None = None,
additional_data: dict[str, Any] | None = None,
ttl: int = OAUTH_STATE_TTL,
) -> str:
"""
Generate a secure state parameter and store session data in Redis.
Generate a secure state parameter and store session data in the cache backend.
Args:
federated_connector_id: ID of the federated connector
user_id: ID of the user initiating OAuth
redirect_uri: Optional redirect URI after OAuth completion
additional_data: Any additional data to store with the session
ttl: Time-to-live in seconds for the Redis key
ttl: Time-to-live in seconds for the cache key
Returns:
Base64-encoded state parameter
"""
# Generate a random UUID for the state
state_uuid = uuid.uuid4()
state_b64 = base64.urlsafe_b64encode(state_uuid.bytes).decode("utf-8").rstrip("=")
# Convert UUID to base64 for URL-safe state parameter
state_bytes = state_uuid.bytes
state_b64 = base64.urlsafe_b64encode(state_bytes).decode("utf-8").rstrip("=")
# Create session object
session = OAuthSession(
federated_connector_id=federated_connector_id,
user_id=user_id,
@@ -90,15 +79,9 @@ def generate_oauth_state(
additional_data=additional_data,
)
# Store in Redis with TTL
redis_client = get_redis_client()
redis_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
redis_client.set(
redis_key,
json.dumps(session.to_dict()),
ex=ttl,
)
cache = get_cache_backend()
cache_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
cache.set(cache_key, json.dumps(session.to_dict()), ex=ttl)
logger.info(
f"Generated OAuth state for federated_connector_id={federated_connector_id}, "
@@ -125,18 +108,15 @@ def verify_oauth_state(state: str) -> OAuthSession:
state_bytes = base64.urlsafe_b64decode(padded_state)
state_uuid = uuid.UUID(bytes=state_bytes)
# Look up in Redis
redis_client = get_redis_client()
redis_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
cache = get_cache_backend()
cache_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
session_data = cast(bytes, redis_client.get(redis_key))
session_data = cache.get(cache_key)
if not session_data:
raise ValueError(f"OAuth state not found in Redis: {state}")
raise ValueError(f"OAuth state not found: {state}")
# Delete the key after retrieval (one-time use)
redis_client.delete(redis_key)
cache.delete(cache_key)
# Parse and return session
session_dict = json.loads(session_data)
return OAuthSession.from_dict(session_dict)

View File

@@ -1,13 +1,11 @@
import json
from typing import cast
from redis.client import Redis
from onyx.cache.interface import CacheBackend
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import KVStore
from onyx.key_value_store.interface import KeyValueStore
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
from onyx.utils.special_types import JSON_ro
@@ -20,22 +18,27 @@ KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day
class PgRedisKVStore(KeyValueStore):
def __init__(self, redis_client: Redis | None = None) -> None:
# If no redis_client is provided, fall back to the context var
if redis_client is not None:
self.redis_client = redis_client
else:
self.redis_client = get_redis_client()
def __init__(self, cache: CacheBackend | None = None) -> None:
self._cache = cache
def _get_cache(self) -> CacheBackend:
if self._cache is None:
from onyx.cache.factory import get_cache_backend
self._cache = get_cache_backend()
return self._cache
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
# Not encrypted in Redis, but encrypted in Postgres
# Not encrypted in Cache backend (typically Redis), but encrypted in Postgres
try:
self.redis_client.set(
self._get_cache().set(
REDIS_KEY_PREFIX + key, json.dumps(val), ex=KV_REDIS_KEY_EXPIRATION
)
except Exception as e:
# Fallback gracefully to Postgres if Redis fails
logger.error(f"Failed to set value in Redis for key '{key}': {str(e)}")
# Fallback gracefully to Postgres if Cache backend fails
logger.error(
f"Failed to set value in Cache backend for key '{key}': {str(e)}"
)
encrypted_val = val if encrypt else None
plain_val = val if not encrypt else None
@@ -53,16 +56,12 @@ class PgRedisKVStore(KeyValueStore):
def load(self, key: str, refresh_cache: bool = False) -> JSON_ro:
if not refresh_cache:
try:
redis_value = self.redis_client.get(REDIS_KEY_PREFIX + key)
if redis_value:
if not isinstance(redis_value, bytes):
raise ValueError(
f"Redis value for key '{key}' is not a bytes object"
)
return json.loads(redis_value.decode("utf-8"))
cached = self._get_cache().get(REDIS_KEY_PREFIX + key)
if cached is not None:
return json.loads(cached.decode("utf-8"))
except Exception as e:
logger.error(
f"Failed to get value from Redis for key '{key}': {str(e)}"
f"Failed to get value from cache for key '{key}': {str(e)}"
)
with get_session_with_current_tenant() as db_session:
@@ -79,21 +78,21 @@ class PgRedisKVStore(KeyValueStore):
value = None
try:
self.redis_client.set(
self._get_cache().set(
REDIS_KEY_PREFIX + key,
json.dumps(value),
ex=KV_REDIS_KEY_EXPIRATION,
)
except Exception as e:
logger.error(f"Failed to set value in Redis for key '{key}': {str(e)}")
logger.error(f"Failed to set value in cache for key '{key}': {str(e)}")
return cast(JSON_ro, value)
def delete(self, key: str) -> None:
try:
self.redis_client.delete(REDIS_KEY_PREFIX + key)
self._get_cache().delete(REDIS_KEY_PREFIX + key)
except Exception as e:
logger.error(f"Failed to delete value from Redis for key '{key}': {str(e)}")
logger.error(f"Failed to delete value from cache for key '{key}': {str(e)}")
with get_session_with_current_tenant() as db_session:
result = db_session.query(KVStore).filter_by(key=key).delete()

View File

@@ -67,6 +67,18 @@ Status checked against LiteLLM v1.81.6-nightly (2026-02-02):
STATUS: STILL NEEDED - litellm_core_utils/litellm_logging.py lines 3185-3199 set
usage as a dict with chat completion format instead of keeping it as
ResponseAPIUsage. Our patch creates a deep copy before modification.
7. Responses API metadata=None TypeError (_patch_responses_metadata_none):
- LiteLLM's @client decorator wrapper in utils.py uses kwargs.get("metadata", {})
to check for router calls, but when metadata is explicitly None (key exists with
value None), the default {} is not used
- This causes "argument of type 'NoneType' is not iterable" TypeError which swallows
the real exception (e.g. AuthenticationError for wrong API key)
- Surfaces as: APIConnectionError: OpenAIException - argument of type 'NoneType' is
not iterable
STATUS: STILL NEEDED - litellm/utils.py wrapper function (line 1721) does not guard
against metadata being explicitly None. Triggered when Responses API bridge
passes **litellm_params containing metadata=None.
"""
import time
@@ -725,6 +737,44 @@ def _patch_logging_assembled_streaming_response() -> None:
LiteLLMLoggingObj._get_assembled_streaming_response = _patched_get_assembled_streaming_response # type: ignore[method-assign]
def _patch_responses_metadata_none() -> None:
"""
Patches litellm.responses to normalize metadata=None to metadata={} in kwargs.
LiteLLM's @client decorator wrapper in utils.py (line 1721) does:
_is_litellm_router_call = "model_group" in kwargs.get("metadata", {})
When metadata is explicitly None in kwargs, kwargs.get("metadata", {}) returns
None (the key exists, so the default is not used), causing:
TypeError: argument of type 'NoneType' is not iterable
This swallows the real exception (e.g. AuthenticationError) and surfaces as:
APIConnectionError: OpenAIException - argument of type 'NoneType' is not iterable
This happens when the Responses API bridge calls litellm.responses() with
**litellm_params which may contain metadata=None.
STATUS: STILL NEEDED - litellm/utils.py wrapper function uses kwargs.get("metadata", {})
which does not guard against metadata being explicitly None. Same pattern exists
on line 1407 for async path.
"""
import litellm as _litellm
from functools import wraps
original_responses = _litellm.responses
if getattr(original_responses, "_metadata_patched", False):
return
@wraps(original_responses)
def _patched_responses(*args: Any, **kwargs: Any) -> Any:
if kwargs.get("metadata") is None:
kwargs["metadata"] = {}
return original_responses(*args, **kwargs)
_patched_responses._metadata_patched = True # type: ignore[attr-defined]
_litellm.responses = _patched_responses
def apply_monkey_patches() -> None:
"""
Apply all necessary monkey patches to LiteLLM for compatibility.
@@ -736,6 +786,7 @@ def apply_monkey_patches() -> None:
- Patching AzureOpenAIResponsesAPIConfig.should_fake_stream to enable native streaming
- Patching ResponsesAPIResponse.model_construct to fix usage format in all code paths
- Patching LiteLLMLoggingObj._get_assembled_streaming_response to avoid mutating original response
- Patching litellm.responses to fix metadata=None causing TypeError in error handling
"""
_patch_ollama_chunk_parser()
_patch_openai_responses_parallel_tool_calls()
@@ -743,3 +794,4 @@ def apply_monkey_patches() -> None:
_patch_azure_responses_should_fake_stream()
_patch_responses_api_usage_format()
_patch_logging_assembled_streaming_response()
_patch_responses_metadata_none()

View File

@@ -13,44 +13,38 @@ from datetime import datetime
import httpx
from sqlalchemy.orm import Session
from onyx.cache.factory import get_cache_backend
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
from onyx.db.llm import fetch_auto_mode_providers
from onyx.db.llm import sync_auto_mode_models
from onyx.llm.well_known_providers.auto_update_models import LLMRecommendations
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Redis key for caching the last updated timestamp (per-tenant)
_REDIS_KEY_LAST_UPDATED_AT = "auto_llm_update:last_updated_at"
_CACHE_KEY_LAST_UPDATED_AT = "auto_llm_update:last_updated_at"
_CACHE_TTL_SECONDS = 60 * 60 * 24 # 24 hours
def _get_cached_last_updated_at() -> datetime | None:
"""Get the cached last_updated_at timestamp from Redis."""
try:
redis_client = get_redis_client()
value = redis_client.get(_REDIS_KEY_LAST_UPDATED_AT)
if value and isinstance(value, bytes):
# Value is bytes, decode to string then parse as ISO format
value = get_cache_backend().get(_CACHE_KEY_LAST_UPDATED_AT)
if value is not None:
return datetime.fromisoformat(value.decode("utf-8"))
except Exception as e:
logger.warning(f"Failed to get cached last_updated_at from Redis: {e}")
logger.warning(f"Failed to get cached last_updated_at: {e}")
return None
def _set_cached_last_updated_at(updated_at: datetime) -> None:
"""Set the cached last_updated_at timestamp in Redis."""
try:
redis_client = get_redis_client()
# Store as ISO format string, with 24 hour expiration
redis_client.set(
_REDIS_KEY_LAST_UPDATED_AT,
get_cache_backend().set(
_CACHE_KEY_LAST_UPDATED_AT,
updated_at.isoformat(),
ex=60 * 60 * 24, # 24 hours
ex=_CACHE_TTL_SECONDS,
)
except Exception as e:
logger.warning(f"Failed to set cached last_updated_at in Redis: {e}")
logger.warning(f"Failed to set cached last_updated_at: {e}")
def fetch_llm_recommendations_from_github(
@@ -148,9 +142,8 @@ def sync_llm_models_from_github(
def reset_cache() -> None:
"""Reset the cache timestamp in Redis. Useful for testing."""
"""Reset the cache timestamp. Useful for testing."""
try:
redis_client = get_redis_client()
redis_client.delete(_REDIS_KEY_LAST_UPDATED_AT)
get_cache_backend().delete(_CACHE_KEY_LAST_UPDATED_AT)
except Exception as e:
logger.warning(f"Failed to reset cache in Redis: {e}")
logger.warning(f"Failed to reset cache: {e}")

View File

@@ -32,11 +32,14 @@ from onyx.auth.schemas import UserUpdate
from onyx.auth.users import auth_backend
from onyx.auth.users import create_onyx_oauth_router
from onyx.auth.users import fastapi_users
from onyx.cache.interface import CacheBackendType
from onyx.configs.app_configs import APP_API_PREFIX
from onyx.configs.app_configs import APP_HOST
from onyx.configs.app_configs import APP_PORT
from onyx.configs.app_configs import AUTH_RATE_LIMITING_ENABLED
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import CACHE_BACKEND
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import LOG_ENDPOINT_LATENCY
from onyx.configs.app_configs import OAUTH_CLIENT_ID
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
@@ -254,8 +257,53 @@ def include_auth_router_with_prefix(
)
def validate_cache_backend_settings() -> None:
"""Validate that CACHE_BACKEND=postgres is only used with DISABLE_VECTOR_DB.
The Postgres cache backend eliminates the Redis dependency, but only works
when Celery is not running (which requires DISABLE_VECTOR_DB=true).
"""
if CACHE_BACKEND == CacheBackendType.POSTGRES and not DISABLE_VECTOR_DB:
raise RuntimeError(
"CACHE_BACKEND=postgres requires DISABLE_VECTOR_DB=true. "
"The Postgres cache backend is only supported in no-vector-DB "
"deployments where Celery is replaced by the in-process task runner."
)
def validate_no_vector_db_settings() -> None:
"""Validate that DISABLE_VECTOR_DB is not combined with incompatible settings.
Raises RuntimeError if DISABLE_VECTOR_DB is set alongside MULTI_TENANT or ENABLE_CRAFT,
since these modes require infrastructure that is removed in no-vector-DB deployments.
"""
if not DISABLE_VECTOR_DB:
return
if MULTI_TENANT:
raise RuntimeError(
"DISABLE_VECTOR_DB cannot be used with MULTI_TENANT. "
"Multi-tenant deployments require the vector database for "
"per-tenant document indexing and search. Run in single-tenant "
"mode when disabling the vector database."
)
from onyx.server.features.build.configs import ENABLE_CRAFT
if ENABLE_CRAFT:
raise RuntimeError(
"DISABLE_VECTOR_DB cannot be used with ENABLE_CRAFT. "
"Onyx Craft requires background workers for sandbox lifecycle "
"management, which are removed in no-vector-DB deployments. "
"Disable Craft (ENABLE_CRAFT=false) when disabling the vector database."
)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
validate_no_vector_db_settings()
validate_cache_backend_settings()
# Set recursion limit
if SYSTEM_RECURSION_LIMIT is not None:
sys.setrecursionlimit(SYSTEM_RECURSION_LIMIT)
@@ -324,8 +372,20 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
if AUTH_RATE_LIMITING_ENABLED:
await setup_auth_limiter()
if DISABLE_VECTOR_DB:
from onyx.background.periodic_poller import recover_stuck_user_files
from onyx.background.periodic_poller import start_periodic_poller
recover_stuck_user_files(POSTGRES_DEFAULT_SCHEMA)
start_periodic_poller(POSTGRES_DEFAULT_SCHEMA)
yield
if DISABLE_VECTOR_DB:
from onyx.background.periodic_poller import stop_periodic_poller
stop_periodic_poller()
SqlEngine.reset_engine()
if AUTH_RATE_LIMITING_ENABLED:

View File

@@ -3,10 +3,12 @@ import datetime
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from onyx.auth.schemas import UserRole
from onyx.configs.onyxbot_configs import ONYX_BOT_FEEDBACK_REMINDER
from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import SlackChannelConfig
from onyx.db.user_preferences import activate_user
from onyx.db.users import add_slack_user_if_not_exists
from onyx.db.users import get_user_by_email
from onyx.onyxbot.slack.blocks import get_feedback_reminder_blocks
@@ -243,6 +245,44 @@ def handle_message(
)
return False
elif (
not existing_user.is_active
and existing_user.role == UserRole.SLACK_USER
):
check_seat_fn = fetch_ee_implementation_or_noop(
"onyx.db.license",
"check_seat_availability",
None,
)
seat_result = check_seat_fn(db_session=db_session)
if seat_result is not None and not seat_result.available:
logger.info(
f"Blocked inactive Slack user {message_info.email}: "
f"{seat_result.error_message}"
)
respond_in_thread_or_channel(
client=client,
channel=channel,
thread_ts=message_info.msg_to_respond,
text=(
"We weren't able to respond because your organization "
"has reached its user seat limit. Your account is "
"currently deactivated and cannot be reactivated "
"until more seats are available. Please contact "
"your Onyx administrator."
),
)
return False
activate_user(existing_user, db_session)
invalidate_license_cache_fn = fetch_ee_implementation_or_noop(
"onyx.db.license",
"invalidate_license_cache",
None,
)
invalidate_license_cache_fn()
logger.info(f"Reactivated inactive Slack user {message_info.email}")
add_slack_user_if_not_exists(db_session, message_info.email)
# first check if we need to respond with a standard answer

View File

@@ -32,7 +32,7 @@ from onyx.db.persona import get_persona_snapshots_for_user
from onyx.db.persona import get_persona_snapshots_paginated
from onyx.db.persona import mark_persona_as_deleted
from onyx.db.persona import mark_persona_as_not_deleted
from onyx.db.persona import update_persona_is_default
from onyx.db.persona import update_persona_featured
from onyx.db.persona import update_persona_label
from onyx.db.persona import update_persona_public_status
from onyx.db.persona import update_persona_shared
@@ -130,8 +130,8 @@ class IsPublicRequest(BaseModel):
is_public: bool
class IsDefaultRequest(BaseModel):
is_default_persona: bool
class IsFeaturedRequest(BaseModel):
featured: bool
@admin_router.patch("/{persona_id}/visible")
@@ -168,22 +168,22 @@ def patch_user_persona_public_status(
raise HTTPException(status_code=403, detail=str(e))
@admin_router.patch("/{persona_id}/default")
def patch_persona_default_status(
@admin_router.patch("/{persona_id}/featured")
def patch_persona_featured_status(
persona_id: int,
is_default_request: IsDefaultRequest,
is_featured_request: IsFeaturedRequest,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> None:
try:
update_persona_is_default(
update_persona_featured(
persona_id=persona_id,
is_default=is_default_request.is_default_persona,
featured=is_featured_request.featured,
db_session=db_session,
user=user,
)
except ValueError as e:
logger.exception("Failed to update persona default status")
logger.exception("Failed to update persona featured status")
raise HTTPException(status_code=403, detail=str(e))

View File

@@ -5,7 +5,6 @@ from pydantic import BaseModel
from pydantic import Field
from onyx.configs.constants import DocumentSource
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.enums import HierarchyNodeType
from onyx.db.models import Document
from onyx.db.models import HierarchyNode
@@ -108,11 +107,7 @@ class PersonaUpsertRequest(BaseModel):
name: str
description: str
document_set_ids: list[int]
num_chunks: float
is_public: bool
recency_bias: RecencyBiasSetting
llm_filter_extraction: bool
llm_relevance_filter: bool
llm_model_provider_override: str | None = None
llm_model_version_override: str | None = None
starter_messages: list[StarterMessage] | None = None
@@ -128,7 +123,7 @@ class PersonaUpsertRequest(BaseModel):
)
search_start_date: datetime | None = None
label_ids: list[int] | None = None
is_default_persona: bool = False
featured: bool = False
display_priority: int | None = None
# Accept string UUIDs from frontend
user_file_ids: list[str] | None = None
@@ -155,9 +150,6 @@ class MinimalPersonaSnapshot(BaseModel):
tools: list[ToolSnapshot]
starter_messages: list[StarterMessage] | None
llm_relevance_filter: bool
llm_filter_extraction: bool
# only show document sets in the UI that the assistant has access to
document_sets: list[DocumentSetSummary]
# Counts for knowledge sources (used to determine if search tool should be enabled)
@@ -175,7 +167,7 @@ class MinimalPersonaSnapshot(BaseModel):
is_public: bool
is_visible: bool
display_priority: int | None
is_default_persona: bool
featured: bool
builtin_persona: bool
# Used for filtering
@@ -214,8 +206,6 @@ class MinimalPersonaSnapshot(BaseModel):
if should_expose_tool_to_fe(tool)
],
starter_messages=persona.starter_messages,
llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction,
document_sets=[
DocumentSetSummary.from_model(document_set)
for document_set in persona.document_sets
@@ -230,7 +220,7 @@ class MinimalPersonaSnapshot(BaseModel):
is_public=persona.is_public,
is_visible=persona.is_visible,
display_priority=persona.display_priority,
is_default_persona=persona.is_default_persona,
featured=persona.featured,
builtin_persona=persona.builtin_persona,
labels=[PersonaLabelSnapshot.from_model(label) for label in persona.labels],
owner=(
@@ -252,11 +242,9 @@ class PersonaSnapshot(BaseModel):
# Return string UUIDs to frontend for consistency
user_file_ids: list[str]
display_priority: int | None
is_default_persona: bool
featured: bool
builtin_persona: bool
starter_messages: list[StarterMessage] | None
llm_relevance_filter: bool
llm_filter_extraction: bool
tools: list[ToolSnapshot]
labels: list["PersonaLabelSnapshot"]
owner: MinimalUserSnapshot | None
@@ -265,7 +253,6 @@ class PersonaSnapshot(BaseModel):
document_sets: list[DocumentSetSummary]
llm_model_provider_override: str | None
llm_model_version_override: str | None
num_chunks: float | None
# Hierarchy nodes attached for scoped search
hierarchy_nodes: list[HierarchyNodeSnapshot] = Field(default_factory=list)
# Individual documents attached for scoped search
@@ -289,11 +276,9 @@ class PersonaSnapshot(BaseModel):
icon_name=persona.icon_name,
user_file_ids=[str(file.id) for file in persona.user_files],
display_priority=persona.display_priority,
is_default_persona=persona.is_default_persona,
featured=persona.featured,
builtin_persona=persona.builtin_persona,
starter_messages=persona.starter_messages,
llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction,
tools=[
ToolSnapshot.from_model(tool)
for tool in persona.tools
@@ -324,7 +309,6 @@ class PersonaSnapshot(BaseModel):
],
llm_model_provider_override=persona.llm_model_provider_override,
llm_model_version_override=persona.llm_model_version_override,
num_chunks=persona.num_chunks,
system_prompt=persona.system_prompt,
replace_base_system_prompt=persona.replace_base_system_prompt,
task_prompt=persona.task_prompt,
@@ -332,12 +316,10 @@ class PersonaSnapshot(BaseModel):
)
# Model with full context on perona's internal settings
# Model with full context on persona's internal settings
# This is used for flows which need to know all settings
class FullPersonaSnapshot(PersonaSnapshot):
search_start_date: datetime | None = None
llm_relevance_filter: bool = False
llm_filter_extraction: bool = False
@classmethod
def from_model(
@@ -360,7 +342,7 @@ class FullPersonaSnapshot(PersonaSnapshot):
icon_name=persona.icon_name,
user_file_ids=[str(file.id) for file in persona.user_files],
display_priority=persona.display_priority,
is_default_persona=persona.is_default_persona,
featured=persona.featured,
builtin_persona=persona.builtin_persona,
starter_messages=persona.starter_messages,
users=[
@@ -391,10 +373,7 @@ class FullPersonaSnapshot(PersonaSnapshot):
DocumentSetSummary.from_model(document_set_model)
for document_set_model in persona.document_sets
],
num_chunks=persona.num_chunks,
search_start_date=persona.search_start_date,
llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction,
llm_model_provider_override=persona.llm_model_provider_override,
llm_model_version_override=persona.llm_model_version_override,
system_prompt=persona.system_prompt,

View File

@@ -2,6 +2,7 @@ import json
from uuid import UUID
from fastapi import APIRouter
from fastapi import BackgroundTasks
from fastapi import Depends
from fastapi import File
from fastapi import Form
@@ -12,13 +13,7 @@ from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.auth.users import current_user
from onyx.background.celery.tasks.user_file_processing.tasks import (
enqueue_user_file_project_sync_task,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
get_user_file_project_sync_queue_depth,
)
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
@@ -34,7 +29,6 @@ from onyx.db.models import UserProject
from onyx.db.persona import get_personas_by_ids
from onyx.db.projects import get_project_token_count
from onyx.db.projects import upload_files_to_user_files_with_indexing
from onyx.redis.redis_pool import get_redis_client
from onyx.server.features.projects.models import CategorizedFilesSnapshot
from onyx.server.features.projects.models import ChatSessionRequest
from onyx.server.features.projects.models import TokenCountResponse
@@ -55,7 +49,27 @@ class UserFileDeleteResult(BaseModel):
assistant_names: list[str] = []
def _trigger_user_file_project_sync(user_file_id: UUID, tenant_id: str) -> None:
def _trigger_user_file_project_sync(
user_file_id: UUID,
tenant_id: str,
background_tasks: BackgroundTasks | None = None,
) -> None:
if DISABLE_VECTOR_DB and background_tasks is not None:
from onyx.background.task_utils import drain_project_sync_loop
background_tasks.add_task(drain_project_sync_loop, tenant_id)
logger.info(f"Queued in-process project sync for user_file_id={user_file_id}")
return
from onyx.background.celery.tasks.user_file_processing.tasks import (
enqueue_user_file_project_sync_task,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
get_user_file_project_sync_queue_depth,
)
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.redis.redis_pool import get_redis_client
queue_depth = get_user_file_project_sync_queue_depth(client_app)
if queue_depth > USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH:
logger.warning(
@@ -111,6 +125,7 @@ def create_project(
@router.post("/file/upload", tags=PUBLIC_API_TAGS)
def upload_user_files(
bg_tasks: BackgroundTasks,
files: list[UploadFile] = File(...),
project_id: int | None = Form(None),
temp_id_map: str | None = Form(None), # JSON string mapping hashed key -> temp_id
@@ -137,12 +152,12 @@ def upload_user_files(
user=user,
temp_id_map=parsed_temp_id_map,
db_session=db_session,
background_tasks=bg_tasks if DISABLE_VECTOR_DB else None,
)
return CategorizedFilesSnapshot.from_result(categorized_files_result)
except Exception as e:
# Log error with type, message, and stack for easier debugging
logger.exception(f"Error uploading files - {type(e).__name__}: {str(e)}")
raise HTTPException(
status_code=500,
@@ -192,6 +207,7 @@ def get_files_in_project(
def unlink_user_file_from_project(
project_id: int,
file_id: UUID,
bg_tasks: BackgroundTasks,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> Response:
@@ -208,7 +224,6 @@ def unlink_user_file_from_project(
if project is None:
raise HTTPException(status_code=404, detail="Project not found")
user_id = user.id
user_file = (
db_session.query(UserFile)
.filter(UserFile.id == file_id, UserFile.user_id == user_id)
@@ -224,7 +239,7 @@ def unlink_user_file_from_project(
db_session.commit()
tenant_id = get_current_tenant_id()
_trigger_user_file_project_sync(user_file.id, tenant_id)
_trigger_user_file_project_sync(user_file.id, tenant_id, bg_tasks)
return Response(status_code=204)
@@ -237,6 +252,7 @@ def unlink_user_file_from_project(
def link_user_file_to_project(
project_id: int,
file_id: UUID,
bg_tasks: BackgroundTasks,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> UserFileSnapshot:
@@ -268,7 +284,7 @@ def link_user_file_to_project(
db_session.commit()
tenant_id = get_current_tenant_id()
_trigger_user_file_project_sync(user_file.id, tenant_id)
_trigger_user_file_project_sync(user_file.id, tenant_id, bg_tasks)
return UserFileSnapshot.from_model(user_file)
@@ -335,7 +351,7 @@ def upsert_project_instructions(
class ProjectPayload(BaseModel):
project: UserProjectSnapshot
files: list[UserFileSnapshot] | None = None
persona_id_to_is_default: dict[int, bool] | None = None
persona_id_to_featured: dict[int, bool] | None = None
@router.get(
@@ -354,13 +370,11 @@ def get_project_details(
if session.persona_id is not None
]
personas = get_personas_by_ids(persona_ids, db_session)
persona_id_to_is_default = {
persona.id: persona.is_default_persona for persona in personas
}
persona_id_to_featured = {persona.id: persona.featured for persona in personas}
return ProjectPayload(
project=project,
files=files,
persona_id_to_is_default=persona_id_to_is_default,
persona_id_to_featured=persona_id_to_featured,
)
@@ -426,6 +440,7 @@ def delete_project(
@router.delete("/file/{file_id}", tags=PUBLIC_API_TAGS)
def delete_user_file(
file_id: UUID,
bg_tasks: BackgroundTasks,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> UserFileDeleteResult:
@@ -458,15 +473,25 @@ def delete_user_file(
db_session.commit()
tenant_id = get_current_tenant_id()
task = client_app.send_task(
OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
kwargs={"user_file_id": str(user_file.id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_DELETE,
priority=OnyxCeleryPriority.HIGH,
)
logger.info(
f"Triggered delete for user_file_id={user_file.id} with task_id={task.id}"
)
if DISABLE_VECTOR_DB:
from onyx.background.task_utils import drain_delete_loop
bg_tasks.add_task(drain_delete_loop, tenant_id)
logger.info(f"Queued in-process delete for user_file_id={user_file.id}")
else:
from onyx.background.celery.versioned_apps.client import app as client_app
task = client_app.send_task(
OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
kwargs={"user_file_id": str(user_file.id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_DELETE,
priority=OnyxCeleryPriority.HIGH,
)
logger.info(
f"Triggered delete for user_file_id={user_file.id} "
f"with task_id={task.id}"
)
return UserFileDeleteResult(
has_associations=False, project_names=[], assistant_names=[]
)

View File

@@ -8,10 +8,10 @@ import httpx
from sqlalchemy.orm import Session
from onyx import __version__
from onyx.cache.factory import get_shared_cache_backend
from onyx.configs.app_configs import INSTANCE_TYPE
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.release_notes import create_release_notifications_for_versions
from onyx.redis.redis_pool import get_shared_redis_client
from onyx.server.features.release_notes.constants import AUTO_REFRESH_THRESHOLD_SECONDS
from onyx.server.features.release_notes.constants import FETCH_TIMEOUT
from onyx.server.features.release_notes.constants import GITHUB_CHANGELOG_RAW_URL
@@ -113,60 +113,46 @@ def parse_mdx_to_release_note_entries(mdx_content: str) -> list[ReleaseNoteEntry
def get_cached_etag() -> str | None:
"""Get the cached GitHub ETag from Redis."""
redis_client = get_shared_redis_client()
cache = get_shared_cache_backend()
try:
etag = redis_client.get(REDIS_KEY_ETAG)
etag = cache.get(REDIS_KEY_ETAG)
if etag:
return etag.decode("utf-8") if isinstance(etag, bytes) else str(etag)
return etag.decode("utf-8")
return None
except Exception as e:
logger.error(f"Failed to get cached etag from Redis: {e}")
logger.error(f"Failed to get cached etag: {e}")
return None
def get_last_fetch_time() -> datetime | None:
"""Get the last fetch timestamp from Redis."""
redis_client = get_shared_redis_client()
cache = get_shared_cache_backend()
try:
fetched_at_str = redis_client.get(REDIS_KEY_FETCHED_AT)
if not fetched_at_str:
raw = cache.get(REDIS_KEY_FETCHED_AT)
if not raw:
return None
decoded = (
fetched_at_str.decode("utf-8")
if isinstance(fetched_at_str, bytes)
else str(fetched_at_str)
)
last_fetch = datetime.fromisoformat(decoded)
# Defensively ensure timezone awareness
# fromisoformat() returns naive datetime if input lacks timezone
last_fetch = datetime.fromisoformat(raw.decode("utf-8"))
if last_fetch.tzinfo is None:
# Assume UTC for naive datetimes
last_fetch = last_fetch.replace(tzinfo=timezone.utc)
else:
# Convert to UTC if timezone-aware
last_fetch = last_fetch.astimezone(timezone.utc)
return last_fetch
except Exception as e:
logger.error(f"Failed to get last fetch time from Redis: {e}")
logger.error(f"Failed to get last fetch time from cache: {e}")
return None
def save_fetch_metadata(etag: str | None) -> None:
"""Save ETag and fetch timestamp to Redis."""
redis_client = get_shared_redis_client()
cache = get_shared_cache_backend()
now = datetime.now(timezone.utc)
try:
redis_client.set(REDIS_KEY_FETCHED_AT, now.isoformat(), ex=REDIS_CACHE_TTL)
cache.set(REDIS_KEY_FETCHED_AT, now.isoformat(), ex=REDIS_CACHE_TTL)
if etag:
redis_client.set(REDIS_KEY_ETAG, etag, ex=REDIS_CACHE_TTL)
cache.set(REDIS_KEY_ETAG, etag, ex=REDIS_CACHE_TTL)
except Exception as e:
logger.error(f"Failed to save fetch metadata to Redis: {e}")
logger.error(f"Failed to save fetch metadata to cache: {e}")
def is_cache_stale() -> bool:
@@ -196,11 +182,10 @@ def ensure_release_notes_fresh_and_notify(db_session: Session) -> None:
if not is_cache_stale():
return
# Acquire lock to prevent concurrent fetches
redis_client = get_shared_redis_client()
lock = redis_client.lock(
cache = get_shared_cache_backend()
lock = cache.lock(
OnyxRedisLocks.RELEASE_NOTES_FETCH_LOCK,
timeout=90, # 90 second timeout for the lock
timeout=90,
)
# Non-blocking acquire - if we can't get the lock, another request is handling it

View File

@@ -5,7 +5,6 @@ from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
from onyx.configs.kg_configs import KG_BETA_ASSISTANT_DESCRIPTION
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.engine.sql_engine import get_session
from onyx.db.entities import get_entity_stats_by_grounded_source_name
from onyx.db.entity_type import get_configured_entity_types
@@ -134,11 +133,7 @@ def enable_or_disable_kg(
system_prompt=KG_BETA_ASSISTANT_SYSTEM_PROMPT,
task_prompt=KG_BETA_ASSISTANT_TASK_PROMPT,
datetime_aware=False,
num_chunks=25,
llm_relevance_filter=False,
is_public=False,
llm_filter_extraction=False,
recency_bias=RecencyBiasSetting.NO_DECAY,
document_set_ids=[],
tool_ids=[search_tool.id, kg_tool.id],
llm_model_provider_override=None,
@@ -147,7 +142,7 @@ def enable_or_disable_kg(
users=[user.id],
groups=[],
label_ids=[],
is_default_persona=False,
featured=False,
display_priority=0,
user_file_ids=[],
)

View File

@@ -97,7 +97,6 @@ def _build_llm_provider_request(
), # Only this from source
api_base=api_base, # From request
api_version=api_version, # From request
default_model_name=model_name,
deployment_name=deployment_name, # From request
is_public=True,
groups=[],
@@ -136,7 +135,6 @@ def _build_llm_provider_request(
api_key=api_key,
api_base=api_base,
api_version=api_version,
default_model_name=model_name,
deployment_name=deployment_name,
is_public=True,
groups=[],
@@ -168,7 +166,6 @@ def _create_image_gen_llm_provider__no_commit(
api_key=provider_request.api_key,
api_base=provider_request.api_base,
api_version=provider_request.api_version,
default_model_name=provider_request.default_model_name,
deployment_name=provider_request.deployment_name,
is_public=provider_request.is_public,
custom_config=provider_request.custom_config,

View File

@@ -22,7 +22,10 @@ from onyx.auth.users import current_chat_accessible_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import LLMModelFlowType
from onyx.db.llm import can_user_access_llm_provider
from onyx.db.llm import fetch_default_llm_model
from onyx.db.llm import fetch_default_vision_model
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import fetch_existing_llm_provider_by_id
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_existing_models
from onyx.db.llm import fetch_persona_with_groups
@@ -52,11 +55,12 @@ from onyx.llm.well_known_providers.llm_provider_options import (
)
from onyx.server.manage.llm.models import BedrockFinalModelResponse
from onyx.server.manage.llm.models import BedrockModelsRequest
from onyx.server.manage.llm.models import DefaultModel
from onyx.server.manage.llm.models import LLMCost
from onyx.server.manage.llm.models import LLMProviderDescriptor
from onyx.server.manage.llm.models import LLMProviderResponse
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from onyx.server.manage.llm.models import OllamaFinalModelResponse
from onyx.server.manage.llm.models import OllamaModelDetails
from onyx.server.manage.llm.models import OllamaModelsRequest
@@ -233,12 +237,9 @@ def test_llm_configuration(
test_api_key = test_llm_request.api_key
test_custom_config = test_llm_request.custom_config
if test_llm_request.name:
# NOTE: we are querying by name. we probably should be querying by an invariant id, but
# as it turns out the name is not editable in the UI and other code also keys off name,
# so we won't rock the boat just yet.
existing_provider = fetch_existing_llm_provider(
name=test_llm_request.name, db_session=db_session
if test_llm_request.id:
existing_provider = fetch_existing_llm_provider_by_id(
id=test_llm_request.id, db_session=db_session
)
if existing_provider:
test_custom_config = _restore_masked_custom_config_values(
@@ -268,7 +269,7 @@ def test_llm_configuration(
llm = get_llm(
provider=test_llm_request.provider,
model=test_llm_request.default_model_name,
model=test_llm_request.model,
api_key=test_api_key,
api_base=test_llm_request.api_base,
api_version=test_llm_request.api_version,
@@ -303,7 +304,7 @@ def list_llm_providers(
include_image_gen: bool = Query(False),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderView]:
) -> LLMProviderResponse[LLMProviderView]:
start_time = datetime.now(timezone.utc)
logger.debug("Starting to fetch LLM providers")
@@ -328,7 +329,15 @@ def list_llm_providers(
duration = (end_time - start_time).total_seconds()
logger.debug(f"Completed fetching LLM providers in {duration:.2f} seconds")
return llm_provider_list
return LLMProviderResponse[LLMProviderView].from_models(
providers=llm_provider_list,
default_text=DefaultModel.from_model_config(
fetch_default_llm_model(db_session)
),
default_vision=DefaultModel.from_model_config(
fetch_default_vision_model(db_session)
),
)
@admin_router.put("/provider")
@@ -344,18 +353,44 @@ def put_llm_provider(
# validate request (e.g. if we're intending to create but the name already exists we should throw an error)
# NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache
# the result
existing_provider = fetch_existing_llm_provider(
existing_provider = None
if llm_provider_upsert_request.id:
existing_provider = fetch_existing_llm_provider_by_id(
id=llm_provider_upsert_request.id, db_session=db_session
)
# Check name constraints
# TODO: Once port from name to id is complete, unique name will no longer be required
if existing_provider and llm_provider_upsert_request.name != existing_provider.name:
raise HTTPException(
status_code=400,
detail="Renaming providers is not currently supported",
)
found_provider = fetch_existing_llm_provider(
name=llm_provider_upsert_request.name, db_session=db_session
)
if found_provider is not None and found_provider is not existing_provider:
raise HTTPException(
status_code=400,
detail=f"Provider with name={llm_provider_upsert_request.name} already exists",
)
if existing_provider and is_creation:
raise HTTPException(
status_code=400,
detail=f"LLM Provider with name {llm_provider_upsert_request.name} already exists",
detail=(
f"LLM Provider with name {llm_provider_upsert_request.name} and "
f"id={llm_provider_upsert_request.id} already exists"
),
)
elif not existing_provider and not is_creation:
raise HTTPException(
status_code=400,
detail=f"LLM Provider with name {llm_provider_upsert_request.name} does not exist",
detail=(
f"LLM Provider with name {llm_provider_upsert_request.name} and "
f"id={llm_provider_upsert_request.id} does not exist"
),
)
# SSRF Protection: Validate api_base and custom_config match stored values
@@ -393,22 +428,6 @@ def put_llm_provider(
deduplicated_personas.append(persona_id)
llm_provider_upsert_request.personas = deduplicated_personas
default_model_found = False
for model_configuration in llm_provider_upsert_request.model_configurations:
if model_configuration.name == llm_provider_upsert_request.default_model_name:
model_configuration.is_visible = True
default_model_found = True
# TODO: Remove this logic on api change
# Believed to be a dead pathway but we want to be safe for now
if not default_model_found:
llm_provider_upsert_request.model_configurations.append(
ModelConfigurationUpsertRequest(
name=llm_provider_upsert_request.default_model_name, is_visible=True
)
)
# the llm api key is sanitized when returned to clients, so the only time we
# should get a real key is when it is explicitly changed
if existing_provider and not llm_provider_upsert_request.api_key_changed:
@@ -438,8 +457,8 @@ def put_llm_provider(
config = fetch_llm_recommendations_from_github()
if config and llm_provider_upsert_request.provider in config.providers:
# Refetch the provider to get the updated model
updated_provider = fetch_existing_llm_provider(
name=llm_provider_upsert_request.name, db_session=db_session
updated_provider = fetch_existing_llm_provider_by_id(
id=result.id, db_session=db_session
)
if updated_provider:
sync_auto_mode_models(
@@ -460,37 +479,48 @@ def put_llm_provider(
@admin_router.delete("/provider/{provider_id}")
def delete_llm_provider(
provider_id: int,
force: bool = Query(False),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
try:
if not force:
model = fetch_default_llm_model(db_session)
if model and model.llm_provider_id == provider_id:
raise HTTPException(
status_code=400,
detail="Cannot delete the default LLM provider",
)
remove_llm_provider(db_session, provider_id)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@admin_router.post("/provider/{provider_id}/default")
@admin_router.post("/default")
def set_provider_as_default(
provider_id: int,
default_model_request: DefaultModel,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
update_default_provider(provider_id=provider_id, db_session=db_session)
update_default_provider(
provider_id=default_model_request.provider_id,
model_name=default_model_request.model_name,
db_session=db_session,
)
@admin_router.post("/provider/{provider_id}/default-vision")
@admin_router.post("/default-vision")
def set_provider_as_default_vision(
provider_id: int,
vision_model: str | None = Query(
None, description="The default vision model to use"
),
default_model: DefaultModel,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
if vision_model is None:
raise HTTPException(status_code=404, detail="Vision model not provided")
update_default_vision_provider(
provider_id=provider_id, vision_model=vision_model, db_session=db_session
provider_id=default_model.provider_id,
vision_model=default_model.model_name,
db_session=db_session,
)
@@ -516,7 +546,7 @@ def get_auto_config(
def get_vision_capable_providers(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[VisionProviderResponse]:
) -> LLMProviderResponse[VisionProviderResponse]:
"""Return a list of LLM providers and their models that support image input"""
vision_models = fetch_existing_models(
db_session=db_session, flow_types=[LLMModelFlowType.VISION]
@@ -545,7 +575,13 @@ def get_vision_capable_providers(
]
logger.debug(f"Found {len(vision_provider_response)} vision-capable providers")
return vision_provider_response
return LLMProviderResponse[VisionProviderResponse].from_models(
providers=vision_provider_response,
default_vision=DefaultModel.from_model_config(
fetch_default_vision_model(db_session)
),
)
"""Endpoints for all"""
@@ -555,7 +591,7 @@ def get_vision_capable_providers(
def list_llm_provider_basics(
user: User = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderDescriptor]:
) -> LLMProviderResponse[LLMProviderDescriptor]:
"""Get LLM providers accessible to the current user.
Returns:
@@ -577,9 +613,9 @@ def list_llm_provider_basics(
for provider in all_providers:
# Use centralized access control logic with persona=None since we're
# listing providers without a specific persona context. This correctly:
# - Includes all public providers
# - Includes public providers WITHOUT persona restrictions
# - Includes providers user can access via group membership
# - Excludes persona-only restricted providers (requires specific persona)
# - Excludes providers with persona restrictions (requires specific persona)
# - Excludes non-public providers with no restrictions (admin-only)
if can_user_access_llm_provider(
provider, user_group_ids, persona=None, is_admin=is_admin
@@ -592,7 +628,15 @@ def list_llm_provider_basics(
f"Completed fetching {len(accessible_providers)} user-accessible providers in {duration:.2f} seconds"
)
return accessible_providers
return LLMProviderResponse[LLMProviderDescriptor].from_models(
providers=accessible_providers,
default_text=DefaultModel.from_model_config(
fetch_default_llm_model(db_session)
),
default_vision=DefaultModel.from_model_config(
fetch_default_vision_model(db_session)
),
)
def get_valid_model_names_for_persona(
@@ -604,7 +648,7 @@ def get_valid_model_names_for_persona(
Returns a list of model names (e.g., ["gpt-4o", "claude-3-5-sonnet"]) that are
available to the user when using this persona, respecting all RBAC restrictions.
Public providers are always included.
Public providers are included unless they have persona restrictions that exclude this persona.
"""
persona = fetch_persona_with_groups(db_session, persona_id)
if not persona:
@@ -618,7 +662,7 @@ def get_valid_model_names_for_persona(
valid_models = []
for llm_provider_model in all_providers:
# Public providers always included, restricted checked via RBAC
# Check access with persona context — respects all RBAC restrictions
if can_user_access_llm_provider(
llm_provider_model, user_group_ids, persona, is_admin=is_admin
):
@@ -635,11 +679,11 @@ def list_llm_providers_for_persona(
persona_id: int,
user: User = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderDescriptor]:
) -> LLMProviderResponse[LLMProviderDescriptor]:
"""Get LLM providers for a specific persona.
Returns providers that the user can access when using this persona:
- All public providers (is_public=True) - ALWAYS included
- Public providers (respecting persona restrictions if set)
- Restricted providers user can access via group/persona restrictions
This endpoint is used for background fetching of restricted providers
@@ -668,7 +712,7 @@ def list_llm_providers_for_persona(
llm_provider_list: list[LLMProviderDescriptor] = []
for llm_provider_model in all_providers:
# Use simplified access check - public providers always included
# Check access with persona context — respects persona restrictions
if can_user_access_llm_provider(
llm_provider_model, user_group_ids, persona, is_admin=is_admin
):
@@ -682,7 +726,51 @@ def list_llm_providers_for_persona(
f"Completed fetching {len(llm_provider_list)} LLM providers for persona {persona_id} in {duration:.2f} seconds"
)
return llm_provider_list
# Get the default model and vision model for the persona
# TODO: Port persona's over to use ID
persona_default_provider = persona.llm_model_provider_override
persona_default_model = persona.llm_model_version_override
default_text_model = fetch_default_llm_model(db_session)
default_vision_model = fetch_default_vision_model(db_session)
# Build default_text and default_vision using persona overrides when available,
# falling back to the global defaults.
default_text = DefaultModel.from_model_config(default_text_model)
default_vision = DefaultModel.from_model_config(default_vision_model)
if persona_default_provider:
provider = fetch_existing_llm_provider(persona_default_provider, db_session)
if provider and can_user_access_llm_provider(
provider, user_group_ids, persona, is_admin=is_admin
):
if persona_default_model:
# Persona specifies both provider and model — use them directly
default_text = DefaultModel(
provider_id=provider.id,
model_name=persona_default_model,
)
else:
# Persona specifies only the provider — pick a visible (public) model,
# falling back to any model on this provider
visible_model = next(
(mc for mc in provider.model_configurations if mc.is_visible),
None,
)
fallback_model = visible_model or next(
iter(provider.model_configurations), None
)
if fallback_model:
default_text = DefaultModel(
provider_id=provider.id,
model_name=fallback_model.name,
)
return LLMProviderResponse[LLMProviderDescriptor].from_models(
providers=llm_provider_list,
default_text=default_text,
default_vision=default_vision,
)
@admin_router.get("/provider-contextual-cost")

View File

@@ -1,5 +1,9 @@
from __future__ import annotations
from typing import Any
from typing import Generic
from typing import TYPE_CHECKING
from typing import TypeVar
from pydantic import BaseModel
from pydantic import Field
@@ -21,50 +25,22 @@ if TYPE_CHECKING:
ModelConfiguration as ModelConfigurationModel,
)
# TODO: Clear this up on api refactor
# There is still logic that requires sending each providers default model name
# There is no logic that requires sending the providers default vision model name
# We only send for the one that is actually the default
def get_default_llm_model_name(llm_provider_model: "LLMProviderModel") -> str:
"""Find the default conversation model name for a provider.
Returns the model name if found, otherwise returns empty string.
"""
for model_config in llm_provider_model.model_configurations:
for flow in model_config.llm_model_flows:
if flow.is_default and flow.llm_model_flow_type == LLMModelFlowType.CHAT:
return model_config.name
return ""
def get_default_vision_model_name(llm_provider_model: "LLMProviderModel") -> str | None:
"""Find the default vision model name for a provider.
Returns the model name if found, otherwise returns None.
"""
for model_config in llm_provider_model.model_configurations:
for flow in model_config.llm_model_flows:
if flow.is_default and flow.llm_model_flow_type == LLMModelFlowType.VISION:
return model_config.name
return None
T = TypeVar("T", "LLMProviderDescriptor", "LLMProviderView", "VisionProviderResponse")
class TestLLMRequest(BaseModel):
# provider level
name: str | None = None
id: int | None = None
provider: str
model: str
api_key: str | None = None
api_base: str | None = None
api_version: str | None = None
custom_config: dict[str, str] | None = None
# model level
default_model_name: str
deployment_name: str | None = None
model_configurations: list["ModelConfigurationUpsertRequest"]
# if try and use the existing API/custom config key
api_key_changed: bool
custom_config_changed: bool
@@ -80,13 +56,10 @@ class LLMProviderDescriptor(BaseModel):
"""A descriptor for an LLM provider that can be safely viewed by
non-admin users. Used when giving a list of available LLMs."""
id: int
name: str
provider: str
provider_display_name: str # Human-friendly name like "Claude (Anthropic)"
default_model_name: str
is_default_provider: bool | None
is_default_vision_provider: bool | None
default_vision_model: str | None
model_configurations: list["ModelConfigurationView"]
@classmethod
@@ -99,24 +72,12 @@ class LLMProviderDescriptor(BaseModel):
)
provider = llm_provider_model.provider
default_model_name = get_default_llm_model_name(llm_provider_model)
default_vision_model = get_default_vision_model_name(llm_provider_model)
is_default_provider = bool(default_model_name)
is_default_vision_provider = default_vision_model is not None
default_model_name = (
default_model_name or llm_provider_model.default_model_name or ""
)
return cls(
id=llm_provider_model.id,
name=llm_provider_model.name,
provider=provider,
provider_display_name=get_provider_display_name(provider),
default_model_name=default_model_name,
is_default_provider=is_default_provider,
is_default_vision_provider=is_default_vision_provider,
default_vision_model=default_vision_model,
model_configurations=filter_model_configurations(
llm_provider_model.model_configurations, provider
),
@@ -130,18 +91,17 @@ class LLMProvider(BaseModel):
api_base: str | None = None
api_version: str | None = None
custom_config: dict[str, str] | None = None
default_model_name: str
is_public: bool = True
is_auto_mode: bool = False
groups: list[int] = Field(default_factory=list)
personas: list[int] = Field(default_factory=list)
deployment_name: str | None = None
default_vision_model: str | None = None
class LLMProviderUpsertRequest(LLMProvider):
# should only be used for a "custom" provider
# for default providers, the built-in model names are used
id: int | None = None
api_key_changed: bool = False
custom_config_changed: bool = False
model_configurations: list["ModelConfigurationUpsertRequest"] = []
@@ -157,8 +117,6 @@ class LLMProviderView(LLMProvider):
"""Stripped down representation of LLMProvider for display / limited access info only"""
id: int
is_default_provider: bool | None = None
is_default_vision_provider: bool | None = None
model_configurations: list["ModelConfigurationView"]
@classmethod
@@ -180,16 +138,6 @@ class LLMProviderView(LLMProvider):
provider = llm_provider_model.provider
default_model_name = get_default_llm_model_name(llm_provider_model)
default_vision_model = get_default_vision_model_name(llm_provider_model)
is_default_provider = bool(default_model_name)
is_default_vision_provider = default_vision_model is not None
default_model_name = (
default_model_name or llm_provider_model.default_model_name or ""
)
return cls(
id=llm_provider_model.id,
name=llm_provider_model.name,
@@ -202,10 +150,6 @@ class LLMProviderView(LLMProvider):
api_base=llm_provider_model.api_base,
api_version=llm_provider_model.api_version,
custom_config=llm_provider_model.custom_config,
default_model_name=default_model_name,
is_default_provider=is_default_provider,
is_default_vision_provider=is_default_vision_provider,
default_vision_model=default_vision_model,
is_public=llm_provider_model.is_public,
is_auto_mode=llm_provider_model.is_auto_mode,
groups=groups,
@@ -425,3 +369,38 @@ class OpenRouterFinalModelResponse(BaseModel):
int | None
) # From OpenRouter API context_length (may be missing for some models)
supports_image_input: bool
class DefaultModel(BaseModel):
provider_id: int
model_name: str
@classmethod
def from_model_config(
cls, model_config: ModelConfigurationModel | None
) -> DefaultModel | None:
if not model_config:
return None
return cls(
provider_id=model_config.llm_provider_id,
model_name=model_config.name,
)
class LLMProviderResponse(BaseModel, Generic[T]):
providers: list[T]
default_text: DefaultModel | None = None
default_vision: DefaultModel | None = None
@classmethod
def from_models(
cls,
providers: list[T],
default_text: DefaultModel | None = None,
default_vision: DefaultModel | None = None,
) -> LLMProviderResponse[T]:
return cls(
providers=providers,
default_text=default_text,
default_vision=default_vision,
)

View File

@@ -198,7 +198,6 @@ def patch_slack_channel_config(
channel_name=channel_config["channel_name"],
document_set_ids=slack_channel_config_creation_request.document_sets,
existing_persona_id=existing_persona_id,
enable_auto_filters=slack_channel_config_creation_request.enable_auto_filters,
).id
slack_channel_config_model = update_slack_channel_config(

View File

@@ -13,13 +13,13 @@ from fastapi import Request
from fastapi import Response
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from redis.client import Redis
from sqlalchemy.orm import Session
from onyx.auth.api_key import get_hashed_api_key_from_request
from onyx.auth.pat import get_hashed_pat_from_request
from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_user
from onyx.cache.factory import get_cache_backend
from onyx.chat.chat_processing_checker import is_chat_session_processing
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.chat_utils import convert_chat_history_basic
@@ -67,7 +67,6 @@ from onyx.llm.constants import LlmProviderNames
from onyx.llm.factory import get_default_llm
from onyx.llm.factory import get_llm_for_persona
from onyx.llm.factory import get_llm_token_counter
from onyx.redis.redis_pool import get_redis_client
from onyx.secondary_llm_flows.chat_session_naming import generate_chat_session_name
from onyx.server.api_key_usage import check_api_key_usage
from onyx.server.query_and_chat.models import ChatFeedbackRequest
@@ -152,10 +151,20 @@ def get_user_chat_sessions(
project_id: int | None = None,
only_non_project_chats: bool = True,
include_failed_chats: bool = False,
page_size: int = Query(default=50, ge=1, le=100),
before: str | None = Query(default=None),
) -> ChatSessionsResponse:
user_id = user.id
try:
before_dt = (
datetime.datetime.fromisoformat(before) if before is not None else None
)
except ValueError:
raise HTTPException(status_code=422, detail="Invalid 'before' timestamp format")
try:
# Fetch one extra to determine if there are more results
chat_sessions = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
@@ -163,11 +172,16 @@ def get_user_chat_sessions(
project_id=project_id,
only_non_project_chats=only_non_project_chats,
include_failed_chats=include_failed_chats,
limit=page_size + 1,
before=before_dt,
)
except ValueError:
raise ValueError("Chat session does not exist or has been deleted")
has_more = len(chat_sessions) > page_size
chat_sessions = chat_sessions[:page_size]
return ChatSessionsResponse(
sessions=[
ChatSessionDetails(
@@ -181,7 +195,8 @@ def get_user_chat_sessions(
current_temperature_override=chat.temperature_override,
)
for chat in chat_sessions
]
],
has_more=has_more,
)
@@ -314,7 +329,7 @@ def get_chat_session(
]
try:
is_processing = is_chat_session_processing(session_id, get_redis_client())
is_processing = is_chat_session_processing(session_id, get_cache_backend())
# Edit the last message to indicate loading (Overriding default message value)
if is_processing and chat_message_details:
last_msg = chat_message_details[-1]
@@ -911,11 +926,10 @@ async def search_chats(
def stop_chat_session(
chat_session_id: UUID,
user: User = Depends(current_user), # noqa: ARG001
redis_client: Redis = Depends(get_redis_client),
) -> dict[str, str]:
"""
Stop a chat session by setting a stop signal in Redis.
Stop a chat session by setting a stop signal.
This endpoint is called by the frontend when the user clicks the stop button.
"""
set_fence(chat_session_id, redis_client, True)
set_fence(chat_session_id, get_cache_backend(), True)
return {"message": "Chat session stopped"}

View File

@@ -192,6 +192,7 @@ class ChatSessionDetails(BaseModel):
class ChatSessionsResponse(BaseModel):
sessions: list[ChatSessionDetails]
has_more: bool = False
class ChatMessageDetail(BaseModel):

View File

@@ -19,6 +19,7 @@ class ApplicationStatus(str, Enum):
PAYMENT_REMINDER = "payment_reminder"
GRACE_PERIOD = "grace_period"
GATED_ACCESS = "gated_access"
SEAT_LIMIT_EXCEEDED = "seat_limit_exceeded"
class Notification(BaseModel):
@@ -82,6 +83,10 @@ class Settings(BaseModel):
# Default Assistant settings
disable_default_assistant: bool | None = False
# Seat usage - populated by license enforcement when seat limit is exceeded
seat_count: int | None = None
used_seats: int | None = None
# OpenSearch migration
opensearch_indexing_enabled: bool = False

View File

@@ -1,3 +1,4 @@
from onyx.cache.factory import get_cache_backend
from onyx.configs.app_configs import DISABLE_USER_KNOWLEDGE
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
@@ -6,11 +7,8 @@ from onyx.configs.constants import KV_SETTINGS_KEY
from onyx.configs.constants import OnyxRedisLocks
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.redis.redis_pool import get_redis_client
from onyx.server.settings.models import Settings
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -33,30 +31,22 @@ def load_settings() -> Settings:
logger.error(f"Error loading settings from KV store: {str(e)}")
settings = Settings()
tenant_id = get_current_tenant_id() if MULTI_TENANT else None
redis_client = get_redis_client(tenant_id=tenant_id)
cache = get_cache_backend()
try:
value = redis_client.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
value = cache.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
if value is not None:
assert isinstance(value, bytes)
anonymous_user_enabled = int(value.decode("utf-8")) == 1
else:
# Default to False
anonymous_user_enabled = False
# Optionally store the default back to Redis
redis_client.set(
OnyxRedisLocks.ANONYMOUS_USER_ENABLED, "0", ex=SETTINGS_TTL
)
cache.set(OnyxRedisLocks.ANONYMOUS_USER_ENABLED, "0", ex=SETTINGS_TTL)
except Exception as e:
# Log the error and reset to default
logger.error(f"Error loading anonymous user setting from Redis: {str(e)}")
logger.error(f"Error loading anonymous user setting from cache: {str(e)}")
anonymous_user_enabled = False
settings.anonymous_user_enabled = anonymous_user_enabled
settings.query_history_type = ONYX_QUERY_HISTORY_TYPE
# Override user knowledge setting if disabled via environment variable
if DISABLE_USER_KNOWLEDGE:
settings.user_knowledge_enabled = False
@@ -66,11 +56,10 @@ def load_settings() -> Settings:
def store_settings(settings: Settings) -> None:
tenant_id = get_current_tenant_id() if MULTI_TENANT else None
redis_client = get_redis_client(tenant_id=tenant_id)
cache = get_cache_backend()
if settings.anonymous_user_enabled is not None:
redis_client.set(
cache.set(
OnyxRedisLocks.ANONYMOUS_USER_ENABLED,
"1" if settings.anonymous_user_enabled else "0",
ex=SETTINGS_TTL,

View File

@@ -25,6 +25,7 @@ from onyx.db.enums import EmbeddingPrecision
from onyx.db.index_attempt import cancel_indexing_attempts_past_model
from onyx.db.index_attempt import expire_index_attempts
from onyx.db.llm import fetch_default_llm_model
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.search_settings import get_active_search_settings
@@ -254,14 +255,18 @@ def setup_postgres(db_session: Session) -> None:
logger.notice("Setting up default OpenAI LLM for dev.")
llm_model = GEN_AI_MODEL_VERSION or "gpt-4o-mini"
provider_name = "DevEnvPresetOpenAI"
existing = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
model_req = LLMProviderUpsertRequest(
name="DevEnvPresetOpenAI",
id=existing.id if existing else None,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key=GEN_AI_API_KEY,
api_base=None,
api_version=None,
custom_config=None,
default_model_name=llm_model,
is_public=True,
groups=[],
model_configurations=[
@@ -273,7 +278,9 @@ def setup_postgres(db_session: Session) -> None:
new_llm_provider = upsert_llm_provider(
llm_provider_upsert_request=model_req, db_session=db_session
)
update_default_provider(provider_id=new_llm_provider.id, db_session=db_session)
update_default_provider(
provider_id=new_llm_provider.id, model_name=llm_model, db_session=db_session
)
def update_default_multipass_indexing(db_session: Session) -> None:

View File

@@ -8,37 +8,3 @@ dependencies = [
[tool.uv.sources]
onyx = { workspace = true }
[tool.mypy]
plugins = "sqlalchemy.ext.mypy.plugin"
mypy_path = "backend"
explicit_package_bases = true
disallow_untyped_defs = true
warn_unused_ignores = true
enable_error_code = ["possibly-undefined"]
strict_equality = true
# Patterns match paths whether mypy is run from backend/ (CI) or repo root (e.g. VS Code extension with target ./backend)
exclude = [
"(?:^|/)generated/",
"(?:^|/)\\.venv/",
"(?:^|/)onyx/server/features/build/sandbox/kubernetes/docker/skills/",
"(?:^|/)onyx/server/features/build/sandbox/kubernetes/docker/templates/",
]
[[tool.mypy.overrides]]
module = "alembic.versions.*"
disable_error_code = ["var-annotated"]
[[tool.mypy.overrides]]
module = "alembic_tenants.versions.*"
disable_error_code = ["var-annotated"]
[[tool.mypy.overrides]]
module = "generated.*"
follow_imports = "silent"
ignore_errors = true
[[tool.mypy.overrides]]
module = "transformers.*"
follow_imports = "skip"
ignore_errors = true

View File

@@ -528,7 +528,7 @@ lxml==5.3.0
# unstructured
# xmlsec
# zeep
lxml-html-clean==0.4.3
lxml-html-clean==0.4.4
# via lxml
magika==0.6.3
# via markitdown
@@ -809,7 +809,7 @@ pypandoc-binary==1.16.2
# via onyx
pyparsing==3.2.5
# via httplib2
pypdf==6.7.3
pypdf==6.7.5
# via
# onyx
# unstructured-client

View File

@@ -17,7 +17,7 @@ def test_bedrock_llm_configuration(client: TestClient) -> None:
# Prepare the test request payload
test_request: dict[str, Any] = {
"provider": LlmProviderNames.BEDROCK,
"default_model_name": _DEFAULT_BEDROCK_MODEL,
"model": _DEFAULT_BEDROCK_MODEL,
"api_key": None,
"api_base": None,
"api_version": None,
@@ -44,7 +44,7 @@ def test_bedrock_llm_configuration_invalid_key(client: TestClient) -> None:
# Prepare the test request payload with invalid credentials
test_request: dict[str, Any] = {
"provider": LlmProviderNames.BEDROCK,
"default_model_name": _DEFAULT_BEDROCK_MODEL,
"model": _DEFAULT_BEDROCK_MODEL,
"api_key": None,
"api_base": None,
"api_version": None,

View File

@@ -28,7 +28,6 @@ def ensure_default_llm_provider(db_session: Session) -> None:
provider=LlmProviderNames.OPENAI,
api_key=os.environ.get("OPENAI_API_KEY", "test"),
is_public=True,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini",
@@ -41,7 +40,7 @@ def ensure_default_llm_provider(db_session: Session) -> None:
llm_provider_upsert_request=llm_provider_request,
db_session=db_session,
)
update_default_provider(provider.id, db_session)
update_default_provider(provider.id, "gpt-4o-mini", db_session)
except Exception as exc: # pragma: no cover - only hits on duplicate setup issues
# Rollback to clear the pending transaction state
db_session.rollback()

View File

@@ -47,7 +47,6 @@ def test_answer_with_only_anthropic_provider(
name=provider_name,
provider=LlmProviderNames.ANTHROPIC,
api_key=anthropic_api_key,
default_model_name=anthropic_model,
is_public=True,
groups=[],
model_configurations=[
@@ -59,7 +58,7 @@ def test_answer_with_only_anthropic_provider(
)
try:
update_default_provider(anthropic_provider.id, db_session)
update_default_provider(anthropic_provider.id, anthropic_model, db_session)
test_user = create_test_user(db_session, email_prefix="anthropic_only")
chat_session = create_chat_session(

View File

@@ -9,7 +9,6 @@ from onyx.chat.models import AnswerStreamPart
from onyx.chat.models import StreamingError
from onyx.chat.process_message import handle_stream_message_objects
from onyx.db.chat import create_chat_session
from onyx.db.models import RecencyBiasSetting
from onyx.db.models import User
from onyx.db.persona import upsert_persona
from onyx.server.query_and_chat.models import MessageResponseIDInfo
@@ -74,10 +73,6 @@ def test_stream_chat_message_objects_without_web_search(
user=None, # System persona
name=f"Test Persona {uuid.uuid4()}",
description="Test persona with no tools for web search test",
num_chunks=10.0,
llm_relevance_filter=True,
llm_filter_extraction=True,
recency_bias=RecencyBiasSetting.BASE_DECAY,
llm_model_provider_override=None,
llm_model_version_override=None,
starter_messages=None,

View File

@@ -0,0 +1,257 @@
"""External dependency unit tests for periodic task claiming.
Tests ``_try_claim_task`` and ``_try_run_periodic_task`` against real
PostgreSQL, verifying happy-path behavior and concurrent-access safety.
The claim mechanism uses a transaction-scoped advisory lock + a KVStore
timestamp for cross-instance dedup. The DB session is released before
the task runs, so long-running tasks don't hold connections.
"""
import time
from collections.abc import Generator
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from unittest.mock import MagicMock
from uuid import uuid4
import pytest
from onyx.background.periodic_poller import _PeriodicTaskDef
from onyx.background.periodic_poller import _try_claim_task
from onyx.background.periodic_poller import _try_run_periodic_task
from onyx.background.periodic_poller import PERIODIC_TASK_KV_PREFIX
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.models import KVStore
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from tests.external_dependency_unit.constants import TEST_TENANT_ID
_TEST_LOCK_BASE = 90_000
@pytest.fixture(scope="module", autouse=True)
def _init_engine() -> None:
SqlEngine.init_engine(pool_size=10, max_overflow=5)
def _make_task(
*,
name: str | None = None,
interval: float = 3600,
lock_id: int | None = None,
run_fn: MagicMock | None = None,
) -> _PeriodicTaskDef:
return _PeriodicTaskDef(
name=name if name is not None else f"test-{uuid4().hex[:8]}",
interval_seconds=interval,
lock_id=lock_id if lock_id is not None else _TEST_LOCK_BASE,
run_fn=run_fn if run_fn is not None else MagicMock(),
)
@pytest.fixture(autouse=True)
def _cleanup_kv(
tenant_context: None, # noqa: ARG001
) -> Generator[None, None, None]:
yield
with get_session_with_current_tenant() as db_session:
db_session.query(KVStore).filter(
KVStore.key.like(f"{PERIODIC_TASK_KV_PREFIX}test-%")
).delete(synchronize_session=False)
db_session.commit()
# ------------------------------------------------------------------
# Happy-path: _try_claim_task
# ------------------------------------------------------------------
class TestClaimHappyPath:
def test_first_claim_succeeds(self) -> None:
assert _try_claim_task(_make_task()) is True
def test_first_claim_creates_kv_row(self) -> None:
task = _make_task()
_try_claim_task(task)
with get_session_with_current_tenant() as db_session:
row = (
db_session.query(KVStore)
.filter_by(key=PERIODIC_TASK_KV_PREFIX + task.name)
.first()
)
assert row is not None
assert row.value is not None
def test_second_claim_within_interval_fails(self) -> None:
task = _make_task(interval=3600)
assert _try_claim_task(task) is True
assert _try_claim_task(task) is False
def test_claim_after_interval_succeeds(self) -> None:
task = _make_task(interval=1)
assert _try_claim_task(task) is True
kv_key = PERIODIC_TASK_KV_PREFIX + task.name
with get_session_with_current_tenant() as db_session:
row = db_session.query(KVStore).filter_by(key=kv_key).first()
assert row is not None
row.value = (datetime.now(timezone.utc) - timedelta(seconds=10)).isoformat()
db_session.commit()
assert _try_claim_task(task) is True
# ------------------------------------------------------------------
# Happy-path: _try_run_periodic_task
# ------------------------------------------------------------------
class TestRunHappyPath:
def test_runs_task_and_updates_last_run_at(self) -> None:
mock_fn = MagicMock()
task = _make_task(run_fn=mock_fn)
_try_run_periodic_task(task)
mock_fn.assert_called_once()
assert task.last_run_at > 0
def test_skips_when_in_memory_interval_not_elapsed(self) -> None:
mock_fn = MagicMock()
task = _make_task(run_fn=mock_fn, interval=3600)
task.last_run_at = time.monotonic()
_try_run_periodic_task(task)
mock_fn.assert_not_called()
def test_skips_when_db_claim_blocked(self) -> None:
name = f"test-{uuid4().hex[:8]}"
lock_id = _TEST_LOCK_BASE + 10
_try_claim_task(_make_task(name=name, lock_id=lock_id, interval=3600))
mock_fn = MagicMock()
task = _make_task(name=name, lock_id=lock_id, interval=3600, run_fn=mock_fn)
_try_run_periodic_task(task)
mock_fn.assert_not_called()
def test_task_exception_does_not_propagate(self) -> None:
task = _make_task(run_fn=MagicMock(side_effect=RuntimeError("boom")))
_try_run_periodic_task(task)
def test_claim_committed_before_task_runs(self) -> None:
"""The KV claim must be visible in the DB when run_fn executes."""
task_name = f"test-order-{uuid4().hex[:8]}"
kv_key = PERIODIC_TASK_KV_PREFIX + task_name
claim_visible: list[bool] = []
def check_claim() -> None:
with get_session_with_current_tenant() as db_session:
row = db_session.query(KVStore).filter_by(key=kv_key).first()
claim_visible.append(row is not None and row.value is not None)
task = _PeriodicTaskDef(
name=task_name,
interval_seconds=3600,
lock_id=_TEST_LOCK_BASE + 11,
run_fn=check_claim,
)
_try_run_periodic_task(task)
assert claim_visible == [True]
# ------------------------------------------------------------------
# Concurrency: only one claimer should win
# ------------------------------------------------------------------
class TestClaimConcurrency:
def test_concurrent_claims_single_winner(self) -> None:
"""Many threads claim the same task — exactly one should succeed."""
num_threads = 20
task_name = f"test-race-{uuid4().hex[:8]}"
lock_id = _TEST_LOCK_BASE + 20
def claim() -> bool:
CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
return _try_claim_task(
_PeriodicTaskDef(
name=task_name,
interval_seconds=3600,
lock_id=lock_id,
run_fn=lambda: None,
)
)
results: list[bool] = []
with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = [executor.submit(claim) for _ in range(num_threads)]
for future in as_completed(futures):
results.append(future.result())
winners = sum(1 for r in results if r)
assert winners == 1, f"Expected 1 winner, got {winners}"
def test_concurrent_run_single_execution(self) -> None:
"""Many threads run the same task — run_fn fires exactly once."""
num_threads = 20
task_name = f"test-run-race-{uuid4().hex[:8]}"
lock_id = _TEST_LOCK_BASE + 21
counter = MagicMock()
def run() -> None:
CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
_try_run_periodic_task(
_PeriodicTaskDef(
name=task_name,
interval_seconds=3600,
lock_id=lock_id,
run_fn=counter,
)
)
with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = [executor.submit(run) for _ in range(num_threads)]
for future in as_completed(futures):
future.result()
assert (
counter.call_count == 1
), f"Expected run_fn called once, got {counter.call_count}"
def test_no_errors_under_contention(self) -> None:
"""All threads complete without exceptions under high contention."""
num_threads = 30
task_name = f"test-err-{uuid4().hex[:8]}"
lock_id = _TEST_LOCK_BASE + 22
errors: list[Exception] = []
def claim() -> bool:
CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
return _try_claim_task(
_PeriodicTaskDef(
name=task_name,
interval_seconds=3600,
lock_id=lock_id,
run_fn=lambda: None,
)
)
with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = [executor.submit(claim) for _ in range(num_threads)]
for future in as_completed(futures):
try:
future.result()
except Exception as e:
errors.append(e)
assert errors == [], f"Got {len(errors)} errors: {errors}"

View File

@@ -0,0 +1,352 @@
"""External dependency unit tests for startup recovery (Step 10g).
Seeds ``UserFile`` records in stuck states (PROCESSING, DELETING,
needs_project_sync) then calls ``recover_stuck_user_files`` and verifies
the drain loops pick them up via ``FOR UPDATE SKIP LOCKED``.
Uses real PostgreSQL (via ``db_session`` / ``tenant_context`` fixtures).
The per-file ``*_impl`` functions are mocked so no real file store or
connector is needed — we only verify that recovery finds and dispatches
the correct files.
"""
from collections.abc import Generator
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import UUID
from uuid import uuid4
import pytest
import sqlalchemy as sa
from sqlalchemy.orm import Session
from onyx.background.periodic_poller import recover_stuck_user_files
from onyx.db.enums import UserFileStatus
from onyx.db.models import UserFile
from tests.external_dependency_unit.conftest import create_test_user
from tests.external_dependency_unit.constants import TEST_TENANT_ID
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_IMPL_MODULE = "onyx.background.celery.tasks.user_file_processing.tasks"
def _create_user_file(
db_session: Session,
user_id: object,
*,
status: UserFileStatus = UserFileStatus.PROCESSING,
needs_project_sync: bool = False,
needs_persona_sync: bool = False,
) -> UserFile:
uf = UserFile(
id=uuid4(),
user_id=user_id,
file_id=f"test_file_{uuid4().hex[:8]}",
name=f"test_{uuid4().hex[:8]}.txt",
file_type="text/plain",
status=status,
needs_project_sync=needs_project_sync,
needs_persona_sync=needs_persona_sync,
)
db_session.add(uf)
db_session.commit()
db_session.refresh(uf)
return uf
def _fake_delete_impl(
user_file_id: str, tenant_id: str, redis_locking: bool # noqa: ARG001
) -> None:
"""Mock side-effect: delete the row so the drain loop terminates."""
from onyx.db.engine.sql_engine import get_session_with_current_tenant
with get_session_with_current_tenant() as session:
session.execute(sa.delete(UserFile).where(UserFile.id == UUID(user_file_id)))
session.commit()
def _fake_sync_impl(
user_file_id: str, tenant_id: str, redis_locking: bool # noqa: ARG001
) -> None:
"""Mock side-effect: clear sync flags so the drain loop terminates."""
from onyx.db.engine.sql_engine import get_session_with_current_tenant
with get_session_with_current_tenant() as session:
session.execute(
sa.update(UserFile)
.where(UserFile.id == UUID(user_file_id))
.values(needs_project_sync=False, needs_persona_sync=False)
)
session.commit()
@pytest.fixture()
def _cleanup_user_files(db_session: Session) -> Generator[list[UserFile], None, None]:
"""Track created UserFile rows and delete them after each test."""
created: list[UserFile] = []
yield created
for uf in created:
existing = db_session.get(UserFile, uf.id)
if existing:
db_session.delete(existing)
db_session.commit()
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestRecoverProcessingFiles:
"""Files in PROCESSING status are re-processed via the processing drain loop."""
def test_processing_files_recovered(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
_cleanup_user_files: list[UserFile],
) -> None:
user = create_test_user(db_session, "recovery_proc")
uf = _create_user_file(db_session, user.id, status=UserFileStatus.PROCESSING)
_cleanup_user_files.append(uf)
mock_impl = MagicMock()
with patch(f"{_IMPL_MODULE}.process_user_file_impl", mock_impl):
recover_stuck_user_files(TEST_TENANT_ID)
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
assert (
str(uf.id) in called_ids
), f"Expected file {uf.id} to be recovered but got: {called_ids}"
def test_completed_files_not_recovered(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
_cleanup_user_files: list[UserFile],
) -> None:
user = create_test_user(db_session, "recovery_comp")
uf = _create_user_file(db_session, user.id, status=UserFileStatus.COMPLETED)
_cleanup_user_files.append(uf)
mock_impl = MagicMock()
with patch(f"{_IMPL_MODULE}.process_user_file_impl", mock_impl):
recover_stuck_user_files(TEST_TENANT_ID)
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
assert (
str(uf.id) not in called_ids
), f"COMPLETED file {uf.id} should not have been recovered"
class TestRecoverDeletingFiles:
"""Files in DELETING status are recovered via the delete drain loop."""
def test_deleting_files_recovered(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
_cleanup_user_files: list[UserFile],
) -> None:
user = create_test_user(db_session, "recovery_del")
uf = _create_user_file(db_session, user.id, status=UserFileStatus.DELETING)
# Row is deleted by _fake_delete_impl, so no cleanup needed.
mock_impl = MagicMock(side_effect=_fake_delete_impl)
with patch(f"{_IMPL_MODULE}.delete_user_file_impl", mock_impl):
recover_stuck_user_files(TEST_TENANT_ID)
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
assert (
str(uf.id) in called_ids
), f"Expected file {uf.id} to be recovered for deletion but got: {called_ids}"
class TestRecoverSyncFiles:
"""Files needing project/persona sync are recovered via the sync drain loop."""
def test_needs_project_sync_recovered(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
_cleanup_user_files: list[UserFile],
) -> None:
user = create_test_user(db_session, "recovery_sync")
uf = _create_user_file(
db_session,
user.id,
status=UserFileStatus.COMPLETED,
needs_project_sync=True,
)
_cleanup_user_files.append(uf)
mock_impl = MagicMock(side_effect=_fake_sync_impl)
with patch(f"{_IMPL_MODULE}.project_sync_user_file_impl", mock_impl):
recover_stuck_user_files(TEST_TENANT_ID)
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
assert (
str(uf.id) in called_ids
), f"Expected file {uf.id} to be recovered for sync but got: {called_ids}"
def test_needs_persona_sync_recovered(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
_cleanup_user_files: list[UserFile],
) -> None:
user = create_test_user(db_session, "recovery_psync")
uf = _create_user_file(
db_session,
user.id,
status=UserFileStatus.COMPLETED,
needs_persona_sync=True,
)
_cleanup_user_files.append(uf)
mock_impl = MagicMock(side_effect=_fake_sync_impl)
with patch(f"{_IMPL_MODULE}.project_sync_user_file_impl", mock_impl):
recover_stuck_user_files(TEST_TENANT_ID)
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
assert (
str(uf.id) in called_ids
), f"Expected file {uf.id} to be recovered for persona sync but got: {called_ids}"
class TestRecoveryMultipleFiles:
"""Recovery processes all stuck files in one pass, not just the first."""
def test_multiple_processing_files(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
_cleanup_user_files: list[UserFile],
) -> None:
user = create_test_user(db_session, "recovery_multi")
files = []
for _ in range(3):
uf = _create_user_file(
db_session, user.id, status=UserFileStatus.PROCESSING
)
_cleanup_user_files.append(uf)
files.append(uf)
mock_impl = MagicMock()
with patch(f"{_IMPL_MODULE}.process_user_file_impl", mock_impl):
recover_stuck_user_files(TEST_TENANT_ID)
called_ids = {call.kwargs["user_file_id"] for call in mock_impl.call_args_list}
expected_ids = {str(uf.id) for uf in files}
assert expected_ids.issubset(called_ids), (
f"Expected all {len(files)} files to be recovered. "
f"Missing: {expected_ids - called_ids}"
)
class TestTransientFailures:
"""Drain loops skip failed files, process the rest, and terminate."""
def test_processing_failure_skips_and_continues(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
_cleanup_user_files: list[UserFile],
) -> None:
user = create_test_user(db_session, "fail_proc")
uf_fail = _create_user_file(
db_session, user.id, status=UserFileStatus.PROCESSING
)
uf_ok = _create_user_file(db_session, user.id, status=UserFileStatus.PROCESSING)
_cleanup_user_files.extend([uf_fail, uf_ok])
fail_id = str(uf_fail.id)
def side_effect(
*, user_file_id: str, tenant_id: str, redis_locking: bool # noqa: ARG001
) -> None:
if user_file_id == fail_id:
raise RuntimeError("transient failure")
mock_impl = MagicMock(side_effect=side_effect)
with patch(f"{_IMPL_MODULE}.process_user_file_impl", mock_impl):
recover_stuck_user_files(TEST_TENANT_ID)
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
assert fail_id in called_ids, "Failed file should have been attempted"
assert str(uf_ok.id) in called_ids, "Healthy file should have been processed"
assert called_ids.count(fail_id) == 1, "Failed file retried — infinite loop"
assert called_ids.count(str(uf_ok.id)) == 1
def test_delete_failure_skips_and_continues(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
_cleanup_user_files: list[UserFile],
) -> None:
user = create_test_user(db_session, "fail_del")
uf_fail = _create_user_file(db_session, user.id, status=UserFileStatus.DELETING)
uf_ok = _create_user_file(db_session, user.id, status=UserFileStatus.DELETING)
_cleanup_user_files.append(uf_fail)
fail_id = str(uf_fail.id)
def side_effect(
*, user_file_id: str, tenant_id: str, redis_locking: bool
) -> None:
if user_file_id == fail_id:
raise RuntimeError("transient failure")
_fake_delete_impl(user_file_id, tenant_id, redis_locking)
mock_impl = MagicMock(side_effect=side_effect)
with patch(f"{_IMPL_MODULE}.delete_user_file_impl", mock_impl):
recover_stuck_user_files(TEST_TENANT_ID)
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
assert fail_id in called_ids, "Failed file should have been attempted"
assert str(uf_ok.id) in called_ids, "Healthy file should have been deleted"
assert called_ids.count(fail_id) == 1, "Failed file retried — infinite loop"
assert called_ids.count(str(uf_ok.id)) == 1
def test_sync_failure_skips_and_continues(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
_cleanup_user_files: list[UserFile],
) -> None:
user = create_test_user(db_session, "fail_sync")
uf_fail = _create_user_file(
db_session,
user.id,
status=UserFileStatus.COMPLETED,
needs_project_sync=True,
)
uf_ok = _create_user_file(
db_session,
user.id,
status=UserFileStatus.COMPLETED,
needs_persona_sync=True,
)
_cleanup_user_files.extend([uf_fail, uf_ok])
fail_id = str(uf_fail.id)
def side_effect(
*, user_file_id: str, tenant_id: str, redis_locking: bool
) -> None:
if user_file_id == fail_id:
raise RuntimeError("transient failure")
_fake_sync_impl(user_file_id, tenant_id, redis_locking)
mock_impl = MagicMock(side_effect=side_effect)
with patch(f"{_IMPL_MODULE}.project_sync_user_file_impl", mock_impl):
recover_stuck_user_files(TEST_TENANT_ID)
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
assert fail_id in called_ids, "Failed file should have been attempted"
assert str(uf_ok.id) in called_ids, "Healthy file should have been synced"
assert called_ids.count(fail_id) == 1, "Failed file retried — infinite loop"
assert called_ids.count(str(uf_ok.id)) == 1

View File

@@ -0,0 +1,57 @@
"""Fixtures for cache backend tests.
Requires a running PostgreSQL instance (and Redis for parity tests).
Run with::
python -m dotenv -f .vscode/.env run -- pytest tests/external_dependency_unit/cache/
"""
from collections.abc import Generator
import pytest
from onyx.cache.interface import CacheBackend
from onyx.cache.postgres_backend import PostgresCacheBackend
from onyx.cache.redis_backend import RedisCacheBackend
from onyx.db.engine.sql_engine import SqlEngine
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from tests.external_dependency_unit.constants import TEST_TENANT_ID
@pytest.fixture(scope="session", autouse=True)
def _init_db() -> Generator[None, None, None]:
"""Initialize DB engine. Assumes Postgres has migrations applied (e.g. via docker compose)."""
SqlEngine.init_engine(pool_size=5, max_overflow=2)
yield
@pytest.fixture(autouse=True)
def _tenant_context() -> Generator[None, None, None]:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
try:
yield
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
@pytest.fixture
def pg_cache() -> PostgresCacheBackend:
return PostgresCacheBackend(TEST_TENANT_ID)
@pytest.fixture
def redis_cache() -> RedisCacheBackend:
from onyx.redis.redis_pool import redis_pool
return RedisCacheBackend(redis_pool.get_client(TEST_TENANT_ID))
@pytest.fixture(params=["postgres", "redis"], ids=["postgres", "redis"])
def cache(
request: pytest.FixtureRequest,
pg_cache: PostgresCacheBackend,
redis_cache: RedisCacheBackend,
) -> CacheBackend:
if request.param == "postgres":
return pg_cache
return redis_cache

View File

@@ -0,0 +1,100 @@
"""Parameterized tests that run the same CacheBackend operations against
both Redis and PostgreSQL, asserting identical return values.
Each test runs twice (once per backend) via the ``cache`` fixture defined
in conftest.py.
"""
import time
from uuid import uuid4
from onyx.cache.interface import CacheBackend
from onyx.cache.interface import TTL_KEY_NOT_FOUND
from onyx.cache.interface import TTL_NO_EXPIRY
def _key() -> str:
return f"parity_{uuid4().hex[:12]}"
class TestKVParity:
def test_get_missing(self, cache: CacheBackend) -> None:
assert cache.get(_key()) is None
def test_get_set(self, cache: CacheBackend) -> None:
k = _key()
cache.set(k, b"value")
assert cache.get(k) == b"value"
def test_overwrite(self, cache: CacheBackend) -> None:
k = _key()
cache.set(k, b"a")
cache.set(k, b"b")
assert cache.get(k) == b"b"
def test_set_string(self, cache: CacheBackend) -> None:
k = _key()
cache.set(k, "hello")
assert cache.get(k) == b"hello"
def test_set_int(self, cache: CacheBackend) -> None:
k = _key()
cache.set(k, 42)
assert cache.get(k) == b"42"
def test_delete(self, cache: CacheBackend) -> None:
k = _key()
cache.set(k, b"x")
cache.delete(k)
assert cache.get(k) is None
def test_exists(self, cache: CacheBackend) -> None:
k = _key()
assert not cache.exists(k)
cache.set(k, b"x")
assert cache.exists(k)
class TestTTLParity:
def test_ttl_missing(self, cache: CacheBackend) -> None:
assert cache.ttl(_key()) == TTL_KEY_NOT_FOUND
def test_ttl_no_expiry(self, cache: CacheBackend) -> None:
k = _key()
cache.set(k, b"x")
assert cache.ttl(k) == TTL_NO_EXPIRY
def test_ttl_remaining(self, cache: CacheBackend) -> None:
k = _key()
cache.set(k, b"x", ex=10)
remaining = cache.ttl(k)
assert 8 <= remaining <= 10
def test_set_with_ttl_expires(self, cache: CacheBackend) -> None:
k = _key()
cache.set(k, b"x", ex=1)
assert cache.get(k) == b"x"
time.sleep(1.5)
assert cache.get(k) is None
class TestLockParity:
def test_acquire_release(self, cache: CacheBackend) -> None:
lock = cache.lock(f"parity_lock_{uuid4().hex[:8]}")
assert lock.acquire(blocking=False)
assert lock.owned()
lock.release()
assert not lock.owned()
class TestListParity:
def test_rpush_blpop(self, cache: CacheBackend) -> None:
k = f"parity_list_{uuid4().hex[:8]}"
cache.rpush(k, b"item")
result = cache.blpop([k], timeout=1)
assert result is not None
assert result[1] == b"item"
def test_blpop_timeout(self, cache: CacheBackend) -> None:
result = cache.blpop([f"parity_empty_{uuid4().hex[:8]}"], timeout=1)
assert result is None

View File

@@ -0,0 +1,129 @@
"""Tests for PgRedisKVStore's cache layer integration with CacheBackend.
Verifies that the KV store correctly uses the CacheBackend for caching
in front of PostgreSQL: cache hits, cache misses falling through to PG,
cache population after PG reads, cache invalidation on delete, and
graceful degradation when the cache backend raises.
Requires running PostgreSQL.
"""
import json
from collections.abc import Generator
from unittest.mock import MagicMock
import pytest
from sqlalchemy import delete
from onyx.cache.interface import CacheBackend
from onyx.cache.postgres_backend import PostgresCacheBackend
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.models import CacheStore
from onyx.db.models import KVStore
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.key_value_store.store import PgRedisKVStore
from onyx.key_value_store.store import REDIS_KEY_PREFIX
from tests.external_dependency_unit.constants import TEST_TENANT_ID
@pytest.fixture(autouse=True)
def _clean_kv() -> Generator[None, None, None]:
yield
with get_session_with_tenant(tenant_id=TEST_TENANT_ID) as session:
session.execute(delete(KVStore))
session.execute(delete(CacheStore))
session.commit()
@pytest.fixture
def kv_store(pg_cache: PostgresCacheBackend) -> PgRedisKVStore:
return PgRedisKVStore(cache=pg_cache)
class TestStoreAndLoad:
def test_store_populates_cache_and_pg(
self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend
) -> None:
kv_store.store("k1", {"hello": "world"})
cached = pg_cache.get(REDIS_KEY_PREFIX + "k1")
assert cached is not None
assert json.loads(cached) == {"hello": "world"}
loaded = kv_store.load("k1")
assert loaded == {"hello": "world"}
def test_load_returns_cached_value_without_pg_hit(
self, pg_cache: PostgresCacheBackend
) -> None:
"""If the cache already has the value, PG should not be queried."""
pg_cache.set(REDIS_KEY_PREFIX + "cached_only", json.dumps({"from": "cache"}))
kv = PgRedisKVStore(cache=pg_cache)
assert kv.load("cached_only") == {"from": "cache"}
def test_load_falls_through_to_pg_on_cache_miss(
self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend
) -> None:
kv_store.store("k2", [1, 2, 3])
pg_cache.delete(REDIS_KEY_PREFIX + "k2")
assert pg_cache.get(REDIS_KEY_PREFIX + "k2") is None
loaded = kv_store.load("k2")
assert loaded == [1, 2, 3]
repopulated = pg_cache.get(REDIS_KEY_PREFIX + "k2")
assert repopulated is not None
assert json.loads(repopulated) == [1, 2, 3]
def test_load_with_refresh_cache_skips_cache(
self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend
) -> None:
kv_store.store("k3", "original")
pg_cache.set(REDIS_KEY_PREFIX + "k3", json.dumps("stale"))
loaded = kv_store.load("k3", refresh_cache=True)
assert loaded == "original"
class TestDelete:
def test_delete_removes_from_cache_and_pg(
self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend
) -> None:
kv_store.store("del_me", "bye")
kv_store.delete("del_me")
assert pg_cache.get(REDIS_KEY_PREFIX + "del_me") is None
with pytest.raises(KvKeyNotFoundError):
kv_store.load("del_me")
def test_delete_missing_key_raises(self, kv_store: PgRedisKVStore) -> None:
with pytest.raises(KvKeyNotFoundError):
kv_store.delete("nonexistent")
class TestCacheFailureGracefulDegradation:
def test_store_succeeds_when_cache_set_raises(self) -> None:
failing_cache = MagicMock(spec=CacheBackend)
failing_cache.set.side_effect = ConnectionError("cache down")
kv = PgRedisKVStore(cache=failing_cache)
kv.store("resilient", {"data": True})
working_cache = MagicMock(spec=CacheBackend)
working_cache.get.return_value = None
kv_reader = PgRedisKVStore(cache=working_cache)
loaded = kv_reader.load("resilient")
assert loaded == {"data": True}
def test_load_falls_through_when_cache_get_raises(self) -> None:
failing_cache = MagicMock(spec=CacheBackend)
failing_cache.get.side_effect = ConnectionError("cache down")
failing_cache.set.side_effect = ConnectionError("cache down")
kv = PgRedisKVStore(cache=failing_cache)
kv.store("survive", 42)
loaded = kv.load("survive")
assert loaded == 42

View File

@@ -0,0 +1,229 @@
"""Tests for PostgresCacheBackend against real PostgreSQL.
Covers every method on the backend: KV CRUD, TTL behaviour, advisory
locks (acquire / release / contention), list operations (rpush / blpop),
and the periodic cleanup function.
"""
import time
from uuid import uuid4
from sqlalchemy import select
from onyx.cache.interface import TTL_KEY_NOT_FOUND
from onyx.cache.interface import TTL_NO_EXPIRY
from onyx.cache.postgres_backend import cleanup_expired_cache_entries
from onyx.cache.postgres_backend import PostgresCacheBackend
from onyx.db.models import CacheStore
def _key() -> str:
return f"test_{uuid4().hex[:12]}"
# ------------------------------------------------------------------
# Basic KV
# ------------------------------------------------------------------
class TestKV:
def test_get_set(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"hello")
assert pg_cache.get(k) == b"hello"
def test_get_missing(self, pg_cache: PostgresCacheBackend) -> None:
assert pg_cache.get(_key()) is None
def test_set_overwrite(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"first")
pg_cache.set(k, b"second")
assert pg_cache.get(k) == b"second"
def test_set_string_value(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, "string_val")
assert pg_cache.get(k) == b"string_val"
def test_set_int_value(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, 42)
assert pg_cache.get(k) == b"42"
def test_delete(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"to_delete")
pg_cache.delete(k)
assert pg_cache.get(k) is None
def test_delete_missing_is_noop(self, pg_cache: PostgresCacheBackend) -> None:
pg_cache.delete(_key())
def test_exists(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
assert not pg_cache.exists(k)
pg_cache.set(k, b"x")
assert pg_cache.exists(k)
# ------------------------------------------------------------------
# TTL
# ------------------------------------------------------------------
class TestTTL:
def test_set_with_ttl_expires(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"ephemeral", ex=1)
assert pg_cache.get(k) == b"ephemeral"
time.sleep(1.5)
assert pg_cache.get(k) is None
def test_ttl_no_expiry(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"forever")
assert pg_cache.ttl(k) == TTL_NO_EXPIRY
def test_ttl_missing_key(self, pg_cache: PostgresCacheBackend) -> None:
assert pg_cache.ttl(_key()) == TTL_KEY_NOT_FOUND
def test_ttl_remaining(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"x", ex=10)
remaining = pg_cache.ttl(k)
assert 8 <= remaining <= 10
def test_ttl_expired_key(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"x", ex=1)
time.sleep(1.5)
assert pg_cache.ttl(k) == TTL_KEY_NOT_FOUND
def test_expire_adds_ttl(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"x")
assert pg_cache.ttl(k) == TTL_NO_EXPIRY
pg_cache.expire(k, 10)
assert 8 <= pg_cache.ttl(k) <= 10
def test_exists_respects_ttl(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"x", ex=1)
assert pg_cache.exists(k)
time.sleep(1.5)
assert not pg_cache.exists(k)
# ------------------------------------------------------------------
# Locks
# ------------------------------------------------------------------
class TestLock:
def test_acquire_release(self, pg_cache: PostgresCacheBackend) -> None:
lock = pg_cache.lock(f"lock_{uuid4().hex[:8]}")
assert lock.acquire(blocking=False)
assert lock.owned()
lock.release()
assert not lock.owned()
def test_contention(self, pg_cache: PostgresCacheBackend) -> None:
name = f"contention_{uuid4().hex[:8]}"
lock1 = pg_cache.lock(name)
lock2 = pg_cache.lock(name)
assert lock1.acquire(blocking=False)
assert not lock2.acquire(blocking=False)
lock1.release()
assert lock2.acquire(blocking=False)
lock2.release()
def test_context_manager(self, pg_cache: PostgresCacheBackend) -> None:
with pg_cache.lock(f"ctx_{uuid4().hex[:8]}") as lock:
assert lock.owned()
assert not lock.owned()
def test_blocking_timeout(self, pg_cache: PostgresCacheBackend) -> None:
name = f"timeout_{uuid4().hex[:8]}"
holder = pg_cache.lock(name)
holder.acquire(blocking=False)
waiter = pg_cache.lock(name, timeout=0.3)
start = time.monotonic()
assert not waiter.acquire(blocking=True, blocking_timeout=0.3)
elapsed = time.monotonic() - start
assert elapsed >= 0.25
holder.release()
# ------------------------------------------------------------------
# List (rpush / blpop)
# ------------------------------------------------------------------
class TestList:
def test_rpush_blpop(self, pg_cache: PostgresCacheBackend) -> None:
k = f"list_{uuid4().hex[:8]}"
pg_cache.rpush(k, b"item1")
result = pg_cache.blpop([k], timeout=1)
assert result is not None
assert result == (k.encode(), b"item1")
def test_blpop_timeout(self, pg_cache: PostgresCacheBackend) -> None:
result = pg_cache.blpop([f"empty_{uuid4().hex[:8]}"], timeout=1)
assert result is None
def test_fifo_order(self, pg_cache: PostgresCacheBackend) -> None:
k = f"fifo_{uuid4().hex[:8]}"
pg_cache.rpush(k, b"first")
time.sleep(0.01)
pg_cache.rpush(k, b"second")
r1 = pg_cache.blpop([k], timeout=1)
r2 = pg_cache.blpop([k], timeout=1)
assert r1 is not None and r1[1] == b"first"
assert r2 is not None and r2[1] == b"second"
def test_multiple_keys(self, pg_cache: PostgresCacheBackend) -> None:
k1 = f"mk1_{uuid4().hex[:8]}"
k2 = f"mk2_{uuid4().hex[:8]}"
pg_cache.rpush(k2, b"from_k2")
result = pg_cache.blpop([k1, k2], timeout=1)
assert result is not None
assert result == (k2.encode(), b"from_k2")
# ------------------------------------------------------------------
# Cleanup
# ------------------------------------------------------------------
class TestCleanup:
def test_removes_expired_rows(self, pg_cache: PostgresCacheBackend) -> None:
from onyx.db.engine.sql_engine import get_session_with_current_tenant
k = _key()
pg_cache.set(k, b"stale", ex=1)
time.sleep(1.5)
cleanup_expired_cache_entries()
stmt = select(CacheStore.key).where(CacheStore.key == k)
with get_session_with_current_tenant() as session:
row = session.execute(stmt).first()
assert row is None, "expired row should be physically deleted"
def test_preserves_unexpired_rows(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"fresh", ex=300)
cleanup_expired_cache_entries()
assert pg_cache.get(k) == b"fresh"
def test_preserves_no_ttl_rows(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"permanent")
cleanup_expired_cache_entries()
assert pg_cache.get(k) == b"permanent"

View File

@@ -36,7 +36,6 @@ from onyx.background.celery.tasks.user_file_processing.tasks import (
from onyx.background.celery.tasks.user_file_processing.tasks import (
user_file_project_sync_lock_key,
)
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.enums import UserFileStatus
from onyx.db.models import Persona
from onyx.db.models import Persona__UserFile
@@ -86,12 +85,6 @@ def _create_test_persona(
persona = Persona(
name=f"Test Persona {uuid4().hex[:8]}",
description="Test persona",
num_chunks=10.0,
chunks_above=0,
chunks_below=0,
llm_relevance_filter=False,
llm_filter_extraction=False,
recency_bias=RecencyBiasSetting.NO_DECAY,
system_prompt="You are a test assistant",
task_prompt="Answer the question",
tools=[],
@@ -410,10 +403,6 @@ class TestUpsertPersonaMarksSyncFlag:
user=user,
name=f"persona-{uuid4().hex[:8]}",
description="test",
num_chunks=10.0,
llm_relevance_filter=False,
llm_filter_extraction=False,
recency_bias=RecencyBiasSetting.NO_DECAY,
llm_model_provider_override=None,
llm_model_version_override=None,
starter_messages=None,
@@ -442,10 +431,6 @@ class TestUpsertPersonaMarksSyncFlag:
user=user,
name=f"persona-{uuid4().hex[:8]}",
description="test",
num_chunks=10.0,
llm_relevance_filter=False,
llm_filter_extraction=False,
recency_bias=RecencyBiasSetting.NO_DECAY,
llm_model_provider_override=None,
llm_model_version_override=None,
starter_messages=None,
@@ -461,16 +446,11 @@ class TestUpsertPersonaMarksSyncFlag:
uf_old.needs_persona_sync = False
db_session.commit()
assert persona.num_chunks is not None
# Now update the persona to swap files
upsert_persona(
user=user,
name=persona.name,
description=persona.description,
num_chunks=persona.num_chunks,
llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction,
recency_bias=persona.recency_bias,
llm_model_provider_override=None,
llm_model_version_override=None,
starter_messages=None,
@@ -501,10 +481,6 @@ class TestUpsertPersonaMarksSyncFlag:
user=user,
name=f"persona-{uuid4().hex[:8]}",
description="test",
num_chunks=10.0,
llm_relevance_filter=False,
llm_filter_extraction=False,
recency_bias=RecencyBiasSetting.NO_DECAY,
llm_model_provider_override=None,
llm_model_version_override=None,
starter_messages=None,
@@ -519,15 +495,10 @@ class TestUpsertPersonaMarksSyncFlag:
uf.needs_persona_sync = False
db_session.commit()
assert persona.num_chunks is not None
upsert_persona(
user=user,
name=persona.name,
description=persona.description,
num_chunks=persona.num_chunks,
llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction,
recency_bias=persona.recency_bias,
llm_model_provider_override=None,
llm_model_version_override=None,
starter_messages=None,

View File

@@ -18,7 +18,6 @@ from sqlalchemy.orm import Session
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.enums import UserFileStatus
from onyx.db.models import Persona
from onyx.db.models import Persona__UserFile
@@ -58,12 +57,6 @@ def _create_persona(db_session: Session, user: User) -> Persona:
persona = Persona(
name=f"Test Persona {uuid4().hex[:8]}",
description="Test persona",
num_chunks=10.0,
chunks_above=0,
chunks_below=0,
llm_relevance_filter=False,
llm_filter_extraction=False,
recency_bias=RecencyBiasSetting.NO_DECAY,
system_prompt="test",
task_prompt="test",
tools=[],

View File

@@ -29,6 +29,7 @@ from onyx.server.manage.llm.api import (
test_llm_configuration as run_test_llm_configuration,
)
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from onyx.server.manage.llm.models import TestLLMRequest as LLMTestRequest
@@ -44,15 +45,14 @@ def _create_test_provider(
db_session: Session,
name: str,
api_key: str = "sk-test-key-00000000000000000000000000000000000",
) -> None:
) -> LLMProviderView:
"""Helper to create a test LLM provider in the database."""
upsert_llm_provider(
return upsert_llm_provider(
LLMProviderUpsertRequest(
name=name,
provider=LlmProviderNames.OPENAI,
api_key=api_key,
api_key_changed=True,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(name="gpt-4o-mini", is_visible=True)
],
@@ -102,17 +102,11 @@ class TestLLMConfigurationEndpoint:
# This should complete without exception
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=None, # New provider (not in DB)
provider=LlmProviderNames.OPENAI,
api_key="sk-new-test-key-0000000000000000000000000000",
api_key_changed=True,
custom_config_changed=False,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
model="gpt-4o-mini",
),
_=_create_mock_admin(),
db_session=db_session,
@@ -152,17 +146,11 @@ class TestLLMConfigurationEndpoint:
with pytest.raises(HTTPException) as exc_info:
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=None,
provider=LlmProviderNames.OPENAI,
api_key="sk-invalid-key-00000000000000000000000000",
api_key_changed=True,
custom_config_changed=False,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
model="gpt-4o-mini",
),
_=_create_mock_admin(),
db_session=db_session,
@@ -194,7 +182,9 @@ class TestLLMConfigurationEndpoint:
try:
# First, create the provider in the database
_create_test_provider(db_session, provider_name, api_key=original_api_key)
provider = _create_test_provider(
db_session, provider_name, api_key=original_api_key
)
with patch(
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
@@ -202,17 +192,12 @@ class TestLLMConfigurationEndpoint:
# Test with api_key_changed=False - should use stored key
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=provider_name, # Existing provider
id=provider.id,
provider=LlmProviderNames.OPENAI,
api_key=None, # Not providing a new key
api_key_changed=False, # Using existing key
custom_config_changed=False,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
model="gpt-4o-mini",
),
_=_create_mock_admin(),
db_session=db_session,
@@ -246,7 +231,9 @@ class TestLLMConfigurationEndpoint:
try:
# First, create the provider in the database
_create_test_provider(db_session, provider_name, api_key=original_api_key)
provider = _create_test_provider(
db_session, provider_name, api_key=original_api_key
)
with patch(
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
@@ -254,17 +241,12 @@ class TestLLMConfigurationEndpoint:
# Test with api_key_changed=True - should use new key
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=provider_name, # Existing provider
id=provider.id,
provider=LlmProviderNames.OPENAI,
api_key=new_api_key, # Providing a new key
api_key_changed=True, # Key is being changed
custom_config_changed=False,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
model="gpt-4o-mini",
),
_=_create_mock_admin(),
db_session=db_session,
@@ -297,7 +279,7 @@ class TestLLMConfigurationEndpoint:
try:
# First, create the provider in the database with custom_config
upsert_llm_provider(
provider = upsert_llm_provider(
LLMProviderUpsertRequest(
name=provider_name,
provider=LlmProviderNames.OPENAI,
@@ -305,7 +287,6 @@ class TestLLMConfigurationEndpoint:
api_key_changed=True,
custom_config=original_custom_config,
custom_config_changed=True,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
@@ -321,18 +302,13 @@ class TestLLMConfigurationEndpoint:
# Test with custom_config_changed=False - should use stored config
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=provider_name,
id=provider.id,
provider=LlmProviderNames.OPENAI,
api_key=None,
api_key_changed=False,
custom_config=None, # Not providing new config
custom_config_changed=False, # Using existing config
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
model="gpt-4o-mini",
),
_=_create_mock_admin(),
db_session=db_session,
@@ -368,17 +344,11 @@ class TestLLMConfigurationEndpoint:
for model_name in test_models:
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=None,
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
custom_config_changed=False,
default_model_name=model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=model_name, is_visible=True
)
],
model=model_name,
),
_=_create_mock_admin(),
db_session=db_session,
@@ -442,7 +412,6 @@ class TestDefaultProviderEndpoint:
provider=LlmProviderNames.OPENAI,
api_key=provider_1_api_key,
api_key_changed=True,
default_model_name=provider_1_initial_model,
model_configurations=[
ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True),
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
@@ -452,7 +421,7 @@ class TestDefaultProviderEndpoint:
)
# Set provider 1 as the default provider explicitly
update_default_provider(provider_1.id, db_session)
update_default_provider(provider_1.id, provider_1_initial_model, db_session)
# Step 2: Call run_test_default_provider - should use provider 1's default model
with patch(
@@ -472,7 +441,6 @@ class TestDefaultProviderEndpoint:
provider=LlmProviderNames.OPENAI,
api_key=provider_2_api_key,
api_key_changed=True,
default_model_name=provider_2_default_model,
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
@@ -499,11 +467,11 @@ class TestDefaultProviderEndpoint:
# Step 5: Update provider 1's default model
upsert_llm_provider(
LLMProviderUpsertRequest(
id=provider_1.id,
name=provider_1_name,
provider=LlmProviderNames.OPENAI,
api_key=provider_1_api_key,
api_key_changed=True,
default_model_name=provider_1_updated_model, # Changed
model_configurations=[
ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True),
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
@@ -512,6 +480,9 @@ class TestDefaultProviderEndpoint:
db_session=db_session,
)
# Set provider 1's default model to the updated model
update_default_provider(provider_1.id, provider_1_updated_model, db_session)
# Step 6: Call run_test_default_provider - should use new model on provider 1
with patch(
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
@@ -524,7 +495,7 @@ class TestDefaultProviderEndpoint:
captured_llms.clear()
# Step 7: Change the default provider to provider 2
update_default_provider(provider_2.id, db_session)
update_default_provider(provider_2.id, provider_2_default_model, db_session)
# Step 8: Call run_test_default_provider - should use provider 2
with patch(
@@ -596,7 +567,6 @@ class TestDefaultProviderEndpoint:
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
@@ -605,7 +575,7 @@ class TestDefaultProviderEndpoint:
),
db_session=db_session,
)
update_default_provider(provider.id, db_session)
update_default_provider(provider.id, "gpt-4o-mini", db_session)
# Test should fail
with patch(

View File

@@ -49,7 +49,6 @@ def _create_test_provider(
api_key_changed=True,
api_base=api_base,
custom_config=custom_config,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(name="gpt-4o-mini", is_visible=True)
],
@@ -91,14 +90,14 @@ class TestLLMProviderChanges:
the API key should be blocked.
"""
try:
_create_test_provider(db_session, provider_name)
provider = _create_test_provider(db_session, provider_name)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_base="https://attacker.example.com",
default_model_name="gpt-4o-mini",
)
with pytest.raises(HTTPException) as exc_info:
@@ -125,16 +124,16 @@ class TestLLMProviderChanges:
Changing api_base IS allowed when the API key is also being changed.
"""
try:
_create_test_provider(db_session, provider_name)
provider = _create_test_provider(db_session, provider_name)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-new-key-00000000000000000000000000000000000",
api_key_changed=True,
api_base="https://custom-endpoint.example.com/v1",
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -159,14 +158,16 @@ class TestLLMProviderChanges:
original_api_base = "https://original.example.com/v1"
try:
_create_test_provider(db_session, provider_name, api_base=original_api_base)
provider = _create_test_provider(
db_session, provider_name, api_base=original_api_base
)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_base=original_api_base,
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -190,14 +191,14 @@ class TestLLMProviderChanges:
changes. This allows model-only updates when provider has no custom base URL.
"""
try:
_create_test_provider(db_session, provider_name, api_base=None)
view = _create_test_provider(db_session, provider_name, api_base=None)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=view.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_base="",
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -223,14 +224,16 @@ class TestLLMProviderChanges:
original_api_base = "https://original.example.com/v1"
try:
_create_test_provider(db_session, provider_name, api_base=original_api_base)
provider = _create_test_provider(
db_session, provider_name, api_base=original_api_base
)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_base=None,
default_model_name="gpt-4o-mini",
)
with pytest.raises(HTTPException) as exc_info:
@@ -259,14 +262,14 @@ class TestLLMProviderChanges:
users have full control over their deployment.
"""
try:
_create_test_provider(db_session, provider_name)
provider = _create_test_provider(db_session, provider_name)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_base="https://custom.example.com/v1",
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -297,7 +300,6 @@ class TestLLMProviderChanges:
api_key="sk-new-key-00000000000000000000000000000000000",
api_key_changed=True,
api_base="https://custom.example.com/v1",
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -322,7 +324,7 @@ class TestLLMProviderChanges:
redirect LLM API requests).
"""
try:
_create_test_provider(
provider = _create_test_provider(
db_session,
provider_name,
custom_config={"SOME_CONFIG": "original_value"},
@@ -330,11 +332,11 @@ class TestLLMProviderChanges:
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
custom_config={"OPENAI_API_BASE": "https://attacker.example.com"},
custom_config_changed=True,
default_model_name="gpt-4o-mini",
)
with pytest.raises(HTTPException) as exc_info:
@@ -362,15 +364,15 @@ class TestLLMProviderChanges:
without changing the API key.
"""
try:
_create_test_provider(db_session, provider_name)
provider = _create_test_provider(db_session, provider_name)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
custom_config={"OPENAI_API_BASE": "https://attacker.example.com"},
custom_config_changed=True,
default_model_name="gpt-4o-mini",
)
with pytest.raises(HTTPException) as exc_info:
@@ -399,7 +401,7 @@ class TestLLMProviderChanges:
new_config = {"AWS_REGION_NAME": "us-west-2"}
try:
_create_test_provider(
provider = _create_test_provider(
db_session,
provider_name,
custom_config={"AWS_REGION_NAME": "us-east-1"},
@@ -407,13 +409,13 @@ class TestLLMProviderChanges:
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-new-key-00000000000000000000000000000000000",
api_key_changed=True,
custom_config_changed=True,
custom_config=new_config,
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -438,17 +440,17 @@ class TestLLMProviderChanges:
original_config = {"AWS_REGION_NAME": "us-east-1"}
try:
_create_test_provider(
provider = _create_test_provider(
db_session, provider_name, custom_config=original_config
)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
custom_config=original_config,
custom_config_changed=True,
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -474,7 +476,7 @@ class TestLLMProviderChanges:
new_config = {"AWS_REGION_NAME": "eu-west-1"}
try:
_create_test_provider(
provider = _create_test_provider(
db_session,
provider_name,
custom_config={"AWS_REGION_NAME": "us-east-1"},
@@ -482,10 +484,10 @@ class TestLLMProviderChanges:
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
custom_config=new_config,
default_model_name="gpt-4o-mini",
custom_config_changed=True,
)
@@ -530,14 +532,8 @@ def test_upload_with_custom_config_then_change(
with patch("onyx.server.manage.llm.api.test_llm", side_effect=capture_test_llm):
run_llm_config_test(
LLMTestRequest(
name=name,
provider=provider_name,
default_model_name=default_model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name, is_visible=True
)
],
model=default_model_name,
api_key_changed=False,
custom_config_changed=True,
custom_config=custom_config,
@@ -546,11 +542,10 @@ def test_upload_with_custom_config_then_change(
db_session=db_session,
)
put_llm_provider(
provider = put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=name,
provider=provider_name,
default_model_name=default_model_name,
custom_config=custom_config,
model_configurations=[
ModelConfigurationUpsertRequest(
@@ -569,14 +564,9 @@ def test_upload_with_custom_config_then_change(
# Turn auto mode off
run_llm_config_test(
LLMTestRequest(
name=name,
id=provider.id,
provider=provider_name,
default_model_name=default_model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name, is_visible=True
)
],
model=default_model_name,
api_key_changed=False,
custom_config_changed=False,
),
@@ -586,9 +576,9 @@ def test_upload_with_custom_config_then_change(
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
id=provider.id,
name=name,
provider=provider_name,
default_model_name=default_model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name, is_visible=True
@@ -616,13 +606,13 @@ def test_upload_with_custom_config_then_change(
)
# Check inside the database and check that custom_config is the same as the original
provider = fetch_existing_llm_provider(name=name, db_session=db_session)
if not provider:
db_provider = fetch_existing_llm_provider(name=name, db_session=db_session)
if not db_provider:
assert False, "Provider not found in the database"
assert provider.custom_config == custom_config, (
assert db_provider.custom_config == custom_config, (
f"Expected custom_config {custom_config}, "
f"but got {provider.custom_config}"
f"but got {db_provider.custom_config}"
)
finally:
db_session.rollback()
@@ -642,11 +632,10 @@ def test_preserves_masked_sensitive_custom_config_on_provider_update(
}
try:
put_llm_provider(
view = put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=name,
provider=provider,
default_model_name=default_model_name,
custom_config=original_custom_config,
model_configurations=[
ModelConfigurationUpsertRequest(
@@ -665,9 +654,9 @@ def test_preserves_masked_sensitive_custom_config_on_provider_update(
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
id=view.id,
name=name,
provider=provider,
default_model_name=default_model_name,
custom_config={
"vertex_credentials": _mask_string(
original_custom_config["vertex_credentials"]
@@ -706,7 +695,7 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
) -> None:
"""LLM test should restore masked sensitive custom config values before invocation."""
name = f"test-provider-vertex-test-{uuid4().hex[:8]}"
provider = LlmProviderNames.VERTEX_AI.value
provider_name = LlmProviderNames.VERTEX_AI.value
default_model_name = "gemini-2.5-pro"
original_custom_config = {
"vertex_credentials": '{"type":"service_account","private_key":"REAL_PRIVATE_KEY"}',
@@ -719,11 +708,10 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
return ""
try:
put_llm_provider(
provider = put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=name,
provider=provider,
default_model_name=default_model_name,
provider=provider_name,
custom_config=original_custom_config,
model_configurations=[
ModelConfigurationUpsertRequest(
@@ -742,14 +730,9 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
with patch("onyx.server.manage.llm.api.test_llm", side_effect=capture_test_llm):
run_llm_config_test(
LLMTestRequest(
name=name,
provider=provider,
default_model_name=default_model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name, is_visible=True
)
],
id=provider.id,
provider=provider_name,
model=default_model_name,
api_key_changed=False,
custom_config_changed=True,
custom_config={

View File

@@ -15,9 +15,11 @@ import pytest
from sqlalchemy.orm import Session
from onyx.db.enums import LLMModelFlowType
from onyx.db.llm import fetch_auto_mode_providers
from onyx.db.llm import fetch_default_llm_model
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_llm_provider_view
from onyx.db.llm import remove_llm_provider
from onyx.db.llm import sync_auto_mode_models
from onyx.db.llm import update_default_provider
@@ -135,7 +137,6 @@ class TestAutoModeSyncFeature:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
default_model_name=expected_default_model,
model_configurations=[], # No model configs provided
),
is_creation=True,
@@ -163,13 +164,8 @@ class TestAutoModeSyncFeature:
if mc.name in all_expected_models:
assert mc.is_visible is True, f"Model '{mc.name}' should be visible"
# Verify the default model was set correctly
assert (
provider.default_model_name == expected_default_model
), f"Default model should be '{expected_default_model}'"
# Step 4: Set the provider as default
update_default_provider(provider.id, db_session)
update_default_provider(provider.id, expected_default_model, db_session)
# Step 5: Fetch the default provider and verify
default_model = fetch_default_llm_model(db_session)
@@ -238,7 +234,6 @@ class TestAutoModeSyncFeature:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
default_model_name="gpt-4o",
model_configurations=[],
),
is_creation=True,
@@ -317,7 +312,6 @@ class TestAutoModeSyncFeature:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=False, # Not in auto mode initially
default_model_name="gpt-4",
model_configurations=initial_models,
),
is_creation=True,
@@ -326,13 +320,13 @@ class TestAutoModeSyncFeature:
)
# Verify initial state: all models are visible
provider = fetch_existing_llm_provider(
initial_provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
assert provider.is_auto_mode is False
assert initial_provider is not None
assert initial_provider.is_auto_mode is False
for mc in provider.model_configurations:
for mc in initial_provider.model_configurations:
assert (
mc.is_visible is True
), f"Initial model '{mc.name}' should be visible"
@@ -344,12 +338,12 @@ class TestAutoModeSyncFeature:
):
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
id=initial_provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key=None, # Not changing API key
api_key_changed=False,
is_auto_mode=True, # Now enabling auto mode
default_model_name=auto_mode_default,
model_configurations=[], # Auto mode will sync from config
),
is_creation=False, # This is an update
@@ -360,15 +354,15 @@ class TestAutoModeSyncFeature:
# Step 3: Verify model visibility after auto mode transition
# Expire session cache to force fresh fetch after sync_auto_mode_models committed
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
provider_view = fetch_llm_provider_view(
provider_name=provider_name, db_session=db_session
)
assert provider is not None
assert provider.is_auto_mode is True
assert provider_view is not None
assert provider_view.is_auto_mode is True
# Build a map of model name -> visibility
model_visibility = {
mc.name: mc.is_visible for mc in provider.model_configurations
mc.name: mc.is_visible for mc in provider_view.model_configurations
}
# Models in auto mode config should be visible
@@ -388,9 +382,6 @@ class TestAutoModeSyncFeature:
model_visibility[model_name] is False
), f"Model '{model_name}' not in auto config should NOT be visible"
# Verify the default model was updated
assert provider.default_model_name == auto_mode_default
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)
@@ -432,8 +423,12 @@ class TestAutoModeSyncFeature:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
default_model_name="gpt-4o",
model_configurations=[],
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o",
is_visible=True,
)
],
),
is_creation=True,
_=_create_mock_admin(),
@@ -535,7 +530,6 @@ class TestAutoModeSyncFeature:
api_key=provider_1_api_key,
api_key_changed=True,
is_auto_mode=True,
default_model_name=provider_1_default_model,
model_configurations=[],
),
is_creation=True,
@@ -549,7 +543,7 @@ class TestAutoModeSyncFeature:
name=provider_1_name, db_session=db_session
)
assert provider_1 is not None
update_default_provider(provider_1.id, db_session)
update_default_provider(provider_1.id, provider_1_default_model, db_session)
with patch(
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
@@ -563,7 +557,6 @@ class TestAutoModeSyncFeature:
api_key=provider_2_api_key,
api_key_changed=True,
is_auto_mode=True,
default_model_name=provider_2_default_model,
model_configurations=[],
),
is_creation=True,
@@ -584,7 +577,7 @@ class TestAutoModeSyncFeature:
name=provider_2_name, db_session=db_session
)
assert provider_2 is not None
update_default_provider(provider_2.id, db_session)
update_default_provider(provider_2.id, provider_2_default_model, db_session)
# Step 5: Verify provider 2 is now the default
db_session.expire_all()
@@ -644,7 +637,6 @@ class TestAutoModeMissingFlows:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
default_model_name="gpt-4o",
model_configurations=[],
),
is_creation=True,
@@ -701,3 +693,364 @@ class TestAutoModeMissingFlows:
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)
class TestAutoModeTransitionsAndResync:
"""Tests for auto/manual transitions, config evolution, and sync idempotency."""
def test_auto_to_manual_mode_preserves_models_and_stops_syncing(
self,
db_session: Session,
provider_name: str,
) -> None:
"""Disabling auto mode should preserve the current model list and
prevent future syncs from altering visibility.
Steps:
1. Create provider in auto mode — models synced from config.
2. Update provider to manual mode (is_auto_mode=False).
3. Verify all models remain with unchanged visibility.
4. Call sync_auto_mode_models with a *different* config.
5. Verify fetch_auto_mode_providers excludes this provider, so the
periodic task would never call sync on it.
"""
initial_config = _create_mock_llm_recommendations(
provider=LlmProviderNames.OPENAI,
default_model_name="gpt-4o",
additional_models=["gpt-4o-mini"],
)
try:
# Step 1: Create in auto mode
with patch(
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
return_value=initial_config,
):
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
model_configurations=[],
),
is_creation=True,
_=_create_mock_admin(),
db_session=db_session,
)
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
visibility_before = {
mc.name: mc.is_visible for mc in provider.model_configurations
}
assert visibility_before == {"gpt-4o": True, "gpt-4o-mini": True}
# Step 2: Switch to manual mode
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key=None,
api_key_changed=False,
is_auto_mode=False,
model_configurations=[
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
),
],
),
is_creation=False,
_=_create_mock_admin(),
db_session=db_session,
)
# Step 3: Models unchanged
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
assert provider.is_auto_mode is False
visibility_after = {
mc.name: mc.is_visible for mc in provider.model_configurations
}
assert visibility_after == visibility_before
# Step 4-5: Provider excluded from auto mode queries
auto_providers = fetch_auto_mode_providers(db_session)
auto_provider_ids = {p.id for p in auto_providers}
assert provider.id not in auto_provider_ids
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)
def test_resync_adds_new_and_hides_removed_models(
self,
db_session: Session,
provider_name: str,
) -> None:
"""When the GitHub config changes between syncs, a subsequent sync
should add newly listed models and hide models that were removed.
Steps:
1. Create provider in auto mode with config v1: [gpt-4o, gpt-4o-mini].
2. Sync with config v2: [gpt-4o, gpt-4-turbo] (gpt-4o-mini removed,
gpt-4-turbo added).
3. Verify gpt-4o still visible, gpt-4o-mini hidden, gpt-4-turbo added
and visible.
"""
config_v1 = _create_mock_llm_recommendations(
provider=LlmProviderNames.OPENAI,
default_model_name="gpt-4o",
additional_models=["gpt-4o-mini"],
)
config_v2 = _create_mock_llm_recommendations(
provider=LlmProviderNames.OPENAI,
default_model_name="gpt-4o",
additional_models=["gpt-4-turbo"],
)
try:
# Step 1: Create with config v1
with patch(
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
return_value=config_v1,
):
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
model_configurations=[],
),
is_creation=True,
_=_create_mock_admin(),
db_session=db_session,
)
# Step 2: Re-sync with config v2
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
changes = sync_auto_mode_models(
db_session=db_session,
provider=provider,
llm_recommendations=config_v2,
)
assert changes > 0
# Step 3: Verify
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
visibility = {
mc.name: mc.is_visible for mc in provider.model_configurations
}
# gpt-4o: still in config -> visible
assert visibility["gpt-4o"] is True
# gpt-4o-mini: removed from config -> hidden (not deleted)
assert "gpt-4o-mini" in visibility, "Removed model should still exist in DB"
assert visibility["gpt-4o-mini"] is False
# gpt-4-turbo: newly added -> visible
assert visibility["gpt-4-turbo"] is True
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)
def test_sync_is_idempotent(
self,
db_session: Session,
provider_name: str,
) -> None:
"""Running sync twice with the same config should produce zero
changes on the second call."""
config = _create_mock_llm_recommendations(
provider=LlmProviderNames.OPENAI,
default_model_name="gpt-4o",
additional_models=["gpt-4o-mini", "gpt-4-turbo"],
)
try:
with patch(
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
return_value=config,
):
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
model_configurations=[],
),
is_creation=True,
_=_create_mock_admin(),
db_session=db_session,
)
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
# First explicit sync (may report changes if creation already synced)
sync_auto_mode_models(
db_session=db_session,
provider=provider,
llm_recommendations=config,
)
# Snapshot state after first sync
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
snapshot = {
mc.name: (mc.is_visible, mc.display_name)
for mc in provider.model_configurations
}
# Second sync — should be a no-op
changes = sync_auto_mode_models(
db_session=db_session,
provider=provider,
llm_recommendations=config,
)
assert (
changes == 0
), f"Expected 0 changes on idempotent re-sync, got {changes}"
# State should be identical
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
current = {
mc.name: (mc.is_visible, mc.display_name)
for mc in provider.model_configurations
}
assert current == snapshot
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)
def test_default_model_hidden_when_removed_from_config(
self,
db_session: Session,
provider_name: str,
) -> None:
"""When the current default model is removed from the config, sync
should hide it. The default model flow row should still exist (it
points at the ModelConfiguration), but the model is no longer visible.
Steps:
1. Create provider with config: default=gpt-4o, additional=[gpt-4o-mini].
2. Set gpt-4o as the global default.
3. Re-sync with config: default=gpt-4o-mini (gpt-4o removed entirely).
4. Verify gpt-4o is hidden, gpt-4o-mini is visible, and
fetch_default_llm_model still returns a result (the flow row persists).
"""
config_v1 = _create_mock_llm_recommendations(
provider=LlmProviderNames.OPENAI,
default_model_name="gpt-4o",
additional_models=["gpt-4o-mini"],
)
config_v2 = _create_mock_llm_recommendations(
provider=LlmProviderNames.OPENAI,
default_model_name="gpt-4o-mini",
additional_models=[],
)
try:
with patch(
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
return_value=config_v1,
):
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
model_configurations=[],
),
is_creation=True,
_=_create_mock_admin(),
db_session=db_session,
)
# Step 2: Set gpt-4o as global default
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
update_default_provider(provider.id, "gpt-4o", db_session)
default_before = fetch_default_llm_model(db_session)
assert default_before is not None
assert default_before.name == "gpt-4o"
# Step 3: Re-sync with config v2 (gpt-4o removed)
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
changes = sync_auto_mode_models(
db_session=db_session,
provider=provider,
llm_recommendations=config_v2,
)
assert changes > 0
# Step 4: Verify visibility
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
visibility = {
mc.name: mc.is_visible for mc in provider.model_configurations
}
assert visibility["gpt-4o"] is False, "Removed default should be hidden"
assert visibility["gpt-4o-mini"] is True, "New default should be visible"
# The LLMModelFlow row for gpt-4o still exists (is_default=True),
# but the model is hidden. fetch_default_llm_model filters on
# is_visible=True, so it should NOT return gpt-4o.
db_session.expire_all()
default_after = fetch_default_llm_model(db_session)
assert (
default_after is None or default_after.name != "gpt-4o"
), "Hidden model should not be returned as the default"
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)

View File

@@ -64,7 +64,6 @@ def _create_provider(
name=name,
provider=provider,
api_key="sk-ant-api03-...",
default_model_name="claude-3-5-sonnet-20240620",
is_public=is_public,
model_configurations=[
ModelConfigurationUpsertRequest(
@@ -154,7 +153,9 @@ def test_user_sends_message_to_private_provider(
)
_create_provider(db_session, LlmProviderNames.GOOGLE, "private-provider", False)
update_default_provider(public_provider_id, db_session)
update_default_provider(
public_provider_id, "claude-3-5-sonnet-20240620", db_session
)
try:
# Create chat session

View File

@@ -42,7 +42,6 @@ def _create_llm_provider_and_model(
name=provider_name,
provider="openai",
api_key="test-api-key",
default_model_name=model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=model_name,

View File

@@ -22,7 +22,6 @@ from sqlalchemy.orm import Session
from slack_sdk.errors import SlackApiError
from onyx.configs.constants import FederatedConnectorSource
from onyx.context.search.enums import RecencyBiasSetting
from onyx.context.search.federated.slack_search import fetch_and_cache_channel_metadata
from onyx.db.models import DocumentSet
from onyx.db.models import FederatedConnector
@@ -55,11 +54,6 @@ def _create_test_persona_with_slack_config(db_session: Session) -> Persona | Non
persona = Persona(
name=f"test_slack_persona_{unique_id}",
description="Test persona for Slack federated search",
chunks_above=0,
chunks_below=0,
llm_relevance_filter=True,
llm_filter_extraction=True,
recency_bias=RecencyBiasSetting.AUTO,
system_prompt="You are a helpful assistant.",
task_prompt="Answer the user's question based on the provided context.",
)
@@ -434,7 +428,6 @@ class TestSlackBotFederatedSearch:
name=f"test-llm-provider-{uuid4().hex[:8]}",
provider=LlmProviderNames.OPENAI,
api_key=api_key,
default_model_name="gpt-4o",
is_public=True,
model_configurations=[
ModelConfigurationUpsertRequest(
@@ -448,7 +441,7 @@ class TestSlackBotFederatedSearch:
db_session=db_session,
)
update_default_provider(provider_view.id, db_session)
update_default_provider(provider_view.id, "gpt-4o", db_session)
def _teardown_common_mocks(self, patches: list) -> None:
"""Stop all patches"""
@@ -819,11 +812,6 @@ def test_slack_channel_config_eager_loads_persona(db_session: Session) -> None:
persona = Persona(
name=f"test_eager_load_persona_{unique_id}",
description="Test persona for eager loading test",
chunks_above=0,
chunks_below=0,
llm_relevance_filter=True,
llm_filter_extraction=True,
recency_bias=RecencyBiasSetting.AUTO,
system_prompt="You are a helpful assistant.",
task_prompt="Answer the user's question.",
)

View File

@@ -21,7 +21,6 @@ import pytest
from sqlalchemy.orm import Session
from onyx.chat.emitter import get_default_emitter
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.enums import MCPAuthenticationPerformer
from onyx.db.enums import MCPAuthenticationType
from onyx.db.enums import MCPTransport
@@ -47,12 +46,6 @@ def _create_test_persona_with_mcp_tool(
persona = Persona(
name=f"Test MCP Persona {uuid4().hex[:8]}",
description="Test persona with MCP tools",
num_chunks=10.0,
chunks_above=0,
chunks_below=0,
llm_relevance_filter=False,
llm_filter_extraction=False,
recency_bias=RecencyBiasSetting.NO_DECAY,
system_prompt="You are a helpful assistant",
task_prompt="Answer the user's question",
tools=tools,

View File

@@ -17,7 +17,6 @@ import pytest
from sqlalchemy.orm import Session
from onyx.chat.emitter import get_default_emitter
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.models import OAuthAccount
from onyx.db.models import OAuthConfig
from onyx.db.models import Persona
@@ -57,12 +56,6 @@ def _create_test_persona(db_session: Session, user: User, tools: list[Tool]) ->
persona = Persona(
name=f"Test Persona {uuid4().hex[:8]}",
description="Test persona",
num_chunks=10.0,
chunks_above=0,
chunks_below=0,
llm_relevance_filter=False,
llm_filter_extraction=False,
recency_bias=RecencyBiasSetting.NO_DECAY,
system_prompt="You are a helpful assistant",
task_prompt="Answer the user's question",
tools=tools,

View File

@@ -933,6 +933,7 @@ from unittest.mock import patch
import pytest
from fastapi import UploadFile
from fastapi.background import BackgroundTasks
from sqlalchemy.orm import Session
from starlette.datastructures import Headers
@@ -1139,6 +1140,7 @@ def test_code_interpreter_receives_chat_files(
# Upload a test CSV
csv_content = b"name,age,city\nAlice,30,NYC\nBob,25,SF\n"
result = upload_user_files(
bg_tasks=BackgroundTasks(),
files=[
UploadFile(
file=io.BytesIO(csv_content),

View File

@@ -4,10 +4,12 @@ from uuid import uuid4
import requests
from onyx.llm.constants import LlmProviderNames
from onyx.server.manage.llm.models import DefaultModel
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestUser
@@ -32,7 +34,6 @@ class LLMProviderManager:
llm_provider = LLMProviderUpsertRequest(
name=name or f"test-provider-{uuid4()}",
provider=provider or LlmProviderNames.OPENAI,
default_model_name=default_model_name or "gpt-4o-mini",
api_key=api_key or os.environ["OPENAI_API_KEY"],
api_base=api_base,
api_version=api_version,
@@ -65,7 +66,7 @@ class LLMProviderManager:
name=response_data["name"],
provider=response_data["provider"],
api_key=response_data["api_key"],
default_model_name=response_data["default_model_name"],
default_model_name=default_model_name or "gpt-4o-mini",
is_public=response_data["is_public"],
is_auto_mode=response_data.get("is_auto_mode", False),
groups=response_data["groups"],
@@ -75,9 +76,19 @@ class LLMProviderManager:
)
if set_as_default:
if default_model_name is None:
default_model_name = "gpt-4o-mini"
set_default_response = requests.post(
f"{API_SERVER_URL}/admin/llm/provider/{llm_response.json()['id']}/default",
headers=user_performing_action.headers,
f"{API_SERVER_URL}/admin/llm/default",
json={
"provider_id": response_data["id"],
"model_name": default_model_name,
},
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
set_default_response.raise_for_status()
@@ -104,7 +115,7 @@ class LLMProviderManager:
headers=user_performing_action.headers,
)
response.raise_for_status()
return [LLMProviderView(**ug) for ug in response.json()]
return [LLMProviderView(**p) for p in response.json()["providers"]]
@staticmethod
def verify(
@@ -113,7 +124,11 @@ class LLMProviderManager:
verify_deleted: bool = False,
) -> None:
all_llm_providers = LLMProviderManager.get_all(user_performing_action)
default_model = LLMProviderManager.get_default_model(user_performing_action)
for fetched_llm_provider in all_llm_providers:
model_names = [
model.name for model in fetched_llm_provider.model_configurations
]
if llm_provider.id == fetched_llm_provider.id:
if verify_deleted:
raise ValueError(
@@ -126,11 +141,30 @@ class LLMProviderManager:
if (
fetched_llm_groups == llm_provider_groups
and llm_provider.provider == fetched_llm_provider.provider
and llm_provider.default_model_name
== fetched_llm_provider.default_model_name
and (
default_model is None or default_model.model_name in model_names
)
and llm_provider.is_public == fetched_llm_provider.is_public
and set(fetched_llm_provider.personas) == set(llm_provider.personas)
):
return
if not verify_deleted:
raise ValueError(f"LLM Provider {llm_provider.id} not found")
@staticmethod
def get_default_model(
user_performing_action: DATestUser | None = None,
) -> DefaultModel | None:
response = requests.get(
f"{API_SERVER_URL}/admin/llm/provider",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
default_text = response.json().get("default_text")
if default_text is None:
return None
return DefaultModel(**default_text)

View File

@@ -3,7 +3,6 @@ from uuid import uuid4
import requests
from onyx.context.search.enums import RecencyBiasSetting
from onyx.server.features.persona.models import FullPersonaSnapshot
from onyx.server.features.persona.models import PersonaUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
@@ -20,11 +19,7 @@ class PersonaManager:
description: str | None = None,
system_prompt: str | None = None,
task_prompt: str | None = None,
num_chunks: float = 5,
llm_relevance_filter: bool = True,
is_public: bool = True,
llm_filter_extraction: bool = True,
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO,
datetime_aware: bool = False,
document_set_ids: list[int] | None = None,
tool_ids: list[int] | None = None,
@@ -35,6 +30,7 @@ class PersonaManager:
label_ids: list[int] | None = None,
user_file_ids: list[str] | None = None,
display_priority: int | None = None,
featured: bool = False,
) -> DATestPersona:
name = name or f"test-persona-{uuid4()}"
description = description or f"Description for {name}"
@@ -47,11 +43,7 @@ class PersonaManager:
system_prompt=system_prompt,
task_prompt=task_prompt,
datetime_aware=datetime_aware,
num_chunks=num_chunks,
llm_relevance_filter=llm_relevance_filter,
is_public=is_public,
llm_filter_extraction=llm_filter_extraction,
recency_bias=recency_bias,
document_set_ids=document_set_ids or [],
tool_ids=tool_ids or [],
llm_model_provider_override=llm_model_provider_override,
@@ -61,6 +53,7 @@ class PersonaManager:
label_ids=label_ids or [],
user_file_ids=user_file_ids or [],
display_priority=display_priority,
featured=featured,
)
response = requests.post(
@@ -75,11 +68,7 @@ class PersonaManager:
id=persona_data["id"],
name=name,
description=description,
num_chunks=num_chunks,
llm_relevance_filter=llm_relevance_filter,
is_public=is_public,
llm_filter_extraction=llm_filter_extraction,
recency_bias=recency_bias,
system_prompt=system_prompt,
task_prompt=task_prompt,
datetime_aware=datetime_aware,
@@ -90,6 +79,7 @@ class PersonaManager:
users=users or [],
groups=groups or [],
label_ids=label_ids or [],
featured=featured,
)
@staticmethod
@@ -100,11 +90,7 @@ class PersonaManager:
description: str | None = None,
system_prompt: str | None = None,
task_prompt: str | None = None,
num_chunks: float | None = None,
llm_relevance_filter: bool | None = None,
is_public: bool | None = None,
llm_filter_extraction: bool | None = None,
recency_bias: RecencyBiasSetting | None = None,
datetime_aware: bool = False,
document_set_ids: list[int] | None = None,
tool_ids: list[int] | None = None,
@@ -113,6 +99,7 @@ class PersonaManager:
users: list[str] | None = None,
groups: list[int] | None = None,
label_ids: list[int] | None = None,
featured: bool | None = None,
) -> DATestPersona:
system_prompt = system_prompt or f"System prompt for {persona.name}"
task_prompt = task_prompt or f"Task prompt for {persona.name}"
@@ -123,13 +110,7 @@ class PersonaManager:
system_prompt=system_prompt,
task_prompt=task_prompt,
datetime_aware=datetime_aware,
num_chunks=num_chunks or persona.num_chunks,
llm_relevance_filter=llm_relevance_filter or persona.llm_relevance_filter,
is_public=persona.is_public if is_public is None else is_public,
llm_filter_extraction=(
llm_filter_extraction or persona.llm_filter_extraction
),
recency_bias=recency_bias or persona.recency_bias,
document_set_ids=document_set_ids or persona.document_set_ids,
tool_ids=tool_ids or persona.tool_ids,
llm_model_provider_override=(
@@ -141,6 +122,7 @@ class PersonaManager:
users=[UUID(user) for user in (users or persona.users)],
groups=groups or persona.groups,
label_ids=label_ids or persona.label_ids,
featured=featured if featured is not None else persona.featured,
)
response = requests.patch(
@@ -155,16 +137,12 @@ class PersonaManager:
id=updated_persona_data["id"],
name=updated_persona_data["name"],
description=updated_persona_data["description"],
num_chunks=updated_persona_data["num_chunks"],
llm_relevance_filter=updated_persona_data["llm_relevance_filter"],
is_public=updated_persona_data["is_public"],
llm_filter_extraction=updated_persona_data["llm_filter_extraction"],
recency_bias=recency_bias or persona.recency_bias,
system_prompt=system_prompt,
task_prompt=task_prompt,
datetime_aware=datetime_aware,
document_set_ids=updated_persona_data["document_sets"],
tool_ids=updated_persona_data["tools"],
document_set_ids=[ds["id"] for ds in updated_persona_data["document_sets"]],
tool_ids=[t["id"] for t in updated_persona_data["tools"]],
llm_model_provider_override=updated_persona_data[
"llm_model_provider_override"
],
@@ -173,7 +151,8 @@ class PersonaManager:
],
users=[user["email"] for user in updated_persona_data["users"]],
groups=updated_persona_data["groups"],
label_ids=updated_persona_data["labels"],
label_ids=[label["id"] for label in updated_persona_data["labels"]],
featured=updated_persona_data["featured"],
)
@staticmethod
@@ -222,32 +201,13 @@ class PersonaManager:
fetched_persona.description,
)
)
if fetched_persona.num_chunks != persona.num_chunks:
mismatches.append(
("num_chunks", persona.num_chunks, fetched_persona.num_chunks)
)
if fetched_persona.llm_relevance_filter != persona.llm_relevance_filter:
mismatches.append(
(
"llm_relevance_filter",
persona.llm_relevance_filter,
fetched_persona.llm_relevance_filter,
)
)
if fetched_persona.is_public != persona.is_public:
mismatches.append(
("is_public", persona.is_public, fetched_persona.is_public)
)
if (
fetched_persona.llm_filter_extraction
!= persona.llm_filter_extraction
):
if fetched_persona.featured != persona.featured:
mismatches.append(
(
"llm_filter_extraction",
persona.llm_filter_extraction,
fetched_persona.llm_filter_extraction,
)
("featured", persona.featured, fetched_persona.featured)
)
if (
fetched_persona.llm_model_provider_override

View File

@@ -0,0 +1,66 @@
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
class ScimClient:
"""HTTP client for making authenticated SCIM v2 requests."""
@staticmethod
def _headers(raw_token: str) -> dict[str, str]:
return {
**GENERAL_HEADERS,
"Authorization": f"Bearer {raw_token}",
}
@staticmethod
def get(path: str, raw_token: str) -> requests.Response:
return requests.get(
f"{API_SERVER_URL}/scim/v2{path}",
headers=ScimClient._headers(raw_token),
timeout=60,
)
@staticmethod
def post(path: str, raw_token: str, json: dict) -> requests.Response:
return requests.post(
f"{API_SERVER_URL}/scim/v2{path}",
json=json,
headers=ScimClient._headers(raw_token),
timeout=60,
)
@staticmethod
def put(path: str, raw_token: str, json: dict) -> requests.Response:
return requests.put(
f"{API_SERVER_URL}/scim/v2{path}",
json=json,
headers=ScimClient._headers(raw_token),
timeout=60,
)
@staticmethod
def patch(path: str, raw_token: str, json: dict) -> requests.Response:
return requests.patch(
f"{API_SERVER_URL}/scim/v2{path}",
json=json,
headers=ScimClient._headers(raw_token),
timeout=60,
)
@staticmethod
def delete(path: str, raw_token: str) -> requests.Response:
return requests.delete(
f"{API_SERVER_URL}/scim/v2{path}",
headers=ScimClient._headers(raw_token),
timeout=60,
)
@staticmethod
def get_no_auth(path: str) -> requests.Response:
return requests.get(
f"{API_SERVER_URL}/scim/v2{path}",
headers=GENERAL_HEADERS,
timeout=60,
)

View File

@@ -1,7 +1,6 @@
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestScimToken
from tests.integration.common_utils.test_models import DATestUser
@@ -51,29 +50,3 @@ class ScimTokenManager:
created_at=data["created_at"],
last_used_at=data.get("last_used_at"),
)
@staticmethod
def get_scim_headers(raw_token: str) -> dict[str, str]:
return {
**GENERAL_HEADERS,
"Authorization": f"Bearer {raw_token}",
}
@staticmethod
def scim_get(
path: str,
raw_token: str,
) -> requests.Response:
return requests.get(
f"{API_SERVER_URL}/scim/v2{path}",
headers=ScimTokenManager.get_scim_headers(raw_token),
timeout=60,
)
@staticmethod
def scim_get_no_auth(path: str) -> requests.Response:
return requests.get(
f"{API_SERVER_URL}/scim/v2{path}",
headers=GENERAL_HEADERS,
timeout=60,
)

View File

@@ -10,7 +10,6 @@ from pydantic import Field
from onyx.auth.schemas import UserRole
from onyx.configs.constants import MessageType
from onyx.configs.constants import QAFeedbackType
from onyx.context.search.enums import RecencyBiasSetting
from onyx.context.search.models import SavedSearchDoc
from onyx.context.search.models import SearchDoc
from onyx.db.enums import AccessType
@@ -128,7 +127,7 @@ class DATestLLMProvider(BaseModel):
name: str
provider: str
api_key: str
default_model_name: str
default_model_name: str | None = None
is_public: bool
is_auto_mode: bool = False
groups: list[int]
@@ -162,11 +161,7 @@ class DATestPersona(BaseModel):
id: int
name: str
description: str
num_chunks: float
llm_relevance_filter: bool
is_public: bool
llm_filter_extraction: bool
recency_bias: RecencyBiasSetting
document_set_ids: list[int]
tool_ids: list[int]
llm_model_provider_override: str | None
@@ -174,6 +169,7 @@ class DATestPersona(BaseModel):
users: list[str]
groups: list[int]
label_ids: list[int]
featured: bool = False
# Embedded prompt fields (no longer separate prompt_ids)
system_prompt: str | None = None

View File

@@ -8,7 +8,6 @@ from collections.abc import Generator
import pytest
from sqlalchemy.orm import Session
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.discord_bot import bulk_create_channel_configs
from onyx.db.discord_bot import create_discord_bot_config
from onyx.db.discord_bot import create_guild_config
@@ -36,14 +35,8 @@ def _create_test_persona(db_session: Session, persona_id: int, name: str) -> Per
id=persona_id,
name=name,
description="Test persona for Discord bot tests",
num_chunks=5.0,
chunks_above=1,
chunks_below=1,
llm_relevance_filter=False,
llm_filter_extraction=False,
recency_bias=RecencyBiasSetting.FAVOR_RECENT,
is_visible=True,
is_default_persona=False,
featured=False,
deleted=False,
builtin_persona=False,
)

View File

@@ -414,6 +414,24 @@ def test_mock_connector_checkpoint_recovery(
)
assert finished_index_attempt.status == IndexingStatus.FAILED
# Pause the connector immediately to prevent check_for_indexing from
# creating automatic retry attempts while we reset the mock server.
# Without this, the INITIAL_INDEXING status causes immediate retries
# that would consume (or fail against) the mock server before we can
# set up the recovery behavior.
CCPairManager.pause_cc_pair(cc_pair, user_performing_action=admin_user)
# Collect all index attempt IDs created so far (the initial one plus
# any automatic retries that may have started before the pause took effect).
all_prior_attempt_ids: list[int] = []
index_attempts_page = IndexAttemptManager.get_index_attempt_page(
cc_pair_id=cc_pair.id,
page=0,
page_size=100,
user_performing_action=admin_user,
)
all_prior_attempt_ids = [ia.id for ia in index_attempts_page.items]
# Verify initial state: both docs should be indexed
with get_session_with_current_tenant() as db_session:
documents = DocumentManager.fetch_documents_for_cc_pair(
@@ -465,17 +483,14 @@ def test_mock_connector_checkpoint_recovery(
)
assert response.status_code == 200
# After the failure, the connector is in repeated error state and paused.
# Set the manual indexing trigger first (while paused), then unpause.
# This ensures the trigger is set before CHECK_FOR_INDEXING runs, which will
# prevent the connector from being re-paused when repeated error state is detected.
# Set the manual indexing trigger, then unpause to allow the recovery run.
CCPairManager.run_once(
cc_pair, from_beginning=False, user_performing_action=admin_user
)
CCPairManager.unpause_cc_pair(cc_pair, user_performing_action=admin_user)
recovery_index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
cc_pair_id=cc_pair.id,
index_attempts_to_ignore=[initial_index_attempt.id],
index_attempts_to_ignore=all_prior_attempt_ids,
user_performing_action=admin_user,
)
IndexAttemptManager.wait_for_index_attempt_completion(

View File

@@ -130,8 +130,8 @@ def test_repeated_error_state_detection_and_recovery(
# )
break
if time.monotonic() - start_time > 30:
assert False, "CC pair did not enter repeated error state within 30 seconds"
if time.monotonic() - start_time > 90:
assert False, "CC pair did not enter repeated error state within 90 seconds"
time.sleep(2)

View File

@@ -42,12 +42,10 @@ def _create_provider_with_api(
llm_provider_data = {
"name": name,
"provider": provider_type,
"default_model_name": default_model,
"api_key": "test-api-key-for-auto-mode-testing",
"api_base": None,
"api_version": None,
"custom_config": None,
"fast_default_model_name": default_model,
"is_public": True,
"is_auto_mode": is_auto_mode,
"groups": [],
@@ -72,7 +70,7 @@ def _get_provider_by_id(admin_user: DATestUser, provider_id: int) -> dict:
headers=admin_user.headers,
)
response.raise_for_status()
for provider in response.json():
for provider in response.json()["providers"]:
if provider["id"] == provider_id:
return provider
raise ValueError(f"Provider with id {provider_id} not found")
@@ -219,15 +217,6 @@ def test_auto_mode_provider_gets_synced_from_github_config(
"is_visible"
], "Outdated model should not be visible after sync"
# Verify default model was set from GitHub config
expected_default = (
default_model["name"] if isinstance(default_model, dict) else default_model
)
assert synced_provider["default_model_name"] == expected_default, (
f"Default model should be {expected_default}, "
f"got {synced_provider['default_model_name']}"
)
def test_manual_mode_provider_not_affected_by_auto_sync(
reset: None, # noqa: ARG001
@@ -273,7 +262,3 @@ def test_manual_mode_provider_not_affected_by_auto_sync(
f"Manual mode provider models should not change. "
f"Initial: {initial_models}, Current: {current_models}"
)
assert (
updated_provider["default_model_name"] == custom_model
), f"Manual mode default model should remain {custom_model}"

View File

@@ -4,22 +4,22 @@ import pytest
import requests
from sqlalchemy.orm import Session
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import LLMModelFlowType
from onyx.db.llm import can_user_access_llm_provider
from onyx.db.llm import fetch_user_group_ids
from onyx.db.models import LLMModelFlow
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import LLMProvider as LLMProviderModel
from onyx.db.models import LLMProvider__Persona
from onyx.db.models import LLMProvider__UserGroup
from onyx.db.models import ModelConfiguration
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.llm.constants import LlmProviderNames
from onyx.llm.factory import get_llm_for_persona
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.persona import PersonaManager
@@ -41,24 +41,30 @@ def _create_llm_provider(
is_public: bool,
is_default: bool,
) -> LLMProviderModel:
provider = LLMProviderModel(
name=name,
provider=LlmProviderNames.OPENAI,
api_key=None,
api_base=None,
api_version=None,
custom_config=None,
default_model_name=default_model_name,
deployment_name=None,
is_public=is_public,
# Use None instead of False to avoid unique constraint violation
# The is_default_provider column has unique=True, so only one True and one False allowed
is_default_provider=is_default if is_default else None,
is_default_vision_provider=False,
default_vision_model=None,
_provider = upsert_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=name,
provider=LlmProviderNames.OPENAI,
api_key=None,
api_base=None,
api_version=None,
custom_config=None,
is_public=is_public,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name,
is_visible=True,
)
],
),
db_session=db_session,
)
db_session.add(provider)
db_session.flush()
if is_default:
update_default_provider(_provider.id, default_model_name, db_session)
provider = db_session.get(LLMProviderModel, _provider.id)
if not provider:
raise ValueError(f"Provider {name} not found")
return provider
@@ -71,12 +77,6 @@ def _create_persona(
persona = Persona(
name=name,
description=f"{name} description",
num_chunks=5,
chunks_above=2,
chunks_below=2,
llm_relevance_filter=True,
llm_filter_extraction=True,
recency_bias=RecencyBiasSetting.AUTO,
llm_model_provider_override=provider_name,
llm_model_version_override="gpt-4o-mini",
system_prompt="System prompt",
@@ -243,6 +243,116 @@ def test_can_user_access_llm_provider_or_logic(
)
def test_public_provider_with_persona_restrictions(
users: tuple[DATestUser, DATestUser],
) -> None:
"""Public providers should still enforce persona restrictions.
Regression test for the bug where is_public=True caused
can_user_access_llm_provider() to return True immediately,
bypassing persona whitelist checks entirely.
"""
admin_user, _basic_user = users
with get_session_with_current_tenant() as db_session:
# Public provider with persona restrictions
public_restricted = _create_llm_provider(
db_session,
name="public-persona-restricted",
default_model_name="gpt-4o",
is_public=True,
is_default=True,
)
whitelisted_persona = _create_persona(
db_session,
name="whitelisted-persona",
provider_name=public_restricted.name,
)
non_whitelisted_persona = _create_persona(
db_session,
name="non-whitelisted-persona",
provider_name=public_restricted.name,
)
# Only whitelist one persona
db_session.add(
LLMProvider__Persona(
llm_provider_id=public_restricted.id,
persona_id=whitelisted_persona.id,
)
)
db_session.flush()
db_session.refresh(public_restricted)
admin_model = db_session.get(User, admin_user.id)
assert admin_model is not None
admin_group_ids = fetch_user_group_ids(db_session, admin_model)
# Whitelisted persona — should be allowed
assert can_user_access_llm_provider(
public_restricted,
admin_group_ids,
whitelisted_persona,
)
# Non-whitelisted persona — should be denied despite is_public=True
assert not can_user_access_llm_provider(
public_restricted,
admin_group_ids,
non_whitelisted_persona,
)
# No persona context (e.g. global provider list) — should be denied
# because provider has persona restrictions set
assert not can_user_access_llm_provider(
public_restricted,
admin_group_ids,
persona=None,
)
def test_public_provider_without_persona_restrictions(
users: tuple[DATestUser, DATestUser],
) -> None:
"""Public providers with no persona restrictions remain accessible to all."""
admin_user, basic_user = users
with get_session_with_current_tenant() as db_session:
public_unrestricted = _create_llm_provider(
db_session,
name="public-unrestricted",
default_model_name="gpt-4o",
is_public=True,
is_default=True,
)
any_persona = _create_persona(
db_session,
name="any-persona",
provider_name=public_unrestricted.name,
)
admin_model = db_session.get(User, admin_user.id)
basic_model = db_session.get(User, basic_user.id)
assert admin_model is not None
assert basic_model is not None
admin_group_ids = fetch_user_group_ids(db_session, admin_model)
basic_group_ids = fetch_user_group_ids(db_session, basic_model)
# Any user, any persona — all allowed
assert can_user_access_llm_provider(
public_unrestricted, admin_group_ids, any_persona
)
assert can_user_access_llm_provider(
public_unrestricted, basic_group_ids, any_persona
)
assert can_user_access_llm_provider(
public_unrestricted, admin_group_ids, persona=None
)
def test_get_llm_for_persona_falls_back_when_access_denied(
users: tuple[DATestUser, DATestUser],
) -> None:
@@ -270,24 +380,6 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
provider_name=restricted_provider.name,
)
# Set up ModelConfiguration + LLMModelFlow so get_default_llm() can
# resolve the default provider when the fallback path is triggered.
default_model_config = ModelConfiguration(
llm_provider_id=default_provider.id,
name=default_provider.default_model_name,
is_visible=True,
)
db_session.add(default_model_config)
db_session.flush()
db_session.add(
LLMModelFlow(
model_configuration_id=default_model_config.id,
llm_model_flow_type=LLMModelFlowType.CHAT,
is_default=True,
)
)
db_session.flush()
access_group = UserGroup(name="persona-group")
db_session.add(access_group)
db_session.flush()
@@ -321,13 +413,19 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
persona=persona,
user=admin_model,
)
assert allowed_llm.config.model_name == restricted_provider.default_model_name
assert (
allowed_llm.config.model_name
== restricted_provider.model_configurations[0].name
)
fallback_llm = get_llm_for_persona(
persona=persona,
user=basic_model,
)
assert fallback_llm.config.model_name == default_provider.default_model_name
assert (
fallback_llm.config.model_name
== default_provider.model_configurations[0].name
)
def test_list_llm_provider_basics_excludes_non_public_unrestricted(
@@ -346,6 +444,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
name="public-provider",
is_public=True,
set_as_default=True,
default_model_name="gpt-4o",
user_performing_action=admin_user,
)
@@ -365,7 +464,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
headers=basic_user.headers,
)
assert response.status_code == 200
providers = response.json()
providers = response.json()["providers"]
provider_names = [p["name"] for p in providers]
# Public provider should be visible
@@ -380,7 +479,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
headers=admin_user.headers,
)
assert admin_response.status_code == 200
admin_providers = admin_response.json()
admin_providers = admin_response.json()["providers"]
admin_provider_names = [p["name"] for p in admin_providers]
assert public_provider.name in admin_provider_names
@@ -396,6 +495,7 @@ def test_provider_delete_clears_persona_references(reset: None) -> None: # noqa
name="default-provider",
is_public=True,
set_as_default=True,
default_model_name="gpt-4o",
user_performing_action=admin_user,
)

View File

@@ -107,7 +107,7 @@ def test_authorized_persona_access_returns_filtered_providers(
# Should succeed
assert response.status_code == 200
providers = response.json()
providers = response.json()["providers"]
# Should include the restricted provider since basic_user can access the persona
provider_names = [p["name"] for p in providers]
@@ -140,7 +140,7 @@ def test_persona_id_zero_applies_rbac(
# Should succeed (persona_id=0 refers to default persona, which is public)
assert response.status_code == 200
providers = response.json()
providers = response.json()["providers"]
# Should NOT include the restricted provider since basic_user is not in group2
provider_names = [p["name"] for p in providers]
@@ -182,7 +182,7 @@ def test_admin_can_query_any_persona(
# Should succeed - admins can access any persona
assert response.status_code == 200
providers = response.json()
providers = response.json()["providers"]
# Should include the restricted provider
provider_names = [p["name"] for p in providers]
@@ -223,7 +223,7 @@ def test_public_persona_accessible_to_all(
# Should succeed
assert response.status_code == 200
providers = response.json()
providers = response.json()["providers"]
# Should return the public provider
assert len(providers) > 0

View File

@@ -25,7 +25,7 @@ def test_cold_startup_default_assistant() -> None:
result = db_session.execute(
text(
"""
SELECT id, name, builtin_persona, is_default_persona, deleted
SELECT id, name, builtin_persona, featured, deleted
FROM persona
WHERE builtin_persona = true
ORDER BY id
@@ -40,7 +40,7 @@ def test_cold_startup_default_assistant() -> None:
assert default[0] == 0, "Default assistant should have ID 0"
assert default[1] == "Assistant", "Should be named 'Assistant'"
assert default[2] is True, "Should be builtin"
assert default[3] is True, "Should be default"
assert default[3] is True, "Should be featured"
assert default[4] is False, "Should not be deleted"
# Check tools are properly associated

View File

@@ -195,11 +195,7 @@ def _base_persona_body(**overrides: object) -> dict:
"description": "test",
"system_prompt": "test",
"task_prompt": "",
"num_chunks": 5,
"is_public": True,
"recency_bias": "auto",
"llm_filter_extraction": False,
"llm_relevance_filter": False,
"datetime_aware": False,
"document_set_ids": [],
"tool_ids": [],

Some files were not shown because too many files have changed in this diff Show More