Compare commits

...

51 Commits

Author SHA1 Message Date
Raunak Bhagat
8f5d7e271a refactor: migrate ModalHeader to Content layout
Modal.Header now uses opal Content component for icon + title rendering.
Description passed to Content directly with a hidden
DialogPrimitive.Description for accessibility. Close button absolutely
positioned per Figma mocks.
2026-03-01 12:06:18 -08:00
Raunak Bhagat
bb6e20614d refactor(opal): split ContentLg into ContentXl + ContentLg
- ContentXl: variant="heading" (icon row on top, flex-col) with
  moreIcon1/moreIcon2 support
- ContentLg: simplified to always flex-row (variant="section")
- Section preset font updated to font-heading-h3-muted
- Renamed type aliases: XlContentProps, LgContentProps, MdContentProps,
  SmContentProps
- Renamed internal layout files to size-based names (ContentLg, ContentMd,
  ContentSm)
2026-03-01 12:05:56 -08:00
Danelegend
bd9319e592 feat: LLM Provider Rework (#8761)
Co-authored-by: Raunak Bhagat <r@rabh.io>
2026-02-28 01:29:49 +00:00
Nikolas Garza
db5955d6f2 fix(ee): show Access Restricted page when seat limit exceeded (#8877) 2026-02-28 01:26:00 +00:00
Raunak Bhagat
5e447440ea refactor(Suggestions): migrate to opal Interactive + Content (#8881) 2026-02-27 23:39:20 +00:00
Justin Tahara
78c6ca39b8 fix(minio): No cURL in minio container (#8876) 2026-02-27 22:37:42 +00:00
Raunak Bhagat
71a7cf09b3 refactor(opal): migrate LineItemLayout to Content/ContentAction (#8824) 2026-02-27 22:27:09 +00:00
dependabot[bot]
91d30a0156 chore(deps): bump actions/download-artifact from 4.2.1 to 7.0.0 (#8474)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-02-27 22:11:03 +00:00
dependabot[bot]
7b30752767 chore(deps): bump rollup from 4.52.5 to 4.59.0 in /web (#8782)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-02-27 21:57:10 +00:00
Justin Tahara
4450ecf07c fix(gong): Respecting Retry Timeout Header (#8866) 2026-02-27 21:45:31 +00:00
Danelegend
0e6b766996 feat: Add python tool as default for default persona (#8857) 2026-02-27 21:32:55 +00:00
dependabot[bot]
12c8cd338b chore(deps): bump werkzeug from 3.1.5 to 3.1.6 (#8615)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-02-27 21:08:33 +00:00
dependabot[bot]
ad5688bf65 chore(deps-dev): bump rollup from 4.55.1 to 4.59.0 in /widget (#8863)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-27 21:02:20 +00:00
Jamison Lahman
d2deefd1f1 chore(whitelabeling): always show sidebar icon without logo icon (#8860)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-02-27 20:36:11 +00:00
Jamison Lahman
18b90d405d chore(deps): upgrade fastapi: 0.128.0->0.133.1 (#8862) 2026-02-27 20:26:27 +00:00
Raunak Bhagat
8394e8837b feat(opal): extract widthVariant to shared and add to Content (#8859) 2026-02-27 19:50:32 +00:00
Jamison Lahman
f06df891c4 chore(fe): InputSelect has a min-width (#8858) 2026-02-27 19:20:37 +00:00
Wenxi
d6d5e72c18 feat(ods): whois utility to find tenant_ids and admin emails (#8855) 2026-02-27 18:21:29 +00:00
Danelegend
449f5d62f9 fix: Code output extending over thinking bounds (#8837) 2026-02-27 08:26:54 +00:00
Yuhong Sun
4d256c5666 chore: remove instance of Assistant from frontend (#8848)
Co-authored-by: Nik <nikolas.garza5@gmail.com>
2026-02-27 04:22:28 +00:00
Danelegend
2e53496f46 feat: Code interpreter admin page visuals (#8729) 2026-02-27 04:01:02 +00:00
acaprau
63a206706a docs(best practices): Add comment about import-time side effects and main.py files (#8820)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-02-27 01:29:56 +00:00
Nikolas Garza
28427b3e5f fix(metrics): restore default HTTP request counter and histogram metrics (#8842) 2026-02-27 00:53:22 +00:00
Justin Tahara
3cafcd8a5e chore(llm): add OpenRouter nightly tests (#8818) 2026-02-26 23:54:25 +00:00
Justin Tahara
f2c50b7bb5 chore(llm): add Ollama nightly tests (#8817) 2026-02-26 23:28:40 +00:00
Jamison Lahman
6b28c6bbfc fix(fe): Search Actions popover has consistent hover states (#8826) 2026-02-26 23:16:09 +00:00
Justin Tahara
226e801665 chore(llm): add Azure nightly tests (#8816) 2026-02-26 23:05:03 +00:00
Justin Tahara
be13aa1310 chore(llm): add Vertex AI nightly tests (#8813) 2026-02-26 22:38:05 +00:00
Nikolas Garza
45d38c4906 feat(metrics): add per-tenant Prometheus metrics (#8822) 2026-02-26 22:37:35 +00:00
Danelegend
8aab518532 fix: Admin page modal centering excludes sidebar (#8823) 2026-02-26 22:27:58 +00:00
Nikolas Garza
da6ce10e86 test(scim): add integration tests for SCIM token management (#8819) 2026-02-26 22:22:16 +00:00
Nikolas Garza
aaf8253520 fix(ee): show subscription text on expired access page for cloud users (#8804) 2026-02-26 22:15:44 +00:00
Jamison Lahman
7c7f81b164 chore(fe): add feature agent to editor page (#8814) 2026-02-26 22:12:20 +00:00
Justin Tahara
2d4a3c72e9 chore(llm): Nightly Bedrock Tests (#8812) 2026-02-26 22:10:31 +00:00
acaprau
7c51712018 fix(db ssl): Remove import-time side effect of creating SSL context if IAM enabled (#8811) 2026-02-26 21:37:13 +00:00
Evan Lohn
aa5614695d feat: sharepoint tenant avoid org get (#8802) 2026-02-26 21:28:56 +00:00
Jamison Lahman
8d7255d3c4 chore(fe): support featured agents w/o being public (#8809) 2026-02-26 21:16:23 +00:00
Evan Lohn
d403498f48 feat: context injection unification (#8687) 2026-02-26 21:11:19 +00:00
dependabot[bot]
9ef3095c17 chore(deps): bump pypdf from 6.6.2 to 6.7.3 (#8808)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-02-26 20:42:01 +00:00
Justin Tahara
a39e93a0cb chore(llm): LLM Integration Tests Generic Setup (#8803) 2026-02-26 19:59:19 +00:00
Jamison Lahman
46d73cdfee fix(docker): prefer user runtime docker socket (#8799) 2026-02-26 10:55:44 -08:00
Raunak Bhagat
1e04ce78e0 feat(opal): add Hoverable compound component (#8798) 2026-02-26 17:08:53 +00:00
Jamison Lahman
f9b81c1725 feat(agents): share agents with labels or featured (#8742)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-02-26 16:21:05 +00:00
SubashMohan
3bc1b89fee fix(memory): timeline UI alignment issues and highlighting issue (#8753) 2026-02-26 08:46:43 +00:00
Nikolas Garza
01743d99d4 fix(billing): handle manual license users without Stripe subscription (#8787) 2026-02-26 08:07:14 +00:00
acaprau
092c1db7e0 chore(opensearch): Allow programatic schema updates (#8794) 2026-02-26 07:49:56 +00:00
acaprau
40ac0d859a chore(opensearch): OpenSearchClient implements context manager, also closes on del (#8781)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-02-26 07:38:16 +00:00
SubashMohan
929e58361f fix: resolve OAuth token manager using masked secrets (#8673) 2026-02-26 07:06:51 +00:00
SubashMohan
6d472df7c5 fix(timeline): Fix double-collapse and improve tool status messages (#8751) 2026-02-26 07:05:48 +00:00
acaprau
cfa7acd904 chore(opensearch): MT cloud should verify index on document index init, and do cluster setup once at start (#8776) 2026-02-26 06:42:06 +00:00
Danelegend
5c5a6f943b chore: deprecate llm provider fields (#8783) 2026-02-26 05:27:28 +00:00
268 changed files with 10094 additions and 3314 deletions

View File

@@ -9,7 +9,8 @@ inputs:
required: true
provider-api-key:
description: "API key for NIGHTLY_LLM_API_KEY"
required: true
required: false
default: ""
strict:
description: "String true/false for NIGHTLY_LLM_STRICT"
required: true
@@ -17,6 +18,14 @@ inputs:
description: "Optional NIGHTLY_LLM_API_BASE"
required: false
default: ""
api-version:
description: "Optional NIGHTLY_LLM_API_VERSION"
required: false
default: ""
deployment-name:
description: "Optional NIGHTLY_LLM_DEPLOYMENT_NAME"
required: false
default: ""
custom-config-json:
description: "Optional NIGHTLY_LLM_CUSTOM_CONFIG_JSON"
required: false
@@ -59,6 +68,7 @@ runs:
DISABLE_TELEMETRY=true
INTEGRATION_TESTS_MODE=true
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
AWS_REGION_NAME=us-west-2
ONYX_BACKEND_IMAGE=${ECR_CACHE}:nightly-llm-it-backend-${RUN_ID}
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:nightly-llm-it-model-server-${RUN_ID}
EOF2
@@ -82,6 +92,8 @@ runs:
NIGHTLY_LLM_PROVIDER: ${{ inputs.provider }}
NIGHTLY_LLM_API_KEY: ${{ inputs.provider-api-key }}
NIGHTLY_LLM_API_BASE: ${{ inputs.api-base }}
NIGHTLY_LLM_API_VERSION: ${{ inputs.api-version }}
NIGHTLY_LLM_DEPLOYMENT_NAME: ${{ inputs.deployment-name }}
NIGHTLY_LLM_CUSTOM_CONFIG_JSON: ${{ inputs.custom-config-json }}
NIGHTLY_LLM_STRICT: ${{ inputs.strict }}
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
@@ -91,11 +103,6 @@ runs:
max_attempts: 2
retry_wait_seconds: 10
command: |
if [ -z "${MODELS}" ]; then
echo "Input 'models' must be non-empty for provider '${NIGHTLY_LLM_PROVIDER}'."
exit 1
fi
docker run --rm --network onyx_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
@@ -110,10 +117,13 @@ runs:
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e TEST_WEB_HOSTNAME=test-runner \
-e AWS_REGION_NAME=us-west-2 \
-e NIGHTLY_LLM_PROVIDER="${NIGHTLY_LLM_PROVIDER}" \
-e NIGHTLY_LLM_MODELS="${MODELS}" \
-e NIGHTLY_LLM_API_KEY="${NIGHTLY_LLM_API_KEY}" \
-e NIGHTLY_LLM_API_BASE="${NIGHTLY_LLM_API_BASE}" \
-e NIGHTLY_LLM_API_VERSION="${NIGHTLY_LLM_API_VERSION}" \
-e NIGHTLY_LLM_DEPLOYMENT_NAME="${NIGHTLY_LLM_DEPLOYMENT_NAME}" \
-e NIGHTLY_LLM_CUSTOM_CONFIG_JSON="${NIGHTLY_LLM_CUSTOM_CONFIG_JSON}" \
-e NIGHTLY_LLM_STRICT="${NIGHTLY_LLM_STRICT}" \
${RUNS_ON_ECR_CACHE}:nightly-llm-it-${RUN_ID} \

View File

@@ -1,44 +0,0 @@
name: Nightly LLM Provider Chat Tests (OpenAI)
concurrency:
group: Nightly-LLM-Provider-Chat-OpenAI-${{ github.workflow }}-${{ github.ref_name }}
cancel-in-progress: true
on:
schedule:
# Runs daily at 10:30 UTC (2:30 AM PST / 3:30 AM PDT)
- cron: "30 10 * * *"
workflow_dispatch:
permissions:
contents: read
jobs:
openai-provider-chat-test:
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
with:
provider: openai
models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
strict: true
secrets:
provider_api_key: ${{ secrets.OPENAI_API_KEY }}
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
notify-slack-on-failure:
needs: [openai-provider-chat-test]
if: failure() && github.event_name == 'schedule'
runs-on: ubuntu-slim
timeout-minutes: 5
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Send Slack notification
uses: ./.github/actions/slack-notify
with:
webhook-url: ${{ secrets.SLACK_WEBHOOK }}
failed-jobs: openai-provider-chat-test
title: "🚨 Scheduled OpenAI Provider Chat Tests failed!"
ref-name: ${{ github.ref_name }}

View File

@@ -0,0 +1,56 @@
name: Nightly LLM Provider Chat Tests
concurrency:
group: Nightly-LLM-Provider-Chat-${{ github.workflow }}-${{ github.ref_name }}
cancel-in-progress: true
on:
schedule:
# Runs daily at 10:30 UTC (2:30 AM PST / 3:30 AM PDT)
- cron: "30 10 * * *"
workflow_dispatch:
permissions:
contents: read
jobs:
provider-chat-test:
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
with:
openai_models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
anthropic_models: ${{ vars.NIGHTLY_LLM_ANTHROPIC_MODELS }}
bedrock_models: ${{ vars.NIGHTLY_LLM_BEDROCK_MODELS }}
vertex_ai_models: ${{ vars.NIGHTLY_LLM_VERTEX_AI_MODELS }}
azure_models: ${{ vars.NIGHTLY_LLM_AZURE_MODELS }}
azure_api_base: ${{ vars.NIGHTLY_LLM_AZURE_API_BASE }}
ollama_models: ${{ vars.NIGHTLY_LLM_OLLAMA_MODELS }}
openrouter_models: ${{ vars.NIGHTLY_LLM_OPENROUTER_MODELS }}
strict: true
secrets:
openai_api_key: ${{ secrets.OPENAI_API_KEY }}
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
bedrock_api_key: ${{ secrets.BEDROCK_API_KEY }}
vertex_ai_custom_config_json: ${{ secrets.NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON }}
azure_api_key: ${{ secrets.AZURE_API_KEY }}
ollama_api_key: ${{ secrets.OLLAMA_API_KEY }}
openrouter_api_key: ${{ secrets.OPENROUTER_API_KEY }}
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
notify-slack-on-failure:
needs: [provider-chat-test]
if: failure() && github.event_name == 'schedule'
runs-on: ubuntu-slim
timeout-minutes: 5
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Send Slack notification
uses: ./.github/actions/slack-notify
with:
webhook-url: ${{ secrets.SLACK_WEBHOOK }}
failed-jobs: provider-chat-test
title: "🚨 Scheduled LLM Provider Chat Tests failed!"
ref-name: ${{ github.ref_name }}

View File

@@ -114,8 +114,10 @@ jobs:
- name: Mark workflow as failed if cherry-pick failed
if: steps.gate.outputs.should_cherrypick == 'true' && steps.run_cherry_pick.outputs.status == 'failure'
env:
CHERRY_PICK_REASON: ${{ steps.run_cherry_pick.outputs.reason }}
run: |
echo "::error::Automated cherry-pick failed (${{ steps.run_cherry_pick.outputs.reason }})."
echo "::error::Automated cherry-pick failed (${CHERRY_PICK_REASON})."
exit 1
notify-slack-on-cherry-pick-failure:

View File

@@ -603,7 +603,7 @@ jobs:
pull-requests: write
steps:
- name: Download visual diff summaries
uses: actions/download-artifact@95815c38cf2ff2164869cbab79da8d1f422bc89e # ratchet:actions/download-artifact@v4
uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131
with:
pattern: screenshot-diff-summary-*
path: summaries/

View File

@@ -89,6 +89,10 @@ env:
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ vars.SHAREPOINT_CLIENT_DIRECTORY_ID }}
SHAREPOINT_SITE: ${{ vars.SHAREPOINT_SITE }}
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
# Github
ACCESS_TOKEN_GITHUB: ${{ secrets.ACCESS_TOKEN_GITHUB }}

View File

@@ -3,33 +3,66 @@ name: Reusable Nightly LLM Provider Chat Tests
on:
workflow_call:
inputs:
provider:
description: "Provider slug passed to NIGHTLY_LLM_PROVIDER (e.g. openai, anthropic)"
required: true
openai_models:
description: "Comma-separated models for openai"
required: false
default: ""
type: string
models:
description: "Comma-separated model list passed to NIGHTLY_LLM_MODELS"
required: true
anthropic_models:
description: "Comma-separated models for anthropic"
required: false
default: ""
type: string
bedrock_models:
description: "Comma-separated models for bedrock"
required: false
default: ""
type: string
vertex_ai_models:
description: "Comma-separated models for vertex_ai"
required: false
default: ""
type: string
azure_models:
description: "Comma-separated models for azure"
required: false
default: ""
type: string
ollama_models:
description: "Comma-separated models for ollama_chat"
required: false
default: ""
type: string
openrouter_models:
description: "Comma-separated models for openrouter"
required: false
default: ""
type: string
azure_api_base:
description: "API base for azure provider"
required: false
default: ""
type: string
strict:
description: "Pass-through value for NIGHTLY_LLM_STRICT"
description: "Default NIGHTLY_LLM_STRICT passed to tests"
required: false
default: true
type: boolean
api_base:
description: "Optional NIGHTLY_LLM_API_BASE override"
required: false
default: ""
type: string
custom_config_json:
description: "Optional NIGHTLY_LLM_CUSTOM_CONFIG_JSON override"
required: false
default: ""
type: string
secrets:
provider_api_key:
description: "Provider API key passed to NIGHTLY_LLM_API_KEY"
required: true
openai_api_key:
required: false
anthropic_api_key:
required: false
bedrock_api_key:
required: false
vertex_ai_custom_config_json:
required: false
azure_api_key:
required: false
ollama_api_key:
required: false
openrouter_api_key:
required: false
DOCKER_USERNAME:
required: true
DOCKER_TOKEN:
@@ -38,29 +71,8 @@ on:
permissions:
contents: read
env:
NIGHTLY_LLM_PROVIDER: ${{ inputs.provider }}
NIGHTLY_LLM_MODELS: ${{ inputs.models }}
NIGHTLY_LLM_API_KEY: ${{ secrets.provider_api_key }}
NIGHTLY_LLM_API_BASE: ${{ inputs.api_base }}
NIGHTLY_LLM_CUSTOM_CONFIG_JSON: ${{ inputs.custom_config_json }}
NIGHTLY_LLM_STRICT: ${{ inputs.strict && 'true' || 'false' }}
jobs:
validate-inputs:
# NOTE: Keep this cheap and fail before image builds if required inputs are missing.
runs-on: ubuntu-slim
timeout-minutes: 5
steps:
- name: Validate required nightly provider inputs
run: |
if [ -z "${NIGHTLY_LLM_MODELS}" ]; then
echo "Input 'models' must be non-empty for provider '${NIGHTLY_LLM_PROVIDER}'."
exit 1
fi
build-backend-image:
needs: [validate-inputs]
runs-on:
[
runs-on,
@@ -90,7 +102,6 @@ jobs:
docker-no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' && 'true' || 'false' }}
build-model-server-image:
needs: [validate-inputs]
runs-on:
[
runs-on,
@@ -119,7 +130,6 @@ jobs:
docker-token: ${{ secrets.DOCKER_TOKEN }}
build-integration-image:
needs: [validate-inputs]
runs-on:
[
runs-on,
@@ -149,11 +159,75 @@ jobs:
provider-chat-test:
needs:
[build-backend-image, build-model-server-image, build-integration-image]
[
build-backend-image,
build-model-server-image,
build-integration-image,
]
strategy:
fail-fast: false
matrix:
include:
- provider: openai
models: ${{ inputs.openai_models }}
api_key_secret: openai_api_key
custom_config_secret: ""
api_base: ""
api_version: ""
deployment_name: ""
required: true
- provider: anthropic
models: ${{ inputs.anthropic_models }}
api_key_secret: anthropic_api_key
custom_config_secret: ""
api_base: ""
api_version: ""
deployment_name: ""
required: true
- provider: bedrock
models: ${{ inputs.bedrock_models }}
api_key_secret: bedrock_api_key
custom_config_secret: ""
api_base: ""
api_version: ""
deployment_name: ""
required: false
- provider: vertex_ai
models: ${{ inputs.vertex_ai_models }}
api_key_secret: ""
custom_config_secret: vertex_ai_custom_config_json
api_base: ""
api_version: ""
deployment_name: ""
required: false
- provider: azure
models: ${{ inputs.azure_models }}
api_key_secret: azure_api_key
custom_config_secret: ""
api_base: ${{ inputs.azure_api_base }}
api_version: "2025-04-01-preview"
deployment_name: ""
required: false
- provider: ollama_chat
models: ${{ inputs.ollama_models }}
api_key_secret: ollama_api_key
custom_config_secret: ""
api_base: "https://ollama.com"
api_version: ""
deployment_name: ""
required: false
- provider: openrouter
models: ${{ inputs.openrouter_models }}
api_key_secret: openrouter_api_key
custom_config_secret: ""
api_base: "https://openrouter.ai/api/v1"
api_version: ""
deployment_name: ""
required: false
runs-on:
- runs-on
- runner=4cpu-linux-arm64
- "run-id=${{ github.run_id }}-nightly-${{ inputs.provider }}-provider-chat-test"
- "run-id=${{ github.run_id }}-nightly-${{ matrix.provider }}-provider-chat-test"
- extras=ecr-cache
timeout-minutes: 45
steps:
@@ -167,12 +241,14 @@ jobs:
- name: Run nightly provider chat test
uses: ./.github/actions/run-nightly-provider-chat-test
with:
provider: ${{ env.NIGHTLY_LLM_PROVIDER }}
models: ${{ env.NIGHTLY_LLM_MODELS }}
provider-api-key: ${{ secrets.provider_api_key }}
strict: ${{ env.NIGHTLY_LLM_STRICT }}
api-base: ${{ env.NIGHTLY_LLM_API_BASE }}
custom-config-json: ${{ env.NIGHTLY_LLM_CUSTOM_CONFIG_JSON }}
provider: ${{ matrix.provider }}
models: ${{ matrix.models }}
provider-api-key: ${{ matrix.api_key_secret && secrets[matrix.api_key_secret] || '' }}
strict: ${{ inputs.strict && 'true' || 'false' }}
api-base: ${{ matrix.api_base }}
api-version: ${{ matrix.api_version }}
deployment-name: ${{ matrix.deployment_name }}
custom-config-json: ${{ matrix.custom_config_secret && secrets[matrix.custom_config_secret] || '' }}
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
run-id: ${{ github.run_id }}
docker-username: ${{ secrets.DOCKER_USERNAME }}
@@ -194,7 +270,7 @@ jobs:
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: docker-all-logs-nightly-${{ inputs.provider }}-llm-provider
name: docker-all-logs-nightly-${{ matrix.provider }}-llm-provider
path: |
${{ github.workspace }}/api_server.log
${{ github.workspace }}/docker-compose.log

View File

@@ -0,0 +1,69 @@
"""add python tool on default
Revision ID: 57122d037335
Revises: c0c937d5c9e5
Create Date: 2026-02-27 10:10:40.124925
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "57122d037335"
down_revision = "c0c937d5c9e5"
branch_labels = None
depends_on = None
PYTHON_TOOL_NAME = "python"
def upgrade() -> None:
conn = op.get_bind()
# Look up the PythonTool id
result = conn.execute(
sa.text("SELECT id FROM tool WHERE name = :name"),
{"name": PYTHON_TOOL_NAME},
).fetchone()
if not result:
return
tool_id = result[0]
# Attach to the default persona (id=0) if not already attached
conn.execute(
sa.text(
"""
INSERT INTO persona__tool (persona_id, tool_id)
VALUES (0, :tool_id)
ON CONFLICT DO NOTHING
"""
),
{"tool_id": tool_id},
)
def downgrade() -> None:
conn = op.get_bind()
result = conn.execute(
sa.text("SELECT id FROM tool WHERE name = :name"),
{"name": PYTHON_TOOL_NAME},
).fetchone()
if not result:
return
conn.execute(
sa.text(
"""
DELETE FROM persona__tool
WHERE persona_id = 0 AND tool_id = :tool_id
"""
),
{"tool_id": result[0]},
)

View File

@@ -0,0 +1,70 @@
"""llm provider deprecate fields
Revision ID: c0c937d5c9e5
Revises: 8ffcc2bcfc11
Create Date: 2026-02-25 17:35:46.125102
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c0c937d5c9e5"
down_revision = "8ffcc2bcfc11"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Make default_model_name nullable (was NOT NULL)
op.alter_column(
"llm_provider",
"default_model_name",
existing_type=sa.String(),
nullable=True,
)
# Drop unique constraint on is_default_provider (defaults now tracked via LLMModelFlow)
op.drop_constraint(
"llm_provider_is_default_provider_key",
"llm_provider",
type_="unique",
)
# Remove server_default from is_default_vision_provider (was server_default=false())
op.alter_column(
"llm_provider",
"is_default_vision_provider",
existing_type=sa.Boolean(),
server_default=None,
)
def downgrade() -> None:
# Restore default_model_name to NOT NULL (set empty string for any NULLs first)
op.execute(
"UPDATE llm_provider SET default_model_name = '' WHERE default_model_name IS NULL"
)
op.alter_column(
"llm_provider",
"default_model_name",
existing_type=sa.String(),
nullable=False,
)
# Restore unique constraint on is_default_provider
op.create_unique_constraint(
"llm_provider_is_default_provider_key",
"llm_provider",
["is_default_provider"],
)
# Restore server_default for is_default_vision_provider
op.alter_column(
"llm_provider",
"is_default_vision_provider",
existing_type=sa.Boolean(),
server_default=sa.false(),
)

View File

@@ -322,6 +322,7 @@ def list_users(
"""List users with optional SCIM filter and pagination."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
dal.commit()
try:
scim_filter = parse_scim_filter(filter)
@@ -365,6 +366,7 @@ def get_user(
"""Get a single user by ID."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
dal.commit()
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, ScimJSONResponse):
@@ -721,6 +723,7 @@ def list_groups(
"""List groups with optional SCIM filter and pagination."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
dal.commit()
try:
scim_filter = parse_scim_filter(filter)
@@ -757,6 +760,7 @@ def get_group(
"""Get a single group by ID."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
dal.commit()
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, ScimJSONResponse):

View File

@@ -20,6 +20,7 @@ from ee.onyx.server.enterprise_settings.store import (
from ee.onyx.server.enterprise_settings.store import upload_logo
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import Tool
@@ -117,15 +118,38 @@ def _seed_custom_tools(db_session: Session, tools: List[CustomToolSeed]) -> None
def _seed_llms(
db_session: Session, llm_upsert_requests: list[LLMProviderUpsertRequest]
) -> None:
if llm_upsert_requests:
logger.notice("Seeding LLMs")
seeded_providers = [
upsert_llm_provider(llm_upsert_request, db_session)
for llm_upsert_request in llm_upsert_requests
]
update_default_provider(
provider_id=seeded_providers[0].id, db_session=db_session
)
if not llm_upsert_requests:
return
logger.notice("Seeding LLMs")
for request in llm_upsert_requests:
existing = fetch_existing_llm_provider(name=request.name, db_session=db_session)
if existing:
request.id = existing.id
seeded_providers = [
upsert_llm_provider(llm_upsert_request, db_session)
for llm_upsert_request in llm_upsert_requests
]
default_provider = next(
(p for p in seeded_providers if p.model_configurations), None
)
if not default_provider:
return
visible_configs = [
mc for mc in default_provider.model_configurations if mc.is_visible
]
default_config = (
visible_configs[0]
if visible_configs
else default_provider.model_configurations[0]
)
update_default_provider(
provider_id=default_provider.id,
model_name=default_config.name,
db_session=db_session,
)
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:

View File

@@ -109,6 +109,12 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
if metadata.status == _BLOCKING_STATUS:
settings.application_status = metadata.status
settings.ee_features_enabled = False
elif metadata.used_seats > metadata.seats:
# License is valid but seat limit exceeded
settings.application_status = ApplicationStatus.SEAT_LIMIT_EXCEEDED
settings.seat_count = metadata.seats
settings.used_seats = metadata.used_seats
settings.ee_features_enabled = True
else:
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
settings.ee_features_enabled = True

View File

@@ -33,6 +33,7 @@ from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine.sql_engine import get_session_with_shared_schema
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.image_generation import create_default_image_gen_config_from_api_key
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_cloud_embedding_provider
from onyx.db.llm import upsert_llm_provider
@@ -302,12 +303,17 @@ def configure_default_api_keys(db_session: Session) -> None:
has_set_default_provider = False
def _upsert(request: LLMProviderUpsertRequest) -> None:
def _upsert(request: LLMProviderUpsertRequest, default_model: str) -> None:
nonlocal has_set_default_provider
try:
existing = fetch_existing_llm_provider(
name=request.name, db_session=db_session
)
if existing:
request.id = existing.id
provider = upsert_llm_provider(request, db_session)
if not has_set_default_provider:
update_default_provider(provider.id, db_session)
update_default_provider(provider.id, default_model, db_session)
has_set_default_provider = True
except Exception as e:
logger.error(f"Failed to configure {request.provider} provider: {e}")
@@ -325,14 +331,13 @@ def configure_default_api_keys(db_session: Session) -> None:
name="OpenAI",
provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY,
default_model_name=default_model_name,
model_configurations=_build_model_configuration_upsert_requests(
OPENAI_PROVIDER_NAME, recommendations
),
api_key_changed=True,
is_auto_mode=True,
)
_upsert(openai_provider)
_upsert(openai_provider, default_model_name)
# Create default image generation config using the OpenAI API key
try:
@@ -361,14 +366,13 @@ def configure_default_api_keys(db_session: Session) -> None:
name="Anthropic",
provider=ANTHROPIC_PROVIDER_NAME,
api_key=ANTHROPIC_DEFAULT_API_KEY,
default_model_name=default_model_name,
model_configurations=_build_model_configuration_upsert_requests(
ANTHROPIC_PROVIDER_NAME, recommendations
),
api_key_changed=True,
is_auto_mode=True,
)
_upsert(anthropic_provider)
_upsert(anthropic_provider, default_model_name)
else:
logger.info(
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
@@ -393,14 +397,13 @@ def configure_default_api_keys(db_session: Session) -> None:
name="Google Vertex AI",
provider=VERTEXAI_PROVIDER_NAME,
custom_config=custom_config,
default_model_name=default_model_name,
model_configurations=_build_model_configuration_upsert_requests(
VERTEXAI_PROVIDER_NAME, recommendations
),
api_key_changed=True,
is_auto_mode=True,
)
_upsert(vertexai_provider)
_upsert(vertexai_provider, default_model_name)
else:
logger.info(
"VERTEXAI_DEFAULT_CREDENTIALS not set, skipping Vertex AI provider configuration"
@@ -432,12 +435,11 @@ def configure_default_api_keys(db_session: Session) -> None:
name="OpenRouter",
provider=OPENROUTER_PROVIDER_NAME,
api_key=OPENROUTER_DEFAULT_API_KEY,
default_model_name=default_model_name,
model_configurations=model_configurations,
api_key_changed=True,
is_auto_mode=True,
)
_upsert(openrouter_provider)
_upsert(openrouter_provider, default_model_name)
else:
logger.info(
"OPENROUTER_DEFAULT_API_KEY not set, skipping OpenRouter provider configuration"

View File

@@ -58,16 +58,27 @@ class OAuthTokenManager:
if not user_token.token_data:
raise ValueError("No token data available for refresh")
if (
self.oauth_config.client_id is None
or self.oauth_config.client_secret is None
):
raise ValueError(
"OAuth client_id and client_secret are required for token refresh"
)
token_data = self._unwrap_token_data(user_token.token_data)
data: dict[str, str] = {
"grant_type": "refresh_token",
"refresh_token": token_data["refresh_token"],
"client_id": self._unwrap_sensitive_str(self.oauth_config.client_id),
"client_secret": self._unwrap_sensitive_str(
self.oauth_config.client_secret
),
}
response = requests.post(
self.oauth_config.token_url,
data={
"grant_type": "refresh_token",
"refresh_token": token_data["refresh_token"],
"client_id": self.oauth_config.client_id,
"client_secret": self.oauth_config.client_secret,
},
data=data,
headers={"Accept": "application/json"},
)
response.raise_for_status()
@@ -115,15 +126,26 @@ class OAuthTokenManager:
def exchange_code_for_token(self, code: str, redirect_uri: str) -> dict[str, Any]:
"""Exchange authorization code for access token"""
if (
self.oauth_config.client_id is None
or self.oauth_config.client_secret is None
):
raise ValueError(
"OAuth client_id and client_secret are required for code exchange"
)
data: dict[str, str] = {
"grant_type": "authorization_code",
"code": code,
"client_id": self._unwrap_sensitive_str(self.oauth_config.client_id),
"client_secret": self._unwrap_sensitive_str(
self.oauth_config.client_secret
),
"redirect_uri": redirect_uri,
}
response = requests.post(
self.oauth_config.token_url,
data={
"grant_type": "authorization_code",
"code": code,
"client_id": self.oauth_config.client_id,
"client_secret": self.oauth_config.client_secret,
"redirect_uri": redirect_uri,
},
data=data,
headers={"Accept": "application/json"},
)
response.raise_for_status()
@@ -141,8 +163,13 @@ class OAuthTokenManager:
oauth_config: OAuthConfig, redirect_uri: str, state: str
) -> str:
"""Build OAuth authorization URL"""
if oauth_config.client_id is None:
raise ValueError("OAuth client_id is required to build authorization URL")
params: dict[str, Any] = {
"client_id": oauth_config.client_id,
"client_id": OAuthTokenManager._unwrap_sensitive_str(
oauth_config.client_id
),
"redirect_uri": redirect_uri,
"response_type": "code",
"state": state,
@@ -161,6 +188,12 @@ class OAuthTokenManager:
return f"{oauth_config.authorization_url}{separator}{urlencode(params)}"
@staticmethod
def _unwrap_sensitive_str(value: SensitiveValue[str] | str) -> str:
if isinstance(value, SensitiveValue):
return value.get_value(apply_mask=False)
return value
@staticmethod
def _unwrap_token_data(
token_data: SensitiveValue[dict[str, Any]] | dict[str, Any],

View File

@@ -48,6 +48,7 @@ from onyx.document_index.opensearch.opensearch_document_index import (
OpenSearchDocumentIndex,
)
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
from onyx.indexing.models import IndexingSetting
from onyx.redis.redis_pool import get_redis_client
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
@@ -149,8 +150,12 @@ def migrate_chunks_from_vespa_to_opensearch_task(
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
search_settings = get_current_search_settings(db_session)
tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT)
indexing_setting = IndexingSetting.from_db_model(search_settings)
opensearch_document_index = OpenSearchDocumentIndex(
index_name=search_settings.index_name, tenant_state=tenant_state
tenant_state=tenant_state,
index_name=search_settings.index_name,
embedding_dim=indexing_setting.final_embedding_dim,
embedding_precision=indexing_setting.embedding_precision,
)
vespa_document_index = VespaDocumentIndex(
index_name=search_settings.index_name,

View File

@@ -76,7 +76,7 @@ def _user_file_queued_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
def user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
@@ -764,7 +764,7 @@ def process_single_user_file_project_sync(
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
file_lock: RedisLock = redis_client.lock(
_user_file_project_sync_lock_key(user_file_id),
user_file_project_sync_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
)

View File

@@ -3,7 +3,6 @@ import time
from collections.abc import Callable
from collections.abc import Generator
from queue import Empty
from typing import Any
from onyx.chat.citation_processor import CitationMapping
from onyx.chat.emitter import Emitter
@@ -163,13 +162,11 @@ class ChatStateContainer:
def run_chat_loop_with_state_containers(
func: Callable[..., None],
chat_loop_func: Callable[[Emitter, ChatStateContainer], None],
completion_callback: Callable[[ChatStateContainer], None],
is_connected: Callable[[], bool],
emitter: Emitter,
state_container: ChatStateContainer,
*args: Any,
**kwargs: Any,
) -> Generator[Packet, None]:
"""
Explicit wrapper function that runs a function in a background thread
@@ -180,19 +177,18 @@ def run_chat_loop_with_state_containers(
Args:
func: The function to wrap (should accept emitter and state_container as first and second args)
completion_callback: Callback function to call when the function completes
emitter: Emitter instance for sending packets
state_container: ChatStateContainer instance for accumulating state
is_connected: Callable that returns False when stop signal is set
*args: Additional positional arguments for func
**kwargs: Additional keyword arguments for func
Usage:
packets = run_chat_loop_with_state_containers(
my_func,
completion_callback=completion_callback,
emitter=emitter,
state_container=state_container,
is_connected=check_func,
arg1, arg2, kwarg1=value1
)
for packet in packets:
# Process packets
@@ -201,9 +197,7 @@ def run_chat_loop_with_state_containers(
def run_with_exception_capture() -> None:
try:
# Ensure state_container is passed explicitly, removing it from kwargs if present
kwargs_with_state = {**kwargs, "state_container": state_container}
func(emitter, *args, **kwargs_with_state)
chat_loop_func(emitter, state_container)
except Exception as e:
# If execution fails, emit an exception packet
emitter.emit(

View File

@@ -461,7 +461,7 @@ def _build_tool_call_response_history_message(
def convert_chat_history(
chat_history: list[ChatMessage],
files: list[ChatLoadedFile],
project_image_files: list[ChatLoadedFile],
context_image_files: list[ChatLoadedFile],
additional_context: str | None,
token_counter: Callable[[str], int],
tool_id_to_name_map: dict[int, str],
@@ -541,11 +541,11 @@ def convert_chat_history(
)
# Add the user message with image files attached
# If this is the last USER message, also include project_image_files
# Note: project image file tokens are NOT counted in the token count
# If this is the last USER message, also include context_image_files
# Note: context image file tokens are NOT counted in the token count
if idx == last_user_message_idx:
if project_image_files:
image_files.extend(project_image_files)
if context_image_files:
image_files.extend(context_image_files)
if additional_context:
simple_messages.append(

View File

@@ -15,10 +15,10 @@ from onyx.chat.emitter import Emitter
from onyx.chat.llm_step import extract_tool_calls_from_response_text
from onyx.chat.llm_step import run_llm_step
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import ExtractedProjectFiles
from onyx.chat.models import ContextFileMetadata
from onyx.chat.models import ExtractedContextFiles
from onyx.chat.models import FileToolMetadata
from onyx.chat.models import LlmStepResult
from onyx.chat.models import ProjectFileMetadata
from onyx.chat.models import ToolCallSimple
from onyx.chat.prompt_utils import build_reminder_message
from onyx.chat.prompt_utils import build_system_prompt
@@ -203,17 +203,17 @@ def _try_fallback_tool_extraction(
MAX_LLM_CYCLES = 6
def _build_project_file_citation_mapping(
project_file_metadata: list[ProjectFileMetadata],
def _build_context_file_citation_mapping(
file_metadata: list[ContextFileMetadata],
starting_citation_num: int = 1,
) -> CitationMapping:
"""Build citation mapping for project files.
"""Build citation mapping for context files.
Converts project file metadata into SearchDoc objects that can be cited.
Converts context file metadata into SearchDoc objects that can be cited.
Citation numbers start from the provided starting number.
Args:
project_file_metadata: List of project file metadata
file_metadata: List of context file metadata
starting_citation_num: Starting citation number (default: 1)
Returns:
@@ -221,8 +221,7 @@ def _build_project_file_citation_mapping(
"""
citation_mapping: CitationMapping = {}
for idx, file_meta in enumerate(project_file_metadata, start=starting_citation_num):
# Create a SearchDoc for each project file
for idx, file_meta in enumerate(file_metadata, start=starting_citation_num):
search_doc = SearchDoc(
document_id=file_meta.file_id,
chunk_ind=0,
@@ -242,29 +241,28 @@ def _build_project_file_citation_mapping(
def _build_project_message(
project_files: ExtractedProjectFiles | None,
context_files: ExtractedContextFiles | None,
token_counter: Callable[[str], int] | None,
) -> list[ChatMessageSimple]:
"""Build messages for project / tool-backed files.
"""Build messages for context-injected / tool-backed files.
Returns up to two messages:
1. The full-text project files message (if project_file_texts is populated).
1. The full-text files message (if file_texts is populated).
2. A lightweight metadata message for files the LLM should access via the
FileReaderTool (e.g. oversized chat-attached files or project files that
don't fit in context).
FileReaderTool (e.g. oversized files that don't fit in context).
"""
if not project_files:
if not context_files:
return []
messages: list[ChatMessageSimple] = []
if project_files.project_file_texts:
if context_files.file_texts:
messages.append(
_create_project_files_message(project_files, token_counter=None)
_create_context_files_message(context_files, token_counter=None)
)
if project_files.file_metadata_for_tool and token_counter:
if context_files.file_metadata_for_tool and token_counter:
messages.append(
_create_file_tool_metadata_message(
project_files.file_metadata_for_tool, token_counter
context_files.file_metadata_for_tool, token_counter
)
)
return messages
@@ -275,7 +273,7 @@ def construct_message_history(
custom_agent_prompt: ChatMessageSimple | None,
simple_chat_history: list[ChatMessageSimple],
reminder_message: ChatMessageSimple | None,
project_files: ExtractedProjectFiles | None,
context_files: ExtractedContextFiles | None,
available_tokens: int,
last_n_user_messages: int | None = None,
token_counter: Callable[[str], int] | None = None,
@@ -289,7 +287,7 @@ def construct_message_history(
# Build the project / file-metadata messages up front so we can use their
# actual token counts for the budget.
project_messages = _build_project_message(project_files, token_counter)
project_messages = _build_project_message(context_files, token_counter)
project_messages_tokens = sum(m.token_count for m in project_messages)
history_token_budget = available_tokens
@@ -445,17 +443,17 @@ def construct_message_history(
)
# Attach project images to the last user message
if project_files and project_files.project_image_files:
if context_files and context_files.image_files:
existing_images = last_user_message.image_files or []
last_user_message = ChatMessageSimple(
message=last_user_message.message,
token_count=last_user_message.token_count,
message_type=last_user_message.message_type,
image_files=existing_images + project_files.project_image_files,
image_files=existing_images + context_files.image_files,
)
# Build the final message list according to README ordering:
# [system], [history_before_last_user], [custom_agent], [project_files],
# [system], [history_before_last_user], [custom_agent], [context_files],
# [forgotten_files], [last_user_message], [messages_after_last_user], [reminder]
result = [system_prompt] if system_prompt else []
@@ -466,14 +464,14 @@ def construct_message_history(
if custom_agent_prompt:
result.append(custom_agent_prompt)
# 3. Add project files / file-metadata messages (inserted before last user message)
# 3. Add context files / file-metadata messages (inserted before last user message)
result.extend(project_messages)
# 4. Add forgotten-files metadata (right before the user's question)
if forgotten_files_message:
result.append(forgotten_files_message)
# 5. Add last user message (with project images attached)
# 5. Add last user message (with context images attached)
result.append(last_user_message)
# 6. Add messages after last user message (tool calls, responses, etc.)
@@ -547,11 +545,11 @@ def _create_file_tool_metadata_message(
)
def _create_project_files_message(
project_files: ExtractedProjectFiles,
def _create_context_files_message(
context_files: ExtractedContextFiles,
token_counter: Callable[[str], int] | None, # noqa: ARG001
) -> ChatMessageSimple:
"""Convert project files to a ChatMessageSimple message.
"""Convert context files to a ChatMessageSimple message.
Format follows the README specification for document representation.
"""
@@ -559,7 +557,7 @@ def _create_project_files_message(
# Format as documents JSON as described in README
documents_list = []
for idx, file_text in enumerate(project_files.project_file_texts, start=1):
for idx, file_text in enumerate(context_files.file_texts, start=1):
documents_list.append(
{
"document": idx,
@@ -570,10 +568,10 @@ def _create_project_files_message(
documents_json = json.dumps({"documents": documents_list}, indent=2)
message_content = f"Here are some documents provided for context, they may not all be relevant:\n{documents_json}"
# Use pre-calculated token count from project_files
# Use pre-calculated token count from context_files
return ChatMessageSimple(
message=message_content,
token_count=project_files.total_token_count,
token_count=context_files.total_token_count,
message_type=MessageType.USER,
)
@@ -584,7 +582,7 @@ def run_llm_loop(
simple_chat_history: list[ChatMessageSimple],
tools: list[Tool],
custom_agent_prompt: str | None,
project_files: ExtractedProjectFiles,
context_files: ExtractedContextFiles,
persona: Persona | None,
user_memory_context: UserMemoryContext | None,
llm: LLM,
@@ -627,9 +625,9 @@ def run_llm_loop(
# Add project file citation mappings if project files are present
project_citation_mapping: CitationMapping = {}
if project_files.project_file_metadata:
project_citation_mapping = _build_project_file_citation_mapping(
project_files.project_file_metadata
if context_files.file_metadata:
project_citation_mapping = _build_context_file_citation_mapping(
context_files.file_metadata
)
citation_processor.update_citation_mapping(project_citation_mapping)
@@ -647,7 +645,7 @@ def run_llm_loop(
# TODO allow citing of images in Projects. Since attached to the last user message, it has no text associated with it.
# One future workaround is to include the images as separate user messages with citation information and process those.
always_cite_documents: bool = bool(
project_files.project_as_filter or project_files.project_file_texts
context_files.use_as_search_filter or context_files.file_texts
)
should_cite_documents: bool = False
ran_image_gen: bool = False
@@ -788,7 +786,7 @@ def run_llm_loop(
custom_agent_prompt=custom_agent_prompt_msg,
simple_chat_history=simple_chat_history,
reminder_message=reminder_msg,
project_files=project_files,
context_files=context_files,
available_tokens=available_tokens,
token_counter=token_counter,
all_injected_file_metadata=all_injected_file_metadata,

View File

@@ -31,13 +31,6 @@ class CustomToolResponse(BaseModel):
tool_name: str
class ProjectSearchConfig(BaseModel):
"""Configuration for search tool availability in project context."""
search_usage: SearchToolUsage
disable_forced_tool: bool
class CreateChatSessionID(BaseModel):
chat_session_id: UUID
@@ -132,8 +125,8 @@ class ChatMessageSimple(BaseModel):
file_id: str | None = None
class ProjectFileMetadata(BaseModel):
"""Metadata for a project file to enable citation support."""
class ContextFileMetadata(BaseModel):
"""Metadata for a context-injected file to enable citation support."""
file_id: str
filename: str
@@ -167,20 +160,28 @@ class ChatHistoryResult(BaseModel):
all_injected_file_metadata: dict[str, FileToolMetadata]
class ExtractedProjectFiles(BaseModel):
project_file_texts: list[str]
project_image_files: list[ChatLoadedFile]
project_as_filter: bool
class ExtractedContextFiles(BaseModel):
"""Result of attempting to load user files (from a project or persona) into context."""
file_texts: list[str]
image_files: list[ChatLoadedFile]
use_as_search_filter: bool
total_token_count: int
# Metadata for project files to enable citations
project_file_metadata: list[ProjectFileMetadata]
# None if not a project
project_uncapped_token_count: int | None
# Lightweight metadata for files exposed via FileReaderTool
# (populated when files don't fit in context and vector DB is disabled)
# (populated when files don't fit in context and vector DB is disabled).
file_metadata: list[ContextFileMetadata]
uncapped_token_count: int | None
file_metadata_for_tool: list[FileToolMetadata] = []
class SearchParams(BaseModel):
"""Resolved search filter IDs and search-tool usage for a chat turn."""
search_project_id: int | None
search_persona_id: int | None
search_usage: SearchToolUsage
class LlmStepResult(BaseModel):
reasoning: str | None
answer: str | None

View File

@@ -3,6 +3,7 @@ IMPORTANT: familiarize yourself with the design concepts prior to contributing t
An overview can be found in the README.md file in this directory.
"""
import io
import re
import traceback
from collections.abc import Callable
@@ -33,11 +34,11 @@ from onyx.chat.models import ChatBasicResponse
from onyx.chat.models import ChatFullResponse
from onyx.chat.models import ChatLoadedFile
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import ContextFileMetadata
from onyx.chat.models import CreateChatSessionID
from onyx.chat.models import ExtractedProjectFiles
from onyx.chat.models import ExtractedContextFiles
from onyx.chat.models import FileToolMetadata
from onyx.chat.models import ProjectFileMetadata
from onyx.chat.models import ProjectSearchConfig
from onyx.chat.models import SearchParams
from onyx.chat.models import StreamingError
from onyx.chat.models import ToolCallResponse
from onyx.chat.prompt_utils import calculate_reserved_tokens
@@ -62,11 +63,12 @@ from onyx.db.models import ChatSession
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.models import UserFile
from onyx.db.projects import get_project_token_count
from onyx.db.projects import get_user_files_from_project
from onyx.db.tools import get_tools
from onyx.deep_research.dr_loop import run_deep_research_llm_loop
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
from onyx.file_store.utils import load_in_memory_chat_files
from onyx.file_store.utils import verify_user_files
from onyx.llm.factory import get_llm_for_persona
@@ -139,12 +141,12 @@ def _collect_available_file_ids(
pass
if project_id:
project_files = get_user_files_from_project(
user_files = get_user_files_from_project(
project_id=project_id,
user_id=user_id,
db_session=db_session,
)
for uf in project_files:
for uf in user_files:
user_file_ids.add(uf.id)
return _AvailableFiles(
@@ -192,9 +194,67 @@ def _convert_loaded_files_to_chat_files(
return chat_files
def _extract_project_file_texts_and_images(
def resolve_context_user_files(
persona: Persona,
project_id: int | None,
user_id: UUID | None,
db_session: Session,
) -> list[UserFile]:
"""Apply the precedence rule to decide which user files to load.
A custom persona fully supersedes the project. When a chat uses a
custom persona, the project is purely organisational — its files are
never loaded and never made searchable.
Custom persona → persona's own user_files (may be empty).
Default persona inside a project → project files.
Otherwise → empty list.
"""
if persona.id != DEFAULT_PERSONA_ID:
return list(persona.user_files) if persona.user_files else []
if project_id:
return get_user_files_from_project(
project_id=project_id,
user_id=user_id,
db_session=db_session,
)
return []
def _empty_extracted_context_files() -> ExtractedContextFiles:
return ExtractedContextFiles(
file_texts=[],
image_files=[],
use_as_search_filter=False,
total_token_count=0,
file_metadata=[],
uncapped_token_count=None,
)
def _extract_text_from_in_memory_file(f: InMemoryChatFile) -> str | None:
"""Extract text content from an InMemoryChatFile.
PLAIN_TEXT: the content is pre-extracted UTF-8 plaintext stored during
ingestion — decode directly.
DOC / CSV / other text types: the content is the original file bytes —
use extract_file_text which handles encoding detection and format parsing.
"""
try:
if f.file_type == ChatFileType.PLAIN_TEXT:
return f.content.decode("utf-8", errors="ignore").replace("\x00", "")
return extract_file_text(
file=io.BytesIO(f.content),
file_name=f.filename or "",
break_on_unprocessable=False,
)
except Exception:
logger.warning(f"Failed to extract text from file {f.file_id}", exc_info=True)
return None
def extract_context_files(
user_files: list[UserFile],
llm_max_context_window: int,
reserved_token_count: int,
db_session: Session,
@@ -203,8 +263,12 @@ def _extract_project_file_texts_and_images(
# 60% of the LLM's max context window. The other benefit is that for projects with
# more files, this makes it so that we don't throw away the history too quickly every time.
max_llm_context_percentage: float = 0.6,
) -> ExtractedProjectFiles:
"""Extract text content from project files if they fit within the context window.
) -> ExtractedContextFiles:
"""Load user files into context if they fit; otherwise flag for search.
The caller is responsible for deciding *which* user files to pass in
(project files, persona files, etc.). This function only cares about
the all-or-nothing fit check and the actual content loading.
Args:
project_id: The project ID to load files from
@@ -213,160 +277,95 @@ def _extract_project_file_texts_and_images(
reserved_token_count: Number of tokens to reserve for other content
db_session: Database session
max_llm_context_percentage: Maximum percentage of the LLM context window to use.
Returns:
ExtractedProjectFiles containing:
- List of text content strings from project files (text files only)
- List of image files from project (ChatLoadedFile objects)
- Project id if the the project should be provided as a filter in search or None if not.
ExtractedContextFiles containing:
- List of text content strings from context files (text files only)
- List of image files from context (ChatLoadedFile objects)
- Total token count of all extracted files
- File metadata for context files
- Uncapped token count of all extracted files
- File metadata for files that don't fit in context and vector DB is disabled
"""
# TODO I believe this is not handling all file types correctly.
project_as_filter = False
if not project_id:
return ExtractedProjectFiles(
project_file_texts=[],
project_image_files=[],
project_as_filter=False,
total_token_count=0,
project_file_metadata=[],
project_uncapped_token_count=None,
)
# TODO(yuhong): I believe this is not handling all file types correctly.
if not user_files:
return _empty_extracted_context_files()
aggregate_tokens = sum(uf.token_count or 0 for uf in user_files)
max_actual_tokens = (
llm_max_context_window - reserved_token_count
) * max_llm_context_percentage
# Calculate total token count for all user files in the project
project_tokens = get_project_token_count(
project_id=project_id,
user_id=user_id,
if aggregate_tokens >= max_actual_tokens:
tool_metadata = []
use_as_search_filter = not DISABLE_VECTOR_DB
if DISABLE_VECTOR_DB:
tool_metadata = _build_file_tool_metadata_for_user_files(user_files)
return ExtractedContextFiles(
file_texts=[],
image_files=[],
use_as_search_filter=use_as_search_filter,
total_token_count=0,
file_metadata=[],
uncapped_token_count=aggregate_tokens,
file_metadata_for_tool=tool_metadata,
)
# Files fit — load them into context
user_file_map = {str(uf.id): uf for uf in user_files}
in_memory_files = load_in_memory_chat_files(
user_file_ids=[uf.id for uf in user_files],
db_session=db_session,
)
project_file_texts: list[str] = []
project_image_files: list[ChatLoadedFile] = []
project_file_metadata: list[ProjectFileMetadata] = []
file_texts: list[str] = []
image_files: list[ChatLoadedFile] = []
file_metadata: list[ContextFileMetadata] = []
total_token_count = 0
if project_tokens < max_actual_tokens:
# Load project files into memory using cached plaintext when available
project_user_files = get_user_files_from_project(
project_id=project_id,
user_id=user_id,
db_session=db_session,
)
if project_user_files:
# Create a mapping from file_id to UserFile for token count lookup
user_file_map = {str(file.id): file for file in project_user_files}
project_file_ids = [file.id for file in project_user_files]
in_memory_project_files = load_in_memory_chat_files(
user_file_ids=project_file_ids,
db_session=db_session,
for f in in_memory_files:
uf = user_file_map.get(str(f.file_id))
if f.file_type.is_text_file():
text_content = _extract_text_from_in_memory_file(f)
if not text_content:
continue
file_texts.append(text_content)
file_metadata.append(
ContextFileMetadata(
file_id=str(f.file_id),
filename=f.filename or f"file_{f.file_id}",
file_content=text_content,
)
)
if uf and uf.token_count:
total_token_count += uf.token_count
elif f.file_type == ChatFileType.IMAGE:
token_count = uf.token_count if uf and uf.token_count else 0
total_token_count += token_count
image_files.append(
ChatLoadedFile(
file_id=f.file_id,
content=f.content,
file_type=f.file_type,
filename=f.filename,
content_text=None,
token_count=token_count,
)
)
# Extract text content from loaded files
for file in in_memory_project_files:
if file.file_type.is_text_file():
try:
text_content = file.content.decode("utf-8", errors="ignore")
# Strip null bytes
text_content = text_content.replace("\x00", "")
if text_content:
project_file_texts.append(text_content)
# Add metadata for citation support
project_file_metadata.append(
ProjectFileMetadata(
file_id=str(file.file_id),
filename=file.filename or f"file_{file.file_id}",
file_content=text_content,
)
)
# Add token count for text file
user_file = user_file_map.get(str(file.file_id))
if user_file and user_file.token_count:
total_token_count += user_file.token_count
except Exception:
# Skip files that can't be decoded
pass
elif file.file_type == ChatFileType.IMAGE:
# Convert InMemoryChatFile to ChatLoadedFile
user_file = user_file_map.get(str(file.file_id))
token_count = (
user_file.token_count
if user_file and user_file.token_count
else 0
)
total_token_count += token_count
chat_loaded_file = ChatLoadedFile(
file_id=file.file_id,
content=file.content,
file_type=file.file_type,
filename=file.filename,
content_text=None, # Images don't have text content
token_count=token_count,
)
project_image_files.append(chat_loaded_file)
else:
if DISABLE_VECTOR_DB:
# Without a vector DB we can't use project-as-filter search.
# Instead, build lightweight metadata so the LLM can call the
# FileReaderTool to inspect individual files on demand.
file_metadata_for_tool = _build_file_tool_metadata_for_project(
project_id=project_id,
user_id=user_id,
db_session=db_session,
)
return ExtractedProjectFiles(
project_file_texts=[],
project_image_files=[],
project_as_filter=False,
total_token_count=0,
project_file_metadata=[],
project_uncapped_token_count=project_tokens,
file_metadata_for_tool=file_metadata_for_tool,
)
project_as_filter = True
return ExtractedProjectFiles(
project_file_texts=project_file_texts,
project_image_files=project_image_files,
project_as_filter=project_as_filter,
return ExtractedContextFiles(
file_texts=file_texts,
image_files=image_files,
use_as_search_filter=False,
total_token_count=total_token_count,
project_file_metadata=project_file_metadata,
project_uncapped_token_count=project_tokens,
file_metadata=file_metadata,
uncapped_token_count=aggregate_tokens,
)
APPROX_CHARS_PER_TOKEN = 4
def _build_file_tool_metadata_for_project(
project_id: int,
user_id: UUID | None,
db_session: Session,
) -> list[FileToolMetadata]:
"""Build lightweight FileToolMetadata for every file in a project.
Used when files are too large to fit in context and the vector DB is
disabled, so the LLM needs to know which files it can read via the
FileReaderTool.
"""
project_user_files = get_user_files_from_project(
project_id=project_id,
user_id=user_id,
db_session=db_session,
)
return [
FileToolMetadata(
file_id=str(uf.id),
filename=uf.name,
approx_char_count=(uf.token_count or 0) * APPROX_CHARS_PER_TOKEN,
)
for uf in project_user_files
]
def _build_file_tool_metadata_for_user_files(
user_files: list[UserFile],
) -> list[FileToolMetadata]:
@@ -381,55 +380,46 @@ def _build_file_tool_metadata_for_user_files(
]
def _get_project_search_availability(
def determine_search_params(
persona_id: int,
project_id: int | None,
persona_id: int | None,
loaded_project_files: bool,
project_has_files: bool,
forced_tool_id: int | None,
search_tool_id: int | None,
) -> ProjectSearchConfig:
"""Determine search tool availability based on project context.
extracted_context_files: ExtractedContextFiles,
) -> SearchParams:
"""Decide which search filter IDs and search-tool usage apply for a chat turn.
Search is disabled when ALL of the following are true:
- User is in a project
- Using the default persona (not a custom agent)
- Project files are already loaded in context
A custom persona fully supersedes the project — project files are never
searchable and the search tool config is entirely controlled by the
persona. The project_id filter is only set for the default persona.
When search is disabled and the user tried to force the search tool,
that forcing is also disabled.
Returns AUTO (follow persona config) in all other cases.
For the default persona inside a project:
- Files overflow → ENABLED (vector DB scopes to these files)
- Files fit → DISABLED (content already in prompt)
- No files at all → DISABLED (nothing to search)
"""
# Not in a project, this should have no impact on search tool availability
if not project_id:
return ProjectSearchConfig(
search_usage=SearchToolUsage.AUTO, disable_forced_tool=False
)
is_custom_persona = persona_id != DEFAULT_PERSONA_ID
# Custom persona in project - let persona config decide
# Even if there are no files in the project, it's still guided by the persona config.
if persona_id != DEFAULT_PERSONA_ID:
return ProjectSearchConfig(
search_usage=SearchToolUsage.AUTO, disable_forced_tool=False
)
search_project_id: int | None = None
search_persona_id: int | None = None
if extracted_context_files.use_as_search_filter:
if is_custom_persona:
search_persona_id = persona_id
else:
search_project_id = project_id
# If in a project with the default persona and the files have been already loaded into the context or
# there are no files in the project, disable search as there is nothing to search for.
if loaded_project_files or not project_has_files:
user_forced_search = (
forced_tool_id is not None
and search_tool_id is not None
and forced_tool_id == search_tool_id
)
return ProjectSearchConfig(
search_usage=SearchToolUsage.DISABLED,
disable_forced_tool=user_forced_search,
)
search_usage = SearchToolUsage.AUTO
if not is_custom_persona and project_id:
has_context_files = bool(extracted_context_files.uncapped_token_count)
files_loaded_in_context = bool(extracted_context_files.file_texts)
# Default persona in a project with files, but also the files have not been loaded into the context already.
return ProjectSearchConfig(
search_usage=SearchToolUsage.ENABLED, disable_forced_tool=False
if extracted_context_files.use_as_search_filter:
search_usage = SearchToolUsage.ENABLED
elif files_loaded_in_context or not has_context_files:
search_usage = SearchToolUsage.DISABLED
return SearchParams(
search_project_id=search_project_id,
search_persona_id=search_persona_id,
search_usage=search_usage,
)
@@ -661,26 +651,37 @@ def handle_stream_message_objects(
user_memory_context=prompt_memory_context,
)
# Process projects, if all of the files fit in the context, it doesn't need to use RAG
extracted_project_files = _extract_project_file_texts_and_images(
# Determine which user files to use. A custom persona fully
# supersedes the project — project files are never loaded or
# searchable when a custom persona is in play. Only the default
# persona inside a project uses the project's files.
context_user_files = resolve_context_user_files(
persona=persona,
project_id=chat_session.project_id,
user_id=user_id,
db_session=db_session,
)
extracted_context_files = extract_context_files(
user_files=context_user_files,
llm_max_context_window=llm.config.max_input_tokens,
reserved_token_count=reserved_token_count,
db_session=db_session,
)
# When the vector DB is disabled, persona-attached user_files have no
# search pipeline path. Inject them as file_metadata_for_tool so the
# LLM can read them via the FileReaderTool.
if DISABLE_VECTOR_DB and persona.user_files:
persona_file_metadata = _build_file_tool_metadata_for_user_files(
persona.user_files
)
# Merge persona file metadata into the extracted project files
extracted_project_files.file_metadata_for_tool.extend(persona_file_metadata)
search_params = determine_search_params(
persona_id=persona.id,
project_id=chat_session.project_id,
extracted_context_files=extracted_context_files,
)
# Also grant access to persona-attached user files for FileReaderTool
if persona.user_files:
existing = set(available_files.user_file_ids)
for uf in persona.user_files:
if uf.id not in existing:
available_files.user_file_ids.append(uf.id)
# Build a mapping of tool_id to tool_name for history reconstruction
all_tools = get_tools(db_session)
tool_id_to_name_map = {tool.id: tool.name for tool in all_tools}
@@ -689,30 +690,17 @@ def handle_stream_message_objects(
None,
)
# Determine if search should be disabled for this project context
forced_tool_id = new_msg_req.forced_tool_id
project_search_config = _get_project_search_availability(
project_id=chat_session.project_id,
persona_id=persona.id,
loaded_project_files=bool(extracted_project_files.project_file_texts),
project_has_files=bool(
extracted_project_files.project_uncapped_token_count
),
forced_tool_id=new_msg_req.forced_tool_id,
search_tool_id=search_tool_id,
)
if project_search_config.disable_forced_tool:
if (
search_params.search_usage == SearchToolUsage.DISABLED
and forced_tool_id is not None
and search_tool_id is not None
and forced_tool_id == search_tool_id
):
forced_tool_id = None
emitter = get_default_emitter()
# Also grant access to persona-attached user files
if persona.user_files:
existing = set(available_files.user_file_ids)
for uf in persona.user_files:
if uf.id not in existing:
available_files.user_file_ids.append(uf.id)
# Construct tools based on the persona configurations
tool_dict = construct_tools(
persona=persona,
@@ -722,11 +710,8 @@ def handle_stream_message_objects(
llm=llm,
search_tool_config=SearchToolConfig(
user_selected_filters=new_msg_req.internal_search_filters,
project_id=(
chat_session.project_id
if extracted_project_files.project_as_filter
else None
),
project_id=search_params.search_project_id,
persona_id=search_params.search_persona_id,
bypass_acl=bypass_acl,
slack_context=slack_context,
enable_slack_search=_should_enable_slack_search(
@@ -744,7 +729,7 @@ def handle_stream_message_objects(
chat_file_ids=available_files.chat_file_ids,
),
allowed_tool_ids=new_msg_req.allowed_tool_ids,
search_usage_forcing_setting=project_search_config.search_usage,
search_usage_forcing_setting=search_params.search_usage,
)
tools: list[Tool] = []
for tool_list in tool_dict.values():
@@ -783,7 +768,7 @@ def handle_stream_message_objects(
chat_history_result = convert_chat_history(
chat_history=chat_history,
files=files,
project_image_files=extracted_project_files.project_image_files,
context_image_files=extracted_context_files.image_files,
additional_context=additional_context,
token_counter=token_counter,
tool_id_to_name_map=tool_id_to_name_map,
@@ -879,46 +864,54 @@ def handle_stream_message_objects(
# (user has already responded to a clarification question)
skip_clarification = is_last_assistant_message_clarification(chat_history)
# NOTE: we _could_ pass in a zero argument function since emitter and state_container
# are just passed in immediately anyways, but the abstraction is cleaner this way.
yield from run_chat_loop_with_state_containers(
run_deep_research_llm_loop,
lambda emitter, state_container: run_deep_research_llm_loop(
emitter=emitter,
state_container=state_container,
simple_chat_history=simple_chat_history,
tools=tools,
custom_agent_prompt=custom_agent_prompt,
llm=llm,
token_counter=token_counter,
db_session=db_session,
skip_clarification=skip_clarification,
user_identity=user_identity,
chat_session_id=str(chat_session.id),
all_injected_file_metadata=all_injected_file_metadata,
),
llm_loop_completion_callback,
is_connected=check_is_connected,
emitter=emitter,
state_container=state_container,
simple_chat_history=simple_chat_history,
tools=tools,
custom_agent_prompt=custom_agent_prompt,
llm=llm,
token_counter=token_counter,
db_session=db_session,
skip_clarification=skip_clarification,
user_identity=user_identity,
chat_session_id=str(chat_session.id),
all_injected_file_metadata=all_injected_file_metadata,
)
else:
yield from run_chat_loop_with_state_containers(
run_llm_loop,
lambda emitter, state_container: run_llm_loop(
emitter=emitter,
state_container=state_container,
simple_chat_history=simple_chat_history,
tools=tools,
custom_agent_prompt=custom_agent_prompt,
context_files=extracted_context_files,
persona=persona,
user_memory_context=user_memory_context,
llm=llm,
token_counter=token_counter,
db_session=db_session,
forced_tool_id=forced_tool_id,
user_identity=user_identity,
chat_session_id=str(chat_session.id),
chat_files=chat_files_for_tools,
include_citations=new_msg_req.include_citations,
all_injected_file_metadata=all_injected_file_metadata,
inject_memories_in_prompt=user.use_memories,
),
llm_loop_completion_callback,
is_connected=check_is_connected, # Not passed through to run_llm_loop
emitter=emitter,
state_container=state_container,
simple_chat_history=simple_chat_history,
tools=tools,
custom_agent_prompt=custom_agent_prompt,
project_files=extracted_project_files,
persona=persona,
user_memory_context=user_memory_context,
llm=llm,
token_counter=token_counter,
db_session=db_session,
forced_tool_id=forced_tool_id,
user_identity=user_identity,
chat_session_id=str(chat_session.id),
chat_files=chat_files_for_tools,
include_citations=new_msg_req.include_citations,
all_injected_file_metadata=all_injected_file_metadata,
inject_memories_in_prompt=user.use_memories,
)
except ValueError as e:

View File

@@ -294,6 +294,12 @@ ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX = (
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
and os.environ.get("ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX", "").lower() == "true"
)
# Whether we should check for and create an index if necessary every time we
# instantiate an OpenSearchDocumentIndex on multitenant cloud. Defaults to True.
VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT = (
os.environ.get("VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT", "true").lower()
== "true"
)
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
# NOTE: this is used if and only if the vespa config server is accessible via a

View File

@@ -32,6 +32,8 @@ class GongConnector(LoadConnector, PollConnector):
BASE_URL = "https://api.gong.io"
MAX_CALL_DETAILS_ATTEMPTS = 6
CALL_DETAILS_DELAY = 30 # in seconds
# Gong API limit is 3 calls/sec — stay safely under it
MIN_REQUEST_INTERVAL = 0.5 # seconds between requests
def __init__(
self,
@@ -45,9 +47,13 @@ class GongConnector(LoadConnector, PollConnector):
self.continue_on_fail = continue_on_fail
self.auth_token_basic: str | None = None
self.hide_user_info = hide_user_info
self._last_request_time: float = 0.0
# urllib3 Retry already respects the Retry-After header by default
# (respect_retry_after_header=True), so on 429 it will sleep for the
# duration Gong specifies before retrying.
retry_strategy = Retry(
total=5,
total=10,
backoff_factor=2,
status_forcelist=[429, 500, 502, 503, 504],
)
@@ -61,8 +67,24 @@ class GongConnector(LoadConnector, PollConnector):
url = f"{GongConnector.BASE_URL}{endpoint}"
return url
def _throttled_request(
self, method: str, url: str, **kwargs: Any
) -> requests.Response:
"""Rate-limited request wrapper. Enforces MIN_REQUEST_INTERVAL between
calls to stay under Gong's 3 calls/sec limit and avoid triggering 429s."""
now = time.monotonic()
elapsed = now - self._last_request_time
if elapsed < self.MIN_REQUEST_INTERVAL:
time.sleep(self.MIN_REQUEST_INTERVAL - elapsed)
response = self._session.request(method, url, **kwargs)
self._last_request_time = time.monotonic()
return response
def _get_workspace_id_map(self) -> dict[str, str]:
response = self._session.get(GongConnector.make_url("/v2/workspaces"))
response = self._throttled_request(
"GET", GongConnector.make_url("/v2/workspaces")
)
response.raise_for_status()
workspaces_details = response.json().get("workspaces")
@@ -106,8 +128,8 @@ class GongConnector(LoadConnector, PollConnector):
del body["filter"]["workspaceId"]
while True:
response = self._session.post(
GongConnector.make_url("/v2/calls/transcript"), json=body
response = self._throttled_request(
"POST", GongConnector.make_url("/v2/calls/transcript"), json=body
)
# If no calls in the range, just break out
if response.status_code == 404:
@@ -142,8 +164,8 @@ class GongConnector(LoadConnector, PollConnector):
"contentSelector": {"exposedFields": {"parties": True}},
}
response = self._session.post(
GongConnector.make_url("/v2/calls/extensive"), json=body
response = self._throttled_request(
"POST", GongConnector.make_url("/v2/calls/extensive"), json=body
)
response.raise_for_status()
@@ -194,7 +216,8 @@ class GongConnector(LoadConnector, PollConnector):
# There's a likely race condition in the API where a transcript will have a
# call id but the call to v2/calls/extensive will not return all of the id's
# retry with exponential backoff has been observed to mitigate this
# in ~2 minutes
# in ~2 minutes. After max attempts, proceed with whatever we have —
# the per-call loop below will skip missing IDs gracefully.
current_attempt = 0
while True:
current_attempt += 1
@@ -213,11 +236,14 @@ class GongConnector(LoadConnector, PollConnector):
f"missing_call_ids={missing_call_ids}"
)
if current_attempt >= self.MAX_CALL_DETAILS_ATTEMPTS:
raise RuntimeError(
f"Attempt count exceeded for _get_call_details_by_ids: "
f"missing_call_ids={missing_call_ids} "
f"max_attempts={self.MAX_CALL_DETAILS_ATTEMPTS}"
logger.error(
f"Giving up on missing call id's after "
f"{self.MAX_CALL_DETAILS_ATTEMPTS} attempts: "
f"missing_call_ids={missing_call_ids}"
f"proceeding with {len(call_details_map)} of "
f"{len(transcript_call_ids)} calls"
)
break
wait_seconds = self.CALL_DETAILS_DELAY * pow(2, current_attempt - 1)
logger.warning(

View File

@@ -23,7 +23,6 @@ from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.serialization import pkcs12
from office365.graph_client import GraphClient # type: ignore[import-untyped]
from office365.intune.organizations.organization import Organization # type: ignore[import-untyped]
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
from office365.onedrive.sites.site import Site # type: ignore[import-untyped]
from office365.onedrive.sites.sites_with_root import SitesWithRoot # type: ignore[import-untyped]
@@ -872,6 +871,56 @@ class SharepointConnector(
"Site URLs must be full Sharepoint URLs (e.g. https://your-tenant.sharepoint.com/sites/your-site or https://your-tenant.sharepoint.com/teams/your-team)"
)
def _extract_tenant_domain_from_sites(self) -> str | None:
"""Extract the tenant domain from configured site URLs.
Site URLs look like https://{tenant}.sharepoint.com/sites/... so the
tenant domain is the first label of the hostname.
"""
for site_url in self.sites:
try:
hostname = urlsplit(site_url.strip()).hostname
except ValueError:
continue
if not hostname:
continue
tenant = hostname.split(".")[0]
if tenant:
return tenant
logger.warning(f"No tenant domain found from {len(self.sites)} sites")
return None
def _resolve_tenant_domain_from_root_site(self) -> str:
"""Resolve tenant domain via GET /v1.0/sites/root which only requires
Sites.Read.All (a permission the connector already needs)."""
root_site = self.graph_client.sites.root.get().execute_query()
hostname = root_site.site_collection.hostname
if not hostname:
raise ConnectorValidationError(
"Could not determine tenant domain from root site"
)
tenant_domain = hostname.split(".")[0]
logger.info(
"Resolved tenant domain '%s' from root site hostname '%s'",
tenant_domain,
hostname,
)
return tenant_domain
def _resolve_tenant_domain(self) -> str:
"""Determine the tenant domain, preferring site URLs over a Graph API
call to avoid needing extra permissions."""
from_sites = self._extract_tenant_domain_from_sites()
if from_sites:
logger.info(
"Resolved tenant domain '%s' from site URLs",
from_sites,
)
return from_sites
logger.info("No site URLs available; resolving tenant domain from root site")
return self._resolve_tenant_domain_from_root_site()
@property
def graph_client(self) -> GraphClient:
if self._graph_client is None:
@@ -1589,6 +1638,11 @@ class SharepointConnector(
sp_private_key = credentials.get("sp_private_key")
sp_certificate_password = credentials.get("sp_certificate_password")
if not sp_client_id:
raise ConnectorValidationError("Client ID is required")
if not sp_directory_id:
raise ConnectorValidationError("Directory (tenant) ID is required")
authority_url = f"{self.authority_host}/{sp_directory_id}"
if auth_method == SharepointAuthMethod.CERTIFICATE.value:
@@ -1641,21 +1695,7 @@ class SharepointConnector(
_acquire_token_for_graph, environment=self._azure_environment
)
if auth_method == SharepointAuthMethod.CERTIFICATE.value:
org = self.graph_client.organization.get().execute_query()
if not org or len(org) == 0:
raise ConnectorValidationError("No organization found")
tenant_info: Organization = org[
0
] # Access first item directly from collection
if not tenant_info.verified_domains:
raise ConnectorValidationError("No verified domains found for tenant")
sp_tenant_domain = tenant_info.verified_domains[0].name
if not sp_tenant_domain:
raise ConnectorValidationError("No verified domains found for tenant")
# remove the .onmicrosoft.com part
self.sp_tenant_domain = sp_tenant_domain.split(".")[0]
self.sp_tenant_domain = self._resolve_tenant_domain()
return None
def _get_drive_names_for_site(self, site_url: str) -> list[str]:

View File

@@ -21,8 +21,8 @@ from onyx.configs.app_configs import POSTGRES_POOL_RECYCLE
from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USE_NULL_POOL
from onyx.configs.app_configs import POSTGRES_USER
from onyx.db.engine.iam_auth import create_ssl_context_if_iam
from onyx.db.engine.iam_auth import get_iam_auth_token
from onyx.db.engine.iam_auth import ssl_context
from onyx.db.engine.sql_engine import ASYNC_DB_API
from onyx.db.engine.sql_engine import build_connection_string
from onyx.db.engine.sql_engine import is_valid_schema_name
@@ -66,7 +66,7 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
if app_name:
connect_args["server_settings"] = {"application_name": app_name}
connect_args["ssl"] = ssl_context
connect_args["ssl"] = create_ssl_context_if_iam()
engine_kwargs = {
"connect_args": connect_args,
@@ -97,7 +97,7 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
user = POSTGRES_USER
token = get_iam_auth_token(host, port, user, AWS_REGION_NAME)
cparams["password"] = token
cparams["ssl"] = ssl_context
cparams["ssl"] = create_ssl_context_if_iam()
return _ASYNC_ENGINE

View File

@@ -1,3 +1,4 @@
import functools
import os
import ssl
from typing import Any
@@ -48,11 +49,9 @@ def provide_iam_token(
configure_psycopg2_iam_auth(cparams, host, port, user, region)
@functools.cache
def create_ssl_context_if_iam() -> ssl.SSLContext | None:
"""Create an SSL context if IAM authentication is enabled, else return None."""
if USE_IAM_AUTH:
return ssl.create_default_context(cafile=SSL_CERT_FILE)
return None
ssl_context = create_ssl_context_if_iam()

View File

@@ -202,7 +202,6 @@ def create_default_image_gen_config_from_api_key(
api_key=api_key,
api_base=None,
api_version=None,
default_model_name=model_name,
deployment_name=None,
is_public=True,
)

View File

@@ -213,11 +213,29 @@ def upsert_llm_provider(
llm_provider_upsert_request: LLMProviderUpsertRequest,
db_session: Session,
) -> LLMProviderView:
existing_llm_provider = fetch_existing_llm_provider(
name=llm_provider_upsert_request.name, db_session=db_session
)
existing_llm_provider: LLMProviderModel | None = None
if llm_provider_upsert_request.id:
existing_llm_provider = fetch_existing_llm_provider_by_id(
id=llm_provider_upsert_request.id, db_session=db_session
)
if not existing_llm_provider:
raise ValueError(
f"LLM provider with id {llm_provider_upsert_request.id} not found"
)
if not existing_llm_provider:
if existing_llm_provider.name != llm_provider_upsert_request.name:
raise ValueError(
f"LLM provider with id {llm_provider_upsert_request.id} name change not allowed"
)
else:
existing_llm_provider = fetch_existing_llm_provider(
name=llm_provider_upsert_request.name, db_session=db_session
)
if existing_llm_provider:
raise ValueError(
f"LLM provider with name '{llm_provider_upsert_request.name}'"
" already exists"
)
existing_llm_provider = LLMProviderModel(name=llm_provider_upsert_request.name)
db_session.add(existing_llm_provider)
@@ -238,11 +256,7 @@ def upsert_llm_provider(
existing_llm_provider.api_base = api_base
existing_llm_provider.api_version = llm_provider_upsert_request.api_version
existing_llm_provider.custom_config = custom_config
# TODO: Remove default model name on api change
# Needed due to /provider/{id}/default endpoint not disclosing the default model name
existing_llm_provider.default_model_name = (
llm_provider_upsert_request.default_model_name
)
existing_llm_provider.is_public = llm_provider_upsert_request.is_public
existing_llm_provider.is_auto_mode = llm_provider_upsert_request.is_auto_mode
existing_llm_provider.deployment_name = llm_provider_upsert_request.deployment_name
@@ -306,15 +320,6 @@ def upsert_llm_provider(
display_name=model_config.display_name,
)
default_model = fetch_default_model(db_session, LLMModelFlowType.CHAT)
if default_model and default_model.llm_provider_id == existing_llm_provider.id:
_update_default_model(
db_session=db_session,
provider_id=existing_llm_provider.id,
model=existing_llm_provider.default_model_name,
flow_type=LLMModelFlowType.CHAT,
)
# Make sure the relationship table stays up to date
update_group_llm_provider_relationships__no_commit(
llm_provider_id=existing_llm_provider.id,
@@ -488,6 +493,22 @@ def fetch_existing_llm_provider(
return provider_model
def fetch_existing_llm_provider_by_id(
id: int, db_session: Session
) -> LLMProviderModel | None:
provider_model = db_session.scalar(
select(LLMProviderModel)
.where(LLMProviderModel.id == id)
.options(
selectinload(LLMProviderModel.model_configurations),
selectinload(LLMProviderModel.groups),
selectinload(LLMProviderModel.personas),
)
)
return provider_model
def fetch_embedding_provider(
db_session: Session, provider_type: EmbeddingProvider
) -> CloudEmbeddingProviderModel | None:
@@ -604,22 +625,13 @@ def remove_llm_provider__no_commit(db_session: Session, provider_id: int) -> Non
db_session.flush()
def update_default_provider(provider_id: int, db_session: Session) -> None:
# Attempt to get the default_model_name from the provider first
# TODO: Remove default_model_name check
provider = db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.id == provider_id,
)
)
if provider is None:
raise ValueError(f"LLM Provider with id={provider_id} does not exist")
def update_default_provider(
provider_id: int, model_name: str, db_session: Session
) -> None:
_update_default_model(
db_session,
provider_id,
provider.default_model_name,
model_name,
LLMModelFlowType.CHAT,
)
@@ -805,12 +817,6 @@ def sync_auto_mode_models(
)
changes += 1
# In Auto mode, default model is always set from GitHub config
default_model = llm_recommendations.get_default_model(provider.provider)
if default_model and provider.default_model_name != default_model.name:
provider.default_model_name = default_model.name
changes += 1
db_session.commit()
return changes

View File

@@ -2822,13 +2822,17 @@ class LLMProvider(Base):
custom_config: Mapped[dict[str, str] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
default_model_name: Mapped[str] = mapped_column(String)
# Deprecated: use LLMModelFlow with CHAT flow type instead
default_model_name: Mapped[str | None] = mapped_column(String, nullable=True)
deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)
# should only be set for a single provider
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
# Deprecated: use LLMModelFlow.is_default with CHAT flow type instead
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
# Deprecated: use LLMModelFlow.is_default with VISION flow type instead
is_default_vision_provider: Mapped[bool | None] = mapped_column(Boolean)
# Deprecated: use LLMModelFlow with VISION flow type instead
default_vision_model: Mapped[str | None] = mapped_column(String, nullable=True)
# EE only
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
@@ -2879,6 +2883,7 @@ class ModelConfiguration(Base):
# - The end-user is configuring a model and chooses not to set a max-input-tokens limit.
max_input_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Deprecated: use LLMModelFlow with VISION flow type instead
supports_image_input: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
# Human-readable display name for the model.

View File

@@ -256,9 +256,6 @@ def create_update_persona(
try:
# Default persona validation
if create_persona_request.is_default_persona:
if not create_persona_request.is_public:
raise ValueError("Cannot make a default persona non public")
# Curators can edit default personas, but not make them
if user.role == UserRole.CURATOR or user.role == UserRole.GLOBAL_CURATOR:
pass
@@ -335,6 +332,7 @@ def update_persona_shared(
db_session: Session,
group_ids: list[int] | None = None,
is_public: bool | None = None,
label_ids: list[int] | None = None,
) -> None:
"""Simplified version of `create_update_persona` which only touches the
accessibility rather than any of the logic (e.g. prompt, connected data sources,
@@ -344,9 +342,7 @@ def update_persona_shared(
)
if user and user.role != UserRole.ADMIN and persona.user_id != user.id:
raise HTTPException(
status_code=403, detail="You don't have permission to modify this persona"
)
raise PermissionError("You don't have permission to modify this persona")
versioned_update_persona_access = fetch_versioned_implementation(
"onyx.db.persona", "update_persona_access"
@@ -360,6 +356,15 @@ def update_persona_shared(
group_ids=group_ids,
)
if label_ids is not None:
labels = (
db_session.query(PersonaLabel).filter(PersonaLabel.id.in_(label_ids)).all()
)
if len(labels) != len(label_ids):
raise ValueError("Some label IDs were not found in the database")
persona.labels.clear()
persona.labels = labels
db_session.commit()
@@ -965,6 +970,8 @@ def upsert_persona(
labels = (
db_session.query(PersonaLabel).filter(PersonaLabel.id.in_(label_ids)).all()
)
if len(labels) != len(label_ids):
raise ValueError("Some label IDs were not found in the database")
# Fetch and attach hierarchy_nodes by IDs
hierarchy_nodes = None
@@ -1161,9 +1168,6 @@ def update_persona_is_default(
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
)
if not persona.is_public:
persona.is_public = True
persona.is_default_persona = is_default
db_session.commit()

View File

@@ -6,6 +6,7 @@ from sqlalchemy import select
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.db.models import Project__UserFile
from onyx.db.models import UserFile
@@ -57,12 +58,19 @@ def fetch_user_project_ids_for_user_files(
db_session: Session,
) -> dict[str, list[int]]:
"""Fetch user project ids for specified user files"""
stmt = select(UserFile).where(UserFile.id.in_(user_file_ids))
results = db_session.execute(stmt).scalars().all()
return {
str(user_file.id): [project.id for project in user_file.projects]
for user_file in results
user_file_uuid_ids = [UUID(user_file_id) for user_file_id in user_file_ids]
stmt = select(Project__UserFile.user_file_id, Project__UserFile.project_id).where(
Project__UserFile.user_file_id.in_(user_file_uuid_ids)
)
rows = db_session.execute(stmt).all()
user_file_id_to_project_ids: dict[str, list[int]] = {
user_file_id: [] for user_file_id in user_file_ids
}
for user_file_id, project_id in rows:
user_file_id_to_project_ids[str(user_file_id)].append(project_id)
return user_file_id_to_project_ids
def fetch_persona_ids_for_user_files(

View File

@@ -139,7 +139,7 @@ def generate_final_report(
custom_agent_prompt=None,
simple_chat_history=history,
reminder_message=reminder_message,
project_files=None,
context_files=None,
available_tokens=llm.config.max_input_tokens,
all_injected_file_metadata=all_injected_file_metadata,
)
@@ -257,7 +257,7 @@ def run_deep_research_llm_loop(
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
project_files=None,
context_files=None,
available_tokens=available_tokens,
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
all_injected_file_metadata=all_injected_file_metadata,
@@ -321,7 +321,7 @@ def run_deep_research_llm_loop(
custom_agent_prompt=None,
simple_chat_history=simple_chat_history + [reminder_message],
reminder_message=None,
project_files=None,
context_files=None,
available_tokens=available_tokens,
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT + 1,
all_injected_file_metadata=all_injected_file_metadata,
@@ -485,7 +485,7 @@ def run_deep_research_llm_loop(
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=first_cycle_reminder_message,
project_files=None,
context_files=None,
available_tokens=available_tokens,
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
all_injected_file_metadata=all_injected_file_metadata,

View File

@@ -11,6 +11,7 @@ from onyx.document_index.opensearch.opensearch_document_index import (
OpenSearchOldDocumentIndex,
)
from onyx.document_index.vespa.index import VespaIndex
from onyx.indexing.models import IndexingSetting
from shared_configs.configs import MULTI_TENANT
@@ -49,8 +50,11 @@ def get_default_document_index(
opensearch_retrieval_enabled = get_opensearch_retrieval_state(db_session)
if opensearch_retrieval_enabled:
indexing_setting = IndexingSetting.from_db_model(search_settings)
return OpenSearchOldDocumentIndex(
index_name=search_settings.index_name,
embedding_dim=indexing_setting.final_embedding_dim,
embedding_precision=indexing_setting.embedding_precision,
secondary_index_name=secondary_index_name,
large_chunks_enabled=search_settings.large_chunks_enabled,
secondary_large_chunks_enabled=secondary_large_chunks_enabled,
@@ -118,8 +122,11 @@ def get_all_document_indices(
)
opensearch_document_index: OpenSearchOldDocumentIndex | None = None
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
indexing_setting = IndexingSetting.from_db_model(search_settings)
opensearch_document_index = OpenSearchOldDocumentIndex(
index_name=search_settings.index_name,
embedding_dim=indexing_setting.final_embedding_dim,
embedding_precision=indexing_setting.embedding_precision,
secondary_index_name=None,
large_chunks_enabled=False,
secondary_large_chunks_enabled=None,

View File

@@ -1,5 +1,7 @@
import logging
import time
from contextlib import AbstractContextManager
from contextlib import nullcontext
from typing import Any
from typing import Generic
from typing import TypeVar
@@ -83,22 +85,26 @@ def get_new_body_without_vectors(body: dict[str, Any]) -> dict[str, Any]:
return new_body
class OpenSearchClient:
"""Client for interacting with OpenSearch.
class OpenSearchClient(AbstractContextManager):
"""Client for interacting with OpenSearch for cluster-level operations.
OpenSearch's Python module has pretty bad typing support so this client
attempts to protect the rest of the codebase from this. As a consequence,
most methods here return the minimum data needed for the rest of Onyx, and
tend to rely on Exceptions to handle errors.
TODO(andrei): This class currently assumes the structure of the database
schema when it returns a DocumentChunk. Make the class, or at least the
search method, templated on the structure the caller can expect.
Args:
host: The host of the OpenSearch cluster.
port: The port of the OpenSearch cluster.
auth: The authentication credentials for the OpenSearch cluster. A tuple
of (username, password).
use_ssl: Whether to use SSL for the OpenSearch cluster. Defaults to
True.
verify_certs: Whether to verify the SSL certificates for the OpenSearch
cluster. Defaults to False.
ssl_show_warn: Whether to show warnings for SSL certificates. Defaults
to False.
timeout: The timeout for the OpenSearch cluster. Defaults to
DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S.
"""
def __init__(
self,
index_name: str,
host: str = OPENSEARCH_HOST,
port: int = OPENSEARCH_REST_API_PORT,
auth: tuple[str, str] = (OPENSEARCH_ADMIN_USERNAME, OPENSEARCH_ADMIN_PASSWORD),
@@ -107,9 +113,8 @@ class OpenSearchClient:
ssl_show_warn: bool = False,
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
):
self._index_name = index_name
logger.debug(
f"Creating OpenSearch client for index {index_name} with host {host} and port {port} and timeout {timeout} seconds."
f"Creating OpenSearch client with host {host}, port {port} and timeout {timeout} seconds."
)
self._client = OpenSearch(
hosts=[{"host": host, "port": port}],
@@ -125,6 +130,142 @@ class OpenSearchClient:
# your request body that is less than this value.
timeout=timeout,
)
def __exit__(self, *_: Any) -> None:
self.close()
def __del__(self) -> None:
try:
self.close()
except Exception:
pass
@log_function_time(print_only=True, debug_only=True, include_args=True)
def create_search_pipeline(
self,
pipeline_id: str,
pipeline_body: dict[str, Any],
) -> None:
"""Creates a search pipeline.
See the OpenSearch documentation for more information on the search
pipeline body.
https://docs.opensearch.org/latest/search-plugins/search-pipelines/index/
Args:
pipeline_id: The ID of the search pipeline to create.
pipeline_body: The body of the search pipeline to create.
Raises:
Exception: There was an error creating the search pipeline.
"""
result = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
if not result.get("acknowledged", False):
raise RuntimeError(f"Failed to create search pipeline {pipeline_id}.")
@log_function_time(print_only=True, debug_only=True, include_args=True)
def delete_search_pipeline(self, pipeline_id: str) -> None:
"""Deletes a search pipeline.
Args:
pipeline_id: The ID of the search pipeline to delete.
Raises:
Exception: There was an error deleting the search pipeline.
"""
result = self._client.search_pipeline.delete(id=pipeline_id)
if not result.get("acknowledged", False):
raise RuntimeError(f"Failed to delete search pipeline {pipeline_id}.")
@log_function_time(print_only=True, debug_only=True, include_args=True)
def put_cluster_settings(self, settings: dict[str, Any]) -> bool:
"""Puts cluster settings.
Args:
settings: The settings to put.
Raises:
Exception: There was an error putting the cluster settings.
Returns:
True if the settings were put successfully, False otherwise.
"""
response = self._client.cluster.put_settings(body=settings)
if response.get("acknowledged", False):
logger.info("Successfully put cluster settings.")
return True
else:
logger.error(f"Failed to put cluster settings: {response}.")
return False
@log_function_time(print_only=True, debug_only=True)
def ping(self) -> bool:
"""Pings the OpenSearch cluster.
Returns:
True if OpenSearch could be reached, False if it could not.
"""
return self._client.ping()
@log_function_time(print_only=True, debug_only=True)
def close(self) -> None:
"""Closes the client.
Raises:
Exception: There was an error closing the client.
"""
self._client.close()
class OpenSearchIndexClient(OpenSearchClient):
"""Client for interacting with OpenSearch for index-level operations.
OpenSearch's Python module has pretty bad typing support so this client
attempts to protect the rest of the codebase from this. As a consequence,
most methods here return the minimum data needed for the rest of Onyx, and
tend to rely on Exceptions to handle errors.
TODO(andrei): This class currently assumes the structure of the database
schema when it returns a DocumentChunk. Make the class, or at least the
search method, templated on the structure the caller can expect.
Args:
index_name: The name of the index to interact with.
host: The host of the OpenSearch cluster.
port: The port of the OpenSearch cluster.
auth: The authentication credentials for the OpenSearch cluster. A tuple
of (username, password).
use_ssl: Whether to use SSL for the OpenSearch cluster. Defaults to
True.
verify_certs: Whether to verify the SSL certificates for the OpenSearch
cluster. Defaults to False.
ssl_show_warn: Whether to show warnings for SSL certificates. Defaults
to False.
timeout: The timeout for the OpenSearch cluster. Defaults to
DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S.
"""
def __init__(
self,
index_name: str,
host: str = OPENSEARCH_HOST,
port: int = OPENSEARCH_REST_API_PORT,
auth: tuple[str, str] = (OPENSEARCH_ADMIN_USERNAME, OPENSEARCH_ADMIN_PASSWORD),
use_ssl: bool = True,
verify_certs: bool = False,
ssl_show_warn: bool = False,
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
):
super().__init__(
host=host,
port=port,
auth=auth,
use_ssl=use_ssl,
verify_certs=verify_certs,
ssl_show_warn=ssl_show_warn,
timeout=timeout,
)
self._index_name = index_name
logger.debug(
f"OpenSearch client created successfully for index {self._index_name}."
)
@@ -192,6 +333,38 @@ class OpenSearchClient:
"""
return self._client.indices.exists(index=self._index_name)
@log_function_time(print_only=True, debug_only=True, include_args=True)
def put_mapping(self, mappings: dict[str, Any]) -> None:
"""Updates the index mapping in an idempotent manner.
- Existing fields with the same definition: No-op (succeeds silently).
- New fields: Added to the index.
- Existing fields with different types: Raises exception (requires
reindex).
See the OpenSearch documentation for more information:
https://docs.opensearch.org/latest/api-reference/index-apis/put-mapping/
Args:
mappings: The complete mapping definition to apply. This will be
merged with existing mappings in the index.
Raises:
Exception: There was an error updating the mappings, such as
attempting to change the type of an existing field.
"""
logger.debug(
f"Putting mappings for index {self._index_name} with mappings {mappings}."
)
response = self._client.indices.put_mapping(
index=self._index_name, body=mappings
)
if not response.get("acknowledged", False):
raise RuntimeError(
f"Failed to put the mapping update for index {self._index_name}."
)
logger.debug(f"Successfully put mappings for index {self._index_name}.")
@log_function_time(print_only=True, debug_only=True, include_args=True)
def validate_index(self, expected_mappings: dict[str, Any]) -> bool:
"""Validates the index.
@@ -610,43 +783,6 @@ class OpenSearchClient:
)
return DocumentChunk.model_validate(document_chunk_source)
@log_function_time(print_only=True, debug_only=True, include_args=True)
def create_search_pipeline(
self,
pipeline_id: str,
pipeline_body: dict[str, Any],
) -> None:
"""Creates a search pipeline.
See the OpenSearch documentation for more information on the search
pipeline body.
https://docs.opensearch.org/latest/search-plugins/search-pipelines/index/
Args:
pipeline_id: The ID of the search pipeline to create.
pipeline_body: The body of the search pipeline to create.
Raises:
Exception: There was an error creating the search pipeline.
"""
result = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
if not result.get("acknowledged", False):
raise RuntimeError(f"Failed to create search pipeline {pipeline_id}.")
@log_function_time(print_only=True, debug_only=True, include_args=True)
def delete_search_pipeline(self, pipeline_id: str) -> None:
"""Deletes a search pipeline.
Args:
pipeline_id: The ID of the search pipeline to delete.
Raises:
Exception: There was an error deleting the search pipeline.
"""
result = self._client.search_pipeline.delete(id=pipeline_id)
if not result.get("acknowledged", False):
raise RuntimeError(f"Failed to delete search pipeline {pipeline_id}.")
@log_function_time(print_only=True, debug_only=True)
def search(
self, body: dict[str, Any], search_pipeline_id: str | None
@@ -807,48 +943,6 @@ class OpenSearchClient:
"""
self._client.indices.refresh(index=self._index_name)
@log_function_time(print_only=True, debug_only=True, include_args=True)
def put_cluster_settings(self, settings: dict[str, Any]) -> bool:
"""Puts cluster settings.
Args:
settings: The settings to put.
Raises:
Exception: There was an error putting the cluster settings.
Returns:
True if the settings were put successfully, False otherwise.
"""
response = self._client.cluster.put_settings(body=settings)
if response.get("acknowledged", False):
logger.info("Successfully put cluster settings.")
return True
else:
logger.error(f"Failed to put cluster settings: {response}.")
return False
@log_function_time(print_only=True, debug_only=True)
def ping(self) -> bool:
"""Pings the OpenSearch cluster.
Returns:
True if OpenSearch could be reached, False if it could not.
"""
return self._client.ping()
@log_function_time(print_only=True, debug_only=True)
def close(self) -> None:
"""Closes the client.
TODO(andrei): Can we have some way to auto close when the client no
longer has any references?
Raises:
Exception: There was an error closing the client.
"""
self._client.close()
def _get_hits_and_profile_from_search_result(
self, result: dict[str, Any]
) -> tuple[list[Any], int | None, bool | None, dict[str, Any], dict[str, Any]]:
@@ -945,14 +1039,7 @@ def wait_for_opensearch_with_timeout(
Returns:
True if OpenSearch is ready, False otherwise.
"""
made_client = False
try:
if client is None:
# NOTE: index_name does not matter because we are only using this object
# to ping.
# TODO(andrei): Make this better.
client = OpenSearchClient(index_name="")
made_client = True
with nullcontext(client) if client else OpenSearchClient() as client:
time_start = time.monotonic()
while True:
if client.ping():
@@ -969,7 +1056,3 @@ def wait_for_opensearch_with_timeout(
f"[OpenSearch] Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={wait_limit_s:.1f}"
)
time.sleep(wait_interval_s)
finally:
if made_client:
assert client is not None
client.close()

View File

@@ -7,6 +7,7 @@ from opensearchpy import NotFoundError
from onyx.access.models import DocumentAccess
from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH
from onyx.configs.app_configs import VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT
from onyx.configs.chat_configs import NUM_RETURNED_HITS
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
from onyx.configs.constants import PUBLIC_DOC_PAT
@@ -40,6 +41,7 @@ from onyx.document_index.interfaces_new import IndexingMetadata
from onyx.document_index.interfaces_new import MetadataUpdateRequest
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.client import OpenSearchClient
from onyx.document_index.opensearch.client import OpenSearchIndexClient
from onyx.document_index.opensearch.client import SearchHit
from onyx.document_index.opensearch.cluster_settings import OPENSEARCH_CLUSTER_SETTINGS
from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
@@ -93,6 +95,25 @@ def generate_opensearch_filtered_access_control_list(
return list(access_control_list)
def set_cluster_state(client: OpenSearchClient) -> None:
if not client.put_cluster_settings(settings=OPENSEARCH_CLUSTER_SETTINGS):
logger.error(
"Failed to put cluster settings. If the settings have never been set before, "
"this may cause unexpected index creation when indexing documents into an "
"index that does not exist, or may cause expected logs to not appear. If this "
"is not the first time running Onyx against this instance of OpenSearch, these "
"settings have likely already been set. Not taking any further action..."
)
client.create_search_pipeline(
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
pipeline_body=MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
)
client.create_search_pipeline(
pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
pipeline_body=ZSCORE_NORMALIZATION_PIPELINE_CONFIG,
)
def _convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
chunk: DocumentChunk,
score: float | None,
@@ -248,6 +269,8 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
def __init__(
self,
index_name: str,
embedding_dim: int,
embedding_precision: EmbeddingPrecision,
secondary_index_name: str | None,
large_chunks_enabled: bool, # noqa: ARG002
secondary_large_chunks_enabled: bool | None, # noqa: ARG002
@@ -258,10 +281,6 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
index_name=index_name,
secondary_index_name=secondary_index_name,
)
if multitenant:
raise ValueError(
"Bug: OpenSearch is not yet ready for multitenant environments but something tried to use it."
)
if multitenant != MULTI_TENANT:
raise ValueError(
"Bug: Multitenant mismatch when initializing an OpenSearchDocumentIndex. "
@@ -269,8 +288,10 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
)
tenant_id = get_current_tenant_id()
self._real_index = OpenSearchDocumentIndex(
index_name=index_name,
tenant_state=TenantState(tenant_id=tenant_id, multitenant=multitenant),
index_name=index_name,
embedding_dim=embedding_dim,
embedding_precision=embedding_precision,
)
@staticmethod
@@ -279,9 +300,8 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
embedding_dims: list[int],
embedding_precisions: list[EmbeddingPrecision],
) -> None:
# TODO(andrei): Implement.
raise NotImplementedError(
"Multitenant index registration is not yet implemented for OpenSearch."
"Bug: Multitenant index registration is not supported for OpenSearch."
)
def ensure_indices_exist(
@@ -471,19 +491,37 @@ class OpenSearchDocumentIndex(DocumentIndex):
for an OpenSearch search engine instance. It handles the complete lifecycle
of document chunks within a specific OpenSearch index/schema.
Although not yet used in this way in the codebase, each kind of embedding
used should correspond to a different instance of this class, and therefore
a different index in OpenSearch.
Each kind of embedding used should correspond to a different instance of
this class, and therefore a different index in OpenSearch.
If in a multitenant environment and
VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT, will verify and create the index
if necessary on initialization. This is because there is no logic which runs
on cluster restart which scans through all search settings over all tenants
and creates the relevant indices.
Args:
tenant_state: The tenant state of the caller.
index_name: The name of the index to interact with.
embedding_dim: The dimensionality of the embeddings used for the index.
embedding_precision: The precision of the embeddings used for the index.
"""
def __init__(
self,
index_name: str,
tenant_state: TenantState,
index_name: str,
embedding_dim: int,
embedding_precision: EmbeddingPrecision,
) -> None:
self._index_name: str = index_name
self._tenant_state: TenantState = tenant_state
self._os_client = OpenSearchClient(index_name=self._index_name)
self._client = OpenSearchIndexClient(index_name=self._index_name)
if self._tenant_state.multitenant and VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT:
self.verify_and_create_index_if_necessary(
embedding_dim=embedding_dim, embedding_precision=embedding_precision
)
def verify_and_create_index_if_necessary(
self,
@@ -492,10 +530,15 @@ class OpenSearchDocumentIndex(DocumentIndex):
) -> None:
"""Verifies and creates the index if necessary.
Also puts the desired cluster settings.
Also puts the desired cluster settings if not in a multitenant
environment.
Also puts the desired search pipeline state, creating the pipelines if
they do not exist and updating them otherwise.
Also puts the desired search pipeline state if not in a multitenant
environment, creating the pipelines if they do not exist and updating
them otherwise.
In a multitenant environment, the above steps happen explicitly on
setup.
Args:
embedding_dim: Vector dimensionality for the vector similarity part
@@ -508,47 +551,38 @@ class OpenSearchDocumentIndex(DocumentIndex):
search pipelines.
"""
logger.debug(
f"[OpenSearchDocumentIndex] Verifying and creating index {self._index_name} if necessary, "
f"with embedding dimension {embedding_dim}."
f"[OpenSearchDocumentIndex] Verifying and creating index {self._index_name} if "
f"necessary, with embedding dimension {embedding_dim}."
)
if not self._tenant_state.multitenant:
set_cluster_state(self._client)
expected_mappings = DocumentSchema.get_document_schema(
embedding_dim, self._tenant_state.multitenant
)
if not self._os_client.put_cluster_settings(
settings=OPENSEARCH_CLUSTER_SETTINGS
):
logger.error(
f"Failed to put cluster settings for index {self._index_name}. If the settings have never been set before this "
"may cause unexpected index creation when indexing documents into an index that does not exist, or may cause "
"expected logs to not appear. If this is not the first time running Onyx against this instance of OpenSearch, "
"these settings have likely already been set. Not taking any further action..."
)
if not self._os_client.index_exists():
if not self._client.index_exists():
if USING_AWS_MANAGED_OPENSEARCH:
index_settings = (
DocumentSchema.get_index_settings_for_aws_managed_opensearch()
)
else:
index_settings = DocumentSchema.get_index_settings()
self._os_client.create_index(
self._client.create_index(
mappings=expected_mappings,
settings=index_settings,
)
if not self._os_client.validate_index(
expected_mappings=expected_mappings,
):
raise RuntimeError(
f"The index {self._index_name} is not valid. The expected mappings do not match the actual mappings."
)
self._os_client.create_search_pipeline(
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
pipeline_body=MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
)
self._os_client.create_search_pipeline(
pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
pipeline_body=ZSCORE_NORMALIZATION_PIPELINE_CONFIG,
)
else:
# Ensure schema is up to date by applying the current mappings.
try:
self._client.put_mapping(expected_mappings)
except Exception as e:
logger.error(
f"Failed to update mappings for index {self._index_name}. This likely means a "
f"field type was changed which requires reindexing. Error: {e}"
)
raise
def index(
self,
@@ -620,7 +654,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
)
# Now index. This will raise if a chunk of the same ID exists, which
# we do not expect because we should have deleted all chunks.
self._os_client.bulk_index_documents(
self._client.bulk_index_documents(
documents=chunk_batch,
tenant_state=self._tenant_state,
)
@@ -660,7 +694,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
tenant_state=self._tenant_state,
)
return self._os_client.delete_by_query(query_body)
return self._client.delete_by_query(query_body)
def update(
self,
@@ -760,7 +794,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
document_id=doc_id,
chunk_index=chunk_index,
)
self._os_client.update_document(
self._client.update_document(
document_chunk_id=document_chunk_id,
properties_to_update=properties_to_update,
)
@@ -799,7 +833,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
min_chunk_index=chunk_request.min_chunk_ind,
max_chunk_index=chunk_request.max_chunk_ind,
)
search_hits = self._os_client.search(
search_hits = self._client.search(
body=query_body,
search_pipeline_id=None,
)
@@ -849,7 +883,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
# NOTE: Using z-score normalization here because it's better for hybrid search from a theoretical standpoint.
# Empirically on a small dataset of up to 10K docs, it's not very different. Likely more impactful at scale.
# https://opensearch.org/blog/introducing-the-z-score-normalization-technique-for-hybrid-search/
search_hits: list[SearchHit[DocumentChunk]] = self._os_client.search(
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
body=query_body,
search_pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
)
@@ -881,7 +915,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
index_filters=filters,
num_to_retrieve=num_to_retrieve,
)
search_hits: list[SearchHit[DocumentChunk]] = self._os_client.search(
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
body=query_body,
search_pipeline_id=None,
)
@@ -909,6 +943,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
# Do not raise if the document already exists, just update. This is
# because the document may already have been indexed during the
# OpenSearch transition period.
self._os_client.bulk_index_documents(
self._client.bulk_index_documents(
documents=chunks, tenant_state=self._tenant_state, update_if_exists=True
)

View File

@@ -405,6 +405,7 @@ class PersonaShareRequest(BaseModel):
user_ids: list[UUID] | None = None
group_ids: list[int] | None = None
is_public: bool | None = None
label_ids: list[int] | None = None
# We notify each user when a user is shared with them
@@ -415,14 +416,22 @@ def share_persona(
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
update_persona_shared(
persona_id=persona_id,
user=user,
db_session=db_session,
user_ids=persona_share_request.user_ids,
group_ids=persona_share_request.group_ids,
is_public=persona_share_request.is_public,
)
try:
update_persona_shared(
persona_id=persona_id,
user=user,
db_session=db_session,
user_ids=persona_share_request.user_ids,
group_ids=persona_share_request.group_ids,
is_public=persona_share_request.is_public,
label_ids=persona_share_request.label_ids,
)
except PermissionError as e:
logger.exception("Failed to share persona")
raise HTTPException(status_code=403, detail=str(e))
except ValueError as e:
logger.exception("Failed to share persona")
raise HTTPException(status_code=400, detail=str(e))
@basic_router.delete("/{persona_id}", tags=PUBLIC_API_TAGS)

View File

@@ -97,7 +97,6 @@ def _build_llm_provider_request(
), # Only this from source
api_base=api_base, # From request
api_version=api_version, # From request
default_model_name=model_name,
deployment_name=deployment_name, # From request
is_public=True,
groups=[],
@@ -136,7 +135,6 @@ def _build_llm_provider_request(
api_key=api_key,
api_base=api_base,
api_version=api_version,
default_model_name=model_name,
deployment_name=deployment_name,
is_public=True,
groups=[],
@@ -168,7 +166,6 @@ def _create_image_gen_llm_provider__no_commit(
api_key=provider_request.api_key,
api_base=provider_request.api_base,
api_version=provider_request.api_version,
default_model_name=provider_request.default_model_name,
deployment_name=provider_request.deployment_name,
is_public=provider_request.is_public,
custom_config=provider_request.custom_config,

View File

@@ -22,7 +22,10 @@ from onyx.auth.users import current_chat_accessible_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import LLMModelFlowType
from onyx.db.llm import can_user_access_llm_provider
from onyx.db.llm import fetch_default_llm_model
from onyx.db.llm import fetch_default_vision_model
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import fetch_existing_llm_provider_by_id
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_existing_models
from onyx.db.llm import fetch_persona_with_groups
@@ -52,11 +55,12 @@ from onyx.llm.well_known_providers.llm_provider_options import (
)
from onyx.server.manage.llm.models import BedrockFinalModelResponse
from onyx.server.manage.llm.models import BedrockModelsRequest
from onyx.server.manage.llm.models import DefaultModel
from onyx.server.manage.llm.models import LLMCost
from onyx.server.manage.llm.models import LLMProviderDescriptor
from onyx.server.manage.llm.models import LLMProviderResponse
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from onyx.server.manage.llm.models import OllamaFinalModelResponse
from onyx.server.manage.llm.models import OllamaModelDetails
from onyx.server.manage.llm.models import OllamaModelsRequest
@@ -233,12 +237,9 @@ def test_llm_configuration(
test_api_key = test_llm_request.api_key
test_custom_config = test_llm_request.custom_config
if test_llm_request.name:
# NOTE: we are querying by name. we probably should be querying by an invariant id, but
# as it turns out the name is not editable in the UI and other code also keys off name,
# so we won't rock the boat just yet.
existing_provider = fetch_existing_llm_provider(
name=test_llm_request.name, db_session=db_session
if test_llm_request.id:
existing_provider = fetch_existing_llm_provider_by_id(
id=test_llm_request.id, db_session=db_session
)
if existing_provider:
test_custom_config = _restore_masked_custom_config_values(
@@ -268,7 +269,7 @@ def test_llm_configuration(
llm = get_llm(
provider=test_llm_request.provider,
model=test_llm_request.default_model_name,
model=test_llm_request.model,
api_key=test_api_key,
api_base=test_llm_request.api_base,
api_version=test_llm_request.api_version,
@@ -303,7 +304,7 @@ def list_llm_providers(
include_image_gen: bool = Query(False),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderView]:
) -> LLMProviderResponse[LLMProviderView]:
start_time = datetime.now(timezone.utc)
logger.debug("Starting to fetch LLM providers")
@@ -328,7 +329,15 @@ def list_llm_providers(
duration = (end_time - start_time).total_seconds()
logger.debug(f"Completed fetching LLM providers in {duration:.2f} seconds")
return llm_provider_list
return LLMProviderResponse[LLMProviderView].from_models(
providers=llm_provider_list,
default_text=DefaultModel.from_model_config(
fetch_default_llm_model(db_session)
),
default_vision=DefaultModel.from_model_config(
fetch_default_vision_model(db_session)
),
)
@admin_router.put("/provider")
@@ -344,18 +353,44 @@ def put_llm_provider(
# validate request (e.g. if we're intending to create but the name already exists we should throw an error)
# NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache
# the result
existing_provider = fetch_existing_llm_provider(
existing_provider = None
if llm_provider_upsert_request.id:
existing_provider = fetch_existing_llm_provider_by_id(
id=llm_provider_upsert_request.id, db_session=db_session
)
# Check name constraints
# TODO: Once port from name to id is complete, unique name will no longer be required
if existing_provider and llm_provider_upsert_request.name != existing_provider.name:
raise HTTPException(
status_code=400,
detail="Renaming providers is not currently supported",
)
found_provider = fetch_existing_llm_provider(
name=llm_provider_upsert_request.name, db_session=db_session
)
if found_provider is not None and found_provider is not existing_provider:
raise HTTPException(
status_code=400,
detail=f"Provider with name={llm_provider_upsert_request.name} already exists",
)
if existing_provider and is_creation:
raise HTTPException(
status_code=400,
detail=f"LLM Provider with name {llm_provider_upsert_request.name} already exists",
detail=(
f"LLM Provider with name {llm_provider_upsert_request.name} and "
f"id={llm_provider_upsert_request.id} already exists"
),
)
elif not existing_provider and not is_creation:
raise HTTPException(
status_code=400,
detail=f"LLM Provider with name {llm_provider_upsert_request.name} does not exist",
detail=(
f"LLM Provider with name {llm_provider_upsert_request.name} and "
f"id={llm_provider_upsert_request.id} does not exist"
),
)
# SSRF Protection: Validate api_base and custom_config match stored values
@@ -393,22 +428,6 @@ def put_llm_provider(
deduplicated_personas.append(persona_id)
llm_provider_upsert_request.personas = deduplicated_personas
default_model_found = False
for model_configuration in llm_provider_upsert_request.model_configurations:
if model_configuration.name == llm_provider_upsert_request.default_model_name:
model_configuration.is_visible = True
default_model_found = True
# TODO: Remove this logic on api change
# Believed to be a dead pathway but we want to be safe for now
if not default_model_found:
llm_provider_upsert_request.model_configurations.append(
ModelConfigurationUpsertRequest(
name=llm_provider_upsert_request.default_model_name, is_visible=True
)
)
# the llm api key is sanitized when returned to clients, so the only time we
# should get a real key is when it is explicitly changed
if existing_provider and not llm_provider_upsert_request.api_key_changed:
@@ -438,8 +457,8 @@ def put_llm_provider(
config = fetch_llm_recommendations_from_github()
if config and llm_provider_upsert_request.provider in config.providers:
# Refetch the provider to get the updated model
updated_provider = fetch_existing_llm_provider(
name=llm_provider_upsert_request.name, db_session=db_session
updated_provider = fetch_existing_llm_provider_by_id(
id=result.id, db_session=db_session
)
if updated_provider:
sync_auto_mode_models(
@@ -469,28 +488,29 @@ def delete_llm_provider(
raise HTTPException(status_code=404, detail=str(e))
@admin_router.post("/provider/{provider_id}/default")
@admin_router.post("/default")
def set_provider_as_default(
provider_id: int,
default_model_request: DefaultModel,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
update_default_provider(provider_id=provider_id, db_session=db_session)
update_default_provider(
provider_id=default_model_request.provider_id,
model_name=default_model_request.model_name,
db_session=db_session,
)
@admin_router.post("/provider/{provider_id}/default-vision")
@admin_router.post("/default-vision")
def set_provider_as_default_vision(
provider_id: int,
vision_model: str | None = Query(
None, description="The default vision model to use"
),
default_model: DefaultModel,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
if vision_model is None:
raise HTTPException(status_code=404, detail="Vision model not provided")
update_default_vision_provider(
provider_id=provider_id, vision_model=vision_model, db_session=db_session
provider_id=default_model.provider_id,
vision_model=default_model.model_name,
db_session=db_session,
)
@@ -516,7 +536,7 @@ def get_auto_config(
def get_vision_capable_providers(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[VisionProviderResponse]:
) -> LLMProviderResponse[VisionProviderResponse]:
"""Return a list of LLM providers and their models that support image input"""
vision_models = fetch_existing_models(
db_session=db_session, flow_types=[LLMModelFlowType.VISION]
@@ -545,7 +565,13 @@ def get_vision_capable_providers(
]
logger.debug(f"Found {len(vision_provider_response)} vision-capable providers")
return vision_provider_response
return LLMProviderResponse[VisionProviderResponse].from_models(
providers=vision_provider_response,
default_vision=DefaultModel.from_model_config(
fetch_default_vision_model(db_session)
),
)
"""Endpoints for all"""
@@ -555,7 +581,7 @@ def get_vision_capable_providers(
def list_llm_provider_basics(
user: User = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderDescriptor]:
) -> LLMProviderResponse[LLMProviderDescriptor]:
"""Get LLM providers accessible to the current user.
Returns:
@@ -592,7 +618,15 @@ def list_llm_provider_basics(
f"Completed fetching {len(accessible_providers)} user-accessible providers in {duration:.2f} seconds"
)
return accessible_providers
return LLMProviderResponse[LLMProviderDescriptor].from_models(
providers=accessible_providers,
default_text=DefaultModel.from_model_config(
fetch_default_llm_model(db_session)
),
default_vision=DefaultModel.from_model_config(
fetch_default_vision_model(db_session)
),
)
def get_valid_model_names_for_persona(
@@ -635,7 +669,7 @@ def list_llm_providers_for_persona(
persona_id: int,
user: User = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderDescriptor]:
) -> LLMProviderResponse[LLMProviderDescriptor]:
"""Get LLM providers for a specific persona.
Returns providers that the user can access when using this persona:
@@ -682,7 +716,51 @@ def list_llm_providers_for_persona(
f"Completed fetching {len(llm_provider_list)} LLM providers for persona {persona_id} in {duration:.2f} seconds"
)
return llm_provider_list
# Get the default model and vision model for the persona
# TODO: Port persona's over to use ID
persona_default_provider = persona.llm_model_provider_override
persona_default_model = persona.llm_model_version_override
default_text_model = fetch_default_llm_model(db_session)
default_vision_model = fetch_default_vision_model(db_session)
# Build default_text and default_vision using persona overrides when available,
# falling back to the global defaults.
default_text = DefaultModel.from_model_config(default_text_model)
default_vision = DefaultModel.from_model_config(default_vision_model)
if persona_default_provider:
provider = fetch_existing_llm_provider(persona_default_provider, db_session)
if provider and can_user_access_llm_provider(
provider, user_group_ids, persona, is_admin=is_admin
):
if persona_default_model:
# Persona specifies both provider and model — use them directly
default_text = DefaultModel(
provider_id=provider.id,
model_name=persona_default_model,
)
else:
# Persona specifies only the provider — pick a visible (public) model,
# falling back to any model on this provider
visible_model = next(
(mc for mc in provider.model_configurations if mc.is_visible),
None,
)
fallback_model = visible_model or next(
iter(provider.model_configurations), None
)
if fallback_model:
default_text = DefaultModel(
provider_id=provider.id,
model_name=fallback_model.name,
)
return LLMProviderResponse[LLMProviderDescriptor].from_models(
providers=llm_provider_list,
default_text=default_text,
default_vision=default_vision,
)
@admin_router.get("/provider-contextual-cost")

View File

@@ -1,5 +1,9 @@
from __future__ import annotations
from typing import Any
from typing import Generic
from typing import TYPE_CHECKING
from typing import TypeVar
from pydantic import BaseModel
from pydantic import Field
@@ -21,50 +25,22 @@ if TYPE_CHECKING:
ModelConfiguration as ModelConfigurationModel,
)
# TODO: Clear this up on api refactor
# There is still logic that requires sending each providers default model name
# There is no logic that requires sending the providers default vision model name
# We only send for the one that is actually the default
def get_default_llm_model_name(llm_provider_model: "LLMProviderModel") -> str:
"""Find the default conversation model name for a provider.
Returns the model name if found, otherwise returns empty string.
"""
for model_config in llm_provider_model.model_configurations:
for flow in model_config.llm_model_flows:
if flow.is_default and flow.llm_model_flow_type == LLMModelFlowType.CHAT:
return model_config.name
return ""
def get_default_vision_model_name(llm_provider_model: "LLMProviderModel") -> str | None:
"""Find the default vision model name for a provider.
Returns the model name if found, otherwise returns None.
"""
for model_config in llm_provider_model.model_configurations:
for flow in model_config.llm_model_flows:
if flow.is_default and flow.llm_model_flow_type == LLMModelFlowType.VISION:
return model_config.name
return None
T = TypeVar("T", "LLMProviderDescriptor", "LLMProviderView", "VisionProviderResponse")
class TestLLMRequest(BaseModel):
# provider level
name: str | None = None
id: int | None = None
provider: str
model: str
api_key: str | None = None
api_base: str | None = None
api_version: str | None = None
custom_config: dict[str, str] | None = None
# model level
default_model_name: str
deployment_name: str | None = None
model_configurations: list["ModelConfigurationUpsertRequest"]
# if try and use the existing API/custom config key
api_key_changed: bool
custom_config_changed: bool
@@ -80,13 +56,10 @@ class LLMProviderDescriptor(BaseModel):
"""A descriptor for an LLM provider that can be safely viewed by
non-admin users. Used when giving a list of available LLMs."""
id: int
name: str
provider: str
provider_display_name: str # Human-friendly name like "Claude (Anthropic)"
default_model_name: str
is_default_provider: bool | None
is_default_vision_provider: bool | None
default_vision_model: str | None
model_configurations: list["ModelConfigurationView"]
@classmethod
@@ -99,22 +72,12 @@ class LLMProviderDescriptor(BaseModel):
)
provider = llm_provider_model.provider
default_model_name = get_default_llm_model_name(llm_provider_model)
default_vision_model = get_default_vision_model_name(llm_provider_model)
is_default_provider = bool(default_model_name)
is_default_vision_provider = default_vision_model is not None
default_model_name = default_model_name or llm_provider_model.default_model_name
return cls(
id=llm_provider_model.id,
name=llm_provider_model.name,
provider=provider,
provider_display_name=get_provider_display_name(provider),
default_model_name=default_model_name,
is_default_provider=is_default_provider,
is_default_vision_provider=is_default_vision_provider,
default_vision_model=default_vision_model,
model_configurations=filter_model_configurations(
llm_provider_model.model_configurations, provider
),
@@ -128,18 +91,17 @@ class LLMProvider(BaseModel):
api_base: str | None = None
api_version: str | None = None
custom_config: dict[str, str] | None = None
default_model_name: str
is_public: bool = True
is_auto_mode: bool = False
groups: list[int] = Field(default_factory=list)
personas: list[int] = Field(default_factory=list)
deployment_name: str | None = None
default_vision_model: str | None = None
class LLMProviderUpsertRequest(LLMProvider):
# should only be used for a "custom" provider
# for default providers, the built-in model names are used
id: int | None = None
api_key_changed: bool = False
custom_config_changed: bool = False
model_configurations: list["ModelConfigurationUpsertRequest"] = []
@@ -155,8 +117,6 @@ class LLMProviderView(LLMProvider):
"""Stripped down representation of LLMProvider for display / limited access info only"""
id: int
is_default_provider: bool | None = None
is_default_vision_provider: bool | None = None
model_configurations: list["ModelConfigurationView"]
@classmethod
@@ -178,14 +138,6 @@ class LLMProviderView(LLMProvider):
provider = llm_provider_model.provider
default_model_name = get_default_llm_model_name(llm_provider_model)
default_vision_model = get_default_vision_model_name(llm_provider_model)
is_default_provider = bool(default_model_name)
is_default_vision_provider = default_vision_model is not None
default_model_name = default_model_name or llm_provider_model.default_model_name
return cls(
id=llm_provider_model.id,
name=llm_provider_model.name,
@@ -198,10 +150,6 @@ class LLMProviderView(LLMProvider):
api_base=llm_provider_model.api_base,
api_version=llm_provider_model.api_version,
custom_config=llm_provider_model.custom_config,
default_model_name=default_model_name,
is_default_provider=is_default_provider,
is_default_vision_provider=is_default_vision_provider,
default_vision_model=default_vision_model,
is_public=llm_provider_model.is_public,
is_auto_mode=llm_provider_model.is_auto_mode,
groups=groups,
@@ -421,3 +369,38 @@ class OpenRouterFinalModelResponse(BaseModel):
int | None
) # From OpenRouter API context_length (may be missing for some models)
supports_image_input: bool
class DefaultModel(BaseModel):
provider_id: int
model_name: str
@classmethod
def from_model_config(
cls, model_config: ModelConfigurationModel | None
) -> DefaultModel | None:
if not model_config:
return None
return cls(
provider_id=model_config.llm_provider_id,
model_name=model_config.name,
)
class LLMProviderResponse(BaseModel, Generic[T]):
providers: list[T]
default_text: DefaultModel | None = None
default_vision: DefaultModel | None = None
@classmethod
def from_models(
cls,
providers: list[T],
default_text: DefaultModel | None = None,
default_vision: DefaultModel | None = None,
) -> LLMProviderResponse[T]:
return cls(
providers=providers,
default_text=default_text,
default_vision=default_vision,
)

View File

@@ -0,0 +1,27 @@
"""Per-tenant request counter metric.
Increments a counter on every request, labelled by tenant, so Grafana can
answer "which tenant is generating the most traffic?"
"""
from prometheus_client import Counter
from prometheus_fastapi_instrumentator.metrics import Info
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
_requests_by_tenant = Counter(
"onyx_api_requests_by_tenant_total",
"Total API requests by tenant",
["tenant_id", "method", "handler", "status"],
)
def per_tenant_request_callback(info: Info) -> None:
"""Increment per-tenant request counter for every request."""
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or "unknown"
_requests_by_tenant.labels(
tenant_id=tenant_id,
method=info.method,
handler=info.modified_handler,
status=info.modified_status,
).inc()

View File

@@ -32,6 +32,7 @@ from sqlalchemy.pool import QueuePool
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_ENDPOINT_CONTEXTVAR
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -72,7 +73,7 @@ _checkout_timeout_total = Counter(
_connections_held = Gauge(
"onyx_db_connections_held_by_endpoint",
"Number of DB connections currently held, by endpoint and engine",
["handler", "engine"],
["handler", "engine", "tenant_id"],
)
_hold_seconds = Histogram(
@@ -163,10 +164,14 @@ def _register_pool_events(engine: Engine, label: str) -> None:
conn_proxy: PoolProxiedConnection, # noqa: ARG001
) -> None:
handler = CURRENT_ENDPOINT_CONTEXTVAR.get() or "unknown"
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or "unknown"
conn_record.info["_metrics_endpoint"] = handler
conn_record.info["_metrics_tenant_id"] = tenant_id
conn_record.info["_metrics_checkout_time"] = time.monotonic()
_checkout_total.labels(engine=label).inc()
_connections_held.labels(handler=handler, engine=label).inc()
_connections_held.labels(
handler=handler, engine=label, tenant_id=tenant_id
).inc()
@event.listens_for(engine, "checkin")
def on_checkin(
@@ -174,9 +179,12 @@ def _register_pool_events(engine: Engine, label: str) -> None:
conn_record: ConnectionPoolEntry,
) -> None:
handler = conn_record.info.pop("_metrics_endpoint", "unknown")
tenant_id = conn_record.info.pop("_metrics_tenant_id", "unknown")
start = conn_record.info.pop("_metrics_checkout_time", None)
_checkin_total.labels(engine=label).inc()
_connections_held.labels(handler=handler, engine=label).dec()
_connections_held.labels(
handler=handler, engine=label, tenant_id=tenant_id
).dec()
if start is not None:
_hold_seconds.labels(handler=handler, engine=label).observe(
time.monotonic() - start
@@ -199,9 +207,12 @@ def _register_pool_events(engine: Engine, label: str) -> None:
# Defensively clean up the held-connections gauge in case checkin
# doesn't fire after invalidation (e.g. hard pool shutdown).
handler = conn_record.info.pop("_metrics_endpoint", None)
tenant_id = conn_record.info.pop("_metrics_tenant_id", "unknown")
start = conn_record.info.pop("_metrics_checkout_time", None)
if handler:
_connections_held.labels(handler=handler, engine=label).dec()
_connections_held.labels(
handler=handler, engine=label, tenant_id=tenant_id
).dec()
if start is not None:
_hold_seconds.labels(handler=handler or "unknown", engine=label).observe(
time.monotonic() - start

View File

@@ -11,9 +11,11 @@ SQLAlchemy connection pool metrics are registered separately via
"""
from prometheus_fastapi_instrumentator import Instrumentator
from prometheus_fastapi_instrumentator.metrics import default as default_metrics
from sqlalchemy.exc import TimeoutError as SATimeoutError
from starlette.applications import Starlette
from onyx.server.metrics.per_tenant import per_tenant_request_callback
from onyx.server.metrics.postgres_connection_pool import pool_timeout_handler
from onyx.server.metrics.slow_requests import slow_request_callback
@@ -59,6 +61,15 @@ def setup_prometheus_metrics(app: Starlette) -> None:
excluded_handlers=_EXCLUDED_HANDLERS,
)
# Explicitly create the default metrics (http_requests_total,
# http_request_duration_seconds, etc.) and add them first. The library
# skips creating defaults when ANY custom instrumentations are registered
# via .add(), so we must include them ourselves.
default_callback = default_metrics(latency_lowr_buckets=_LATENCY_BUCKETS)
if default_callback:
instrumentator.add(default_callback)
instrumentator.add(slow_request_callback)
instrumentator.add(per_tenant_request_callback)
instrumentator.instrument(app, latency_lowr_buckets=_LATENCY_BUCKETS).expose(app)

View File

@@ -19,6 +19,7 @@ class ApplicationStatus(str, Enum):
PAYMENT_REMINDER = "payment_reminder"
GRACE_PERIOD = "grace_period"
GATED_ACCESS = "gated_access"
SEAT_LIMIT_EXCEEDED = "seat_limit_exceeded"
class Notification(BaseModel):
@@ -82,6 +83,10 @@ class Settings(BaseModel):
# Default Assistant settings
disable_default_assistant: bool | None = False
# Seat usage - populated by license enforcement when seat limit is exceeded
seat_count: int | None = None
used_seats: int | None = None
# OpenSearch migration
opensearch_indexing_enabled: bool = False

View File

@@ -4,6 +4,7 @@ from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.app_configs import MANAGED_VESPA
from onyx.configs.app_configs import VESPA_NUM_ATTEMPTS_ON_STARTUP
@@ -24,6 +25,7 @@ from onyx.db.enums import EmbeddingPrecision
from onyx.db.index_attempt import cancel_indexing_attempts_past_model
from onyx.db.index_attempt import expire_index_attempts
from onyx.db.llm import fetch_default_llm_model
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.search_settings import get_active_search_settings
@@ -32,6 +34,9 @@ from onyx.db.search_settings import update_current_search_settings
from onyx.db.swap_index import check_and_perform_index_swap
from onyx.document_index.factory import get_all_document_indices
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.opensearch.client import OpenSearchClient
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
from onyx.document_index.opensearch.opensearch_document_index import set_cluster_state
from onyx.document_index.vespa.index import VespaIndex
from onyx.indexing.models import IndexingSetting
from onyx.key_value_store.factory import get_kv_store
@@ -250,14 +255,18 @@ def setup_postgres(db_session: Session) -> None:
logger.notice("Setting up default OpenAI LLM for dev.")
llm_model = GEN_AI_MODEL_VERSION or "gpt-4o-mini"
provider_name = "DevEnvPresetOpenAI"
existing = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
model_req = LLMProviderUpsertRequest(
name="DevEnvPresetOpenAI",
id=existing.id if existing else None,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key=GEN_AI_API_KEY,
api_base=None,
api_version=None,
custom_config=None,
default_model_name=llm_model,
is_public=True,
groups=[],
model_configurations=[
@@ -269,7 +278,9 @@ def setup_postgres(db_session: Session) -> None:
new_llm_provider = upsert_llm_provider(
llm_provider_upsert_request=model_req, db_session=db_session
)
update_default_provider(provider_id=new_llm_provider.id, db_session=db_session)
update_default_provider(
provider_id=new_llm_provider.id, model_name=llm_model, db_session=db_session
)
def update_default_multipass_indexing(db_session: Session) -> None:
@@ -311,7 +322,14 @@ def setup_multitenant_onyx() -> None:
logger.notice("DISABLE_VECTOR_DB is set — skipping multitenant Vespa setup.")
return
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
opensearch_client = OpenSearchClient()
if not wait_for_opensearch_with_timeout(client=opensearch_client):
raise RuntimeError("Failed to connect to OpenSearch.")
set_cluster_state(opensearch_client)
# For Managed Vespa, the schema is sent over via the Vespa Console manually.
# NOTE: Pretty sure this code is never hit in any production environment.
if not MANAGED_VESPA:
setup_vespa_multitenant(SUPPORTED_EMBEDDING_MODELS)

View File

@@ -120,7 +120,7 @@ def generate_intermediate_report(
custom_agent_prompt=None,
simple_chat_history=history,
reminder_message=reminder_message,
project_files=None,
context_files=None,
available_tokens=llm.config.max_input_tokens,
)
@@ -325,7 +325,7 @@ def run_research_agent_call(
custom_agent_prompt=None,
simple_chat_history=msg_history,
reminder_message=reminder_message,
project_files=None,
context_files=None,
available_tokens=llm.config.max_input_tokens,
)

View File

@@ -257,7 +257,7 @@ exceptiongroup==1.3.0
# via
# braintrust
# fastmcp
fastapi==0.128.0
fastapi==0.133.1
# via
# fastapi-limiter
# fastapi-users
@@ -809,7 +809,7 @@ pypandoc-binary==1.16.2
# via onyx
pyparsing==3.2.5
# via httplib2
pypdf==6.6.2
pypdf==6.7.3
# via
# onyx
# unstructured-client
@@ -1155,6 +1155,7 @@ typing-inspect==0.9.0
# via dataclasses-json
typing-inspection==0.4.2
# via
# fastapi
# mcp
# pydantic
# pydantic-settings
@@ -1216,7 +1217,7 @@ websockets==15.0.1
# via
# fastmcp
# google-genai
werkzeug==3.1.5
werkzeug==3.1.6
# via sendgrid
wrapt==1.17.3
# via

View File

@@ -125,7 +125,7 @@ executing==2.2.1
# via stack-data
faker==40.1.2
# via onyx
fastapi==0.128.0
fastapi==0.133.1
# via
# onyx
# onyx-devtools
@@ -619,6 +619,7 @@ typing-extensions==4.15.0
# typing-inspection
typing-inspection==0.4.2
# via
# fastapi
# mcp
# pydantic
# pydantic-settings

View File

@@ -90,7 +90,7 @@ docstring-parser==0.17.0
# via google-cloud-aiplatform
durationpy==0.10
# via kubernetes
fastapi==0.128.0
fastapi==0.133.1
# via onyx
fastavro==1.12.1
# via cohere
@@ -398,6 +398,7 @@ typing-extensions==4.15.0
# typing-inspection
typing-inspection==0.4.2
# via
# fastapi
# mcp
# pydantic
# pydantic-settings

View File

@@ -108,7 +108,7 @@ durationpy==0.10
# via kubernetes
einops==0.8.1
# via onyx
fastapi==0.128.0
fastapi==0.133.1
# via
# onyx
# sentry-sdk
@@ -525,6 +525,7 @@ typing-extensions==4.15.0
# typing-inspection
typing-inspection==0.4.2
# via
# fastapi
# mcp
# pydantic
# pydantic-settings

View File

@@ -12,6 +12,7 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import ImageSection
from onyx.connectors.sharepoint.connector import SharepointAuthMethod
from onyx.connectors.sharepoint.connector import SharepointConnector
from onyx.db.enums import HierarchyNodeType
from tests.daily.connectors.utils import load_all_from_connector
@@ -521,3 +522,46 @@ def test_sharepoint_connector_hierarchy_nodes(
f"Document {doc.semantic_identifier} should have "
"parent_hierarchy_raw_node_id set"
)
@pytest.fixture
def sharepoint_cert_credentials() -> dict[str, str]:
return {
"authentication_method": SharepointAuthMethod.CERTIFICATE.value,
"sp_client_id": os.environ["PERM_SYNC_SHAREPOINT_CLIENT_ID"],
"sp_private_key": os.environ["PERM_SYNC_SHAREPOINT_PRIVATE_KEY"],
"sp_certificate_password": os.environ[
"PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD"
],
"sp_directory_id": os.environ["PERM_SYNC_SHAREPOINT_DIRECTORY_ID"],
}
def test_resolve_tenant_domain_from_site_urls(
sharepoint_cert_credentials: dict[str, str],
) -> None:
"""Verify that certificate auth resolves the tenant domain from site URLs
without calling the /organization endpoint."""
site_url = os.environ["SHAREPOINT_SITE"]
connector = SharepointConnector(sites=[site_url])
connector.load_credentials(sharepoint_cert_credentials)
assert connector.sp_tenant_domain is not None
assert len(connector.sp_tenant_domain) > 0
# The tenant domain should match the first label of the site URL hostname
from urllib.parse import urlsplit
expected = urlsplit(site_url).hostname.split(".")[0] # type: ignore
assert connector.sp_tenant_domain == expected
def test_resolve_tenant_domain_from_root_site(
sharepoint_cert_credentials: dict[str, str],
) -> None:
"""Verify that certificate auth resolves the tenant domain via the root
site endpoint when no site URLs are configured."""
connector = SharepointConnector(sites=[])
connector.load_credentials(sharepoint_cert_credentials)
assert connector.sp_tenant_domain is not None
assert len(connector.sp_tenant_domain) > 0

View File

@@ -17,7 +17,7 @@ def test_bedrock_llm_configuration(client: TestClient) -> None:
# Prepare the test request payload
test_request: dict[str, Any] = {
"provider": LlmProviderNames.BEDROCK,
"default_model_name": _DEFAULT_BEDROCK_MODEL,
"model": _DEFAULT_BEDROCK_MODEL,
"api_key": None,
"api_base": None,
"api_version": None,
@@ -44,7 +44,7 @@ def test_bedrock_llm_configuration_invalid_key(client: TestClient) -> None:
# Prepare the test request payload with invalid credentials
test_request: dict[str, Any] = {
"provider": LlmProviderNames.BEDROCK,
"default_model_name": _DEFAULT_BEDROCK_MODEL,
"model": _DEFAULT_BEDROCK_MODEL,
"api_key": None,
"api_base": None,
"api_version": None,

View File

@@ -28,7 +28,6 @@ def ensure_default_llm_provider(db_session: Session) -> None:
provider=LlmProviderNames.OPENAI,
api_key=os.environ.get("OPENAI_API_KEY", "test"),
is_public=True,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini",
@@ -41,7 +40,7 @@ def ensure_default_llm_provider(db_session: Session) -> None:
llm_provider_upsert_request=llm_provider_request,
db_session=db_session,
)
update_default_provider(provider.id, db_session)
update_default_provider(provider.id, "gpt-4o-mini", db_session)
except Exception as exc: # pragma: no cover - only hits on duplicate setup issues
# Rollback to clear the pending transaction state
db_session.rollback()

View File

@@ -47,7 +47,6 @@ def test_answer_with_only_anthropic_provider(
name=provider_name,
provider=LlmProviderNames.ANTHROPIC,
api_key=anthropic_api_key,
default_model_name=anthropic_model,
is_public=True,
groups=[],
model_configurations=[
@@ -59,7 +58,7 @@ def test_answer_with_only_anthropic_provider(
)
try:
update_default_provider(anthropic_provider.id, db_session)
update_default_provider(anthropic_provider.id, anthropic_model, db_session)
test_user = create_test_user(db_session, email_prefix="anthropic_only")
chat_session = create_chat_session(

View File

@@ -0,0 +1,544 @@
"""
External dependency unit tests for persona file sync.
Validates that:
1. The check_for_user_file_project_sync beat task picks up UserFiles with
needs_persona_sync=True (not just needs_project_sync).
2. The process_single_user_file_project_sync worker task reads persona
associations from the DB, passes persona_ids to the document index via
VespaDocumentUserFields, and clears needs_persona_sync afterwards.
3. upsert_persona correctly marks affected UserFiles with
needs_persona_sync=True when file associations change.
Uses real Redis and PostgreSQL. Document index (Vespa) calls are mocked
since we only need to verify the arguments passed to update_single.
"""
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
from unittest.mock import PropertyMock
from uuid import uuid4
from sqlalchemy.orm import Session
from onyx.background.celery.tasks.user_file_processing.tasks import (
check_for_user_file_project_sync,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
process_single_user_file_project_sync,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
user_file_project_sync_lock_key,
)
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.enums import UserFileStatus
from onyx.db.models import Persona
from onyx.db.models import Persona__UserFile
from onyx.db.models import User
from onyx.db.models import UserFile
from onyx.db.persona import upsert_persona
from onyx.document_index.interfaces import VespaDocumentUserFields
from onyx.redis.redis_pool import get_redis_client
from tests.external_dependency_unit.conftest import create_test_user
from tests.external_dependency_unit.constants import TEST_TENANT_ID
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _create_completed_user_file(
db_session: Session,
user: User,
needs_persona_sync: bool = False,
needs_project_sync: bool = False,
) -> UserFile:
"""Insert a UserFile in COMPLETED status."""
uf = UserFile(
id=uuid4(),
user_id=user.id,
file_id=f"test_file_{uuid4().hex[:8]}",
name=f"test_{uuid4().hex[:8]}.txt",
file_type="text/plain",
status=UserFileStatus.COMPLETED,
needs_persona_sync=needs_persona_sync,
needs_project_sync=needs_project_sync,
chunk_count=5,
)
db_session.add(uf)
db_session.commit()
db_session.refresh(uf)
return uf
def _create_test_persona(
db_session: Session,
user: User,
user_files: list[UserFile] | None = None,
) -> Persona:
"""Create a minimal Persona via direct model insert."""
persona = Persona(
name=f"Test Persona {uuid4().hex[:8]}",
description="Test persona",
num_chunks=10.0,
chunks_above=0,
chunks_below=0,
llm_relevance_filter=False,
llm_filter_extraction=False,
recency_bias=RecencyBiasSetting.NO_DECAY,
system_prompt="You are a test assistant",
task_prompt="Answer the question",
tools=[],
document_sets=[],
users=[user],
groups=[],
is_visible=True,
is_public=True,
display_priority=None,
starter_messages=None,
deleted=False,
user_files=user_files or [],
user_id=user.id,
)
db_session.add(persona)
db_session.commit()
db_session.refresh(persona)
return persona
def _link_file_to_persona(
db_session: Session, persona: Persona, user_file: UserFile
) -> None:
"""Create the join table row between a persona and a user file."""
link = Persona__UserFile(persona_id=persona.id, user_file_id=user_file.id)
db_session.add(link)
db_session.commit()
_PATCH_QUEUE_DEPTH = (
"onyx.background.celery.tasks.user_file_processing.tasks"
".get_user_file_project_sync_queue_depth"
)
@contextmanager
def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, None]:
"""Patch the ``app`` property on a bound Celery task."""
task_instance = task.run.__self__
with (
patch.object(
type(task_instance),
"app",
new_callable=PropertyMock,
return_value=mock_app,
),
patch(_PATCH_QUEUE_DEPTH, return_value=0),
):
yield
# ---------------------------------------------------------------------------
# Test: check_for_user_file_project_sync picks up persona sync
# ---------------------------------------------------------------------------
class TestCheckSweepIncludesPersonaSync:
"""The beat task must pick up files needing persona sync, not just project sync."""
def test_persona_sync_flag_enqueues_task(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""A file with needs_persona_sync=True (and COMPLETED) gets enqueued."""
user = create_test_user(db_session, "persona_sweep")
uf = _create_completed_user_file(db_session, user, needs_persona_sync=True)
mock_app = MagicMock()
with _patch_task_app(check_for_user_file_project_sync, mock_app):
check_for_user_file_project_sync.run(tenant_id=TEST_TENANT_ID)
enqueued_ids = {
call.kwargs["kwargs"]["user_file_id"]
for call in mock_app.send_task.call_args_list
}
assert str(uf.id) in enqueued_ids
def test_neither_flag_does_not_enqueue(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""A file with both flags False is not enqueued."""
user = create_test_user(db_session, "no_sync")
uf = _create_completed_user_file(db_session, user)
mock_app = MagicMock()
with _patch_task_app(check_for_user_file_project_sync, mock_app):
check_for_user_file_project_sync.run(tenant_id=TEST_TENANT_ID)
enqueued_ids = {
call.kwargs["kwargs"]["user_file_id"]
for call in mock_app.send_task.call_args_list
}
assert str(uf.id) not in enqueued_ids
def test_both_flags_enqueues_once(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""A file with BOTH flags True is enqueued exactly once."""
user = create_test_user(db_session, "both_flags")
uf = _create_completed_user_file(
db_session, user, needs_persona_sync=True, needs_project_sync=True
)
mock_app = MagicMock()
with _patch_task_app(check_for_user_file_project_sync, mock_app):
check_for_user_file_project_sync.run(tenant_id=TEST_TENANT_ID)
matching_calls = [
call
for call in mock_app.send_task.call_args_list
if call.kwargs["kwargs"]["user_file_id"] == str(uf.id)
]
assert len(matching_calls) == 1
# ---------------------------------------------------------------------------
# Test: process_single_user_file_project_sync passes persona_ids to index
# ---------------------------------------------------------------------------
_PATCH_GET_SETTINGS = (
"onyx.background.celery.tasks.user_file_processing.tasks.get_active_search_settings"
)
_PATCH_GET_INDICES = (
"onyx.background.celery.tasks.user_file_processing.tasks.get_all_document_indices"
)
_PATCH_HTTPX_INIT = (
"onyx.background.celery.tasks.user_file_processing.tasks.httpx_init_vespa_pool"
)
_PATCH_DISABLE_VDB = (
"onyx.background.celery.tasks.user_file_processing.tasks.DISABLE_VECTOR_DB"
)
class TestSyncTaskWritesPersonaIds:
"""The sync task reads persona associations and sends them to the index."""
def test_passes_persona_ids_to_update_single(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""After linking a file to a persona, sync sends the persona ID."""
user = create_test_user(db_session, "sync_persona")
uf = _create_completed_user_file(db_session, user, needs_persona_sync=True)
persona = _create_test_persona(db_session, user)
_link_file_to_persona(db_session, persona, uf)
mock_doc_index = MagicMock()
mock_search_settings = MagicMock()
mock_search_settings.primary = MagicMock()
mock_search_settings.secondary = None
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
lock_key = user_file_project_sync_lock_key(str(uf.id))
redis_client.delete(lock_key)
with (
patch(_PATCH_DISABLE_VDB, False),
patch(_PATCH_HTTPX_INIT),
patch(_PATCH_GET_SETTINGS, return_value=mock_search_settings),
patch(_PATCH_GET_INDICES, return_value=[mock_doc_index]),
):
process_single_user_file_project_sync.run(
user_file_id=str(uf.id), tenant_id=TEST_TENANT_ID
)
mock_doc_index.update_single.assert_called_once()
call_args = mock_doc_index.update_single.call_args
user_fields: VespaDocumentUserFields = call_args.kwargs["user_fields"]
assert user_fields.personas is not None
assert persona.id in user_fields.personas
assert call_args.args[0] == str(uf.id)
def test_clears_persona_sync_flag(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""After a successful sync the needs_persona_sync flag is cleared."""
user = create_test_user(db_session, "sync_clear")
uf = _create_completed_user_file(db_session, user, needs_persona_sync=True)
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
lock_key = user_file_project_sync_lock_key(str(uf.id))
redis_client.delete(lock_key)
with patch(_PATCH_DISABLE_VDB, True):
process_single_user_file_project_sync.run(
user_file_id=str(uf.id), tenant_id=TEST_TENANT_ID
)
db_session.refresh(uf)
assert uf.needs_persona_sync is False
def test_passes_both_project_and_persona_ids(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""A file linked to both a project and a persona gets both IDs."""
from onyx.db.models import Project__UserFile
from onyx.db.models import UserProject
user = create_test_user(db_session, "sync_both")
uf = _create_completed_user_file(
db_session, user, needs_persona_sync=True, needs_project_sync=True
)
persona = _create_test_persona(db_session, user)
_link_file_to_persona(db_session, persona, uf)
project = UserProject(user_id=user.id, name="test-project", instructions="")
db_session.add(project)
db_session.commit()
db_session.refresh(project)
link = Project__UserFile(project_id=project.id, user_file_id=uf.id)
db_session.add(link)
db_session.commit()
mock_doc_index = MagicMock()
mock_search_settings = MagicMock()
mock_search_settings.primary = MagicMock()
mock_search_settings.secondary = None
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
lock_key = user_file_project_sync_lock_key(str(uf.id))
redis_client.delete(lock_key)
with (
patch(_PATCH_DISABLE_VDB, False),
patch(_PATCH_HTTPX_INIT),
patch(_PATCH_GET_SETTINGS, return_value=mock_search_settings),
patch(_PATCH_GET_INDICES, return_value=[mock_doc_index]),
):
process_single_user_file_project_sync.run(
user_file_id=str(uf.id), tenant_id=TEST_TENANT_ID
)
call_kwargs = mock_doc_index.update_single.call_args.kwargs
user_fields: VespaDocumentUserFields = call_kwargs["user_fields"]
assert user_fields.personas is not None
assert user_fields.user_projects is not None
assert persona.id in user_fields.personas
assert project.id in user_fields.user_projects
# Both flags should be cleared
db_session.refresh(uf)
assert uf.needs_persona_sync is False
assert uf.needs_project_sync is False
def test_deleted_persona_excluded_from_ids(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""A soft-deleted persona should NOT appear in the persona_ids sent to Vespa."""
user = create_test_user(db_session, "sync_deleted")
uf = _create_completed_user_file(db_session, user, needs_persona_sync=True)
persona = _create_test_persona(db_session, user)
_link_file_to_persona(db_session, persona, uf)
persona.deleted = True
db_session.commit()
mock_doc_index = MagicMock()
mock_search_settings = MagicMock()
mock_search_settings.primary = MagicMock()
mock_search_settings.secondary = None
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
lock_key = user_file_project_sync_lock_key(str(uf.id))
redis_client.delete(lock_key)
with (
patch(_PATCH_DISABLE_VDB, False),
patch(_PATCH_HTTPX_INIT),
patch(_PATCH_GET_SETTINGS, return_value=mock_search_settings),
patch(_PATCH_GET_INDICES, return_value=[mock_doc_index]),
):
process_single_user_file_project_sync.run(
user_file_id=str(uf.id), tenant_id=TEST_TENANT_ID
)
call_kwargs = mock_doc_index.update_single.call_args.kwargs
user_fields: VespaDocumentUserFields = call_kwargs["user_fields"]
assert user_fields.personas is not None
assert persona.id not in user_fields.personas
# ---------------------------------------------------------------------------
# Test: upsert_persona marks files for persona sync
# ---------------------------------------------------------------------------
class TestUpsertPersonaMarksSyncFlag:
"""upsert_persona must set needs_persona_sync on affected UserFiles."""
def test_creating_persona_with_files_marks_sync(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
user = create_test_user(db_session, "upsert_create")
uf = _create_completed_user_file(db_session, user)
assert uf.needs_persona_sync is False
upsert_persona(
user=user,
name=f"persona-{uuid4().hex[:8]}",
description="test",
num_chunks=10.0,
llm_relevance_filter=False,
llm_filter_extraction=False,
recency_bias=RecencyBiasSetting.NO_DECAY,
llm_model_provider_override=None,
llm_model_version_override=None,
starter_messages=None,
system_prompt="test",
task_prompt="test",
datetime_aware=None,
is_public=True,
db_session=db_session,
user_file_ids=[uf.id],
)
db_session.refresh(uf)
assert uf.needs_persona_sync is True
def test_updating_persona_files_marks_both_old_and_new(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""When file associations change, both the removed and added files are flagged."""
user = create_test_user(db_session, "upsert_update")
uf_old = _create_completed_user_file(db_session, user)
uf_new = _create_completed_user_file(db_session, user)
persona = upsert_persona(
user=user,
name=f"persona-{uuid4().hex[:8]}",
description="test",
num_chunks=10.0,
llm_relevance_filter=False,
llm_filter_extraction=False,
recency_bias=RecencyBiasSetting.NO_DECAY,
llm_model_provider_override=None,
llm_model_version_override=None,
starter_messages=None,
system_prompt="test",
task_prompt="test",
datetime_aware=None,
is_public=True,
db_session=db_session,
user_file_ids=[uf_old.id],
)
# Clear the flag from creation so we can observe the update
uf_old.needs_persona_sync = False
db_session.commit()
assert persona.num_chunks is not None
# Now update the persona to swap files
upsert_persona(
user=user,
name=persona.name,
description=persona.description,
num_chunks=persona.num_chunks,
llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction,
recency_bias=persona.recency_bias,
llm_model_provider_override=None,
llm_model_version_override=None,
starter_messages=None,
system_prompt=persona.system_prompt,
task_prompt=persona.task_prompt,
datetime_aware=None,
is_public=persona.is_public,
db_session=db_session,
persona_id=persona.id,
user_file_ids=[uf_new.id],
)
db_session.refresh(uf_old)
db_session.refresh(uf_new)
assert uf_old.needs_persona_sync is True, "Removed file should be flagged"
assert uf_new.needs_persona_sync is True, "Added file should be flagged"
def test_removing_all_files_marks_old_files(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""Removing all files from a persona flags the previously associated files."""
user = create_test_user(db_session, "upsert_remove")
uf = _create_completed_user_file(db_session, user)
persona = upsert_persona(
user=user,
name=f"persona-{uuid4().hex[:8]}",
description="test",
num_chunks=10.0,
llm_relevance_filter=False,
llm_filter_extraction=False,
recency_bias=RecencyBiasSetting.NO_DECAY,
llm_model_provider_override=None,
llm_model_version_override=None,
starter_messages=None,
system_prompt="test",
task_prompt="test",
datetime_aware=None,
is_public=True,
db_session=db_session,
user_file_ids=[uf.id],
)
uf.needs_persona_sync = False
db_session.commit()
assert persona.num_chunks is not None
upsert_persona(
user=user,
name=persona.name,
description=persona.description,
num_chunks=persona.num_chunks,
llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction,
recency_bias=persona.recency_bias,
llm_model_provider_override=None,
llm_model_version_override=None,
starter_messages=None,
system_prompt=persona.system_prompt,
task_prompt=persona.task_prompt,
datetime_aware=None,
is_public=persona.is_public,
db_session=db_session,
persona_id=persona.id,
user_file_ids=[],
)
db_session.refresh(uf)
assert uf.needs_persona_sync is True

View File

@@ -0,0 +1,318 @@
"""
External dependency unit tests for UserFileIndexingAdapter metadata writing.
Validates that build_metadata_aware_chunks produces DocMetadataAwareIndexChunk
objects with both `user_project` and `personas` fields populated correctly
based on actual DB associations.
Uses real PostgreSQL for UserFile/Persona/UserProject rows.
Mocks the LLM tokenizer and file store since they are not relevant here.
"""
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
from sqlalchemy.orm import Session
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.enums import UserFileStatus
from onyx.db.models import Persona
from onyx.db.models import Persona__UserFile
from onyx.db.models import Project__UserFile
from onyx.db.models import User
from onyx.db.models import UserFile
from onyx.db.models import UserProject
from onyx.indexing.adapters.user_file_indexing_adapter import UserFileIndexingAdapter
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
from onyx.indexing.models import ChunkEmbedding
from onyx.indexing.models import IndexChunk
from tests.external_dependency_unit.conftest import create_test_user
from tests.external_dependency_unit.constants import TEST_TENANT_ID
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _create_user_file(db_session: Session, user: User) -> UserFile:
uf = UserFile(
id=uuid4(),
user_id=user.id,
file_id=f"test_file_{uuid4().hex[:8]}",
name=f"test_{uuid4().hex[:8]}.txt",
file_type="text/plain",
status=UserFileStatus.COMPLETED,
chunk_count=1,
)
db_session.add(uf)
db_session.commit()
db_session.refresh(uf)
return uf
def _create_persona(db_session: Session, user: User) -> Persona:
persona = Persona(
name=f"Test Persona {uuid4().hex[:8]}",
description="Test persona",
num_chunks=10.0,
chunks_above=0,
chunks_below=0,
llm_relevance_filter=False,
llm_filter_extraction=False,
recency_bias=RecencyBiasSetting.NO_DECAY,
system_prompt="test",
task_prompt="test",
tools=[],
document_sets=[],
users=[user],
groups=[],
is_visible=True,
is_public=True,
display_priority=None,
starter_messages=None,
deleted=False,
user_id=user.id,
)
db_session.add(persona)
db_session.commit()
db_session.refresh(persona)
return persona
def _create_project(db_session: Session, user: User) -> UserProject:
project = UserProject(
user_id=user.id,
name=f"project-{uuid4().hex[:8]}",
instructions="",
)
db_session.add(project)
db_session.commit()
db_session.refresh(project)
return project
def _make_index_chunk(user_file: UserFile) -> IndexChunk:
"""Build a minimal IndexChunk whose source document ID matches the UserFile."""
doc = Document(
id=str(user_file.id),
source=DocumentSource.USER_FILE,
semantic_identifier=user_file.name,
sections=[TextSection(text="test chunk content", link=None)],
metadata={},
)
return IndexChunk(
source_document=doc,
chunk_id=0,
blurb="test chunk",
content="test chunk content",
source_links={0: ""},
image_file_id=None,
section_continuation=False,
title_prefix="",
metadata_suffix_semantic="",
metadata_suffix_keyword="",
contextual_rag_reserved_tokens=0,
doc_summary="",
chunk_context="",
mini_chunk_texts=None,
large_chunk_id=None,
embeddings=ChunkEmbedding(
full_embedding=[0.0] * 768,
mini_chunk_embeddings=[],
),
title_embedding=None,
)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestAdapterWritesBothMetadataFields:
"""build_metadata_aware_chunks must populate user_project AND personas."""
@patch(
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
side_effect=Exception("no LLM in test"),
)
def test_file_linked_to_persona_gets_persona_id(
self,
_mock_llm: MagicMock,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
user = create_test_user(db_session, "adapter_persona")
uf = _create_user_file(db_session, user)
persona = _create_persona(db_session, user)
db_session.add(Persona__UserFile(persona_id=persona.id, user_file_id=uf.id))
db_session.commit()
adapter = UserFileIndexingAdapter(
tenant_id=TEST_TENANT_ID, db_session=db_session
)
chunk = _make_index_chunk(uf)
doc = chunk.source_document
context = DocumentBatchPrepareContext(updatable_docs=[doc], id_to_boost_map={})
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
context=context,
)
assert len(result.chunks) == 1
aware_chunk = result.chunks[0]
assert persona.id in aware_chunk.personas
assert aware_chunk.user_project == []
@patch(
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
side_effect=Exception("no LLM in test"),
)
def test_file_linked_to_project_gets_project_id(
self,
_mock_llm: MagicMock,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
user = create_test_user(db_session, "adapter_project")
uf = _create_user_file(db_session, user)
project = _create_project(db_session, user)
db_session.add(Project__UserFile(project_id=project.id, user_file_id=uf.id))
db_session.commit()
adapter = UserFileIndexingAdapter(
tenant_id=TEST_TENANT_ID, db_session=db_session
)
chunk = _make_index_chunk(uf)
context = DocumentBatchPrepareContext(
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
context=context,
)
assert len(result.chunks) == 1
aware_chunk = result.chunks[0]
assert project.id in aware_chunk.user_project
assert aware_chunk.personas == []
@patch(
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
side_effect=Exception("no LLM in test"),
)
def test_file_linked_to_both_gets_both_ids(
self,
_mock_llm: MagicMock,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
user = create_test_user(db_session, "adapter_both")
uf = _create_user_file(db_session, user)
persona = _create_persona(db_session, user)
project = _create_project(db_session, user)
db_session.add(Persona__UserFile(persona_id=persona.id, user_file_id=uf.id))
db_session.add(Project__UserFile(project_id=project.id, user_file_id=uf.id))
db_session.commit()
adapter = UserFileIndexingAdapter(
tenant_id=TEST_TENANT_ID, db_session=db_session
)
chunk = _make_index_chunk(uf)
context = DocumentBatchPrepareContext(
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
context=context,
)
aware_chunk = result.chunks[0]
assert persona.id in aware_chunk.personas
assert project.id in aware_chunk.user_project
@patch(
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
side_effect=Exception("no LLM in test"),
)
def test_file_with_no_associations_gets_empty_lists(
self,
_mock_llm: MagicMock,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
user = create_test_user(db_session, "adapter_empty")
uf = _create_user_file(db_session, user)
adapter = UserFileIndexingAdapter(
tenant_id=TEST_TENANT_ID, db_session=db_session
)
chunk = _make_index_chunk(uf)
context = DocumentBatchPrepareContext(
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
context=context,
)
aware_chunk = result.chunks[0]
assert aware_chunk.personas == []
assert aware_chunk.user_project == []
@patch(
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
side_effect=Exception("no LLM in test"),
)
def test_multiple_personas_all_appear(
self,
_mock_llm: MagicMock,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""A file linked to multiple personas should have all their IDs."""
user = create_test_user(db_session, "adapter_multi")
uf = _create_user_file(db_session, user)
persona_a = _create_persona(db_session, user)
persona_b = _create_persona(db_session, user)
db_session.add(Persona__UserFile(persona_id=persona_a.id, user_file_id=uf.id))
db_session.add(Persona__UserFile(persona_id=persona_b.id, user_file_id=uf.id))
db_session.commit()
adapter = UserFileIndexingAdapter(
tenant_id=TEST_TENANT_ID, db_session=db_session
)
chunk = _make_index_chunk(uf)
context = DocumentBatchPrepareContext(
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
context=context,
)
aware_chunk = result.chunks[0]
assert set(aware_chunk.personas) == {persona_a.id, persona_b.id}

View File

@@ -29,6 +29,7 @@ from onyx.server.manage.llm.api import (
test_llm_configuration as run_test_llm_configuration,
)
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from onyx.server.manage.llm.models import TestLLMRequest as LLMTestRequest
@@ -44,15 +45,14 @@ def _create_test_provider(
db_session: Session,
name: str,
api_key: str = "sk-test-key-00000000000000000000000000000000000",
) -> None:
) -> LLMProviderView:
"""Helper to create a test LLM provider in the database."""
upsert_llm_provider(
return upsert_llm_provider(
LLMProviderUpsertRequest(
name=name,
provider=LlmProviderNames.OPENAI,
api_key=api_key,
api_key_changed=True,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(name="gpt-4o-mini", is_visible=True)
],
@@ -102,17 +102,11 @@ class TestLLMConfigurationEndpoint:
# This should complete without exception
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=None, # New provider (not in DB)
provider=LlmProviderNames.OPENAI,
api_key="sk-new-test-key-0000000000000000000000000000",
api_key_changed=True,
custom_config_changed=False,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
model="gpt-4o-mini",
),
_=_create_mock_admin(),
db_session=db_session,
@@ -152,17 +146,11 @@ class TestLLMConfigurationEndpoint:
with pytest.raises(HTTPException) as exc_info:
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=None,
provider=LlmProviderNames.OPENAI,
api_key="sk-invalid-key-00000000000000000000000000",
api_key_changed=True,
custom_config_changed=False,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
model="gpt-4o-mini",
),
_=_create_mock_admin(),
db_session=db_session,
@@ -194,7 +182,9 @@ class TestLLMConfigurationEndpoint:
try:
# First, create the provider in the database
_create_test_provider(db_session, provider_name, api_key=original_api_key)
provider = _create_test_provider(
db_session, provider_name, api_key=original_api_key
)
with patch(
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
@@ -202,17 +192,12 @@ class TestLLMConfigurationEndpoint:
# Test with api_key_changed=False - should use stored key
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=provider_name, # Existing provider
id=provider.id,
provider=LlmProviderNames.OPENAI,
api_key=None, # Not providing a new key
api_key_changed=False, # Using existing key
custom_config_changed=False,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
model="gpt-4o-mini",
),
_=_create_mock_admin(),
db_session=db_session,
@@ -246,7 +231,9 @@ class TestLLMConfigurationEndpoint:
try:
# First, create the provider in the database
_create_test_provider(db_session, provider_name, api_key=original_api_key)
provider = _create_test_provider(
db_session, provider_name, api_key=original_api_key
)
with patch(
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
@@ -254,17 +241,12 @@ class TestLLMConfigurationEndpoint:
# Test with api_key_changed=True - should use new key
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=provider_name, # Existing provider
id=provider.id,
provider=LlmProviderNames.OPENAI,
api_key=new_api_key, # Providing a new key
api_key_changed=True, # Key is being changed
custom_config_changed=False,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
model="gpt-4o-mini",
),
_=_create_mock_admin(),
db_session=db_session,
@@ -297,7 +279,7 @@ class TestLLMConfigurationEndpoint:
try:
# First, create the provider in the database with custom_config
upsert_llm_provider(
provider = upsert_llm_provider(
LLMProviderUpsertRequest(
name=provider_name,
provider=LlmProviderNames.OPENAI,
@@ -305,7 +287,6 @@ class TestLLMConfigurationEndpoint:
api_key_changed=True,
custom_config=original_custom_config,
custom_config_changed=True,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
@@ -321,18 +302,13 @@ class TestLLMConfigurationEndpoint:
# Test with custom_config_changed=False - should use stored config
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=provider_name,
id=provider.id,
provider=LlmProviderNames.OPENAI,
api_key=None,
api_key_changed=False,
custom_config=None, # Not providing new config
custom_config_changed=False, # Using existing config
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
model="gpt-4o-mini",
),
_=_create_mock_admin(),
db_session=db_session,
@@ -368,17 +344,11 @@ class TestLLMConfigurationEndpoint:
for model_name in test_models:
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=None,
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
custom_config_changed=False,
default_model_name=model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=model_name, is_visible=True
)
],
model=model_name,
),
_=_create_mock_admin(),
db_session=db_session,
@@ -442,7 +412,6 @@ class TestDefaultProviderEndpoint:
provider=LlmProviderNames.OPENAI,
api_key=provider_1_api_key,
api_key_changed=True,
default_model_name=provider_1_initial_model,
model_configurations=[
ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True),
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
@@ -452,7 +421,7 @@ class TestDefaultProviderEndpoint:
)
# Set provider 1 as the default provider explicitly
update_default_provider(provider_1.id, db_session)
update_default_provider(provider_1.id, provider_1_initial_model, db_session)
# Step 2: Call run_test_default_provider - should use provider 1's default model
with patch(
@@ -472,7 +441,6 @@ class TestDefaultProviderEndpoint:
provider=LlmProviderNames.OPENAI,
api_key=provider_2_api_key,
api_key_changed=True,
default_model_name=provider_2_default_model,
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
@@ -499,11 +467,11 @@ class TestDefaultProviderEndpoint:
# Step 5: Update provider 1's default model
upsert_llm_provider(
LLMProviderUpsertRequest(
id=provider_1.id,
name=provider_1_name,
provider=LlmProviderNames.OPENAI,
api_key=provider_1_api_key,
api_key_changed=True,
default_model_name=provider_1_updated_model, # Changed
model_configurations=[
ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True),
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
@@ -512,6 +480,9 @@ class TestDefaultProviderEndpoint:
db_session=db_session,
)
# Set provider 1's default model to the updated model
update_default_provider(provider_1.id, provider_1_updated_model, db_session)
# Step 6: Call run_test_default_provider - should use new model on provider 1
with patch(
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
@@ -524,7 +495,7 @@ class TestDefaultProviderEndpoint:
captured_llms.clear()
# Step 7: Change the default provider to provider 2
update_default_provider(provider_2.id, db_session)
update_default_provider(provider_2.id, provider_2_default_model, db_session)
# Step 8: Call run_test_default_provider - should use provider 2
with patch(
@@ -596,7 +567,6 @@ class TestDefaultProviderEndpoint:
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
@@ -605,7 +575,7 @@ class TestDefaultProviderEndpoint:
),
db_session=db_session,
)
update_default_provider(provider.id, db_session)
update_default_provider(provider.id, "gpt-4o-mini", db_session)
# Test should fail
with patch(

View File

@@ -49,7 +49,6 @@ def _create_test_provider(
api_key_changed=True,
api_base=api_base,
custom_config=custom_config,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(name="gpt-4o-mini", is_visible=True)
],
@@ -91,14 +90,14 @@ class TestLLMProviderChanges:
the API key should be blocked.
"""
try:
_create_test_provider(db_session, provider_name)
provider = _create_test_provider(db_session, provider_name)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_base="https://attacker.example.com",
default_model_name="gpt-4o-mini",
)
with pytest.raises(HTTPException) as exc_info:
@@ -125,16 +124,16 @@ class TestLLMProviderChanges:
Changing api_base IS allowed when the API key is also being changed.
"""
try:
_create_test_provider(db_session, provider_name)
provider = _create_test_provider(db_session, provider_name)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-new-key-00000000000000000000000000000000000",
api_key_changed=True,
api_base="https://custom-endpoint.example.com/v1",
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -159,14 +158,16 @@ class TestLLMProviderChanges:
original_api_base = "https://original.example.com/v1"
try:
_create_test_provider(db_session, provider_name, api_base=original_api_base)
provider = _create_test_provider(
db_session, provider_name, api_base=original_api_base
)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_base=original_api_base,
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -190,14 +191,14 @@ class TestLLMProviderChanges:
changes. This allows model-only updates when provider has no custom base URL.
"""
try:
_create_test_provider(db_session, provider_name, api_base=None)
view = _create_test_provider(db_session, provider_name, api_base=None)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=view.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_base="",
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -223,14 +224,16 @@ class TestLLMProviderChanges:
original_api_base = "https://original.example.com/v1"
try:
_create_test_provider(db_session, provider_name, api_base=original_api_base)
provider = _create_test_provider(
db_session, provider_name, api_base=original_api_base
)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_base=None,
default_model_name="gpt-4o-mini",
)
with pytest.raises(HTTPException) as exc_info:
@@ -259,14 +262,14 @@ class TestLLMProviderChanges:
users have full control over their deployment.
"""
try:
_create_test_provider(db_session, provider_name)
provider = _create_test_provider(db_session, provider_name)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_base="https://custom.example.com/v1",
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -297,7 +300,6 @@ class TestLLMProviderChanges:
api_key="sk-new-key-00000000000000000000000000000000000",
api_key_changed=True,
api_base="https://custom.example.com/v1",
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -322,7 +324,7 @@ class TestLLMProviderChanges:
redirect LLM API requests).
"""
try:
_create_test_provider(
provider = _create_test_provider(
db_session,
provider_name,
custom_config={"SOME_CONFIG": "original_value"},
@@ -330,11 +332,11 @@ class TestLLMProviderChanges:
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
custom_config={"OPENAI_API_BASE": "https://attacker.example.com"},
custom_config_changed=True,
default_model_name="gpt-4o-mini",
)
with pytest.raises(HTTPException) as exc_info:
@@ -362,15 +364,15 @@ class TestLLMProviderChanges:
without changing the API key.
"""
try:
_create_test_provider(db_session, provider_name)
provider = _create_test_provider(db_session, provider_name)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
custom_config={"OPENAI_API_BASE": "https://attacker.example.com"},
custom_config_changed=True,
default_model_name="gpt-4o-mini",
)
with pytest.raises(HTTPException) as exc_info:
@@ -399,7 +401,7 @@ class TestLLMProviderChanges:
new_config = {"AWS_REGION_NAME": "us-west-2"}
try:
_create_test_provider(
provider = _create_test_provider(
db_session,
provider_name,
custom_config={"AWS_REGION_NAME": "us-east-1"},
@@ -407,13 +409,13 @@ class TestLLMProviderChanges:
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-new-key-00000000000000000000000000000000000",
api_key_changed=True,
custom_config_changed=True,
custom_config=new_config,
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -438,17 +440,17 @@ class TestLLMProviderChanges:
original_config = {"AWS_REGION_NAME": "us-east-1"}
try:
_create_test_provider(
provider = _create_test_provider(
db_session, provider_name, custom_config=original_config
)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
custom_config=original_config,
custom_config_changed=True,
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -474,7 +476,7 @@ class TestLLMProviderChanges:
new_config = {"AWS_REGION_NAME": "eu-west-1"}
try:
_create_test_provider(
provider = _create_test_provider(
db_session,
provider_name,
custom_config={"AWS_REGION_NAME": "us-east-1"},
@@ -482,10 +484,10 @@ class TestLLMProviderChanges:
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
custom_config=new_config,
default_model_name="gpt-4o-mini",
custom_config_changed=True,
)
@@ -530,14 +532,8 @@ def test_upload_with_custom_config_then_change(
with patch("onyx.server.manage.llm.api.test_llm", side_effect=capture_test_llm):
run_llm_config_test(
LLMTestRequest(
name=name,
provider=provider_name,
default_model_name=default_model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name, is_visible=True
)
],
model=default_model_name,
api_key_changed=False,
custom_config_changed=True,
custom_config=custom_config,
@@ -546,11 +542,10 @@ def test_upload_with_custom_config_then_change(
db_session=db_session,
)
put_llm_provider(
provider = put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=name,
provider=provider_name,
default_model_name=default_model_name,
custom_config=custom_config,
model_configurations=[
ModelConfigurationUpsertRequest(
@@ -569,14 +564,9 @@ def test_upload_with_custom_config_then_change(
# Turn auto mode off
run_llm_config_test(
LLMTestRequest(
name=name,
id=provider.id,
provider=provider_name,
default_model_name=default_model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name, is_visible=True
)
],
model=default_model_name,
api_key_changed=False,
custom_config_changed=False,
),
@@ -586,9 +576,9 @@ def test_upload_with_custom_config_then_change(
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
id=provider.id,
name=name,
provider=provider_name,
default_model_name=default_model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name, is_visible=True
@@ -616,13 +606,13 @@ def test_upload_with_custom_config_then_change(
)
# Check inside the database and check that custom_config is the same as the original
provider = fetch_existing_llm_provider(name=name, db_session=db_session)
if not provider:
db_provider = fetch_existing_llm_provider(name=name, db_session=db_session)
if not db_provider:
assert False, "Provider not found in the database"
assert provider.custom_config == custom_config, (
assert db_provider.custom_config == custom_config, (
f"Expected custom_config {custom_config}, "
f"but got {provider.custom_config}"
f"but got {db_provider.custom_config}"
)
finally:
db_session.rollback()
@@ -642,11 +632,10 @@ def test_preserves_masked_sensitive_custom_config_on_provider_update(
}
try:
put_llm_provider(
view = put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=name,
provider=provider,
default_model_name=default_model_name,
custom_config=original_custom_config,
model_configurations=[
ModelConfigurationUpsertRequest(
@@ -665,9 +654,9 @@ def test_preserves_masked_sensitive_custom_config_on_provider_update(
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
id=view.id,
name=name,
provider=provider,
default_model_name=default_model_name,
custom_config={
"vertex_credentials": _mask_string(
original_custom_config["vertex_credentials"]
@@ -706,7 +695,7 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
) -> None:
"""LLM test should restore masked sensitive custom config values before invocation."""
name = f"test-provider-vertex-test-{uuid4().hex[:8]}"
provider = LlmProviderNames.VERTEX_AI.value
provider_name = LlmProviderNames.VERTEX_AI.value
default_model_name = "gemini-2.5-pro"
original_custom_config = {
"vertex_credentials": '{"type":"service_account","private_key":"REAL_PRIVATE_KEY"}',
@@ -719,11 +708,10 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
return ""
try:
put_llm_provider(
provider = put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=name,
provider=provider,
default_model_name=default_model_name,
provider=provider_name,
custom_config=original_custom_config,
model_configurations=[
ModelConfigurationUpsertRequest(
@@ -742,14 +730,9 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
with patch("onyx.server.manage.llm.api.test_llm", side_effect=capture_test_llm):
run_llm_config_test(
LLMTestRequest(
name=name,
provider=provider,
default_model_name=default_model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name, is_visible=True
)
],
id=provider.id,
provider=provider_name,
model=default_model_name,
api_key_changed=False,
custom_config_changed=True,
custom_config={

View File

@@ -15,9 +15,11 @@ import pytest
from sqlalchemy.orm import Session
from onyx.db.enums import LLMModelFlowType
from onyx.db.llm import fetch_auto_mode_providers
from onyx.db.llm import fetch_default_llm_model
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_llm_provider_view
from onyx.db.llm import remove_llm_provider
from onyx.db.llm import sync_auto_mode_models
from onyx.db.llm import update_default_provider
@@ -135,7 +137,6 @@ class TestAutoModeSyncFeature:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
default_model_name=expected_default_model,
model_configurations=[], # No model configs provided
),
is_creation=True,
@@ -163,13 +164,8 @@ class TestAutoModeSyncFeature:
if mc.name in all_expected_models:
assert mc.is_visible is True, f"Model '{mc.name}' should be visible"
# Verify the default model was set correctly
assert (
provider.default_model_name == expected_default_model
), f"Default model should be '{expected_default_model}'"
# Step 4: Set the provider as default
update_default_provider(provider.id, db_session)
update_default_provider(provider.id, expected_default_model, db_session)
# Step 5: Fetch the default provider and verify
default_model = fetch_default_llm_model(db_session)
@@ -238,7 +234,6 @@ class TestAutoModeSyncFeature:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
default_model_name="gpt-4o",
model_configurations=[],
),
is_creation=True,
@@ -317,7 +312,6 @@ class TestAutoModeSyncFeature:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=False, # Not in auto mode initially
default_model_name="gpt-4",
model_configurations=initial_models,
),
is_creation=True,
@@ -326,13 +320,13 @@ class TestAutoModeSyncFeature:
)
# Verify initial state: all models are visible
provider = fetch_existing_llm_provider(
initial_provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
assert provider.is_auto_mode is False
assert initial_provider is not None
assert initial_provider.is_auto_mode is False
for mc in provider.model_configurations:
for mc in initial_provider.model_configurations:
assert (
mc.is_visible is True
), f"Initial model '{mc.name}' should be visible"
@@ -344,12 +338,12 @@ class TestAutoModeSyncFeature:
):
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
id=initial_provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key=None, # Not changing API key
api_key_changed=False,
is_auto_mode=True, # Now enabling auto mode
default_model_name=auto_mode_default,
model_configurations=[], # Auto mode will sync from config
),
is_creation=False, # This is an update
@@ -360,15 +354,15 @@ class TestAutoModeSyncFeature:
# Step 3: Verify model visibility after auto mode transition
# Expire session cache to force fresh fetch after sync_auto_mode_models committed
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
provider_view = fetch_llm_provider_view(
provider_name=provider_name, db_session=db_session
)
assert provider is not None
assert provider.is_auto_mode is True
assert provider_view is not None
assert provider_view.is_auto_mode is True
# Build a map of model name -> visibility
model_visibility = {
mc.name: mc.is_visible for mc in provider.model_configurations
mc.name: mc.is_visible for mc in provider_view.model_configurations
}
# Models in auto mode config should be visible
@@ -388,9 +382,6 @@ class TestAutoModeSyncFeature:
model_visibility[model_name] is False
), f"Model '{model_name}' not in auto config should NOT be visible"
# Verify the default model was updated
assert provider.default_model_name == auto_mode_default
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)
@@ -432,8 +423,12 @@ class TestAutoModeSyncFeature:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
default_model_name="gpt-4o",
model_configurations=[],
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o",
is_visible=True,
)
],
),
is_creation=True,
_=_create_mock_admin(),
@@ -535,7 +530,6 @@ class TestAutoModeSyncFeature:
api_key=provider_1_api_key,
api_key_changed=True,
is_auto_mode=True,
default_model_name=provider_1_default_model,
model_configurations=[],
),
is_creation=True,
@@ -549,7 +543,7 @@ class TestAutoModeSyncFeature:
name=provider_1_name, db_session=db_session
)
assert provider_1 is not None
update_default_provider(provider_1.id, db_session)
update_default_provider(provider_1.id, provider_1_default_model, db_session)
with patch(
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
@@ -563,7 +557,6 @@ class TestAutoModeSyncFeature:
api_key=provider_2_api_key,
api_key_changed=True,
is_auto_mode=True,
default_model_name=provider_2_default_model,
model_configurations=[],
),
is_creation=True,
@@ -584,7 +577,7 @@ class TestAutoModeSyncFeature:
name=provider_2_name, db_session=db_session
)
assert provider_2 is not None
update_default_provider(provider_2.id, db_session)
update_default_provider(provider_2.id, provider_2_default_model, db_session)
# Step 5: Verify provider 2 is now the default
db_session.expire_all()
@@ -644,7 +637,6 @@ class TestAutoModeMissingFlows:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
default_model_name="gpt-4o",
model_configurations=[],
),
is_creation=True,
@@ -701,3 +693,364 @@ class TestAutoModeMissingFlows:
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)
class TestAutoModeTransitionsAndResync:
"""Tests for auto/manual transitions, config evolution, and sync idempotency."""
def test_auto_to_manual_mode_preserves_models_and_stops_syncing(
self,
db_session: Session,
provider_name: str,
) -> None:
"""Disabling auto mode should preserve the current model list and
prevent future syncs from altering visibility.
Steps:
1. Create provider in auto mode — models synced from config.
2. Update provider to manual mode (is_auto_mode=False).
3. Verify all models remain with unchanged visibility.
4. Call sync_auto_mode_models with a *different* config.
5. Verify fetch_auto_mode_providers excludes this provider, so the
periodic task would never call sync on it.
"""
initial_config = _create_mock_llm_recommendations(
provider=LlmProviderNames.OPENAI,
default_model_name="gpt-4o",
additional_models=["gpt-4o-mini"],
)
try:
# Step 1: Create in auto mode
with patch(
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
return_value=initial_config,
):
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
model_configurations=[],
),
is_creation=True,
_=_create_mock_admin(),
db_session=db_session,
)
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
visibility_before = {
mc.name: mc.is_visible for mc in provider.model_configurations
}
assert visibility_before == {"gpt-4o": True, "gpt-4o-mini": True}
# Step 2: Switch to manual mode
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key=None,
api_key_changed=False,
is_auto_mode=False,
model_configurations=[
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
),
],
),
is_creation=False,
_=_create_mock_admin(),
db_session=db_session,
)
# Step 3: Models unchanged
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
assert provider.is_auto_mode is False
visibility_after = {
mc.name: mc.is_visible for mc in provider.model_configurations
}
assert visibility_after == visibility_before
# Step 4-5: Provider excluded from auto mode queries
auto_providers = fetch_auto_mode_providers(db_session)
auto_provider_ids = {p.id for p in auto_providers}
assert provider.id not in auto_provider_ids
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)
def test_resync_adds_new_and_hides_removed_models(
self,
db_session: Session,
provider_name: str,
) -> None:
"""When the GitHub config changes between syncs, a subsequent sync
should add newly listed models and hide models that were removed.
Steps:
1. Create provider in auto mode with config v1: [gpt-4o, gpt-4o-mini].
2. Sync with config v2: [gpt-4o, gpt-4-turbo] (gpt-4o-mini removed,
gpt-4-turbo added).
3. Verify gpt-4o still visible, gpt-4o-mini hidden, gpt-4-turbo added
and visible.
"""
config_v1 = _create_mock_llm_recommendations(
provider=LlmProviderNames.OPENAI,
default_model_name="gpt-4o",
additional_models=["gpt-4o-mini"],
)
config_v2 = _create_mock_llm_recommendations(
provider=LlmProviderNames.OPENAI,
default_model_name="gpt-4o",
additional_models=["gpt-4-turbo"],
)
try:
# Step 1: Create with config v1
with patch(
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
return_value=config_v1,
):
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
model_configurations=[],
),
is_creation=True,
_=_create_mock_admin(),
db_session=db_session,
)
# Step 2: Re-sync with config v2
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
changes = sync_auto_mode_models(
db_session=db_session,
provider=provider,
llm_recommendations=config_v2,
)
assert changes > 0
# Step 3: Verify
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
visibility = {
mc.name: mc.is_visible for mc in provider.model_configurations
}
# gpt-4o: still in config -> visible
assert visibility["gpt-4o"] is True
# gpt-4o-mini: removed from config -> hidden (not deleted)
assert "gpt-4o-mini" in visibility, "Removed model should still exist in DB"
assert visibility["gpt-4o-mini"] is False
# gpt-4-turbo: newly added -> visible
assert visibility["gpt-4-turbo"] is True
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)
def test_sync_is_idempotent(
self,
db_session: Session,
provider_name: str,
) -> None:
"""Running sync twice with the same config should produce zero
changes on the second call."""
config = _create_mock_llm_recommendations(
provider=LlmProviderNames.OPENAI,
default_model_name="gpt-4o",
additional_models=["gpt-4o-mini", "gpt-4-turbo"],
)
try:
with patch(
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
return_value=config,
):
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
model_configurations=[],
),
is_creation=True,
_=_create_mock_admin(),
db_session=db_session,
)
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
# First explicit sync (may report changes if creation already synced)
sync_auto_mode_models(
db_session=db_session,
provider=provider,
llm_recommendations=config,
)
# Snapshot state after first sync
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
snapshot = {
mc.name: (mc.is_visible, mc.display_name)
for mc in provider.model_configurations
}
# Second sync — should be a no-op
changes = sync_auto_mode_models(
db_session=db_session,
provider=provider,
llm_recommendations=config,
)
assert (
changes == 0
), f"Expected 0 changes on idempotent re-sync, got {changes}"
# State should be identical
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
current = {
mc.name: (mc.is_visible, mc.display_name)
for mc in provider.model_configurations
}
assert current == snapshot
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)
def test_default_model_hidden_when_removed_from_config(
self,
db_session: Session,
provider_name: str,
) -> None:
"""When the current default model is removed from the config, sync
should hide it. The default model flow row should still exist (it
points at the ModelConfiguration), but the model is no longer visible.
Steps:
1. Create provider with config: default=gpt-4o, additional=[gpt-4o-mini].
2. Set gpt-4o as the global default.
3. Re-sync with config: default=gpt-4o-mini (gpt-4o removed entirely).
4. Verify gpt-4o is hidden, gpt-4o-mini is visible, and
fetch_default_llm_model still returns a result (the flow row persists).
"""
config_v1 = _create_mock_llm_recommendations(
provider=LlmProviderNames.OPENAI,
default_model_name="gpt-4o",
additional_models=["gpt-4o-mini"],
)
config_v2 = _create_mock_llm_recommendations(
provider=LlmProviderNames.OPENAI,
default_model_name="gpt-4o-mini",
additional_models=[],
)
try:
with patch(
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
return_value=config_v1,
):
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
model_configurations=[],
),
is_creation=True,
_=_create_mock_admin(),
db_session=db_session,
)
# Step 2: Set gpt-4o as global default
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
update_default_provider(provider.id, "gpt-4o", db_session)
default_before = fetch_default_llm_model(db_session)
assert default_before is not None
assert default_before.name == "gpt-4o"
# Step 3: Re-sync with config v2 (gpt-4o removed)
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
changes = sync_auto_mode_models(
db_session=db_session,
provider=provider,
llm_recommendations=config_v2,
)
assert changes > 0
# Step 4: Verify visibility
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
visibility = {
mc.name: mc.is_visible for mc in provider.model_configurations
}
assert visibility["gpt-4o"] is False, "Removed default should be hidden"
assert visibility["gpt-4o-mini"] is True, "New default should be visible"
# The LLMModelFlow row for gpt-4o still exists (is_default=True),
# but the model is hidden. fetch_default_llm_model filters on
# is_visible=True, so it should NOT return gpt-4o.
db_session.expire_all()
default_after = fetch_default_llm_model(db_session)
assert (
default_after is None or default_after.name != "gpt-4o"
), "Hidden model should not be returned as the default"
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)

View File

@@ -64,7 +64,6 @@ def _create_provider(
name=name,
provider=provider,
api_key="sk-ant-api03-...",
default_model_name="claude-3-5-sonnet-20240620",
is_public=is_public,
model_configurations=[
ModelConfigurationUpsertRequest(
@@ -154,7 +153,9 @@ def test_user_sends_message_to_private_provider(
)
_create_provider(db_session, LlmProviderNames.GOOGLE, "private-provider", False)
update_default_provider(public_provider_id, db_session)
update_default_provider(
public_provider_id, "claude-3-5-sonnet-20240620", db_session
)
try:
# Create chat session

View File

@@ -1,4 +1,4 @@
"""External dependency unit tests for OpenSearchClient.
"""External dependency unit tests for OpenSearchIndexClient.
These tests assume OpenSearch is running and test all implemented methods
using real schemas, pipelines, and search queries from the codebase.
@@ -19,7 +19,7 @@ from onyx.access.utils import prefix_user_email
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import IndexFilters
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.client import OpenSearchClient
from onyx.document_index.opensearch.client import OpenSearchIndexClient
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
from onyx.document_index.opensearch.opensearch_document_index import (
@@ -125,10 +125,10 @@ def opensearch_available() -> None:
@pytest.fixture(scope="function")
def test_client(
opensearch_available: None, # noqa: ARG001
) -> Generator[OpenSearchClient, None, None]:
) -> Generator[OpenSearchIndexClient, None, None]:
"""Creates an OpenSearch client for testing with automatic cleanup."""
test_index_name = f"test_index_{uuid.uuid4().hex[:8]}"
client = OpenSearchClient(index_name=test_index_name)
client = OpenSearchIndexClient(index_name=test_index_name)
yield client # Test runs here.
@@ -142,7 +142,7 @@ def test_client(
@pytest.fixture(scope="function")
def search_pipeline(test_client: OpenSearchClient) -> Generator[None, None, None]:
def search_pipeline(test_client: OpenSearchIndexClient) -> Generator[None, None, None]:
"""Creates a search pipeline for testing with automatic cleanup."""
test_client.create_search_pipeline(
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
@@ -158,9 +158,9 @@ def search_pipeline(test_client: OpenSearchClient) -> Generator[None, None, None
class TestOpenSearchClient:
"""Tests for OpenSearchClient."""
"""Tests for OpenSearchIndexClient."""
def test_create_index(self, test_client: OpenSearchClient) -> None:
def test_create_index(self, test_client: OpenSearchIndexClient) -> None:
"""Tests creating an index with a real schema."""
# Precondition.
mappings = DocumentSchema.get_document_schema(
@@ -176,7 +176,7 @@ class TestOpenSearchClient:
# Verify index exists.
assert test_client.validate_index(expected_mappings=mappings) is True
def test_delete_existing_index(self, test_client: OpenSearchClient) -> None:
def test_delete_existing_index(self, test_client: OpenSearchIndexClient) -> None:
"""Tests deleting an existing index returns True."""
# Precondition.
mappings = DocumentSchema.get_document_schema(
@@ -193,7 +193,7 @@ class TestOpenSearchClient:
assert result is True
assert test_client.validate_index(expected_mappings=mappings) is False
def test_delete_nonexistent_index(self, test_client: OpenSearchClient) -> None:
def test_delete_nonexistent_index(self, test_client: OpenSearchIndexClient) -> None:
"""Tests deleting a nonexistent index returns False."""
# Under test.
# Don't create index, just try to delete.
@@ -202,7 +202,7 @@ class TestOpenSearchClient:
# Postcondition.
assert result is False
def test_index_exists(self, test_client: OpenSearchClient) -> None:
def test_index_exists(self, test_client: OpenSearchIndexClient) -> None:
"""Tests checking if an index exists."""
# Precondition.
# Index should not exist before creation.
@@ -219,7 +219,7 @@ class TestOpenSearchClient:
# Index should exist after creation.
assert test_client.index_exists() is True
def test_validate_index(self, test_client: OpenSearchClient) -> None:
def test_validate_index(self, test_client: OpenSearchIndexClient) -> None:
"""Tests validating an index."""
# Precondition.
mappings = DocumentSchema.get_document_schema(
@@ -239,7 +239,120 @@ class TestOpenSearchClient:
# Should return True after creation.
assert test_client.validate_index(expected_mappings=mappings) is True
def test_create_duplicate_index(self, test_client: OpenSearchClient) -> None:
def test_put_mapping_idempotent(self, test_client: OpenSearchIndexClient) -> None:
"""Tests put_mapping with same schema is idempotent."""
# Precondition.
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
)
settings = DocumentSchema.get_index_settings()
test_client.create_index(mappings=mappings, settings=settings)
# Under test.
# Applying the same mappings again should succeed.
test_client.put_mapping(mappings)
# Postcondition.
# Index should still be valid.
assert test_client.validate_index(expected_mappings=mappings)
def test_put_mapping_adds_new_field(
self, test_client: OpenSearchIndexClient
) -> None:
"""Tests put_mapping successfully adds new fields to existing index."""
# Precondition.
# Create index with minimal schema (just required fields).
initial_mappings = {
"dynamic": "strict",
"properties": {
"document_id": {"type": "keyword"},
"chunk_index": {"type": "integer"},
"content": {"type": "text"},
"content_vector": {
"type": "knn_vector",
"dimension": 128,
"method": {
"name": "hnsw",
"space_type": "cosinesimil",
"engine": "lucene",
"parameters": {"ef_construction": 512, "m": 16},
},
},
},
}
settings = DocumentSchema.get_index_settings()
test_client.create_index(mappings=initial_mappings, settings=settings)
# Under test.
# Add a new field using put_mapping.
updated_mappings = {
"properties": {
"document_id": {"type": "keyword"},
"chunk_index": {"type": "integer"},
"content": {"type": "text"},
"content_vector": {
"type": "knn_vector",
"dimension": 128,
"method": {
"name": "hnsw",
"space_type": "cosinesimil",
"engine": "lucene",
"parameters": {"ef_construction": 512, "m": 16},
},
},
# New field
"new_test_field": {"type": "keyword"},
},
}
# Should not raise.
test_client.put_mapping(updated_mappings)
# Postcondition.
# Validate the new schema includes the new field.
assert test_client.validate_index(expected_mappings=updated_mappings)
def test_put_mapping_fails_on_type_change(
self, test_client: OpenSearchIndexClient
) -> None:
"""Tests put_mapping fails when trying to change existing field type."""
# Precondition.
initial_mappings = {
"dynamic": "strict",
"properties": {
"document_id": {"type": "keyword"},
"test_field": {"type": "keyword"},
},
}
settings = DocumentSchema.get_index_settings()
test_client.create_index(mappings=initial_mappings, settings=settings)
# Under test and postcondition.
# Try to change test_field type from keyword to text.
conflicting_mappings = {
"properties": {
"document_id": {"type": "keyword"},
"test_field": {"type": "text"}, # Changed from keyword to text
},
}
# Should raise because field type cannot be changed.
with pytest.raises(Exception, match="mapper|illegal_argument_exception"):
test_client.put_mapping(conflicting_mappings)
def test_put_mapping_on_nonexistent_index(
self, test_client: OpenSearchIndexClient
) -> None:
"""Tests put_mapping on non-existent index raises an error."""
# Precondition.
# Index does not exist yet.
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
)
# Under test and postcondition.
with pytest.raises(Exception, match="index_not_found_exception|404"):
test_client.put_mapping(mappings)
def test_create_duplicate_index(self, test_client: OpenSearchIndexClient) -> None:
"""Tests creating an index twice raises an error."""
# Precondition.
mappings = DocumentSchema.get_document_schema(
@@ -254,14 +367,14 @@ class TestOpenSearchClient:
with pytest.raises(Exception, match="already exists"):
test_client.create_index(mappings=mappings, settings=settings)
def test_update_settings(self, test_client: OpenSearchClient) -> None:
def test_update_settings(self, test_client: OpenSearchIndexClient) -> None:
"""Tests that update_settings raises NotImplementedError."""
# Under test and postcondition.
with pytest.raises(NotImplementedError):
test_client.update_settings(settings={})
def test_create_and_delete_search_pipeline(
self, test_client: OpenSearchClient
self, test_client: OpenSearchIndexClient
) -> None:
"""Tests creating and deleting a search pipeline."""
# Under test and postcondition.
@@ -278,7 +391,7 @@ class TestOpenSearchClient:
)
def test_index_document(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests indexing a document."""
# Precondition.
@@ -306,7 +419,7 @@ class TestOpenSearchClient:
)
def test_bulk_index_documents(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests bulk indexing documents."""
# Precondition.
@@ -337,7 +450,7 @@ class TestOpenSearchClient:
)
def test_index_duplicate_document(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests indexing a duplicate document raises an error."""
# Precondition.
@@ -365,7 +478,7 @@ class TestOpenSearchClient:
test_client.index_document(document=doc, tenant_state=tenant_state)
def test_get_document(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests getting a document."""
# Precondition.
@@ -401,7 +514,7 @@ class TestOpenSearchClient:
assert retrieved_doc == original_doc
def test_get_nonexistent_document(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests getting a nonexistent document raises an error."""
# Precondition.
@@ -419,7 +532,7 @@ class TestOpenSearchClient:
)
def test_delete_existing_document(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests deleting an existing document returns True."""
# Precondition.
@@ -455,7 +568,7 @@ class TestOpenSearchClient:
test_client.get_document(document_chunk_id=doc_chunk_id)
def test_delete_nonexistent_document(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests deleting a nonexistent document returns False."""
# Precondition.
@@ -476,7 +589,7 @@ class TestOpenSearchClient:
assert result is False
def test_delete_by_query(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests deleting documents by query."""
# Precondition.
@@ -552,7 +665,7 @@ class TestOpenSearchClient:
assert len(keep_ids) == 1
def test_update_document(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests updating a document's properties."""
# Precondition.
@@ -601,7 +714,7 @@ class TestOpenSearchClient:
assert updated_doc.public == doc.public
def test_update_nonexistent_document(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests updating a nonexistent document raises an error."""
# Precondition.
@@ -623,7 +736,7 @@ class TestOpenSearchClient:
def test_hybrid_search_with_pipeline(
self,
test_client: OpenSearchClient,
test_client: OpenSearchIndexClient,
search_pipeline: None, # noqa: ARG002
monkeypatch: pytest.MonkeyPatch,
) -> None:
@@ -704,7 +817,7 @@ class TestOpenSearchClient:
def test_search_empty_index(
self,
test_client: OpenSearchClient,
test_client: OpenSearchIndexClient,
search_pipeline: None, # noqa: ARG002
monkeypatch: pytest.MonkeyPatch,
) -> None:
@@ -743,7 +856,7 @@ class TestOpenSearchClient:
def test_hybrid_search_with_pipeline_and_filters(
self,
test_client: OpenSearchClient,
test_client: OpenSearchIndexClient,
search_pipeline: None, # noqa: ARG002
monkeypatch: pytest.MonkeyPatch,
) -> None:
@@ -863,7 +976,7 @@ class TestOpenSearchClient:
def test_hybrid_search_with_pipeline_and_filters_returns_chunks_with_related_content_first(
self,
test_client: OpenSearchClient,
test_client: OpenSearchIndexClient,
search_pipeline: None, # noqa: ARG002
monkeypatch: pytest.MonkeyPatch,
) -> None:
@@ -993,7 +1106,7 @@ class TestOpenSearchClient:
previous_score = current_score
def test_delete_by_query_multitenant_isolation(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""
Tests delete_by_query respects tenant boundaries in multi-tenant mode.
@@ -1087,7 +1200,7 @@ class TestOpenSearchClient:
assert set(remaining_y_ids) == expected_y_ids
def test_delete_by_query_nonexistent_document(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""
Tests delete_by_query for non-existent document returns 0 deleted.
@@ -1116,7 +1229,7 @@ class TestOpenSearchClient:
assert num_deleted == 0
def test_search_for_document_ids(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests search_for_document_ids method returns correct chunk IDs."""
# Precondition.
@@ -1181,7 +1294,7 @@ class TestOpenSearchClient:
assert set(chunk_ids) == expected_ids
def test_search_with_no_document_access_can_retrieve_all_documents(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""
Tests search with no document access can retrieve all documents, even
@@ -1259,7 +1372,7 @@ class TestOpenSearchClient:
def test_time_cutoff_filter(
self,
test_client: OpenSearchClient,
test_client: OpenSearchIndexClient,
search_pipeline: None, # noqa: ARG002
monkeypatch: pytest.MonkeyPatch,
) -> None:
@@ -1352,7 +1465,7 @@ class TestOpenSearchClient:
)
def test_random_search(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests the random search query works."""
# Precondition.

View File

@@ -37,6 +37,7 @@ from onyx.db.opensearch_migration import build_sanitized_to_original_doc_id_mapp
from onyx.db.search_settings import get_active_search_settings
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.client import OpenSearchClient
from onyx.document_index.opensearch.client import OpenSearchIndexClient
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
from onyx.document_index.opensearch.schema import DocumentChunk
@@ -74,7 +75,7 @@ CHUNK_COUNT = 5
def _get_document_chunks_from_opensearch(
opensearch_client: OpenSearchClient, document_id: str, current_tenant_id: str
opensearch_client: OpenSearchIndexClient, document_id: str, current_tenant_id: str
) -> list[DocumentChunk]:
opensearch_client.refresh_index()
filters = IndexFilters(access_control_list=None, tenant_id=current_tenant_id)
@@ -95,7 +96,7 @@ def _get_document_chunks_from_opensearch(
def _delete_document_chunks_from_opensearch(
opensearch_client: OpenSearchClient, document_id: str, current_tenant_id: str
opensearch_client: OpenSearchIndexClient, document_id: str, current_tenant_id: str
) -> None:
opensearch_client.refresh_index()
query_body = DocumentQuery.delete_from_document_id_query(
@@ -283,10 +284,10 @@ def vespa_document_index(
def opensearch_client(
db_session: Session,
full_deployment_setup: None, # noqa: ARG001
) -> Generator[OpenSearchClient, None, None]:
) -> Generator[OpenSearchIndexClient, None, None]:
"""Creates an OpenSearch client for the test tenant."""
active = get_active_search_settings(db_session)
yield OpenSearchClient(index_name=active.primary.index_name) # Test runs here.
yield OpenSearchIndexClient(index_name=active.primary.index_name) # Test runs here.
@pytest.fixture(scope="module")
@@ -330,7 +331,7 @@ def patch_get_vespa_chunks_page_size() -> Generator[int, None, None]:
def test_documents(
db_session: Session,
vespa_document_index: VespaDocumentIndex,
opensearch_client: OpenSearchClient,
opensearch_client: OpenSearchIndexClient,
patch_get_vespa_chunks_page_size: int,
) -> Generator[list[Document], None, None]:
"""
@@ -411,7 +412,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
db_session: Session,
test_documents: list[Document],
vespa_document_index: VespaDocumentIndex,
opensearch_client: OpenSearchClient,
opensearch_client: OpenSearchIndexClient,
test_embedding_dimension: int,
clean_migration_tables: None, # noqa: ARG002
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
@@ -480,7 +481,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
db_session: Session,
test_documents: list[Document],
vespa_document_index: VespaDocumentIndex,
opensearch_client: OpenSearchClient,
opensearch_client: OpenSearchIndexClient,
test_embedding_dimension: int,
clean_migration_tables: None, # noqa: ARG002
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
@@ -618,7 +619,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
db_session: Session,
test_documents: list[Document],
vespa_document_index: VespaDocumentIndex,
opensearch_client: OpenSearchClient,
opensearch_client: OpenSearchIndexClient,
test_embedding_dimension: int,
clean_migration_tables: None, # noqa: ARG002
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
@@ -712,7 +713,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
db_session: Session,
test_documents: list[Document],
vespa_document_index: VespaDocumentIndex,
opensearch_client: OpenSearchClient,
opensearch_client: OpenSearchIndexClient,
test_embedding_dimension: int,
clean_migration_tables: None, # noqa: ARG002
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002

View File

@@ -42,7 +42,6 @@ def _create_llm_provider_and_model(
name=provider_name,
provider="openai",
api_key="test-api-key",
default_model_name=model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=model_name,

View File

@@ -434,7 +434,6 @@ class TestSlackBotFederatedSearch:
name=f"test-llm-provider-{uuid4().hex[:8]}",
provider=LlmProviderNames.OPENAI,
api_key=api_key,
default_model_name="gpt-4o",
is_public=True,
model_configurations=[
ModelConfigurationUpsertRequest(
@@ -448,7 +447,7 @@ class TestSlackBotFederatedSearch:
db_session=db_session,
)
update_default_provider(provider_view.id, db_session)
update_default_provider(provider_view.id, "gpt-4o", db_session)
def _teardown_common_mocks(self, patches: list) -> None:
"""Stop all patches"""

View File

@@ -20,6 +20,7 @@ from onyx.auth.oauth_token_manager import OAuthTokenManager
from onyx.db.models import OAuthConfig
from onyx.db.oauth_config import create_oauth_config
from onyx.db.oauth_config import upsert_user_oauth_token
from onyx.utils.sensitive import SensitiveValue
from tests.external_dependency_unit.conftest import create_test_user
@@ -491,3 +492,19 @@ class TestOAuthTokenManagerURLBuilding:
# Should use & instead of ? since URL already has query params
assert "foo=bar&" in url or "?foo=bar" in url
assert "client_id=custom_client_id" in url
class TestUnwrapSensitiveStr:
"""Tests for _unwrap_sensitive_str static method"""
def test_unwrap_sensitive_str(self) -> None:
"""Test that both SensitiveValue and plain str inputs are handled"""
# SensitiveValue input
sensitive = SensitiveValue[str](
encrypted_bytes=b"test_client_id",
decrypt_fn=lambda b: b.decode(),
)
assert OAuthTokenManager._unwrap_sensitive_str(sensitive) == "test_client_id"
# Plain str input
assert OAuthTokenManager._unwrap_sensitive_str("plain_string") == "plain_string"

View File

@@ -76,9 +76,12 @@ class ChatSessionManager:
user_performing_action: DATestUser,
persona_id: int = 0,
description: str = "Test chat session",
project_id: int | None = None,
) -> DATestChatSession:
chat_session_creation_req = ChatSessionCreationRequest(
persona_id=persona_id, description=description
persona_id=persona_id,
description=description,
project_id=project_id,
)
response = requests.post(
f"{API_SERVER_URL}/chat/create-chat-session",

View File

@@ -4,10 +4,12 @@ from uuid import uuid4
import requests
from onyx.llm.constants import LlmProviderNames
from onyx.server.manage.llm.models import DefaultModel
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestUser
@@ -32,7 +34,6 @@ class LLMProviderManager:
llm_provider = LLMProviderUpsertRequest(
name=name or f"test-provider-{uuid4()}",
provider=provider or LlmProviderNames.OPENAI,
default_model_name=default_model_name or "gpt-4o-mini",
api_key=api_key or os.environ["OPENAI_API_KEY"],
api_base=api_base,
api_version=api_version,
@@ -65,7 +66,7 @@ class LLMProviderManager:
name=response_data["name"],
provider=response_data["provider"],
api_key=response_data["api_key"],
default_model_name=response_data["default_model_name"],
default_model_name=default_model_name or "gpt-4o-mini",
is_public=response_data["is_public"],
is_auto_mode=response_data.get("is_auto_mode", False),
groups=response_data["groups"],
@@ -75,9 +76,19 @@ class LLMProviderManager:
)
if set_as_default:
if default_model_name is None:
default_model_name = "gpt-4o-mini"
set_default_response = requests.post(
f"{API_SERVER_URL}/admin/llm/provider/{llm_response.json()['id']}/default",
headers=user_performing_action.headers,
f"{API_SERVER_URL}/admin/llm/default",
json={
"provider_id": response_data["id"],
"model_name": default_model_name,
},
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
set_default_response.raise_for_status()
@@ -104,7 +115,7 @@ class LLMProviderManager:
headers=user_performing_action.headers,
)
response.raise_for_status()
return [LLMProviderView(**ug) for ug in response.json()]
return [LLMProviderView(**p) for p in response.json()["providers"]]
@staticmethod
def verify(
@@ -113,7 +124,11 @@ class LLMProviderManager:
verify_deleted: bool = False,
) -> None:
all_llm_providers = LLMProviderManager.get_all(user_performing_action)
default_model = LLMProviderManager.get_default_model(user_performing_action)
for fetched_llm_provider in all_llm_providers:
model_names = [
model.name for model in fetched_llm_provider.model_configurations
]
if llm_provider.id == fetched_llm_provider.id:
if verify_deleted:
raise ValueError(
@@ -126,11 +141,30 @@ class LLMProviderManager:
if (
fetched_llm_groups == llm_provider_groups
and llm_provider.provider == fetched_llm_provider.provider
and llm_provider.default_model_name
== fetched_llm_provider.default_model_name
and (
default_model is None or default_model.model_name in model_names
)
and llm_provider.is_public == fetched_llm_provider.is_public
and set(fetched_llm_provider.personas) == set(llm_provider.personas)
):
return
if not verify_deleted:
raise ValueError(f"LLM Provider {llm_provider.id} not found")
@staticmethod
def get_default_model(
user_performing_action: DATestUser | None = None,
) -> DefaultModel | None:
response = requests.get(
f"{API_SERVER_URL}/admin/llm/provider",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
default_text = response.json().get("default_text")
if default_text is None:
return None
return DefaultModel(**default_text)

View File

@@ -0,0 +1,79 @@
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestScimToken
from tests.integration.common_utils.test_models import DATestUser
class ScimTokenManager:
@staticmethod
def create(
name: str,
user_performing_action: DATestUser,
) -> DATestScimToken:
response = requests.post(
f"{API_SERVER_URL}/admin/enterprise-settings/scim/token",
json={"name": name},
headers=user_performing_action.headers,
timeout=60,
)
response.raise_for_status()
data = response.json()
return DATestScimToken(
id=data["id"],
name=data["name"],
token_display=data["token_display"],
is_active=data["is_active"],
created_at=data["created_at"],
last_used_at=data.get("last_used_at"),
raw_token=data["raw_token"],
)
@staticmethod
def get_active(
user_performing_action: DATestUser,
) -> DATestScimToken | None:
response = requests.get(
f"{API_SERVER_URL}/admin/enterprise-settings/scim/token",
headers=user_performing_action.headers,
timeout=60,
)
if response.status_code == 404:
return None
response.raise_for_status()
data = response.json()
return DATestScimToken(
id=data["id"],
name=data["name"],
token_display=data["token_display"],
is_active=data["is_active"],
created_at=data["created_at"],
last_used_at=data.get("last_used_at"),
)
@staticmethod
def get_scim_headers(raw_token: str) -> dict[str, str]:
return {
**GENERAL_HEADERS,
"Authorization": f"Bearer {raw_token}",
}
@staticmethod
def scim_get(
path: str,
raw_token: str,
) -> requests.Response:
return requests.get(
f"{API_SERVER_URL}/scim/v2{path}",
headers=ScimTokenManager.get_scim_headers(raw_token),
timeout=60,
)
@staticmethod
def scim_get_no_auth(path: str) -> requests.Response:
return requests.get(
f"{API_SERVER_URL}/scim/v2{path}",
headers=GENERAL_HEADERS,
timeout=60,
)

View File

@@ -42,6 +42,18 @@ class DATestPAT(BaseModel):
last_used_at: str | None = None
class DATestScimToken(BaseModel):
"""SCIM bearer token model for testing."""
id: int
name: str
raw_token: str | None = None # Only present on initial creation
token_display: str
is_active: bool
created_at: str
last_used_at: str | None = None
class DATestAPIKey(BaseModel):
api_key_id: int
api_key_display: str
@@ -116,7 +128,7 @@ class DATestLLMProvider(BaseModel):
name: str
provider: str
api_key: str
default_model_name: str
default_model_name: str | None = None
is_public: bool
is_auto_mode: bool = False
groups: list[int]

View File

@@ -42,12 +42,10 @@ def _create_provider_with_api(
llm_provider_data = {
"name": name,
"provider": provider_type,
"default_model_name": default_model,
"api_key": "test-api-key-for-auto-mode-testing",
"api_base": None,
"api_version": None,
"custom_config": None,
"fast_default_model_name": default_model,
"is_public": True,
"is_auto_mode": is_auto_mode,
"groups": [],
@@ -72,7 +70,7 @@ def _get_provider_by_id(admin_user: DATestUser, provider_id: int) -> dict:
headers=admin_user.headers,
)
response.raise_for_status()
for provider in response.json():
for provider in response.json()["providers"]:
if provider["id"] == provider_id:
return provider
raise ValueError(f"Provider with id {provider_id} not found")
@@ -219,15 +217,6 @@ def test_auto_mode_provider_gets_synced_from_github_config(
"is_visible"
], "Outdated model should not be visible after sync"
# Verify default model was set from GitHub config
expected_default = (
default_model["name"] if isinstance(default_model, dict) else default_model
)
assert synced_provider["default_model_name"] == expected_default, (
f"Default model should be {expected_default}, "
f"got {synced_provider['default_model_name']}"
)
def test_manual_mode_provider_not_affected_by_auto_sync(
reset: None, # noqa: ARG001
@@ -273,7 +262,3 @@ def test_manual_mode_provider_not_affected_by_auto_sync(
f"Manual mode provider models should not change. "
f"Initial: {initial_models}, Current: {current_models}"
)
assert (
updated_provider["default_model_name"] == custom_model
), f"Manual mode default model should remain {custom_model}"

View File

@@ -6,20 +6,21 @@ from sqlalchemy.orm import Session
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import LLMModelFlowType
from onyx.db.llm import can_user_access_llm_provider
from onyx.db.llm import fetch_user_group_ids
from onyx.db.models import LLMModelFlow
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import LLMProvider as LLMProviderModel
from onyx.db.models import LLMProvider__Persona
from onyx.db.models import LLMProvider__UserGroup
from onyx.db.models import ModelConfiguration
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.llm.constants import LlmProviderNames
from onyx.llm.factory import get_llm_for_persona
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.persona import PersonaManager
@@ -41,24 +42,30 @@ def _create_llm_provider(
is_public: bool,
is_default: bool,
) -> LLMProviderModel:
provider = LLMProviderModel(
name=name,
provider=LlmProviderNames.OPENAI,
api_key=None,
api_base=None,
api_version=None,
custom_config=None,
default_model_name=default_model_name,
deployment_name=None,
is_public=is_public,
# Use None instead of False to avoid unique constraint violation
# The is_default_provider column has unique=True, so only one True and one False allowed
is_default_provider=is_default if is_default else None,
is_default_vision_provider=False,
default_vision_model=None,
_provider = upsert_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=name,
provider=LlmProviderNames.OPENAI,
api_key=None,
api_base=None,
api_version=None,
custom_config=None,
is_public=is_public,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name,
is_visible=True,
)
],
),
db_session=db_session,
)
db_session.add(provider)
db_session.flush()
if is_default:
update_default_provider(_provider.id, default_model_name, db_session)
provider = db_session.get(LLMProviderModel, _provider.id)
if not provider:
raise ValueError(f"Provider {name} not found")
return provider
@@ -270,24 +277,6 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
provider_name=restricted_provider.name,
)
# Set up ModelConfiguration + LLMModelFlow so get_default_llm() can
# resolve the default provider when the fallback path is triggered.
default_model_config = ModelConfiguration(
llm_provider_id=default_provider.id,
name=default_provider.default_model_name,
is_visible=True,
)
db_session.add(default_model_config)
db_session.flush()
db_session.add(
LLMModelFlow(
model_configuration_id=default_model_config.id,
llm_model_flow_type=LLMModelFlowType.CHAT,
is_default=True,
)
)
db_session.flush()
access_group = UserGroup(name="persona-group")
db_session.add(access_group)
db_session.flush()
@@ -321,13 +310,19 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
persona=persona,
user=admin_model,
)
assert allowed_llm.config.model_name == restricted_provider.default_model_name
assert (
allowed_llm.config.model_name
== restricted_provider.model_configurations[0].name
)
fallback_llm = get_llm_for_persona(
persona=persona,
user=basic_model,
)
assert fallback_llm.config.model_name == default_provider.default_model_name
assert (
fallback_llm.config.model_name
== default_provider.model_configurations[0].name
)
def test_list_llm_provider_basics_excludes_non_public_unrestricted(
@@ -346,6 +341,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
name="public-provider",
is_public=True,
set_as_default=True,
default_model_name="gpt-4o",
user_performing_action=admin_user,
)
@@ -365,7 +361,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
headers=basic_user.headers,
)
assert response.status_code == 200
providers = response.json()
providers = response.json()["providers"]
provider_names = [p["name"] for p in providers]
# Public provider should be visible
@@ -380,7 +376,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
headers=admin_user.headers,
)
assert admin_response.status_code == 200
admin_providers = admin_response.json()
admin_providers = admin_response.json()["providers"]
admin_provider_names = [p["name"] for p in admin_providers]
assert public_provider.name in admin_provider_names
@@ -396,6 +392,7 @@ def test_provider_delete_clears_persona_references(reset: None) -> None: # noqa
name="default-provider",
is_public=True,
set_as_default=True,
default_model_name="gpt-4o",
user_performing_action=admin_user,
)

View File

@@ -107,7 +107,7 @@ def test_authorized_persona_access_returns_filtered_providers(
# Should succeed
assert response.status_code == 200
providers = response.json()
providers = response.json()["providers"]
# Should include the restricted provider since basic_user can access the persona
provider_names = [p["name"] for p in providers]
@@ -140,7 +140,7 @@ def test_persona_id_zero_applies_rbac(
# Should succeed (persona_id=0 refers to default persona, which is public)
assert response.status_code == 200
providers = response.json()
providers = response.json()["providers"]
# Should NOT include the restricted provider since basic_user is not in group2
provider_names = [p["name"] for p in providers]
@@ -182,7 +182,7 @@ def test_admin_can_query_any_persona(
# Should succeed - admins can access any persona
assert response.status_code == 200
providers = response.json()
providers = response.json()["providers"]
# Should include the restricted provider
provider_names = [p["name"] for p in providers]
@@ -223,7 +223,7 @@ def test_public_persona_accessible_to_all(
# Should succeed
assert response.status_code == 200
providers = response.json()
providers = response.json()["providers"]
# Should return the public provider
assert len(providers) > 0

View File

@@ -23,6 +23,8 @@ _ENV_PROVIDER = "NIGHTLY_LLM_PROVIDER"
_ENV_MODELS = "NIGHTLY_LLM_MODELS"
_ENV_API_KEY = "NIGHTLY_LLM_API_KEY"
_ENV_API_BASE = "NIGHTLY_LLM_API_BASE"
_ENV_API_VERSION = "NIGHTLY_LLM_API_VERSION"
_ENV_DEPLOYMENT_NAME = "NIGHTLY_LLM_DEPLOYMENT_NAME"
_ENV_CUSTOM_CONFIG_JSON = "NIGHTLY_LLM_CUSTOM_CONFIG_JSON"
_ENV_STRICT = "NIGHTLY_LLM_STRICT"
@@ -34,6 +36,8 @@ class NightlyProviderConfig(BaseModel):
model_names: list[str]
api_key: str | None
api_base: str | None
api_version: str | None
deployment_name: str | None
custom_config: dict[str, str] | None
strict: bool
@@ -45,17 +49,29 @@ def _env_true(env_var: str, default: bool = False) -> bool:
return value.strip().lower() in {"1", "true", "yes", "on"}
def _split_csv_env(env_var: str) -> list[str]:
return [
part.strip() for part in os.environ.get(env_var, "").split(",") if part.strip()
]
def _parse_models_env(env_var: str) -> list[str]:
raw_value = os.environ.get(env_var, "").strip()
if not raw_value:
return []
try:
parsed_json = json.loads(raw_value)
except json.JSONDecodeError:
parsed_json = None
if isinstance(parsed_json, list):
return [str(model).strip() for model in parsed_json if str(model).strip()]
return [part.strip() for part in raw_value.split(",") if part.strip()]
def _load_provider_config() -> NightlyProviderConfig:
provider = os.environ.get(_ENV_PROVIDER, "").strip().lower()
model_names = _split_csv_env(_ENV_MODELS)
model_names = _parse_models_env(_ENV_MODELS)
api_key = os.environ.get(_ENV_API_KEY) or None
api_base = os.environ.get(_ENV_API_BASE) or None
api_version = os.environ.get(_ENV_API_VERSION) or None
deployment_name = os.environ.get(_ENV_DEPLOYMENT_NAME) or None
strict = _env_true(_ENV_STRICT, default=False)
custom_config: dict[str, str] | None = None
@@ -74,6 +90,8 @@ def _load_provider_config() -> NightlyProviderConfig:
model_names=model_names,
api_key=api_key,
api_base=api_base,
api_version=api_version,
deployment_name=deployment_name,
custom_config=custom_config,
strict=strict,
)
@@ -95,10 +113,15 @@ def _validate_provider_config(config: NightlyProviderConfig) -> None:
message=f"{_ENV_MODELS} must include at least one model",
)
if config.provider != "ollama_chat" and not config.api_key:
if config.provider != "ollama_chat" and not (
config.api_key or config.custom_config
):
_skip_or_fail(
strict=config.strict,
message=(f"{_ENV_API_KEY} is required for provider '{config.provider}'"),
message=(
f"{_ENV_API_KEY} or {_ENV_CUSTOM_CONFIG_JSON} is required for "
f"provider '{config.provider}'"
),
)
if config.provider == "ollama_chat" and not (
@@ -109,6 +132,22 @@ def _validate_provider_config(config: NightlyProviderConfig) -> None:
message=(f"{_ENV_API_BASE} is required for provider '{config.provider}'"),
)
if config.provider == "azure":
if not config.api_base:
_skip_or_fail(
strict=config.strict,
message=(
f"{_ENV_API_BASE} is required for provider '{config.provider}'"
),
)
if not config.api_version:
_skip_or_fail(
strict=config.strict,
message=(
f"{_ENV_API_VERSION} is required for provider '{config.provider}'"
),
)
def _assert_integration_mode_enabled() -> None:
assert (
@@ -147,6 +186,8 @@ def _create_provider_payload(
model_name: str,
api_key: str | None,
api_base: str | None,
api_version: str | None,
deployment_name: str | None,
custom_config: dict[str, str] | None,
) -> dict:
return {
@@ -154,6 +195,8 @@ def _create_provider_payload(
"provider": provider,
"api_key": api_key,
"api_base": api_base,
"api_version": api_version,
"deployment_name": deployment_name,
"custom_config": custom_config,
"default_model_name": model_name,
"is_public": True,
@@ -255,6 +298,8 @@ def _create_and_test_provider_for_model(
model_name=model_name,
api_key=config.api_key,
api_base=resolved_api_base,
api_version=config.api_version,
deployment_name=config.deployment_name,
custom_config=config.custom_config,
)
@@ -313,10 +358,21 @@ def test_nightly_provider_chat_workflow(admin_user: DATestUser) -> None:
_seed_connector_for_search_tool(admin_user)
search_tool_id = _get_internal_search_tool_id(admin_user)
failures: list[str] = []
for model_name in config.model_names:
_create_and_test_provider_for_model(
admin_user=admin_user,
config=config,
model_name=model_name,
search_tool_id=search_tool_id,
)
try:
_create_and_test_provider_for_model(
admin_user=admin_user,
config=config,
model_name=model_name,
search_tool_id=search_tool_id,
)
except BaseException as exc:
if isinstance(exc, (KeyboardInterrupt, SystemExit)):
raise
failures.append(
f"provider={config.provider} model={model_name} error={type(exc).__name__}: {exc}"
)
if failures:
pytest.fail("Nightly provider chat failures:\n" + "\n".join(failures))

View File

@@ -72,6 +72,9 @@ def test_cold_startup_default_assistant() -> None:
assert (
"read_file" in tool_names
), "Default assistant should have FileReaderTool attached"
assert (
"python" in tool_names
), "Default assistant should have PythonTool attached"
# Also verify by display names for clarity
assert (
@@ -86,8 +89,11 @@ def test_cold_startup_default_assistant() -> None:
assert (
"File Reader" in tool_display_names
), "Default assistant should have File Reader tool"
# Should have exactly 5 tools
assert (
len(tool_associations) == 5
), f"Default assistant should have exactly 5 tools attached, got {len(tool_associations)}"
"Code Interpreter" in tool_display_names
), "Default assistant should have Code Interpreter tool"
# Should have exactly 6 tools
assert (
len(tool_associations) == 6
), f"Default assistant should have exactly 6 tools attached, got {len(tool_associations)}"

View File

@@ -0,0 +1,318 @@
"""
Integration tests for the unified persona file context flow.
End-to-end tests that verify:
1. Files can be uploaded and attached to a persona via API.
2. The persona correctly reports its attached files.
3. A chat session with a file-bearing persona processes without error.
4. Precedence: custom persona files take priority over project files when
the chat session is inside a project.
These tests run against a real Onyx deployment (all services running).
File processing is asynchronous, so we poll the file status endpoint
until files reach COMPLETED before chatting.
"""
import time
import requests
from onyx.db.enums import UserFileStatus
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import MAX_DELAY
from tests.integration.common_utils.managers.chat import ChatSessionManager
from tests.integration.common_utils.managers.file import FileManager
from tests.integration.common_utils.managers.persona import PersonaManager
from tests.integration.common_utils.managers.project import ProjectManager
from tests.integration.common_utils.test_file_utils import create_test_text_file
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestUser
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
FILE_PROCESSING_POLL_INTERVAL = 2
def _poll_file_statuses(
user_file_ids: list[str],
user: DATestUser,
target_status: UserFileStatus = UserFileStatus.COMPLETED,
timeout: int = MAX_DELAY,
) -> None:
"""Block until all files reach the target status or timeout expires."""
deadline = time.time() + timeout
while time.time() < deadline:
response = requests.post(
f"{API_SERVER_URL}/user/projects/file/statuses",
json={"file_ids": user_file_ids},
headers=user.headers,
)
response.raise_for_status()
statuses = response.json()
if all(f["status"] == target_status.value for f in statuses):
return
time.sleep(FILE_PROCESSING_POLL_INTERVAL)
raise TimeoutError(
f"Files {user_file_ids} did not reach {target_status.value} "
f"within {timeout}s"
)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
def test_persona_with_files_chat_no_error(
admin_user: DATestUser,
llm_provider: DATestLLMProvider, # noqa: ARG001
) -> None:
"""Upload files, attach them to a persona, wait for processing,
then send a chat message. Verify no error is returned."""
# Upload files (creates UserFile records)
text_file = create_test_text_file(
"The secret project codename is NIGHTINGALE. "
"It was started in 2024 by the Advanced Research division."
)
file_descriptors, error = FileManager.upload_files(
files=[("nightingale_brief.txt", text_file)],
user_performing_action=admin_user,
)
assert not error, f"File upload failed: {error}"
assert len(file_descriptors) == 1
user_file_id = file_descriptors[0]["user_file_id"]
assert user_file_id is not None
# Wait for file processing
_poll_file_statuses([user_file_id], admin_user, timeout=120)
# Create persona with the file attached
persona = PersonaManager.create(
user_performing_action=admin_user,
name="Nightingale Agent",
description="Agent with secret file",
system_prompt="You are a helpful assistant with access to uploaded files.",
user_file_ids=[user_file_id],
)
# Verify persona has the file
persona_snapshots = PersonaManager.get_one(persona.id, admin_user)
assert len(persona_snapshots) == 1
assert user_file_id in persona_snapshots[0].user_file_ids
# Chat with the persona
chat_session = ChatSessionManager.create(
persona_id=persona.id,
description="Test persona file context",
user_performing_action=admin_user,
)
response = ChatSessionManager.send_message(
chat_session_id=chat_session.id,
message="What is the secret project codename?",
user_performing_action=admin_user,
)
assert response.error is None, f"Chat should succeed, got error: {response.error}"
assert len(response.full_message) > 0, "Response should not be empty"
def test_persona_without_files_still_works(
admin_user: DATestUser,
llm_provider: DATestLLMProvider, # noqa: ARG001
) -> None:
"""A persona with no attached files should still chat normally."""
persona = PersonaManager.create(
user_performing_action=admin_user,
name="Blank Agent",
description="No files attached",
system_prompt="You are a helpful assistant.",
)
chat_session = ChatSessionManager.create(
persona_id=persona.id,
description="Test blank persona",
user_performing_action=admin_user,
)
response = ChatSessionManager.send_message(
chat_session_id=chat_session.id,
message="Hello, how are you?",
user_performing_action=admin_user,
)
assert response.error is None
assert len(response.full_message) > 0
def test_persona_files_override_project_files(
admin_user: DATestUser,
llm_provider: DATestLLMProvider, # noqa: ARG001
) -> None:
"""When a custom persona (with its own files) is used inside a project,
the persona's files take precedence — the project's files are invisible.
We verify this by putting different content in project vs persona files
and checking which content the model responds with."""
# Upload persona file
persona_file = create_test_text_file("The persona's secret word is ALBATROSS.")
persona_fds, err1 = FileManager.upload_files(
files=[("persona_secret.txt", persona_file)],
user_performing_action=admin_user,
)
assert not err1
persona_user_file_id = persona_fds[0]["user_file_id"]
assert persona_user_file_id is not None
# Create a project and upload project files
project = ProjectManager.create(
name="Precedence Test Project",
user_performing_action=admin_user,
)
project_files = [
("project_secret.txt", b"The project's secret word is FLAMINGO."),
]
project_upload_result = ProjectManager.upload_files(
project_id=project.id,
files=project_files,
user_performing_action=admin_user,
)
assert len(project_upload_result.user_files) == 1
project_user_file_id = str(project_upload_result.user_files[0].id)
# Wait for both persona and project file processing
_poll_file_statuses([persona_user_file_id], admin_user, timeout=120)
_poll_file_statuses([project_user_file_id], admin_user, timeout=120)
# Create persona with persona file
persona = PersonaManager.create(
user_performing_action=admin_user,
name="Override Agent",
description="Persona with its own files",
system_prompt="You are a helpful assistant. Answer using the files.",
user_file_ids=[persona_user_file_id],
)
# Create chat session inside the project but using the custom persona
chat_session = ChatSessionManager.create(
persona_id=persona.id,
project_id=project.id,
user_performing_action=admin_user,
)
response = ChatSessionManager.send_message(
chat_session_id=chat_session.id,
message="What is the secret word?",
user_performing_action=admin_user,
)
assert response.error is None, f"Chat should succeed, got error: {response.error}"
# The persona's file should be what the model sees, not the project's
message_lower = response.full_message.lower()
assert "albatross" in message_lower, (
"Response should reference the persona file's secret word (ALBATROSS), "
f"but got: {response.full_message}"
)
def test_default_persona_in_project_uses_project_files(
admin_user: DATestUser,
llm_provider: DATestLLMProvider, # noqa: ARG001
) -> None:
"""When the default persona (id=0) is used inside a project,
the project's files should be used for context."""
project = ProjectManager.create(
name="Default Persona Project",
user_performing_action=admin_user,
)
project_files = [
("project_info.txt", b"The project mascot is a PANGOLIN."),
]
upload_result = ProjectManager.upload_files(
project_id=project.id,
files=project_files,
user_performing_action=admin_user,
)
assert len(upload_result.user_files) == 1
# Wait for project file processing
project_file_id = str(upload_result.user_files[0].id)
_poll_file_statuses([project_file_id], admin_user, timeout=120)
# Create chat session inside project using default persona (id=0)
chat_session = ChatSessionManager.create(
persona_id=0,
project_id=project.id,
user_performing_action=admin_user,
)
response = ChatSessionManager.send_message(
chat_session_id=chat_session.id,
message="What is the project mascot?",
user_performing_action=admin_user,
)
assert response.error is None
assert "pangolin" in response.full_message.lower(), (
"Response should reference the project file content (PANGOLIN), "
f"but got: {response.full_message}"
)
def test_custom_persona_no_files_in_project_ignores_project(
admin_user: DATestUser,
llm_provider: DATestLLMProvider, # noqa: ARG001
) -> None:
"""A custom persona with NO files, used inside a project with files,
should NOT see the project's files. The project is purely organizational.
We verify by asking about content only in the project file and checking
the model does NOT reference it."""
project = ProjectManager.create(
name="Ignored Project",
user_performing_action=admin_user,
)
project_upload_result = ProjectManager.upload_files(
project_id=project.id,
files=[("project_only.txt", b"The project secret is CAPYBARA.")],
user_performing_action=admin_user,
)
assert len(project_upload_result.user_files) == 1
project_user_file_id = str(project_upload_result.user_files[0].id)
# Wait for project file processing
_poll_file_statuses([project_user_file_id], admin_user, timeout=120)
# Custom persona with no files
persona = PersonaManager.create(
user_performing_action=admin_user,
name="No Files Agent",
description="No files, project is irrelevant",
system_prompt=(
"You are a helpful assistant. If you do not have information "
"to answer a question, say 'I do not have that information.'"
),
)
chat_session = ChatSessionManager.create(
persona_id=persona.id,
project_id=project.id,
user_performing_action=admin_user,
)
response = ChatSessionManager.send_message(
chat_session_id=chat_session.id,
message="What is the project secret?",
user_performing_action=admin_user,
)
assert response.error is None
assert len(response.full_message) > 0
assert "capybara" not in response.full_message.lower(), (
"Response should NOT reference the project file content (CAPYBARA) "
"because the custom persona has no files and should not inherit "
f"project files, but got: {response.full_message}"
)

View File

@@ -0,0 +1,166 @@
"""Integration tests for SCIM token management.
Covers the admin token API and SCIM bearer-token authentication:
1. Token lifecycle: create, retrieve metadata, use for SCIM requests
2. Token rotation: creating a new token revokes previous tokens
3. Revoked tokens are rejected by SCIM endpoints
4. Non-admin users cannot manage SCIM tokens
5. SCIM requests without a token are rejected
6. Service discovery endpoints work without authentication
7. last_used_at is updated after a SCIM request
"""
import time
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.scim_token import ScimTokenManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
def test_scim_token_lifecycle(admin_user: DATestUser) -> None:
"""Create token → retrieve metadata → use for SCIM request."""
token = ScimTokenManager.create(
name="Test SCIM Token",
user_performing_action=admin_user,
)
assert token.raw_token is not None
assert token.raw_token.startswith("onyx_scim_")
assert token.is_active is True
assert "****" in token.token_display
# GET returns the same metadata but raw_token is None because the
# server only reveals the raw token once at creation time (it stores
# only the SHA-256 hash).
active = ScimTokenManager.get_active(user_performing_action=admin_user)
assert active == token.model_copy(update={"raw_token": None})
# Token works for SCIM requests
response = ScimTokenManager.scim_get("/Users", token.raw_token)
assert response.status_code == 200
body = response.json()
assert "Resources" in body
assert body["totalResults"] >= 0
def test_scim_token_rotation_revokes_previous(admin_user: DATestUser) -> None:
"""Creating a new token automatically revokes the previous one."""
first = ScimTokenManager.create(
name="First Token",
user_performing_action=admin_user,
)
assert first.raw_token is not None
response = ScimTokenManager.scim_get("/Users", first.raw_token)
assert response.status_code == 200
# Create second token — should revoke first
second = ScimTokenManager.create(
name="Second Token",
user_performing_action=admin_user,
)
assert second.raw_token is not None
# Active token should now be the second one
active = ScimTokenManager.get_active(user_performing_action=admin_user)
assert active == second.model_copy(update={"raw_token": None})
# First token rejected, second works
assert ScimTokenManager.scim_get("/Users", first.raw_token).status_code == 401
assert ScimTokenManager.scim_get("/Users", second.raw_token).status_code == 200
def test_scim_request_without_token_rejected(
admin_user: DATestUser, # noqa: ARG001
) -> None:
"""SCIM endpoints reject requests with no Authorization header."""
assert ScimTokenManager.scim_get_no_auth("/Users").status_code == 401
def test_scim_request_with_bad_token_rejected(
admin_user: DATestUser, # noqa: ARG001
) -> None:
"""SCIM endpoints reject requests with an invalid token."""
assert (
ScimTokenManager.scim_get("/Users", "onyx_scim_bogus_token_value").status_code
== 401
)
def test_non_admin_cannot_create_token(
admin_user: DATestUser, # noqa: ARG001
) -> None:
"""Non-admin users get 403 when trying to create a SCIM token."""
basic_user = UserManager.create(name="scim_basic_user")
response = requests.post(
f"{API_SERVER_URL}/admin/enterprise-settings/scim/token",
json={"name": "Should Fail"},
headers=basic_user.headers,
timeout=60,
)
assert response.status_code == 403
def test_non_admin_cannot_get_token(
admin_user: DATestUser, # noqa: ARG001
) -> None:
"""Non-admin users get 403 when trying to retrieve SCIM token metadata."""
basic_user = UserManager.create(name="scim_basic_user2")
response = requests.get(
f"{API_SERVER_URL}/admin/enterprise-settings/scim/token",
headers=basic_user.headers,
timeout=60,
)
assert response.status_code == 403
def test_no_active_token_returns_404(new_admin_user: DATestUser) -> None:
"""GET active token returns 404 when no token exists."""
# new_admin_user depends on the reset fixture, ensuring a clean DB
# with no active SCIM tokens.
active = ScimTokenManager.get_active(user_performing_action=new_admin_user)
assert active is None
response = requests.get(
f"{API_SERVER_URL}/admin/enterprise-settings/scim/token",
headers=new_admin_user.headers,
timeout=60,
)
assert response.status_code == 404
def test_service_discovery_no_auth_required(
admin_user: DATestUser, # noqa: ARG001
) -> None:
"""Service discovery endpoints work without any authentication."""
for path in ["/ServiceProviderConfig", "/ResourceTypes", "/Schemas"]:
response = ScimTokenManager.scim_get_no_auth(path)
assert response.status_code == 200, f"{path} returned {response.status_code}"
def test_last_used_at_updated_after_scim_request(
admin_user: DATestUser,
) -> None:
"""last_used_at timestamp is updated after using the token."""
token = ScimTokenManager.create(
name="Last Used Token",
user_performing_action=admin_user,
)
assert token.raw_token is not None
active = ScimTokenManager.get_active(user_performing_action=admin_user)
assert active is not None
assert active.last_used_at is None
# Make a SCIM request, then verify last_used_at is set
assert ScimTokenManager.scim_get("/Users", token.raw_token).status_code == 200
time.sleep(0.5)
active_after = ScimTokenManager.get_active(user_performing_action=admin_user)
assert active_after is not None
assert active_after.last_used_at is not None

View File

@@ -9,6 +9,19 @@ from redis.exceptions import RedisError
from onyx.server.settings.models import ApplicationStatus
from onyx.server.settings.models import Settings
# Fields we assert on across all tests
_ASSERT_FIELDS = {
"application_status",
"ee_features_enabled",
"seat_count",
"used_seats",
}
def _pick(settings: Settings) -> dict:
"""Extract only the fields under test from a Settings object."""
return settings.model_dump(include=_ASSERT_FIELDS)
@pytest.fixture
def base_settings() -> Settings:
@@ -27,17 +40,17 @@ class TestApplyLicenseStatusToSettings:
def test_enforcement_disabled_enables_ee_features(
self, base_settings: Settings
) -> None:
"""When LICENSE_ENFORCEMENT_ENABLED=False, EE features are enabled.
If we're running the EE apply function, EE code was loaded via
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES, so features should be on.
"""
"""When LICENSE_ENFORCEMENT_ENABLED=False, EE features are enabled."""
from ee.onyx.server.settings.api import apply_license_status_to_settings
assert base_settings.ee_features_enabled is False
result = apply_license_status_to_settings(base_settings)
assert result.application_status == ApplicationStatus.ACTIVE
assert result.ee_features_enabled is True
assert _pick(result) == {
"application_status": ApplicationStatus.ACTIVE,
"ee_features_enabled": True,
"seat_count": None,
"used_seats": None,
}
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@patch("ee.onyx.server.settings.api.MULTI_TENANT", True)
@@ -46,13 +59,60 @@ class TestApplyLicenseStatusToSettings:
from ee.onyx.server.settings.api import apply_license_status_to_settings
result = apply_license_status_to_settings(base_settings)
assert result.ee_features_enabled is True
assert _pick(result) == {
"application_status": ApplicationStatus.ACTIVE,
"ee_features_enabled": True,
"seat_count": None,
"used_seats": None,
}
@pytest.mark.parametrize(
"license_status,expected_app_status,expected_ee_enabled",
"license_status,used_seats,seats,expected",
[
(ApplicationStatus.GATED_ACCESS, ApplicationStatus.GATED_ACCESS, False),
(ApplicationStatus.ACTIVE, ApplicationStatus.ACTIVE, True),
(
ApplicationStatus.GATED_ACCESS,
3,
10,
{
"application_status": ApplicationStatus.GATED_ACCESS,
"ee_features_enabled": False,
"seat_count": None,
"used_seats": None,
},
),
(
ApplicationStatus.ACTIVE,
3,
10,
{
"application_status": ApplicationStatus.ACTIVE,
"ee_features_enabled": True,
"seat_count": None,
"used_seats": None,
},
),
(
ApplicationStatus.ACTIVE,
10,
10,
{
"application_status": ApplicationStatus.ACTIVE,
"ee_features_enabled": True,
"seat_count": None,
"used_seats": None,
},
),
(
ApplicationStatus.GRACE_PERIOD,
3,
10,
{
"application_status": ApplicationStatus.ACTIVE,
"ee_features_enabled": True,
"seat_count": None,
"used_seats": None,
},
),
],
)
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@@ -63,25 +123,80 @@ class TestApplyLicenseStatusToSettings:
self,
mock_get_metadata: MagicMock,
mock_get_tenant: MagicMock,
license_status: ApplicationStatus | None,
expected_app_status: ApplicationStatus,
expected_ee_enabled: bool,
license_status: ApplicationStatus,
used_seats: int,
seats: int,
expected: dict,
base_settings: Settings,
) -> None:
"""Self-hosted: license status controls both application_status and ee_features_enabled."""
from ee.onyx.server.settings.api import apply_license_status_to_settings
mock_get_tenant.return_value = "test_tenant"
if license_status is None:
mock_get_metadata.return_value = None
else:
mock_metadata = MagicMock()
mock_metadata.status = license_status
mock_get_metadata.return_value = mock_metadata
mock_metadata = MagicMock()
mock_metadata.status = license_status
mock_metadata.used_seats = used_seats
mock_metadata.seats = seats
mock_get_metadata.return_value = mock_metadata
result = apply_license_status_to_settings(base_settings)
assert result.application_status == expected_app_status
assert result.ee_features_enabled is expected_ee_enabled
assert _pick(result) == expected
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@patch("ee.onyx.server.settings.api.MULTI_TENANT", False)
@patch("ee.onyx.server.settings.api.get_current_tenant_id")
@patch("ee.onyx.server.settings.api.get_cached_license_metadata")
def test_seat_limit_exceeded_sets_status_and_counts(
self,
mock_get_metadata: MagicMock,
mock_get_tenant: MagicMock,
base_settings: Settings,
) -> None:
"""Seat limit exceeded sets SEAT_LIMIT_EXCEEDED with counts, keeps EE enabled."""
from ee.onyx.server.settings.api import apply_license_status_to_settings
mock_get_tenant.return_value = "test_tenant"
mock_metadata = MagicMock()
mock_metadata.status = ApplicationStatus.ACTIVE
mock_metadata.used_seats = 15
mock_metadata.seats = 10
mock_get_metadata.return_value = mock_metadata
result = apply_license_status_to_settings(base_settings)
assert _pick(result) == {
"application_status": ApplicationStatus.SEAT_LIMIT_EXCEEDED,
"ee_features_enabled": True,
"seat_count": 10,
"used_seats": 15,
}
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@patch("ee.onyx.server.settings.api.MULTI_TENANT", False)
@patch("ee.onyx.server.settings.api.get_current_tenant_id")
@patch("ee.onyx.server.settings.api.get_cached_license_metadata")
def test_expired_license_takes_precedence_over_seat_limit(
self,
mock_get_metadata: MagicMock,
mock_get_tenant: MagicMock,
base_settings: Settings,
) -> None:
"""Expired license (GATED_ACCESS) takes precedence over seat limit exceeded."""
from ee.onyx.server.settings.api import apply_license_status_to_settings
mock_get_tenant.return_value = "test_tenant"
mock_metadata = MagicMock()
mock_metadata.status = ApplicationStatus.GATED_ACCESS
mock_metadata.used_seats = 15
mock_metadata.seats = 10
mock_get_metadata.return_value = mock_metadata
result = apply_license_status_to_settings(base_settings)
assert _pick(result) == {
"application_status": ApplicationStatus.GATED_ACCESS,
"ee_features_enabled": False,
"seat_count": None,
"used_seats": None,
}
@patch("ee.onyx.server.settings.api.ENTERPRISE_EDITION_ENABLED", True)
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@@ -105,8 +220,12 @@ class TestApplyLicenseStatusToSettings:
mock_get_metadata.return_value = None
result = apply_license_status_to_settings(base_settings)
assert result.application_status == ApplicationStatus.GATED_ACCESS
assert result.ee_features_enabled is False
assert _pick(result) == {
"application_status": ApplicationStatus.GATED_ACCESS,
"ee_features_enabled": False,
"seat_count": None,
"used_seats": None,
}
@patch("ee.onyx.server.settings.api.ENTERPRISE_EDITION_ENABLED", False)
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@@ -130,8 +249,12 @@ class TestApplyLicenseStatusToSettings:
mock_get_metadata.return_value = None
result = apply_license_status_to_settings(base_settings)
assert result.application_status == ApplicationStatus.ACTIVE
assert result.ee_features_enabled is False
assert _pick(result) == {
"application_status": ApplicationStatus.ACTIVE,
"ee_features_enabled": False,
"seat_count": None,
"used_seats": None,
}
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@patch("ee.onyx.server.settings.api.MULTI_TENANT", False)
@@ -150,8 +273,12 @@ class TestApplyLicenseStatusToSettings:
mock_get_metadata.side_effect = RedisError("Connection failed")
result = apply_license_status_to_settings(base_settings)
assert result.application_status == ApplicationStatus.ACTIVE
assert result.ee_features_enabled is False
assert _pick(result) == {
"application_status": ApplicationStatus.ACTIVE,
"ee_features_enabled": False,
"seat_count": None,
"used_seats": None,
}
class TestSettingsDefaultEEDisabled:

View File

@@ -0,0 +1,426 @@
"""Tests for the unified context file extraction logic (Phase 5).
Covers:
- resolve_context_user_files: precedence rule (custom persona supersedes project)
- extract_context_files: all-or-nothing context window fit check
- Search filter / search_usage determination in the caller
"""
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import UUID
from uuid import uuid4
from onyx.chat.models import ExtractedContextFiles
from onyx.chat.process_message import determine_search_params
from onyx.chat.process_message import extract_context_files
from onyx.chat.process_message import resolve_context_user_files
from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.db.models import UserFile
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
from onyx.tools.models import SearchToolUsage
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_user_file(
token_count: int = 100,
name: str = "file.txt",
file_id: str | None = None,
) -> UserFile:
file_uuid = UUID(file_id) if file_id else uuid4()
return UserFile(
id=file_uuid,
file_id=str(file_uuid),
name=name,
token_count=token_count,
)
def _make_persona(
persona_id: int,
user_files: list | None = None,
) -> MagicMock:
persona = MagicMock()
persona.id = persona_id
persona.user_files = user_files or []
return persona
def _make_in_memory_file(
file_id: str,
content: str = "hello world",
file_type: ChatFileType = ChatFileType.PLAIN_TEXT,
filename: str = "file.txt",
) -> InMemoryChatFile:
return InMemoryChatFile(
file_id=file_id,
content=content.encode("utf-8"),
file_type=file_type,
filename=filename,
)
# ===========================================================================
# resolve_context_user_files
# ===========================================================================
class TestResolveContextUserFiles:
"""Precedence rule: custom persona fully supersedes project."""
def test_custom_persona_with_files_returns_persona_files(self) -> None:
persona_files = [_make_user_file(), _make_user_file()]
persona = _make_persona(persona_id=42, user_files=persona_files)
db_session = MagicMock()
result = resolve_context_user_files(
persona=persona, project_id=99, user_id=uuid4(), db_session=db_session
)
assert result == persona_files
def test_custom_persona_without_files_returns_empty(self) -> None:
"""Custom persona with no files should NOT fall through to project."""
persona = _make_persona(persona_id=42, user_files=[])
db_session = MagicMock()
result = resolve_context_user_files(
persona=persona, project_id=99, user_id=uuid4(), db_session=db_session
)
assert result == []
def test_custom_persona_none_files_returns_empty(self) -> None:
"""Custom persona with user_files=None should NOT fall through."""
persona = _make_persona(persona_id=42, user_files=None)
db_session = MagicMock()
result = resolve_context_user_files(
persona=persona, project_id=99, user_id=uuid4(), db_session=db_session
)
assert result == []
@patch("onyx.chat.process_message.get_user_files_from_project")
def test_default_persona_in_project_returns_project_files(
self, mock_get_files: MagicMock
) -> None:
project_files = [_make_user_file(), _make_user_file()]
mock_get_files.return_value = project_files
persona = _make_persona(persona_id=DEFAULT_PERSONA_ID)
user_id = uuid4()
db_session = MagicMock()
result = resolve_context_user_files(
persona=persona, project_id=99, user_id=user_id, db_session=db_session
)
assert result == project_files
mock_get_files.assert_called_once_with(
project_id=99, user_id=user_id, db_session=db_session
)
def test_default_persona_no_project_returns_empty(self) -> None:
persona = _make_persona(persona_id=DEFAULT_PERSONA_ID)
db_session = MagicMock()
result = resolve_context_user_files(
persona=persona, project_id=None, user_id=uuid4(), db_session=db_session
)
assert result == []
@patch("onyx.chat.process_message.get_user_files_from_project")
def test_custom_persona_without_files_ignores_project(
self, mock_get_files: MagicMock
) -> None:
"""Even with a project_id, custom persona means project is invisible."""
persona = _make_persona(persona_id=7, user_files=[])
db_session = MagicMock()
result = resolve_context_user_files(
persona=persona, project_id=99, user_id=uuid4(), db_session=db_session
)
assert result == []
mock_get_files.assert_not_called()
# ===========================================================================
# extract_context_files
# ===========================================================================
class TestExtractContextFiles:
"""All-or-nothing context window fit check."""
def test_empty_user_files_returns_empty(self) -> None:
db_session = MagicMock()
result = extract_context_files(
user_files=[],
llm_max_context_window=10000,
reserved_token_count=0,
db_session=db_session,
)
assert result.file_texts == []
assert result.image_files == []
assert result.use_as_search_filter is False
assert result.uncapped_token_count is None
@patch("onyx.chat.process_message.load_in_memory_chat_files")
def test_files_fit_in_context_are_loaded(self, mock_load: MagicMock) -> None:
file_id = str(uuid4())
uf = _make_user_file(token_count=100, file_id=file_id)
mock_load.return_value = [
_make_in_memory_file(file_id=file_id, content="file content")
]
result = extract_context_files(
user_files=[uf],
llm_max_context_window=10000,
reserved_token_count=0,
db_session=MagicMock(),
)
assert result.file_texts == ["file content"]
assert result.use_as_search_filter is False
assert result.total_token_count == 100
assert len(result.file_metadata) == 1
assert result.file_metadata[0].file_id == file_id
def test_files_overflow_context_not_loaded(self) -> None:
"""When aggregate tokens exceed 60% of available window, nothing is loaded."""
uf = _make_user_file(token_count=7000)
result = extract_context_files(
user_files=[uf],
llm_max_context_window=10000,
reserved_token_count=0,
db_session=MagicMock(),
)
assert result.file_texts == []
assert result.image_files == []
assert result.use_as_search_filter is True
assert result.uncapped_token_count == 7000
assert result.total_token_count == 0
def test_overflow_boundary_exact(self) -> None:
"""Token count exactly at the 60% boundary should trigger overflow."""
# Available = (10000 - 0) * 0.6 = 6000. Tokens = 6000 → >= threshold.
uf = _make_user_file(token_count=6000)
result = extract_context_files(
user_files=[uf],
llm_max_context_window=10000,
reserved_token_count=0,
db_session=MagicMock(),
)
assert result.use_as_search_filter is True
@patch("onyx.chat.process_message.load_in_memory_chat_files")
def test_just_under_boundary_loads(self, mock_load: MagicMock) -> None:
"""Token count just under the 60% boundary should load files."""
file_id = str(uuid4())
uf = _make_user_file(token_count=5999, file_id=file_id)
mock_load.return_value = [_make_in_memory_file(file_id=file_id, content="data")]
result = extract_context_files(
user_files=[uf],
llm_max_context_window=10000,
reserved_token_count=0,
db_session=MagicMock(),
)
assert result.use_as_search_filter is False
assert result.file_texts == ["data"]
@patch("onyx.chat.process_message.load_in_memory_chat_files")
def test_multiple_files_aggregate_check(self, mock_load: MagicMock) -> None:
"""Multiple small files that individually fit but collectively overflow."""
files = [_make_user_file(token_count=2500) for _ in range(3)]
# 3 * 2500 = 7500 > 6000 threshold
result = extract_context_files(
user_files=files,
llm_max_context_window=10000,
reserved_token_count=0,
db_session=MagicMock(),
)
assert result.use_as_search_filter is True
assert result.file_texts == []
mock_load.assert_not_called()
@patch("onyx.chat.process_message.load_in_memory_chat_files")
def test_reserved_tokens_reduce_available_space(self, mock_load: MagicMock) -> None:
"""Reserved tokens shrink the available window."""
file_id = str(uuid4())
uf = _make_user_file(token_count=3000, file_id=file_id)
# Available = (10000 - 5000) * 0.6 = 3000. Tokens = 3000 → overflow.
result = extract_context_files(
user_files=[uf],
llm_max_context_window=10000,
reserved_token_count=5000,
db_session=MagicMock(),
)
assert result.use_as_search_filter is True
mock_load.assert_not_called()
@patch("onyx.chat.process_message.load_in_memory_chat_files")
def test_image_files_are_extracted(self, mock_load: MagicMock) -> None:
file_id = str(uuid4())
uf = _make_user_file(token_count=50, file_id=file_id)
mock_load.return_value = [
InMemoryChatFile(
file_id=file_id,
content=b"\x89PNG",
file_type=ChatFileType.IMAGE,
filename="photo.png",
)
]
result = extract_context_files(
user_files=[uf],
llm_max_context_window=10000,
reserved_token_count=0,
db_session=MagicMock(),
)
assert len(result.image_files) == 1
assert result.image_files[0].file_id == file_id
assert result.file_texts == []
assert result.total_token_count == 50
@patch("onyx.chat.process_message.DISABLE_VECTOR_DB", True)
def test_overflow_with_vector_db_disabled_provides_tool_metadata(self) -> None:
"""When vector DB is disabled, overflow produces FileToolMetadata."""
uf = _make_user_file(token_count=7000, name="bigfile.txt")
result = extract_context_files(
user_files=[uf],
llm_max_context_window=10000,
reserved_token_count=0,
db_session=MagicMock(),
)
assert result.use_as_search_filter is False
assert len(result.file_metadata_for_tool) == 1
assert result.file_metadata_for_tool[0].filename == "bigfile.txt"
# ===========================================================================
# Search filter + search_usage determination
# ===========================================================================
class TestSearchFilterDetermination:
"""Verify that determine_search_params correctly resolves
search_project_id, search_persona_id, and search_usage based on
the extraction result and the precedence rule.
"""
@staticmethod
def _make_context(
use_as_search_filter: bool = False,
file_texts: list[str] | None = None,
uncapped_token_count: int | None = None,
) -> ExtractedContextFiles:
return ExtractedContextFiles(
file_texts=file_texts or [],
image_files=[],
use_as_search_filter=use_as_search_filter,
total_token_count=0,
file_metadata=[],
uncapped_token_count=uncapped_token_count,
)
def test_custom_persona_files_fit_no_filter(self) -> None:
"""Custom persona, files fit → no search filter, AUTO."""
result = determine_search_params(
persona_id=42,
project_id=99,
extracted_context_files=self._make_context(
file_texts=["content"],
uncapped_token_count=100,
),
)
assert result.search_project_id is None
assert result.search_persona_id is None
assert result.search_usage == SearchToolUsage.AUTO
def test_custom_persona_files_overflow_persona_filter(self) -> None:
"""Custom persona, files overflow → persona_id filter, AUTO."""
result = determine_search_params(
persona_id=42,
project_id=99,
extracted_context_files=self._make_context(use_as_search_filter=True),
)
assert result.search_persona_id == 42
assert result.search_project_id is None
assert result.search_usage == SearchToolUsage.AUTO
def test_custom_persona_no_files_no_project_leak(self) -> None:
"""Custom persona (no files) in project → nothing leaks from project."""
result = determine_search_params(
persona_id=42,
project_id=99,
extracted_context_files=self._make_context(),
)
assert result.search_project_id is None
assert result.search_persona_id is None
assert result.search_usage == SearchToolUsage.AUTO
def test_default_persona_project_files_fit_disables_search(self) -> None:
"""Default persona, project files fit → DISABLED."""
result = determine_search_params(
persona_id=DEFAULT_PERSONA_ID,
project_id=99,
extracted_context_files=self._make_context(
file_texts=["content"],
uncapped_token_count=100,
),
)
assert result.search_project_id is None
assert result.search_usage == SearchToolUsage.DISABLED
def test_default_persona_project_files_overflow_enables_search(self) -> None:
"""Default persona, project files overflow → ENABLED + project_id filter."""
result = determine_search_params(
persona_id=DEFAULT_PERSONA_ID,
project_id=99,
extracted_context_files=self._make_context(
use_as_search_filter=True,
uncapped_token_count=7000,
),
)
assert result.search_project_id == 99
assert result.search_persona_id is None
assert result.search_usage == SearchToolUsage.ENABLED
def test_default_persona_no_project_auto(self) -> None:
"""Default persona, no project → AUTO."""
result = determine_search_params(
persona_id=DEFAULT_PERSONA_ID,
project_id=None,
extracted_context_files=self._make_context(),
)
assert result.search_project_id is None
assert result.search_usage == SearchToolUsage.AUTO
def test_default_persona_project_no_files_disables_search(self) -> None:
"""Default persona in project with no files → DISABLED."""
result = determine_search_params(
persona_id=DEFAULT_PERSONA_ID,
project_id=99,
extracted_context_files=self._make_context(),
)
assert result.search_usage == SearchToolUsage.DISABLED

View File

@@ -7,10 +7,10 @@ from onyx.chat.llm_loop import _try_fallback_tool_extraction
from onyx.chat.llm_loop import construct_message_history
from onyx.chat.models import ChatLoadedFile
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import ExtractedProjectFiles
from onyx.chat.models import ContextFileMetadata
from onyx.chat.models import ExtractedContextFiles
from onyx.chat.models import FileToolMetadata
from onyx.chat.models import LlmStepResult
from onyx.chat.models import ProjectFileMetadata
from onyx.chat.models import ToolCallSimple
from onyx.configs.constants import MessageType
from onyx.file_store.models import ChatFileType
@@ -74,20 +74,20 @@ def create_tool_response(
)
def create_project_files(
def create_context_files(
num_files: int = 0, num_images: int = 0, tokens_per_file: int = 100
) -> ExtractedProjectFiles:
"""Helper to create ExtractedProjectFiles for testing."""
project_file_texts = [f"Project file {i} content" for i in range(num_files)]
project_file_metadata = [
ProjectFileMetadata(
) -> ExtractedContextFiles:
"""Helper to create ExtractedContextFiles for testing."""
file_texts = [f"Project file {i} content" for i in range(num_files)]
file_metadata = [
ContextFileMetadata(
file_id=f"file_{i}",
filename=f"file_{i}.txt",
file_content=f"Project file {i} content",
)
for i in range(num_files)
]
project_image_files = [
image_files = [
ChatLoadedFile(
file_id=f"image_{i}",
content=b"",
@@ -98,13 +98,13 @@ def create_project_files(
)
for i in range(num_images)
]
return ExtractedProjectFiles(
project_file_texts=project_file_texts,
project_image_files=project_image_files,
project_as_filter=False,
return ExtractedContextFiles(
file_texts=file_texts,
image_files=image_files,
use_as_search_filter=False,
total_token_count=num_files * tokens_per_file,
project_file_metadata=project_file_metadata,
project_uncapped_token_count=num_files * tokens_per_file,
file_metadata=file_metadata,
uncapped_token_count=num_files * tokens_per_file,
)
@@ -121,14 +121,14 @@ class TestConstructMessageHistory:
user_msg2 = create_message("How are you?", MessageType.USER, 5)
simple_chat_history = [user_msg1, assistant_msg1, user_msg2]
project_files = create_project_files()
context_files = create_context_files()
result = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
project_files=project_files,
context_files=context_files,
available_tokens=1000,
)
@@ -148,14 +148,14 @@ class TestConstructMessageHistory:
custom_agent = create_message("Custom instructions", MessageType.USER, 10)
simple_chat_history = [user_msg1, assistant_msg1, user_msg2]
project_files = create_project_files()
context_files = create_context_files()
result = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=custom_agent,
simple_chat_history=simple_chat_history,
reminder_message=None,
project_files=project_files,
context_files=context_files,
available_tokens=1000,
)
@@ -167,25 +167,25 @@ class TestConstructMessageHistory:
assert result[3] == custom_agent # Before last user message
assert result[4] == user_msg2
def test_with_project_files(self) -> None:
def test_with_context_files(self) -> None:
"""Test that project files are inserted before the last user message."""
system_prompt = create_message("System", MessageType.SYSTEM, 10)
user_msg1 = create_message("First message", MessageType.USER, 5)
user_msg2 = create_message("Second message", MessageType.USER, 5)
simple_chat_history = [user_msg1, user_msg2]
project_files = create_project_files(num_files=2, tokens_per_file=50)
context_files = create_context_files(num_files=2, tokens_per_file=50)
result = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
project_files=project_files,
context_files=context_files,
available_tokens=1000,
)
# Should have: system, user1, project_files_message, user2
# Should have: system, user1, context_files_message, user2
assert len(result) == 4
assert result[0] == system_prompt
assert result[1] == user_msg1
@@ -202,14 +202,14 @@ class TestConstructMessageHistory:
reminder = create_message("Remember to cite sources", MessageType.USER, 10)
simple_chat_history = [user_msg]
project_files = create_project_files()
context_files = create_context_files()
result = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=reminder,
project_files=project_files,
context_files=context_files,
available_tokens=1000,
)
@@ -235,14 +235,14 @@ class TestConstructMessageHistory:
assistant_with_tool,
tool_response,
]
project_files = create_project_files()
context_files = create_context_files()
result = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
project_files=project_files,
context_files=context_files,
available_tokens=1000,
)
@@ -264,18 +264,18 @@ class TestConstructMessageHistory:
custom_agent = create_message("Custom", MessageType.USER, 10)
simple_chat_history = [user_msg1, user_msg2, assistant_with_tool]
project_files = create_project_files(num_files=1, tokens_per_file=50)
context_files = create_context_files(num_files=1, tokens_per_file=50)
result = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=custom_agent,
simple_chat_history=simple_chat_history,
reminder_message=None,
project_files=project_files,
context_files=context_files,
available_tokens=1000,
)
# Should have: system, user1, custom_agent, project_files, user2, assistant_with_tool
# Should have: system, user1, custom_agent, context_files, user2, assistant_with_tool
assert len(result) == 6
assert result[0] == system_prompt
assert result[1] == user_msg1
@@ -292,14 +292,14 @@ class TestConstructMessageHistory:
user_msg2 = create_message("Second", MessageType.USER, 5)
simple_chat_history = [user_msg1, user_msg2]
project_files = create_project_files(num_files=0, num_images=2)
context_files = create_context_files(num_files=0, num_images=2)
result = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
project_files=project_files,
context_files=context_files,
available_tokens=1000,
)
@@ -332,14 +332,14 @@ class TestConstructMessageHistory:
)
simple_chat_history = [user_msg]
project_files = create_project_files(num_files=0, num_images=1)
context_files = create_context_files(num_files=0, num_images=1)
result = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
project_files=project_files,
context_files=context_files,
available_tokens=1000,
)
@@ -366,7 +366,7 @@ class TestConstructMessageHistory:
assistant_msg2,
user_msg3,
]
project_files = create_project_files()
context_files = create_context_files()
# Budget only allows last 3 messages + system (10 + 20 + 20 + 20 = 70 tokens)
result = construct_message_history(
@@ -374,7 +374,7 @@ class TestConstructMessageHistory:
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
project_files=project_files,
context_files=context_files,
available_tokens=80,
)
@@ -395,7 +395,7 @@ class TestConstructMessageHistory:
tool_response = create_tool_response("tc_1", "tool_response", 20)
simple_chat_history = [user_msg1, user_msg2, assistant_with_tool, tool_response]
project_files = create_project_files()
context_files = create_context_files()
# Budget only allows last user message and messages after + system
# (10 + 20 + 20 + 20 = 70 tokens)
@@ -404,7 +404,7 @@ class TestConstructMessageHistory:
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
project_files=project_files,
context_files=context_files,
available_tokens=80,
)
@@ -432,7 +432,7 @@ class TestConstructMessageHistory:
assistant_msg1,
user_msg2,
]
project_files = create_project_files()
context_files = create_context_files()
# Remaining history budget is 10 tokens (30 total - 10 system - 10 last user):
# keeps [tool_response, assistant_msg1] from history_before_last_user,
@@ -442,7 +442,7 @@ class TestConstructMessageHistory:
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
project_files=project_files,
context_files=context_files,
available_tokens=30,
)
@@ -461,7 +461,7 @@ class TestConstructMessageHistory:
user_msg2 = create_message("Latest question", MessageType.USER, 10)
simple_chat_history = [user_msg1, assistant_with_tool, tool_response, user_msg2]
project_files = create_project_files()
context_files = create_context_files()
# Remaining history budget is 25 tokens (45 total - 10 system - 10 last user):
# keeps both assistant_with_tool and tool_response in history_before_last_user.
@@ -470,7 +470,7 @@ class TestConstructMessageHistory:
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
project_files=project_files,
context_files=context_files,
available_tokens=45,
)
@@ -487,18 +487,18 @@ class TestConstructMessageHistory:
reminder = create_message("Reminder", MessageType.USER, 10)
simple_chat_history: list[ChatMessageSimple] = []
project_files = create_project_files(num_files=1, tokens_per_file=50)
context_files = create_context_files(num_files=1, tokens_per_file=50)
result = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=custom_agent,
simple_chat_history=simple_chat_history,
reminder_message=reminder,
project_files=project_files,
context_files=context_files,
available_tokens=1000,
)
# Should have: system, custom_agent, project_files, reminder
# Should have: system, custom_agent, context_files, reminder
assert len(result) == 4
assert result[0] == system_prompt
assert result[1] == custom_agent
@@ -512,7 +512,7 @@ class TestConstructMessageHistory:
assistant_with_tool = create_assistant_with_tool_call("tc_1", "tool", 5)
simple_chat_history = [assistant_msg, assistant_with_tool]
project_files = create_project_files()
context_files = create_context_files()
with pytest.raises(ValueError, match="No user message found"):
construct_message_history(
@@ -520,7 +520,7 @@ class TestConstructMessageHistory:
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
project_files=project_files,
context_files=context_files,
available_tokens=1000,
)
@@ -531,7 +531,7 @@ class TestConstructMessageHistory:
custom_agent = create_message("Custom", MessageType.USER, 50)
simple_chat_history = [user_msg]
project_files = create_project_files(num_files=1, tokens_per_file=100)
context_files = create_context_files(num_files=1, tokens_per_file=100)
# Total required: 50 (system) + 50 (custom) + 100 (project) + 50 (user) = 250
# But only 200 available
@@ -541,7 +541,7 @@ class TestConstructMessageHistory:
custom_agent_prompt=custom_agent,
simple_chat_history=simple_chat_history,
reminder_message=None,
project_files=project_files,
context_files=context_files,
available_tokens=200,
)
@@ -553,7 +553,7 @@ class TestConstructMessageHistory:
assistant_with_tool = create_assistant_with_tool_call("tc_1", "tool", 30)
simple_chat_history = [user_msg1, user_msg2, assistant_with_tool]
project_files = create_project_files()
context_files = create_context_files()
# Budget: 50 tokens
# Required: 10 (system) + 30 (user2) + 30 (assistant_with_tool) = 70 tokens
@@ -566,7 +566,7 @@ class TestConstructMessageHistory:
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
project_files=project_files,
context_files=context_files,
available_tokens=50,
)
@@ -592,20 +592,20 @@ class TestConstructMessageHistory:
assistant_with_tool,
tool_response,
]
project_files = create_project_files(num_files=2, tokens_per_file=20)
context_files = create_context_files(num_files=2, tokens_per_file=20)
result = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=custom_agent,
simple_chat_history=simple_chat_history,
reminder_message=reminder,
project_files=project_files,
context_files=context_files,
available_tokens=1000,
)
# Expected order:
# system, user1, assistant1, user2, assistant2,
# custom_agent, project_files, user3, assistant_with_tool, tool_response, reminder
# custom_agent, context_files, user3, assistant_with_tool, tool_response, reminder
assert len(result) == 11
assert result[0] == system_prompt
assert result[1] == user_msg1
@@ -622,20 +622,20 @@ class TestConstructMessageHistory:
assert result[9] == tool_response # After last user
assert result[10] == reminder # At the very end
def test_project_files_json_format(self) -> None:
def test_context_files_json_format(self) -> None:
"""Test that project files are formatted correctly as JSON."""
system_prompt = create_message("System", MessageType.SYSTEM, 10)
user_msg = create_message("Hello", MessageType.USER, 5)
simple_chat_history = [user_msg]
project_files = create_project_files(num_files=2, tokens_per_file=50)
context_files = create_context_files(num_files=2, tokens_per_file=50)
result = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
project_files=project_files,
context_files=context_files,
available_tokens=1000,
)
@@ -692,7 +692,7 @@ class TestForgottenFileMetadata:
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
project_files=create_project_files(),
context_files=create_context_files(),
available_tokens=available_tokens,
token_counter=_simple_token_counter,
all_injected_file_metadata=all_injected_file_metadata,

View File

@@ -44,7 +44,6 @@ def _build_provider_view(
id=1,
name="test-provider",
provider=provider,
default_model_name="test-model",
model_configurations=[
ModelConfigurationView(
name="test-model",
@@ -62,7 +61,6 @@ def _build_provider_view(
groups=[],
personas=[],
deployment_name=None,
default_vision_model=None,
)

View File

@@ -106,6 +106,9 @@ def test_checkout_event_stores_endpoint_and_increments_gauge() -> None:
patch(
"onyx.server.metrics.postgres_connection_pool.CURRENT_ENDPOINT_CONTEXTVAR"
) as mock_ctx,
patch(
"onyx.server.metrics.postgres_connection_pool.CURRENT_TENANT_ID_CONTEXTVAR"
) as mock_tenant_ctx,
patch(
"onyx.server.metrics.postgres_connection_pool._connections_held"
) as mock_gauge,
@@ -114,12 +117,14 @@ def test_checkout_event_stores_endpoint_and_increments_gauge() -> None:
mock_labels = MagicMock()
mock_gauge.labels.return_value = mock_labels
mock_ctx.get.return_value = "/api/chat/send-message"
mock_tenant_ctx.get.return_value = "tenant_xyz"
listeners["checkout"](None, conn_record, None)
assert conn_record.info["_metrics_endpoint"] == "/api/chat/send-message"
assert conn_record.info["_metrics_tenant_id"] == "tenant_xyz"
assert "_metrics_checkout_time" in conn_record.info
mock_gauge.labels.assert_called_with(
handler="/api/chat/send-message", engine="sync"
handler="/api/chat/send-message", engine="sync", tenant_id="tenant_xyz"
)
mock_labels.inc.assert_called_once()
@@ -144,6 +149,7 @@ def test_checkin_event_observes_hold_duration() -> None:
conn_record = _make_conn_record()
conn_record.info["_metrics_endpoint"] = "/api/search"
conn_record.info["_metrics_tenant_id"] = "tenant_abc"
conn_record.info["_metrics_checkout_time"] = time.monotonic() - 0.5
with (
@@ -162,7 +168,9 @@ def test_checkin_event_observes_hold_duration() -> None:
listeners["checkin"](None, conn_record)
mock_gauge.labels.assert_called_with(handler="/api/search", engine="sync")
mock_gauge.labels.assert_called_with(
handler="/api/search", engine="sync", tenant_id="tenant_abc"
)
mock_labels.dec.assert_called_once()
mock_hist.labels.assert_called_with(handler="/api/search", engine="sync")
mock_hist_labels.observe.assert_called_once()
@@ -172,11 +180,12 @@ def test_checkin_event_observes_hold_duration() -> None:
# conn_record.info should be cleaned up
assert "_metrics_endpoint" not in conn_record.info
assert "_metrics_tenant_id" not in conn_record.info
assert "_metrics_checkout_time" not in conn_record.info
def test_checkin_with_missing_endpoint_uses_unknown() -> None:
"""Verify checkin gracefully handles missing endpoint info."""
"""Verify checkin gracefully handles missing endpoint and tenant info."""
engine = MagicMock()
engine.pool = MagicMock()
listeners: dict[str, Any] = {}
@@ -207,7 +216,9 @@ def test_checkin_with_missing_endpoint_uses_unknown() -> None:
listeners["checkin"](None, conn_record)
mock_gauge.labels.assert_called_with(handler="unknown", engine="sync")
mock_gauge.labels.assert_called_with(
handler="unknown", engine="sync", tenant_id="unknown"
)
# --- setup_postgres_connection_pool_metrics tests ---

View File

@@ -10,6 +10,7 @@ from fastapi.testclient import TestClient
from prometheus_client import CollectorRegistry
from prometheus_client import Gauge
from onyx.server.metrics.per_tenant import per_tenant_request_callback
from onyx.server.metrics.prometheus_setup import setup_prometheus_metrics
from onyx.server.metrics.slow_requests import slow_request_callback
@@ -81,7 +82,7 @@ def test_setup_attaches_instrumentator_to_app() -> None:
inprogress_labels=True,
excluded_handlers=["/health", "/metrics", "/openapi.json"],
)
mock_instance.add.assert_called_once()
assert mock_instance.add.call_count == 3
mock_instance.instrument.assert_called_once_with(
app,
latency_lowr_buckets=(
@@ -100,6 +101,56 @@ def test_setup_attaches_instrumentator_to_app() -> None:
mock_instance.expose.assert_called_once_with(app)
def test_per_tenant_callback_increments_with_tenant_id() -> None:
"""Verify per-tenant callback reads tenant from contextvar and increments."""
with (
patch(
"onyx.server.metrics.per_tenant.CURRENT_TENANT_ID_CONTEXTVAR"
) as mock_ctx,
patch("onyx.server.metrics.per_tenant._requests_by_tenant") as mock_counter,
):
mock_labels = MagicMock()
mock_counter.labels.return_value = mock_labels
mock_ctx.get.return_value = "tenant_abc"
info = _make_info(
duration=0.1, method="POST", handler="/api/chat", status="200"
)
per_tenant_request_callback(info)
mock_counter.labels.assert_called_once_with(
tenant_id="tenant_abc",
method="POST",
handler="/api/chat",
status="200",
)
mock_labels.inc.assert_called_once()
def test_per_tenant_callback_falls_back_to_unknown() -> None:
"""Verify per-tenant callback uses 'unknown' when contextvar is None."""
with (
patch(
"onyx.server.metrics.per_tenant.CURRENT_TENANT_ID_CONTEXTVAR"
) as mock_ctx,
patch("onyx.server.metrics.per_tenant._requests_by_tenant") as mock_counter,
):
mock_labels = MagicMock()
mock_counter.labels.return_value = mock_labels
mock_ctx.get.return_value = None
info = _make_info(duration=0.1)
per_tenant_request_callback(info)
mock_counter.labels.assert_called_once_with(
tenant_id="unknown",
method="GET",
handler="/api/test",
status="200",
)
mock_labels.inc.assert_called_once()
def test_inprogress_gauge_increments_during_request() -> None:
"""Verify the in-progress gauge goes up while a request is in flight."""
registry = CollectorRegistry()

View File

@@ -163,3 +163,16 @@ Add clear comments:
- Any TODOs you add in the code must be accompanied by either the name/username
of the owner of that TODO, or an issue number for an issue referencing that
piece of work.
- Avoid module-level logic that runs on import, which leads to import-time side
effects. Essentially every piece of meaningful logic should exist within some
function that has to be explicitly invoked. Acceptable exceptions to this may
include loading environment variables or setting up loggers.
- If you find yourself needing something like this, you may want that logic to
exist in a file dedicated for manual execution (contains `if __name__ ==
"__main__":`) which should not be imported by anything else.
- Related to the above, do not conflate Python scripts you intend to run from
the command line (contains `if __name__ == "__main__":`) with modules you
intend to import from elsewhere. If for some unlikely reason they have to be
the same file, any logic specific to executing the file (including imports)
should be contained in the `if __name__ == "__main__":` block.
- Generally these executable files exist in `backend/scripts/`.

View File

@@ -468,7 +468,7 @@ services:
- minio_data:/data
command: server /data --console-address ":9001"
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
test: ["CMD", "mc", "ready", "local"]
interval: 30s
timeout: 20s
retries: 3

View File

@@ -293,7 +293,7 @@ services:
- minio_data:/data
command: server /data --console-address ":9001"
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
test: ["CMD", "mc", "ready", "local"]
interval: 30s
timeout: 20s
retries: 3

View File

@@ -298,7 +298,7 @@ services:
- minio_data:/data
command: server /data --console-address ":9001"
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
test: ["CMD", "mc", "ready", "local"]
interval: 30s
timeout: 20s
retries: 3

View File

@@ -335,7 +335,7 @@ services:
- minio_data:/data
command: server /data --console-address ":9001"
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
test: ["CMD", "mc", "ready", "local"]
interval: 30s
timeout: 20s
retries: 3

View File

@@ -232,7 +232,7 @@ services:
- minio_data:/data
command: server /data --console-address ":9001"
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
test: ["CMD", "mc", "ready", "local"]
interval: 30s
timeout: 20s
retries: 3

View File

@@ -520,7 +520,7 @@ services:
- minio_data:/data
command: server /data --console-address ":9001"
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
test: ["CMD", "mc", "ready", "local"]
interval: 30s
timeout: 20s
retries: 3
@@ -534,9 +534,10 @@ services:
required: false
# Below is needed for the `docker-out-of-docker` execution mode
# For Linux rootless Docker, set DOCKER_SOCK_PATH=${XDG_RUNTIME_DIR}/docker.sock
user: root
volumes:
- /var/run/docker.sock:/var/run/docker.sock
- ${DOCKER_SOCK_PATH:-/var/run/docker.sock}:/var/run/docker.sock
# uncomment below + comment out the above to use the `docker-in-docker` execution mode
# privileged: true

View File

@@ -10,7 +10,7 @@ requires-python = ">=3.11"
dependencies = [
"aioboto3==15.1.0",
"cohere==5.6.1",
"fastapi==0.128.0",
"fastapi==0.133.1",
"google-cloud-aiplatform==1.121.0",
"google-genai==1.52.0",
"litellm==1.81.6",
@@ -92,7 +92,7 @@ backend = [
"python-gitlab==5.6.0",
"python-pptx==0.6.23",
"pypandoc_binary==1.16.2",
"pypdf==6.6.2",
"pypdf==6.7.3",
"pytest-mock==3.12.0",
"pytest-playwright==0.7.0",
"python-docx==1.1.2",

View File

@@ -51,6 +51,7 @@ func NewRootCommand() *cobra.Command {
cmd.AddCommand(NewRunCICommand())
cmd.AddCommand(NewScreenshotDiffCommand())
cmd.AddCommand(NewWebCommand())
cmd.AddCommand(NewWhoisCommand())
return cmd
}

159
tools/ods/cmd/whois.go Normal file
View File

@@ -0,0 +1,159 @@
package cmd
import (
"fmt"
"os"
"regexp"
"strings"
"text/tabwriter"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/onyx-dot-app/onyx/tools/ods/internal/kube"
)
var safeIdentifier = regexp.MustCompile(`^[a-zA-Z0-9_\-]+$`)
// NewWhoisCommand creates the whois command for looking up users/tenants.
func NewWhoisCommand() *cobra.Command {
var ctx string
cmd := &cobra.Command{
Use: "whois <email-fragment or tenant-id>",
Short: "Look up users and admins by email or tenant ID",
Long: `Look up tenant and user information from the data plane PostgreSQL database.
Requires: AWS SSO login, kubectl access to the EKS cluster.
Two modes (auto-detected):
Email fragment:
ods whois chris
→ Searches user_tenant_mapping for emails matching '%chris%'
Tenant ID:
ods whois tenant_abcd1234-...
→ Lists all admin emails in that tenant
Cluster connection is configured via KUBE_CTX_* environment variables.
Each variable is a space-separated tuple: "cluster region namespace"
export KUBE_CTX_DATA_PLANE="<cluster> <region> <namespace>"
export KUBE_CTX_CONTROL_PLANE="<cluster> <region> <namespace>"
etc...
Use -c to select which context (default: data_plane).`,
Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) {
runWhois(args[0], ctx)
},
}
cmd.Flags().StringVarP(&ctx, "context", "c", "data_plane", "cluster context name (maps to KUBE_CTX_<NAME> env var)")
return cmd
}
func clusterFromEnv(name string) *kube.Cluster {
envKey := "KUBE_CTX_" + strings.ToUpper(name)
val := os.Getenv(envKey)
if val == "" {
log.Fatalf("Environment variable %s is not set.\n\nSet it as a space-separated tuple:\n export %s=\"<cluster> <region> <namespace>\"", envKey, envKey)
}
parts := strings.Fields(val)
if len(parts) != 3 {
log.Fatalf("%s must be a space-separated tuple of 3 values (cluster region namespace), got: %q", envKey, val)
}
return &kube.Cluster{Name: parts[0], Region: parts[1], Namespace: parts[2]}
}
// queryPod runs a SQL query via pginto on the given pod and returns cleaned output lines.
func queryPod(c *kube.Cluster, pod, sql string) []string {
raw, err := c.ExecOnPod(pod, "pginto", "-A", "-t", "-F", "\t", "-c", sql)
if err != nil {
log.Fatalf("Query failed: %v", err)
}
var lines []string
for _, line := range strings.Split(strings.TrimSpace(raw), "\n") {
line = strings.TrimSpace(line)
if line != "" && !strings.HasPrefix(line, "Connecting to ") {
lines = append(lines, line)
}
}
return lines
}
func runWhois(query string, ctx string) {
c := clusterFromEnv(ctx)
if err := c.EnsureContext(); err != nil {
log.Fatalf("Failed to ensure cluster context: %v", err)
}
log.Info("Finding api-server pod...")
pod, err := c.FindPod("api-server")
if err != nil {
log.Fatalf("Failed to find api-server pod: %v", err)
}
log.Debugf("Using pod: %s", pod)
if strings.HasPrefix(query, "tenant_") {
findAdminsByTenant(c, pod, query)
} else {
findByEmail(c, pod, query)
}
}
func findByEmail(c *kube.Cluster, pod, fragment string) {
fragment = strings.NewReplacer("'", "", `"`, "", `;`, "", `\`, `\\`, `%`, `\%`, `_`, `\_`).Replace(fragment)
sql := fmt.Sprintf(
`SELECT email, tenant_id, active FROM public.user_tenant_mapping WHERE email LIKE '%%%s%%' ORDER BY email;`,
fragment,
)
log.Infof("Searching for emails matching '%%%s%%'...", fragment)
lines := queryPod(c, pod, sql)
if len(lines) == 0 {
fmt.Println("No results found.")
return
}
fmt.Println()
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
_, _ = fmt.Fprintln(w, "EMAIL\tTENANT ID\tACTIVE")
_, _ = fmt.Fprintln(w, "-----\t---------\t------")
for _, line := range lines {
_, _ = fmt.Fprintln(w, line)
}
_ = w.Flush()
}
func findAdminsByTenant(c *kube.Cluster, pod, tenantID string) {
if !safeIdentifier.MatchString(tenantID) {
log.Fatalf("Invalid tenant ID: %q (must be alphanumeric, hyphens, underscores only)", tenantID)
}
sql := fmt.Sprintf(
`SELECT email FROM "%s"."user" WHERE role = 'ADMIN' AND is_active = true AND email NOT LIKE 'api_key__%%' ORDER BY email;`,
tenantID,
)
log.Infof("Fetching admin emails for %s...", tenantID)
lines := queryPod(c, pod, sql)
if len(lines) == 0 {
fmt.Println("No admin users found for this tenant.")
return
}
fmt.Println()
fmt.Println("EMAIL")
fmt.Println("-----")
for _, line := range lines {
fmt.Println(line)
}
}

View File

@@ -0,0 +1,90 @@
package kube
import (
"bytes"
"fmt"
"os/exec"
"strings"
log "github.com/sirupsen/logrus"
)
// Cluster holds the connection info for a Kubernetes cluster.
type Cluster struct {
Name string
Region string
Namespace string
}
// EnsureContext makes sure the cluster exists in kubeconfig, calling
// aws eks update-kubeconfig only if the context is missing.
func (c *Cluster) EnsureContext() error {
// Check if context already exists in kubeconfig
cmd := exec.Command("kubectl", "config", "get-contexts", c.Name, "--no-headers")
if err := cmd.Run(); err == nil {
log.Debugf("Context %s already exists, skipping aws eks update-kubeconfig", c.Name)
return nil
}
log.Infof("Context %s not found, fetching kubeconfig from AWS...", c.Name)
cmd = exec.Command("aws", "eks", "update-kubeconfig", "--region", c.Region, "--name", c.Name, "--alias", c.Name)
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("aws eks update-kubeconfig failed: %w\n%s", err, string(out))
}
return nil
}
// kubectlArgs returns common kubectl flags to target this cluster without mutating global context.
func (c *Cluster) kubectlArgs() []string {
return []string{"--context", c.Name, "--namespace", c.Namespace}
}
// FindPod returns the name of the first Running/Ready pod matching the given substring.
func (c *Cluster) FindPod(substring string) (string, error) {
args := append(c.kubectlArgs(), "get", "po",
"--field-selector", "status.phase=Running",
"--no-headers",
"-o", "custom-columns=NAME:.metadata.name,READY:.status.conditions[?(@.type=='Ready')].status",
)
cmd := exec.Command("kubectl", args...)
out, err := cmd.Output()
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
return "", fmt.Errorf("kubectl get po failed: %w\n%s", err, string(exitErr.Stderr))
}
return "", fmt.Errorf("kubectl get po failed: %w", err)
}
for _, line := range strings.Split(strings.TrimSpace(string(out)), "\n") {
fields := strings.Fields(line)
if len(fields) < 2 {
continue
}
name, ready := fields[0], fields[1]
if strings.Contains(name, substring) && ready == "True" {
log.Debugf("Found pod: %s", name)
return name, nil
}
}
return "", fmt.Errorf("no ready pod found matching %q", substring)
}
// ExecOnPod runs a command on a pod and returns its stdout.
func (c *Cluster) ExecOnPod(pod string, command ...string) (string, error) {
args := append(c.kubectlArgs(), "exec", pod, "--")
args = append(args, command...)
log.Debugf("Running: kubectl %s", strings.Join(args, " "))
cmd := exec.Command("kubectl", args...)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("kubectl exec failed: %w\n%s", err, stderr.String())
}
return stdout.String(), nil
}

23
uv.lock generated
View File

@@ -1688,17 +1688,18 @@ wheels = [
[[package]]
name = "fastapi"
version = "0.128.0"
version = "0.133.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "annotated-doc" },
{ name = "pydantic" },
{ name = "starlette" },
{ name = "typing-extensions" },
{ name = "typing-inspection" },
]
sdist = { url = "https://files.pythonhosted.org/packages/52/08/8c8508db6c7b9aae8f7175046af41baad690771c9bcde676419965e338c7/fastapi-0.128.0.tar.gz", hash = "sha256:1cc179e1cef10a6be60ffe429f79b829dce99d8de32d7acb7e6c8dfdf7f2645a", size = 365682, upload-time = "2025-12-27T15:21:13.714Z" }
sdist = { url = "https://files.pythonhosted.org/packages/22/6f/0eafed8349eea1fa462238b54a624c8b408cd1ba2795c8e64aa6c34f8ab7/fastapi-0.133.1.tar.gz", hash = "sha256:ed152a45912f102592976fde6cbce7dae1a8a1053da94202e51dd35d184fadd6", size = 378741, upload-time = "2026-02-25T18:18:17.398Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/5c/05/5cbb59154b093548acd0f4c7c474a118eda06da25aa75c616b72d8fcd92a/fastapi-0.128.0-py3-none-any.whl", hash = "sha256:aebd93f9716ee3b4f4fcfe13ffb7cf308d99c9f3ab5622d8877441072561582d", size = 103094, upload-time = "2025-12-27T15:21:12.154Z" },
{ url = "https://files.pythonhosted.org/packages/d2/c9/a175a7779f3599dfa4adfc97a6ce0e157237b3d7941538604aadaf97bfb6/fastapi-0.133.1-py3-none-any.whl", hash = "sha256:658f34ba334605b1617a65adf2ea6461901bdb9af3a3080d63ff791ecf7dc2e2", size = 109029, upload-time = "2026-02-25T18:18:18.578Z" },
]
[[package]]
@@ -4612,7 +4613,7 @@ requires-dist = [
{ name = "einops", marker = "extra == 'model-server'", specifier = "==0.8.1" },
{ name = "exa-py", marker = "extra == 'backend'", specifier = "==1.15.4" },
{ name = "faker", marker = "extra == 'dev'", specifier = "==40.1.2" },
{ name = "fastapi", specifier = "==0.128.0" },
{ name = "fastapi", specifier = "==0.133.1" },
{ name = "fastapi-limiter", marker = "extra == 'backend'", specifier = "==0.1.6" },
{ name = "fastapi-users", marker = "extra == 'backend'", specifier = "==15.0.4" },
{ name = "fastapi-users-db-sqlalchemy", marker = "extra == 'backend'", specifier = "==7.0.0" },
@@ -4677,7 +4678,7 @@ requires-dist = [
{ name = "pygithub", marker = "extra == 'backend'", specifier = "==2.5.0" },
{ name = "pympler", marker = "extra == 'backend'", specifier = "==1.1" },
{ name = "pypandoc-binary", marker = "extra == 'backend'", specifier = "==1.16.2" },
{ name = "pypdf", marker = "extra == 'backend'", specifier = "==6.6.2" },
{ name = "pypdf", marker = "extra == 'backend'", specifier = "==6.7.3" },
{ name = "pytest", marker = "extra == 'dev'", specifier = "==8.3.5" },
{ name = "pytest-alembic", marker = "extra == 'dev'", specifier = "==0.12.1" },
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = "==1.3.0" },
@@ -5924,11 +5925,11 @@ wheels = [
[[package]]
name = "pypdf"
version = "6.6.2"
version = "6.7.3"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/b8/bb/a44bab1ac3c54dbcf653d7b8bcdee93dddb2d3bf025a3912cacb8149a2f2/pypdf-6.6.2.tar.gz", hash = "sha256:0a3ea3b3303982333404e22d8f75d7b3144f9cf4b2970b96856391a516f9f016", size = 5281850, upload-time = "2026-01-26T11:57:55.964Z" }
sdist = { url = "https://files.pythonhosted.org/packages/53/9b/63e767042fc852384dc71e5ff6f990ee4e1b165b1526cf3f9c23a4eebb47/pypdf-6.7.3.tar.gz", hash = "sha256:eca55c78d0ec7baa06f9288e2be5c4e8242d5cbb62c7a4b94f2716f8e50076d2", size = 5303304, upload-time = "2026-02-24T17:23:11.42Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/7d/be/549aaf1dfa4ab4aed29b09703d2fb02c4366fc1f05e880948c296c5764b9/pypdf-6.6.2-py3-none-any.whl", hash = "sha256:44c0c9811cfb3b83b28f1c3d054531d5b8b81abaedee0d8cb403650d023832ba", size = 329132, upload-time = "2026-01-26T11:57:54.099Z" },
{ url = "https://files.pythonhosted.org/packages/b0/90/3308a9b8b46c1424181fdf3f4580d2b423c5471425799e7fc62f92d183f4/pypdf-6.7.3-py3-none-any.whl", hash = "sha256:cd25ac508f20b554a9fafd825186e3ba29591a69b78c156783c5d8a2d63a1c0a", size = 331263, upload-time = "2026-02-24T17:23:09.932Z" },
]
[[package]]
@@ -8079,14 +8080,14 @@ wheels = [
[[package]]
name = "werkzeug"
version = "3.1.5"
version = "3.1.6"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "markupsafe" },
]
sdist = { url = "https://files.pythonhosted.org/packages/5a/70/1469ef1d3542ae7c2c7b72bd5e3a4e6ee69d7978fa8a3af05a38eca5becf/werkzeug-3.1.5.tar.gz", hash = "sha256:6a548b0e88955dd07ccb25539d7d0cc97417ee9e179677d22c7041c8f078ce67", size = 864754, upload-time = "2026-01-08T17:49:23.247Z" }
sdist = { url = "https://files.pythonhosted.org/packages/61/f1/ee81806690a87dab5f5653c1f146c92bc066d7f4cebc603ef88eb9e13957/werkzeug-3.1.6.tar.gz", hash = "sha256:210c6bede5a420a913956b4791a7f4d6843a43b6fcee4dfa08a65e93007d0d25", size = 864736, upload-time = "2026-02-19T15:17:18.884Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ad/e4/8d97cca767bcc1be76d16fb76951608305561c6e056811587f36cb1316a8/werkzeug-3.1.5-py3-none-any.whl", hash = "sha256:5111e36e91086ece91f93268bb39b4a35c1e6f1feac762c9c822ded0a4e322dc", size = 225025, upload-time = "2026-01-08T17:49:21.859Z" },
{ url = "https://files.pythonhosted.org/packages/4d/ec/d58832f89ede95652fd01f4f24236af7d32b70cab2196dfcc2d2fd13c5c2/werkzeug-3.1.6-py3-none-any.whl", hash = "sha256:7ddf3357bb9564e407607f988f683d72038551200c704012bb9a4c523d42f131", size = 225166, upload-time = "2026-02-19T15:17:17.475Z" },
]
[[package]]

View File

@@ -42,7 +42,7 @@ import SvgStar from "@opal/icons/star";
## Usage inside Content
Tag can be rendered as an accessory inside `Content`'s LabelLayout via the `tag` prop:
Tag can be rendered as an accessory inside `Content`'s ContentMd via the `tag` prop:
```tsx
import { Content } from "@opal/layouts";

View File

@@ -1,11 +1,7 @@
import "@opal/components/buttons/Button/styles.css";
import "@opal/components/tooltip.css";
import {
Interactive,
type InteractiveBaseProps,
type InteractiveContainerWidthVariant,
} from "@opal/core";
import type { SizeVariant } from "@opal/shared";
import { Interactive, type InteractiveBaseProps } from "@opal/core";
import type { SizeVariant, WidthVariant } from "@opal/shared";
import type { TooltipSide } from "@opal/components";
import type { IconFunctionComponent } from "@opal/types";
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
@@ -91,7 +87,7 @@ type ButtonProps = InteractiveBaseProps &
tooltip?: string;
/** Width preset. `"auto"` shrink-wraps, `"full"` stretches to parent width. */
width?: InteractiveContainerWidthVariant;
width?: WidthVariant;
/** Which side the tooltip appears on. */
tooltipSide?: TooltipSide;

View File

@@ -0,0 +1,233 @@
import "@opal/core/hoverable/styles.css";
import React, { createContext, useContext, useState, useCallback } from "react";
import { cn } from "@opal/utils";
import type { WithoutStyles } from "@opal/types";
// ---------------------------------------------------------------------------
// Context-per-group registry
// ---------------------------------------------------------------------------
/**
* Lazily-created map of group names to React contexts.
*
* Each group gets its own `React.Context<boolean | null>` so that a
* `Hoverable.Item` only re-renders when its *own* group's hover state
* changes — not when any unrelated group changes.
*
* The default value is `null` (no provider found), which lets
* `Hoverable.Item` distinguish "no Root ancestor" from "Root says
* not hovered" and throw when `group` was explicitly specified.
*/
const contextMap = new Map<string, React.Context<boolean | null>>();
function getOrCreateContext(group: string): React.Context<boolean | null> {
let ctx = contextMap.get(group);
if (!ctx) {
ctx = createContext<boolean | null>(null);
ctx.displayName = `HoverableContext(${group})`;
contextMap.set(group, ctx);
}
return ctx;
}
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
interface HoverableRootProps
extends WithoutStyles<React.HTMLAttributes<HTMLDivElement>> {
children: React.ReactNode;
group: string;
}
type HoverableItemVariant = "opacity-on-hover";
interface HoverableItemProps
extends WithoutStyles<React.HTMLAttributes<HTMLDivElement>> {
children: React.ReactNode;
group?: string;
variant?: HoverableItemVariant;
}
// ---------------------------------------------------------------------------
// HoverableRoot
// ---------------------------------------------------------------------------
/**
* Hover-tracking container for a named group.
*
* Wraps children in a `<div>` that tracks mouse-enter / mouse-leave and
* provides the hover state via a per-group React context.
*
* Nesting works because each `Hoverable.Root` creates a **new** context
* provider that shadows the parent — so an inner `Hoverable.Item group="b"`
* reads from the inner provider, not the outer `group="a"` provider.
*
* @example
* ```tsx
* <Hoverable.Root group="card">
* <Card>
* <Hoverable.Item group="card" variant="opacity-on-hover">
* <TrashIcon />
* </Hoverable.Item>
* </Card>
* </Hoverable.Root>
* ```
*/
function HoverableRoot({
group,
children,
onMouseEnter: consumerMouseEnter,
onMouseLeave: consumerMouseLeave,
...props
}: HoverableRootProps) {
const [hovered, setHovered] = useState(false);
const onMouseEnter = useCallback(
(e: React.MouseEvent<HTMLDivElement>) => {
setHovered(true);
consumerMouseEnter?.(e);
},
[consumerMouseEnter]
);
const onMouseLeave = useCallback(
(e: React.MouseEvent<HTMLDivElement>) => {
setHovered(false);
consumerMouseLeave?.(e);
},
[consumerMouseLeave]
);
const GroupContext = getOrCreateContext(group);
return (
<GroupContext.Provider value={hovered}>
<div {...props} onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}>
{children}
</div>
</GroupContext.Provider>
);
}
// ---------------------------------------------------------------------------
// HoverableItem
// ---------------------------------------------------------------------------
/**
* An element whose visibility is controlled by hover state.
*
* **Local mode** (`group` omitted): the item handles hover on its own
* element via CSS `:hover`. This is the core abstraction.
*
* **Group mode** (`group` provided): visibility is driven by a matching
* `Hoverable.Root` ancestor's hover state via React context. If no
* matching Root is found, an error is thrown.
*
* Uses data-attributes for variant styling (see `styles.css`).
*
* @example
* ```tsx
* // Local mode — hover on the item itself
* <Hoverable.Item variant="opacity-on-hover">
* <TrashIcon />
* </Hoverable.Item>
*
* // Group mode — hover on the Root reveals the item
* <Hoverable.Root group="card">
* <Hoverable.Item group="card" variant="opacity-on-hover">
* <TrashIcon />
* </Hoverable.Item>
* </Hoverable.Root>
* ```
*
* @throws If `group` is specified but no matching `Hoverable.Root` ancestor exists.
*/
function HoverableItem({
group,
variant = "opacity-on-hover",
children,
...props
}: HoverableItemProps) {
const contextValue = useContext(
group ? getOrCreateContext(group) : NOOP_CONTEXT
);
if (group && contextValue === null) {
throw new Error(
`Hoverable.Item group="${group}" has no matching Hoverable.Root ancestor. ` +
`Either wrap it in <Hoverable.Root group="${group}"> or remove the group prop for local hover.`
);
}
const isLocal = group === undefined;
return (
<div
{...props}
className={cn("hoverable-item")}
data-hoverable-variant={variant}
data-hoverable-active={
isLocal ? undefined : contextValue ? "true" : undefined
}
data-hoverable-local={isLocal ? "true" : undefined}
>
{children}
</div>
);
}
/** Stable context used when no group is specified (local mode). */
const NOOP_CONTEXT = createContext<boolean | null>(null);
// ---------------------------------------------------------------------------
// Compound export
// ---------------------------------------------------------------------------
/**
* Hoverable compound component for hover-to-reveal patterns.
*
* Provides two sub-components:
*
* - `Hoverable.Root` — A container that tracks hover state for a named group
* and provides it via React context.
*
* - `Hoverable.Item` — The core abstraction. On its own (no `group`), it
* applies local CSS `:hover` for the variant effect. When `group` is
* specified, it reads hover state from the nearest matching
* `Hoverable.Root` — and throws if no matching Root is found.
*
* Supports nesting: a child `Hoverable.Root` shadows the parent's context,
* so each group's items only respond to their own root's hover.
*
* @example
* ```tsx
* import { Hoverable } from "@opal/core";
*
* // Group mode — hovering the card reveals the trash icon
* <Hoverable.Root group="card">
* <Card>
* <span>Card content</span>
* <Hoverable.Item group="card" variant="opacity-on-hover">
* <TrashIcon />
* </Hoverable.Item>
* </Card>
* </Hoverable.Root>
*
* // Local mode — hovering the item itself reveals it
* <Hoverable.Item variant="opacity-on-hover">
* <TrashIcon />
* </Hoverable.Item>
* ```
*/
const Hoverable = {
Root: HoverableRoot,
Item: HoverableItem,
};
export {
Hoverable,
type HoverableRootProps,
type HoverableItemProps,
type HoverableItemVariant,
};

View File

@@ -0,0 +1,18 @@
/* Hoverable — item transitions */
.hoverable-item {
transition: opacity 150ms ease-in-out;
}
.hoverable-item[data-hoverable-variant="opacity-on-hover"] {
opacity: 0;
}
/* Group mode — Root controls visibility via React context */
.hoverable-item[data-hoverable-variant="opacity-on-hover"][data-hoverable-active="true"] {
opacity: 1;
}
/* Local mode — item handles its own :hover */
.hoverable-item[data-hoverable-variant="opacity-on-hover"][data-hoverable-local="true"]:hover {
opacity: 1;
}

View File

@@ -1,9 +1,16 @@
/* Hoverable */
export {
Hoverable,
type HoverableRootProps,
type HoverableItemProps,
type HoverableItemVariant,
} from "@opal/core/hoverable/components";
/* Interactive */
export {
Interactive,
type InteractiveBaseProps,
type InteractiveBaseVariantProps,
type InteractiveContainerProps,
type InteractiveContainerWidthVariant,
type InteractiveContainerRoundingVariant,
} from "@opal/core/interactive/components";

View File

@@ -3,7 +3,12 @@ import React from "react";
import { Slot } from "@radix-ui/react-slot";
import { cn } from "@opal/utils";
import type { WithoutStyles } from "@opal/types";
import { sizeVariants, type SizeVariant } from "@opal/shared";
import {
sizeVariants,
type SizeVariant,
widthVariants,
type WidthVariant,
} from "@opal/shared";
// ---------------------------------------------------------------------------
// Types
@@ -39,18 +44,6 @@ type InteractiveBaseVariantProps =
selected?: never;
};
/**
* Width presets for `Interactive.Container`.
*
* - `"auto"` — Shrink-wraps to content width (default)
* - `"full"` — Stretches to fill the parent's width (`w-full`)
*/
type InteractiveContainerWidthVariant = "auto" | "full";
const interactiveContainerWidthVariants = {
auto: "w-auto",
full: "w-full",
} as const;
/**
* Border-radius presets for `Interactive.Container`.
*
@@ -345,7 +338,7 @@ interface InteractiveContainerProps
*
* @default "auto"
*/
widthVariant?: InteractiveContainerWidthVariant;
widthVariant?: WidthVariant;
}
/**
@@ -413,7 +406,7 @@ function InteractiveContainer({
height,
minWidth,
padding,
interactiveContainerWidthVariants[widthVariant],
widthVariants[widthVariant],
slotClassName
),
"data-border": border ? ("true" as const) : undefined,
@@ -490,6 +483,5 @@ export {
type InteractiveBaseVariantProps,
type InteractiveBaseSelectVariantProps,
type InteractiveContainerProps,
type InteractiveContainerWidthVariant,
type InteractiveContainerRoundingVariant,
};

Some files were not shown because too many files have changed in this diff Show More