mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-02 14:15:44 +00:00
Compare commits
51 Commits
v3.0.0-clo
...
content-re
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8f5d7e271a | ||
|
|
bb6e20614d | ||
|
|
bd9319e592 | ||
|
|
db5955d6f2 | ||
|
|
5e447440ea | ||
|
|
78c6ca39b8 | ||
|
|
71a7cf09b3 | ||
|
|
91d30a0156 | ||
|
|
7b30752767 | ||
|
|
4450ecf07c | ||
|
|
0e6b766996 | ||
|
|
12c8cd338b | ||
|
|
ad5688bf65 | ||
|
|
d2deefd1f1 | ||
|
|
18b90d405d | ||
|
|
8394e8837b | ||
|
|
f06df891c4 | ||
|
|
d6d5e72c18 | ||
|
|
449f5d62f9 | ||
|
|
4d256c5666 | ||
|
|
2e53496f46 | ||
|
|
63a206706a | ||
|
|
28427b3e5f | ||
|
|
3cafcd8a5e | ||
|
|
f2c50b7bb5 | ||
|
|
6b28c6bbfc | ||
|
|
226e801665 | ||
|
|
be13aa1310 | ||
|
|
45d38c4906 | ||
|
|
8aab518532 | ||
|
|
da6ce10e86 | ||
|
|
aaf8253520 | ||
|
|
7c7f81b164 | ||
|
|
2d4a3c72e9 | ||
|
|
7c51712018 | ||
|
|
aa5614695d | ||
|
|
8d7255d3c4 | ||
|
|
d403498f48 | ||
|
|
9ef3095c17 | ||
|
|
a39e93a0cb | ||
|
|
46d73cdfee | ||
|
|
1e04ce78e0 | ||
|
|
f9b81c1725 | ||
|
|
3bc1b89fee | ||
|
|
01743d99d4 | ||
|
|
092c1db7e0 | ||
|
|
40ac0d859a | ||
|
|
929e58361f | ||
|
|
6d472df7c5 | ||
|
|
cfa7acd904 | ||
|
|
5c5a6f943b |
@@ -9,7 +9,8 @@ inputs:
|
||||
required: true
|
||||
provider-api-key:
|
||||
description: "API key for NIGHTLY_LLM_API_KEY"
|
||||
required: true
|
||||
required: false
|
||||
default: ""
|
||||
strict:
|
||||
description: "String true/false for NIGHTLY_LLM_STRICT"
|
||||
required: true
|
||||
@@ -17,6 +18,14 @@ inputs:
|
||||
description: "Optional NIGHTLY_LLM_API_BASE"
|
||||
required: false
|
||||
default: ""
|
||||
api-version:
|
||||
description: "Optional NIGHTLY_LLM_API_VERSION"
|
||||
required: false
|
||||
default: ""
|
||||
deployment-name:
|
||||
description: "Optional NIGHTLY_LLM_DEPLOYMENT_NAME"
|
||||
required: false
|
||||
default: ""
|
||||
custom-config-json:
|
||||
description: "Optional NIGHTLY_LLM_CUSTOM_CONFIG_JSON"
|
||||
required: false
|
||||
@@ -59,6 +68,7 @@ runs:
|
||||
DISABLE_TELEMETRY=true
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
|
||||
AWS_REGION_NAME=us-west-2
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:nightly-llm-it-backend-${RUN_ID}
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:nightly-llm-it-model-server-${RUN_ID}
|
||||
EOF2
|
||||
@@ -82,6 +92,8 @@ runs:
|
||||
NIGHTLY_LLM_PROVIDER: ${{ inputs.provider }}
|
||||
NIGHTLY_LLM_API_KEY: ${{ inputs.provider-api-key }}
|
||||
NIGHTLY_LLM_API_BASE: ${{ inputs.api-base }}
|
||||
NIGHTLY_LLM_API_VERSION: ${{ inputs.api-version }}
|
||||
NIGHTLY_LLM_DEPLOYMENT_NAME: ${{ inputs.deployment-name }}
|
||||
NIGHTLY_LLM_CUSTOM_CONFIG_JSON: ${{ inputs.custom-config-json }}
|
||||
NIGHTLY_LLM_STRICT: ${{ inputs.strict }}
|
||||
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
|
||||
@@ -91,11 +103,6 @@ runs:
|
||||
max_attempts: 2
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
if [ -z "${MODELS}" ]; then
|
||||
echo "Input 'models' must be non-empty for provider '${NIGHTLY_LLM_PROVIDER}'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
@@ -110,10 +117,13 @@ runs:
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e AWS_REGION_NAME=us-west-2 \
|
||||
-e NIGHTLY_LLM_PROVIDER="${NIGHTLY_LLM_PROVIDER}" \
|
||||
-e NIGHTLY_LLM_MODELS="${MODELS}" \
|
||||
-e NIGHTLY_LLM_API_KEY="${NIGHTLY_LLM_API_KEY}" \
|
||||
-e NIGHTLY_LLM_API_BASE="${NIGHTLY_LLM_API_BASE}" \
|
||||
-e NIGHTLY_LLM_API_VERSION="${NIGHTLY_LLM_API_VERSION}" \
|
||||
-e NIGHTLY_LLM_DEPLOYMENT_NAME="${NIGHTLY_LLM_DEPLOYMENT_NAME}" \
|
||||
-e NIGHTLY_LLM_CUSTOM_CONFIG_JSON="${NIGHTLY_LLM_CUSTOM_CONFIG_JSON}" \
|
||||
-e NIGHTLY_LLM_STRICT="${NIGHTLY_LLM_STRICT}" \
|
||||
${RUNS_ON_ECR_CACHE}:nightly-llm-it-${RUN_ID} \
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
name: Nightly LLM Provider Chat Tests (OpenAI)
|
||||
concurrency:
|
||||
group: Nightly-LLM-Provider-Chat-OpenAI-${{ github.workflow }}-${{ github.ref_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
schedule:
|
||||
# Runs daily at 10:30 UTC (2:30 AM PST / 3:30 AM PDT)
|
||||
- cron: "30 10 * * *"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
openai-provider-chat-test:
|
||||
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
|
||||
with:
|
||||
provider: openai
|
||||
models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
|
||||
strict: true
|
||||
secrets:
|
||||
provider_api_key: ${{ secrets.OPENAI_API_KEY }}
|
||||
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs: [openai-provider-chat-test]
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 5
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.SLACK_WEBHOOK }}
|
||||
failed-jobs: openai-provider-chat-test
|
||||
title: "🚨 Scheduled OpenAI Provider Chat Tests failed!"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
56
.github/workflows/nightly-llm-provider-chat.yml
vendored
Normal file
56
.github/workflows/nightly-llm-provider-chat.yml
vendored
Normal file
@@ -0,0 +1,56 @@
|
||||
name: Nightly LLM Provider Chat Tests
|
||||
concurrency:
|
||||
group: Nightly-LLM-Provider-Chat-${{ github.workflow }}-${{ github.ref_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
schedule:
|
||||
# Runs daily at 10:30 UTC (2:30 AM PST / 3:30 AM PDT)
|
||||
- cron: "30 10 * * *"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
provider-chat-test:
|
||||
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
|
||||
with:
|
||||
openai_models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
|
||||
anthropic_models: ${{ vars.NIGHTLY_LLM_ANTHROPIC_MODELS }}
|
||||
bedrock_models: ${{ vars.NIGHTLY_LLM_BEDROCK_MODELS }}
|
||||
vertex_ai_models: ${{ vars.NIGHTLY_LLM_VERTEX_AI_MODELS }}
|
||||
azure_models: ${{ vars.NIGHTLY_LLM_AZURE_MODELS }}
|
||||
azure_api_base: ${{ vars.NIGHTLY_LLM_AZURE_API_BASE }}
|
||||
ollama_models: ${{ vars.NIGHTLY_LLM_OLLAMA_MODELS }}
|
||||
openrouter_models: ${{ vars.NIGHTLY_LLM_OPENROUTER_MODELS }}
|
||||
strict: true
|
||||
secrets:
|
||||
openai_api_key: ${{ secrets.OPENAI_API_KEY }}
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
bedrock_api_key: ${{ secrets.BEDROCK_API_KEY }}
|
||||
vertex_ai_custom_config_json: ${{ secrets.NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON }}
|
||||
azure_api_key: ${{ secrets.AZURE_API_KEY }}
|
||||
ollama_api_key: ${{ secrets.OLLAMA_API_KEY }}
|
||||
openrouter_api_key: ${{ secrets.OPENROUTER_API_KEY }}
|
||||
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs: [provider-chat-test]
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 5
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.SLACK_WEBHOOK }}
|
||||
failed-jobs: provider-chat-test
|
||||
title: "🚨 Scheduled LLM Provider Chat Tests failed!"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
@@ -114,8 +114,10 @@ jobs:
|
||||
|
||||
- name: Mark workflow as failed if cherry-pick failed
|
||||
if: steps.gate.outputs.should_cherrypick == 'true' && steps.run_cherry_pick.outputs.status == 'failure'
|
||||
env:
|
||||
CHERRY_PICK_REASON: ${{ steps.run_cherry_pick.outputs.reason }}
|
||||
run: |
|
||||
echo "::error::Automated cherry-pick failed (${{ steps.run_cherry_pick.outputs.reason }})."
|
||||
echo "::error::Automated cherry-pick failed (${CHERRY_PICK_REASON})."
|
||||
exit 1
|
||||
|
||||
notify-slack-on-cherry-pick-failure:
|
||||
|
||||
2
.github/workflows/pr-playwright-tests.yml
vendored
2
.github/workflows/pr-playwright-tests.yml
vendored
@@ -603,7 +603,7 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Download visual diff summaries
|
||||
uses: actions/download-artifact@95815c38cf2ff2164869cbab79da8d1f422bc89e # ratchet:actions/download-artifact@v4
|
||||
uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131
|
||||
with:
|
||||
pattern: screenshot-diff-summary-*
|
||||
path: summaries/
|
||||
|
||||
@@ -89,6 +89,10 @@ env:
|
||||
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
|
||||
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ vars.SHAREPOINT_CLIENT_DIRECTORY_ID }}
|
||||
SHAREPOINT_SITE: ${{ vars.SHAREPOINT_SITE }}
|
||||
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
|
||||
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
|
||||
|
||||
# Github
|
||||
ACCESS_TOKEN_GITHUB: ${{ secrets.ACCESS_TOKEN_GITHUB }}
|
||||
|
||||
@@ -3,33 +3,66 @@ name: Reusable Nightly LLM Provider Chat Tests
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
provider:
|
||||
description: "Provider slug passed to NIGHTLY_LLM_PROVIDER (e.g. openai, anthropic)"
|
||||
required: true
|
||||
openai_models:
|
||||
description: "Comma-separated models for openai"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
models:
|
||||
description: "Comma-separated model list passed to NIGHTLY_LLM_MODELS"
|
||||
required: true
|
||||
anthropic_models:
|
||||
description: "Comma-separated models for anthropic"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
bedrock_models:
|
||||
description: "Comma-separated models for bedrock"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
vertex_ai_models:
|
||||
description: "Comma-separated models for vertex_ai"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
azure_models:
|
||||
description: "Comma-separated models for azure"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
ollama_models:
|
||||
description: "Comma-separated models for ollama_chat"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
openrouter_models:
|
||||
description: "Comma-separated models for openrouter"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
azure_api_base:
|
||||
description: "API base for azure provider"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
strict:
|
||||
description: "Pass-through value for NIGHTLY_LLM_STRICT"
|
||||
description: "Default NIGHTLY_LLM_STRICT passed to tests"
|
||||
required: false
|
||||
default: true
|
||||
type: boolean
|
||||
api_base:
|
||||
description: "Optional NIGHTLY_LLM_API_BASE override"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
custom_config_json:
|
||||
description: "Optional NIGHTLY_LLM_CUSTOM_CONFIG_JSON override"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
secrets:
|
||||
provider_api_key:
|
||||
description: "Provider API key passed to NIGHTLY_LLM_API_KEY"
|
||||
required: true
|
||||
openai_api_key:
|
||||
required: false
|
||||
anthropic_api_key:
|
||||
required: false
|
||||
bedrock_api_key:
|
||||
required: false
|
||||
vertex_ai_custom_config_json:
|
||||
required: false
|
||||
azure_api_key:
|
||||
required: false
|
||||
ollama_api_key:
|
||||
required: false
|
||||
openrouter_api_key:
|
||||
required: false
|
||||
DOCKER_USERNAME:
|
||||
required: true
|
||||
DOCKER_TOKEN:
|
||||
@@ -38,29 +71,8 @@ on:
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
NIGHTLY_LLM_PROVIDER: ${{ inputs.provider }}
|
||||
NIGHTLY_LLM_MODELS: ${{ inputs.models }}
|
||||
NIGHTLY_LLM_API_KEY: ${{ secrets.provider_api_key }}
|
||||
NIGHTLY_LLM_API_BASE: ${{ inputs.api_base }}
|
||||
NIGHTLY_LLM_CUSTOM_CONFIG_JSON: ${{ inputs.custom_config_json }}
|
||||
NIGHTLY_LLM_STRICT: ${{ inputs.strict && 'true' || 'false' }}
|
||||
|
||||
jobs:
|
||||
validate-inputs:
|
||||
# NOTE: Keep this cheap and fail before image builds if required inputs are missing.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 5
|
||||
steps:
|
||||
- name: Validate required nightly provider inputs
|
||||
run: |
|
||||
if [ -z "${NIGHTLY_LLM_MODELS}" ]; then
|
||||
echo "Input 'models' must be non-empty for provider '${NIGHTLY_LLM_PROVIDER}'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
build-backend-image:
|
||||
needs: [validate-inputs]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
@@ -90,7 +102,6 @@ jobs:
|
||||
docker-no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' && 'true' || 'false' }}
|
||||
|
||||
build-model-server-image:
|
||||
needs: [validate-inputs]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
@@ -119,7 +130,6 @@ jobs:
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
build-integration-image:
|
||||
needs: [validate-inputs]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
@@ -149,11 +159,75 @@ jobs:
|
||||
|
||||
provider-chat-test:
|
||||
needs:
|
||||
[build-backend-image, build-model-server-image, build-integration-image]
|
||||
[
|
||||
build-backend-image,
|
||||
build-model-server-image,
|
||||
build-integration-image,
|
||||
]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- provider: openai
|
||||
models: ${{ inputs.openai_models }}
|
||||
api_key_secret: openai_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: true
|
||||
- provider: anthropic
|
||||
models: ${{ inputs.anthropic_models }}
|
||||
api_key_secret: anthropic_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: true
|
||||
- provider: bedrock
|
||||
models: ${{ inputs.bedrock_models }}
|
||||
api_key_secret: bedrock_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: vertex_ai
|
||||
models: ${{ inputs.vertex_ai_models }}
|
||||
api_key_secret: ""
|
||||
custom_config_secret: vertex_ai_custom_config_json
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: azure
|
||||
models: ${{ inputs.azure_models }}
|
||||
api_key_secret: azure_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: ${{ inputs.azure_api_base }}
|
||||
api_version: "2025-04-01-preview"
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: ollama_chat
|
||||
models: ${{ inputs.ollama_models }}
|
||||
api_key_secret: ollama_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: "https://ollama.com"
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: openrouter
|
||||
models: ${{ inputs.openrouter_models }}
|
||||
api_key_secret: openrouter_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: "https://openrouter.ai/api/v1"
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- "run-id=${{ github.run_id }}-nightly-${{ inputs.provider }}-provider-chat-test"
|
||||
- "run-id=${{ github.run_id }}-nightly-${{ matrix.provider }}-provider-chat-test"
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
@@ -167,12 +241,14 @@ jobs:
|
||||
- name: Run nightly provider chat test
|
||||
uses: ./.github/actions/run-nightly-provider-chat-test
|
||||
with:
|
||||
provider: ${{ env.NIGHTLY_LLM_PROVIDER }}
|
||||
models: ${{ env.NIGHTLY_LLM_MODELS }}
|
||||
provider-api-key: ${{ secrets.provider_api_key }}
|
||||
strict: ${{ env.NIGHTLY_LLM_STRICT }}
|
||||
api-base: ${{ env.NIGHTLY_LLM_API_BASE }}
|
||||
custom-config-json: ${{ env.NIGHTLY_LLM_CUSTOM_CONFIG_JSON }}
|
||||
provider: ${{ matrix.provider }}
|
||||
models: ${{ matrix.models }}
|
||||
provider-api-key: ${{ matrix.api_key_secret && secrets[matrix.api_key_secret] || '' }}
|
||||
strict: ${{ inputs.strict && 'true' || 'false' }}
|
||||
api-base: ${{ matrix.api_base }}
|
||||
api-version: ${{ matrix.api_version }}
|
||||
deployment-name: ${{ matrix.deployment_name }}
|
||||
custom-config-json: ${{ matrix.custom_config_secret && secrets[matrix.custom_config_secret] || '' }}
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
@@ -194,7 +270,7 @@ jobs:
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-nightly-${{ inputs.provider }}-llm-provider
|
||||
name: docker-all-logs-nightly-${{ matrix.provider }}-llm-provider
|
||||
path: |
|
||||
${{ github.workspace }}/api_server.log
|
||||
${{ github.workspace }}/docker-compose.log
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
"""add python tool on default
|
||||
|
||||
Revision ID: 57122d037335
|
||||
Revises: c0c937d5c9e5
|
||||
Create Date: 2026-02-27 10:10:40.124925
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "57122d037335"
|
||||
down_revision = "c0c937d5c9e5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
PYTHON_TOOL_NAME = "python"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Look up the PythonTool id
|
||||
result = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE name = :name"),
|
||||
{"name": PYTHON_TOOL_NAME},
|
||||
).fetchone()
|
||||
|
||||
if not result:
|
||||
return
|
||||
|
||||
tool_id = result[0]
|
||||
|
||||
# Attach to the default persona (id=0) if not already attached
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": tool_id},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
result = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE name = :name"),
|
||||
{"name": PYTHON_TOOL_NAME},
|
||||
).fetchone()
|
||||
|
||||
if not result:
|
||||
return
|
||||
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM persona__tool
|
||||
WHERE persona_id = 0 AND tool_id = :tool_id
|
||||
"""
|
||||
),
|
||||
{"tool_id": result[0]},
|
||||
)
|
||||
@@ -0,0 +1,70 @@
|
||||
"""llm provider deprecate fields
|
||||
|
||||
Revision ID: c0c937d5c9e5
|
||||
Revises: 8ffcc2bcfc11
|
||||
Create Date: 2026-02-25 17:35:46.125102
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c0c937d5c9e5"
|
||||
down_revision = "8ffcc2bcfc11"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Make default_model_name nullable (was NOT NULL)
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"default_model_name",
|
||||
existing_type=sa.String(),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Drop unique constraint on is_default_provider (defaults now tracked via LLMModelFlow)
|
||||
op.drop_constraint(
|
||||
"llm_provider_is_default_provider_key",
|
||||
"llm_provider",
|
||||
type_="unique",
|
||||
)
|
||||
|
||||
# Remove server_default from is_default_vision_provider (was server_default=false())
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"is_default_vision_provider",
|
||||
existing_type=sa.Boolean(),
|
||||
server_default=None,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Restore default_model_name to NOT NULL (set empty string for any NULLs first)
|
||||
op.execute(
|
||||
"UPDATE llm_provider SET default_model_name = '' WHERE default_model_name IS NULL"
|
||||
)
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"default_model_name",
|
||||
existing_type=sa.String(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Restore unique constraint on is_default_provider
|
||||
op.create_unique_constraint(
|
||||
"llm_provider_is_default_provider_key",
|
||||
"llm_provider",
|
||||
["is_default_provider"],
|
||||
)
|
||||
|
||||
# Restore server_default for is_default_vision_provider
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"is_default_vision_provider",
|
||||
existing_type=sa.Boolean(),
|
||||
server_default=sa.false(),
|
||||
)
|
||||
@@ -322,6 +322,7 @@ def list_users(
|
||||
"""List users with optional SCIM filter and pagination."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
dal.commit()
|
||||
|
||||
try:
|
||||
scim_filter = parse_scim_filter(filter)
|
||||
@@ -365,6 +366,7 @@ def get_user(
|
||||
"""Get a single user by ID."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
dal.commit()
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
@@ -721,6 +723,7 @@ def list_groups(
|
||||
"""List groups with optional SCIM filter and pagination."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
dal.commit()
|
||||
|
||||
try:
|
||||
scim_filter = parse_scim_filter(filter)
|
||||
@@ -757,6 +760,7 @@ def get_group(
|
||||
"""Get a single group by ID."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
dal.commit()
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
|
||||
@@ -20,6 +20,7 @@ from ee.onyx.server.enterprise_settings.store import (
|
||||
from ee.onyx.server.enterprise_settings.store import upload_logo
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import Tool
|
||||
@@ -117,15 +118,38 @@ def _seed_custom_tools(db_session: Session, tools: List[CustomToolSeed]) -> None
|
||||
def _seed_llms(
|
||||
db_session: Session, llm_upsert_requests: list[LLMProviderUpsertRequest]
|
||||
) -> None:
|
||||
if llm_upsert_requests:
|
||||
logger.notice("Seeding LLMs")
|
||||
seeded_providers = [
|
||||
upsert_llm_provider(llm_upsert_request, db_session)
|
||||
for llm_upsert_request in llm_upsert_requests
|
||||
]
|
||||
update_default_provider(
|
||||
provider_id=seeded_providers[0].id, db_session=db_session
|
||||
)
|
||||
if not llm_upsert_requests:
|
||||
return
|
||||
|
||||
logger.notice("Seeding LLMs")
|
||||
for request in llm_upsert_requests:
|
||||
existing = fetch_existing_llm_provider(name=request.name, db_session=db_session)
|
||||
if existing:
|
||||
request.id = existing.id
|
||||
seeded_providers = [
|
||||
upsert_llm_provider(llm_upsert_request, db_session)
|
||||
for llm_upsert_request in llm_upsert_requests
|
||||
]
|
||||
|
||||
default_provider = next(
|
||||
(p for p in seeded_providers if p.model_configurations), None
|
||||
)
|
||||
if not default_provider:
|
||||
return
|
||||
|
||||
visible_configs = [
|
||||
mc for mc in default_provider.model_configurations if mc.is_visible
|
||||
]
|
||||
default_config = (
|
||||
visible_configs[0]
|
||||
if visible_configs
|
||||
else default_provider.model_configurations[0]
|
||||
)
|
||||
update_default_provider(
|
||||
provider_id=default_provider.id,
|
||||
model_name=default_config.name,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:
|
||||
|
||||
@@ -109,6 +109,12 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
if metadata.status == _BLOCKING_STATUS:
|
||||
settings.application_status = metadata.status
|
||||
settings.ee_features_enabled = False
|
||||
elif metadata.used_seats > metadata.seats:
|
||||
# License is valid but seat limit exceeded
|
||||
settings.application_status = ApplicationStatus.SEAT_LIMIT_EXCEEDED
|
||||
settings.seat_count = metadata.seats
|
||||
settings.used_seats = metadata.used_seats
|
||||
settings.ee_features_enabled = True
|
||||
else:
|
||||
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
|
||||
settings.ee_features_enabled = True
|
||||
|
||||
@@ -33,6 +33,7 @@ from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.engine.sql_engine import get_session_with_shared_schema
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.image_generation import create_default_image_gen_config_from_api_key
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_cloud_embedding_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
@@ -302,12 +303,17 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
|
||||
has_set_default_provider = False
|
||||
|
||||
def _upsert(request: LLMProviderUpsertRequest) -> None:
|
||||
def _upsert(request: LLMProviderUpsertRequest, default_model: str) -> None:
|
||||
nonlocal has_set_default_provider
|
||||
try:
|
||||
existing = fetch_existing_llm_provider(
|
||||
name=request.name, db_session=db_session
|
||||
)
|
||||
if existing:
|
||||
request.id = existing.id
|
||||
provider = upsert_llm_provider(request, db_session)
|
||||
if not has_set_default_provider:
|
||||
update_default_provider(provider.id, db_session)
|
||||
update_default_provider(provider.id, default_model, db_session)
|
||||
has_set_default_provider = True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure {request.provider} provider: {e}")
|
||||
@@ -325,14 +331,13 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
OPENAI_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(openai_provider)
|
||||
_upsert(openai_provider, default_model_name)
|
||||
|
||||
# Create default image generation config using the OpenAI API key
|
||||
try:
|
||||
@@ -361,14 +366,13 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="Anthropic",
|
||||
provider=ANTHROPIC_PROVIDER_NAME,
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
ANTHROPIC_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(anthropic_provider)
|
||||
_upsert(anthropic_provider, default_model_name)
|
||||
else:
|
||||
logger.info(
|
||||
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
|
||||
@@ -393,14 +397,13 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="Google Vertex AI",
|
||||
provider=VERTEXAI_PROVIDER_NAME,
|
||||
custom_config=custom_config,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
VERTEXAI_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(vertexai_provider)
|
||||
_upsert(vertexai_provider, default_model_name)
|
||||
else:
|
||||
logger.info(
|
||||
"VERTEXAI_DEFAULT_CREDENTIALS not set, skipping Vertex AI provider configuration"
|
||||
@@ -432,12 +435,11 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="OpenRouter",
|
||||
provider=OPENROUTER_PROVIDER_NAME,
|
||||
api_key=OPENROUTER_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=model_configurations,
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(openrouter_provider)
|
||||
_upsert(openrouter_provider, default_model_name)
|
||||
else:
|
||||
logger.info(
|
||||
"OPENROUTER_DEFAULT_API_KEY not set, skipping OpenRouter provider configuration"
|
||||
|
||||
@@ -58,16 +58,27 @@ class OAuthTokenManager:
|
||||
if not user_token.token_data:
|
||||
raise ValueError("No token data available for refresh")
|
||||
|
||||
if (
|
||||
self.oauth_config.client_id is None
|
||||
or self.oauth_config.client_secret is None
|
||||
):
|
||||
raise ValueError(
|
||||
"OAuth client_id and client_secret are required for token refresh"
|
||||
)
|
||||
|
||||
token_data = self._unwrap_token_data(user_token.token_data)
|
||||
|
||||
data: dict[str, str] = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": token_data["refresh_token"],
|
||||
"client_id": self._unwrap_sensitive_str(self.oauth_config.client_id),
|
||||
"client_secret": self._unwrap_sensitive_str(
|
||||
self.oauth_config.client_secret
|
||||
),
|
||||
}
|
||||
response = requests.post(
|
||||
self.oauth_config.token_url,
|
||||
data={
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": token_data["refresh_token"],
|
||||
"client_id": self.oauth_config.client_id,
|
||||
"client_secret": self.oauth_config.client_secret,
|
||||
},
|
||||
data=data,
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -115,15 +126,26 @@ class OAuthTokenManager:
|
||||
|
||||
def exchange_code_for_token(self, code: str, redirect_uri: str) -> dict[str, Any]:
|
||||
"""Exchange authorization code for access token"""
|
||||
if (
|
||||
self.oauth_config.client_id is None
|
||||
or self.oauth_config.client_secret is None
|
||||
):
|
||||
raise ValueError(
|
||||
"OAuth client_id and client_secret are required for code exchange"
|
||||
)
|
||||
|
||||
data: dict[str, str] = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"client_id": self._unwrap_sensitive_str(self.oauth_config.client_id),
|
||||
"client_secret": self._unwrap_sensitive_str(
|
||||
self.oauth_config.client_secret
|
||||
),
|
||||
"redirect_uri": redirect_uri,
|
||||
}
|
||||
response = requests.post(
|
||||
self.oauth_config.token_url,
|
||||
data={
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"client_id": self.oauth_config.client_id,
|
||||
"client_secret": self.oauth_config.client_secret,
|
||||
"redirect_uri": redirect_uri,
|
||||
},
|
||||
data=data,
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -141,8 +163,13 @@ class OAuthTokenManager:
|
||||
oauth_config: OAuthConfig, redirect_uri: str, state: str
|
||||
) -> str:
|
||||
"""Build OAuth authorization URL"""
|
||||
if oauth_config.client_id is None:
|
||||
raise ValueError("OAuth client_id is required to build authorization URL")
|
||||
|
||||
params: dict[str, Any] = {
|
||||
"client_id": oauth_config.client_id,
|
||||
"client_id": OAuthTokenManager._unwrap_sensitive_str(
|
||||
oauth_config.client_id
|
||||
),
|
||||
"redirect_uri": redirect_uri,
|
||||
"response_type": "code",
|
||||
"state": state,
|
||||
@@ -161,6 +188,12 @@ class OAuthTokenManager:
|
||||
|
||||
return f"{oauth_config.authorization_url}{separator}{urlencode(params)}"
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_sensitive_str(value: SensitiveValue[str] | str) -> str:
|
||||
if isinstance(value, SensitiveValue):
|
||||
return value.get_value(apply_mask=False)
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_token_data(
|
||||
token_data: SensitiveValue[dict[str, Any]] | dict[str, Any],
|
||||
|
||||
@@ -48,6 +48,7 @@ from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchDocumentIndex,
|
||||
)
|
||||
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
@@ -149,8 +150,12 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT)
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
opensearch_document_index = OpenSearchDocumentIndex(
|
||||
index_name=search_settings.index_name, tenant_state=tenant_state
|
||||
tenant_state=tenant_state,
|
||||
index_name=search_settings.index_name,
|
||||
embedding_dim=indexing_setting.final_embedding_dim,
|
||||
embedding_precision=indexing_setting.embedding_precision,
|
||||
)
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
|
||||
@@ -76,7 +76,7 @@ def _user_file_queued_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
|
||||
def user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
@@ -764,7 +764,7 @@ def process_single_user_file_project_sync(
|
||||
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_project_sync_lock_key(user_file_id),
|
||||
user_file_project_sync_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from queue import Empty
|
||||
from typing import Any
|
||||
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.emitter import Emitter
|
||||
@@ -163,13 +162,11 @@ class ChatStateContainer:
|
||||
|
||||
|
||||
def run_chat_loop_with_state_containers(
|
||||
func: Callable[..., None],
|
||||
chat_loop_func: Callable[[Emitter, ChatStateContainer], None],
|
||||
completion_callback: Callable[[ChatStateContainer], None],
|
||||
is_connected: Callable[[], bool],
|
||||
emitter: Emitter,
|
||||
state_container: ChatStateContainer,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Generator[Packet, None]:
|
||||
"""
|
||||
Explicit wrapper function that runs a function in a background thread
|
||||
@@ -180,19 +177,18 @@ def run_chat_loop_with_state_containers(
|
||||
|
||||
Args:
|
||||
func: The function to wrap (should accept emitter and state_container as first and second args)
|
||||
completion_callback: Callback function to call when the function completes
|
||||
emitter: Emitter instance for sending packets
|
||||
state_container: ChatStateContainer instance for accumulating state
|
||||
is_connected: Callable that returns False when stop signal is set
|
||||
*args: Additional positional arguments for func
|
||||
**kwargs: Additional keyword arguments for func
|
||||
|
||||
Usage:
|
||||
packets = run_chat_loop_with_state_containers(
|
||||
my_func,
|
||||
completion_callback=completion_callback,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
is_connected=check_func,
|
||||
arg1, arg2, kwarg1=value1
|
||||
)
|
||||
for packet in packets:
|
||||
# Process packets
|
||||
@@ -201,9 +197,7 @@ def run_chat_loop_with_state_containers(
|
||||
|
||||
def run_with_exception_capture() -> None:
|
||||
try:
|
||||
# Ensure state_container is passed explicitly, removing it from kwargs if present
|
||||
kwargs_with_state = {**kwargs, "state_container": state_container}
|
||||
func(emitter, *args, **kwargs_with_state)
|
||||
chat_loop_func(emitter, state_container)
|
||||
except Exception as e:
|
||||
# If execution fails, emit an exception packet
|
||||
emitter.emit(
|
||||
|
||||
@@ -461,7 +461,7 @@ def _build_tool_call_response_history_message(
|
||||
def convert_chat_history(
|
||||
chat_history: list[ChatMessage],
|
||||
files: list[ChatLoadedFile],
|
||||
project_image_files: list[ChatLoadedFile],
|
||||
context_image_files: list[ChatLoadedFile],
|
||||
additional_context: str | None,
|
||||
token_counter: Callable[[str], int],
|
||||
tool_id_to_name_map: dict[int, str],
|
||||
@@ -541,11 +541,11 @@ def convert_chat_history(
|
||||
)
|
||||
|
||||
# Add the user message with image files attached
|
||||
# If this is the last USER message, also include project_image_files
|
||||
# Note: project image file tokens are NOT counted in the token count
|
||||
# If this is the last USER message, also include context_image_files
|
||||
# Note: context image file tokens are NOT counted in the token count
|
||||
if idx == last_user_message_idx:
|
||||
if project_image_files:
|
||||
image_files.extend(project_image_files)
|
||||
if context_image_files:
|
||||
image_files.extend(context_image_files)
|
||||
|
||||
if additional_context:
|
||||
simple_messages.append(
|
||||
|
||||
@@ -15,10 +15,10 @@ from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.llm_step import extract_tool_calls_from_response_text
|
||||
from onyx.chat.llm_step import run_llm_step
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
from onyx.chat.models import ContextFileMetadata
|
||||
from onyx.chat.models import ExtractedContextFiles
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.models import ProjectFileMetadata
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.chat.prompt_utils import build_reminder_message
|
||||
from onyx.chat.prompt_utils import build_system_prompt
|
||||
@@ -203,17 +203,17 @@ def _try_fallback_tool_extraction(
|
||||
MAX_LLM_CYCLES = 6
|
||||
|
||||
|
||||
def _build_project_file_citation_mapping(
|
||||
project_file_metadata: list[ProjectFileMetadata],
|
||||
def _build_context_file_citation_mapping(
|
||||
file_metadata: list[ContextFileMetadata],
|
||||
starting_citation_num: int = 1,
|
||||
) -> CitationMapping:
|
||||
"""Build citation mapping for project files.
|
||||
"""Build citation mapping for context files.
|
||||
|
||||
Converts project file metadata into SearchDoc objects that can be cited.
|
||||
Converts context file metadata into SearchDoc objects that can be cited.
|
||||
Citation numbers start from the provided starting number.
|
||||
|
||||
Args:
|
||||
project_file_metadata: List of project file metadata
|
||||
file_metadata: List of context file metadata
|
||||
starting_citation_num: Starting citation number (default: 1)
|
||||
|
||||
Returns:
|
||||
@@ -221,8 +221,7 @@ def _build_project_file_citation_mapping(
|
||||
"""
|
||||
citation_mapping: CitationMapping = {}
|
||||
|
||||
for idx, file_meta in enumerate(project_file_metadata, start=starting_citation_num):
|
||||
# Create a SearchDoc for each project file
|
||||
for idx, file_meta in enumerate(file_metadata, start=starting_citation_num):
|
||||
search_doc = SearchDoc(
|
||||
document_id=file_meta.file_id,
|
||||
chunk_ind=0,
|
||||
@@ -242,29 +241,28 @@ def _build_project_file_citation_mapping(
|
||||
|
||||
|
||||
def _build_project_message(
|
||||
project_files: ExtractedProjectFiles | None,
|
||||
context_files: ExtractedContextFiles | None,
|
||||
token_counter: Callable[[str], int] | None,
|
||||
) -> list[ChatMessageSimple]:
|
||||
"""Build messages for project / tool-backed files.
|
||||
"""Build messages for context-injected / tool-backed files.
|
||||
|
||||
Returns up to two messages:
|
||||
1. The full-text project files message (if project_file_texts is populated).
|
||||
1. The full-text files message (if file_texts is populated).
|
||||
2. A lightweight metadata message for files the LLM should access via the
|
||||
FileReaderTool (e.g. oversized chat-attached files or project files that
|
||||
don't fit in context).
|
||||
FileReaderTool (e.g. oversized files that don't fit in context).
|
||||
"""
|
||||
if not project_files:
|
||||
if not context_files:
|
||||
return []
|
||||
|
||||
messages: list[ChatMessageSimple] = []
|
||||
if project_files.project_file_texts:
|
||||
if context_files.file_texts:
|
||||
messages.append(
|
||||
_create_project_files_message(project_files, token_counter=None)
|
||||
_create_context_files_message(context_files, token_counter=None)
|
||||
)
|
||||
if project_files.file_metadata_for_tool and token_counter:
|
||||
if context_files.file_metadata_for_tool and token_counter:
|
||||
messages.append(
|
||||
_create_file_tool_metadata_message(
|
||||
project_files.file_metadata_for_tool, token_counter
|
||||
context_files.file_metadata_for_tool, token_counter
|
||||
)
|
||||
)
|
||||
return messages
|
||||
@@ -275,7 +273,7 @@ def construct_message_history(
|
||||
custom_agent_prompt: ChatMessageSimple | None,
|
||||
simple_chat_history: list[ChatMessageSimple],
|
||||
reminder_message: ChatMessageSimple | None,
|
||||
project_files: ExtractedProjectFiles | None,
|
||||
context_files: ExtractedContextFiles | None,
|
||||
available_tokens: int,
|
||||
last_n_user_messages: int | None = None,
|
||||
token_counter: Callable[[str], int] | None = None,
|
||||
@@ -289,7 +287,7 @@ def construct_message_history(
|
||||
|
||||
# Build the project / file-metadata messages up front so we can use their
|
||||
# actual token counts for the budget.
|
||||
project_messages = _build_project_message(project_files, token_counter)
|
||||
project_messages = _build_project_message(context_files, token_counter)
|
||||
project_messages_tokens = sum(m.token_count for m in project_messages)
|
||||
|
||||
history_token_budget = available_tokens
|
||||
@@ -445,17 +443,17 @@ def construct_message_history(
|
||||
)
|
||||
|
||||
# Attach project images to the last user message
|
||||
if project_files and project_files.project_image_files:
|
||||
if context_files and context_files.image_files:
|
||||
existing_images = last_user_message.image_files or []
|
||||
last_user_message = ChatMessageSimple(
|
||||
message=last_user_message.message,
|
||||
token_count=last_user_message.token_count,
|
||||
message_type=last_user_message.message_type,
|
||||
image_files=existing_images + project_files.project_image_files,
|
||||
image_files=existing_images + context_files.image_files,
|
||||
)
|
||||
|
||||
# Build the final message list according to README ordering:
|
||||
# [system], [history_before_last_user], [custom_agent], [project_files],
|
||||
# [system], [history_before_last_user], [custom_agent], [context_files],
|
||||
# [forgotten_files], [last_user_message], [messages_after_last_user], [reminder]
|
||||
result = [system_prompt] if system_prompt else []
|
||||
|
||||
@@ -466,14 +464,14 @@ def construct_message_history(
|
||||
if custom_agent_prompt:
|
||||
result.append(custom_agent_prompt)
|
||||
|
||||
# 3. Add project files / file-metadata messages (inserted before last user message)
|
||||
# 3. Add context files / file-metadata messages (inserted before last user message)
|
||||
result.extend(project_messages)
|
||||
|
||||
# 4. Add forgotten-files metadata (right before the user's question)
|
||||
if forgotten_files_message:
|
||||
result.append(forgotten_files_message)
|
||||
|
||||
# 5. Add last user message (with project images attached)
|
||||
# 5. Add last user message (with context images attached)
|
||||
result.append(last_user_message)
|
||||
|
||||
# 6. Add messages after last user message (tool calls, responses, etc.)
|
||||
@@ -547,11 +545,11 @@ def _create_file_tool_metadata_message(
|
||||
)
|
||||
|
||||
|
||||
def _create_project_files_message(
|
||||
project_files: ExtractedProjectFiles,
|
||||
def _create_context_files_message(
|
||||
context_files: ExtractedContextFiles,
|
||||
token_counter: Callable[[str], int] | None, # noqa: ARG001
|
||||
) -> ChatMessageSimple:
|
||||
"""Convert project files to a ChatMessageSimple message.
|
||||
"""Convert context files to a ChatMessageSimple message.
|
||||
|
||||
Format follows the README specification for document representation.
|
||||
"""
|
||||
@@ -559,7 +557,7 @@ def _create_project_files_message(
|
||||
|
||||
# Format as documents JSON as described in README
|
||||
documents_list = []
|
||||
for idx, file_text in enumerate(project_files.project_file_texts, start=1):
|
||||
for idx, file_text in enumerate(context_files.file_texts, start=1):
|
||||
documents_list.append(
|
||||
{
|
||||
"document": idx,
|
||||
@@ -570,10 +568,10 @@ def _create_project_files_message(
|
||||
documents_json = json.dumps({"documents": documents_list}, indent=2)
|
||||
message_content = f"Here are some documents provided for context, they may not all be relevant:\n{documents_json}"
|
||||
|
||||
# Use pre-calculated token count from project_files
|
||||
# Use pre-calculated token count from context_files
|
||||
return ChatMessageSimple(
|
||||
message=message_content,
|
||||
token_count=project_files.total_token_count,
|
||||
token_count=context_files.total_token_count,
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
|
||||
@@ -584,7 +582,7 @@ def run_llm_loop(
|
||||
simple_chat_history: list[ChatMessageSimple],
|
||||
tools: list[Tool],
|
||||
custom_agent_prompt: str | None,
|
||||
project_files: ExtractedProjectFiles,
|
||||
context_files: ExtractedContextFiles,
|
||||
persona: Persona | None,
|
||||
user_memory_context: UserMemoryContext | None,
|
||||
llm: LLM,
|
||||
@@ -627,9 +625,9 @@ def run_llm_loop(
|
||||
|
||||
# Add project file citation mappings if project files are present
|
||||
project_citation_mapping: CitationMapping = {}
|
||||
if project_files.project_file_metadata:
|
||||
project_citation_mapping = _build_project_file_citation_mapping(
|
||||
project_files.project_file_metadata
|
||||
if context_files.file_metadata:
|
||||
project_citation_mapping = _build_context_file_citation_mapping(
|
||||
context_files.file_metadata
|
||||
)
|
||||
citation_processor.update_citation_mapping(project_citation_mapping)
|
||||
|
||||
@@ -647,7 +645,7 @@ def run_llm_loop(
|
||||
# TODO allow citing of images in Projects. Since attached to the last user message, it has no text associated with it.
|
||||
# One future workaround is to include the images as separate user messages with citation information and process those.
|
||||
always_cite_documents: bool = bool(
|
||||
project_files.project_as_filter or project_files.project_file_texts
|
||||
context_files.use_as_search_filter or context_files.file_texts
|
||||
)
|
||||
should_cite_documents: bool = False
|
||||
ran_image_gen: bool = False
|
||||
@@ -788,7 +786,7 @@ def run_llm_loop(
|
||||
custom_agent_prompt=custom_agent_prompt_msg,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=reminder_msg,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=available_tokens,
|
||||
token_counter=token_counter,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
|
||||
@@ -31,13 +31,6 @@ class CustomToolResponse(BaseModel):
|
||||
tool_name: str
|
||||
|
||||
|
||||
class ProjectSearchConfig(BaseModel):
|
||||
"""Configuration for search tool availability in project context."""
|
||||
|
||||
search_usage: SearchToolUsage
|
||||
disable_forced_tool: bool
|
||||
|
||||
|
||||
class CreateChatSessionID(BaseModel):
|
||||
chat_session_id: UUID
|
||||
|
||||
@@ -132,8 +125,8 @@ class ChatMessageSimple(BaseModel):
|
||||
file_id: str | None = None
|
||||
|
||||
|
||||
class ProjectFileMetadata(BaseModel):
|
||||
"""Metadata for a project file to enable citation support."""
|
||||
class ContextFileMetadata(BaseModel):
|
||||
"""Metadata for a context-injected file to enable citation support."""
|
||||
|
||||
file_id: str
|
||||
filename: str
|
||||
@@ -167,20 +160,28 @@ class ChatHistoryResult(BaseModel):
|
||||
all_injected_file_metadata: dict[str, FileToolMetadata]
|
||||
|
||||
|
||||
class ExtractedProjectFiles(BaseModel):
|
||||
project_file_texts: list[str]
|
||||
project_image_files: list[ChatLoadedFile]
|
||||
project_as_filter: bool
|
||||
class ExtractedContextFiles(BaseModel):
|
||||
"""Result of attempting to load user files (from a project or persona) into context."""
|
||||
|
||||
file_texts: list[str]
|
||||
image_files: list[ChatLoadedFile]
|
||||
use_as_search_filter: bool
|
||||
total_token_count: int
|
||||
# Metadata for project files to enable citations
|
||||
project_file_metadata: list[ProjectFileMetadata]
|
||||
# None if not a project
|
||||
project_uncapped_token_count: int | None
|
||||
# Lightweight metadata for files exposed via FileReaderTool
|
||||
# (populated when files don't fit in context and vector DB is disabled)
|
||||
# (populated when files don't fit in context and vector DB is disabled).
|
||||
file_metadata: list[ContextFileMetadata]
|
||||
uncapped_token_count: int | None
|
||||
file_metadata_for_tool: list[FileToolMetadata] = []
|
||||
|
||||
|
||||
class SearchParams(BaseModel):
|
||||
"""Resolved search filter IDs and search-tool usage for a chat turn."""
|
||||
|
||||
search_project_id: int | None
|
||||
search_persona_id: int | None
|
||||
search_usage: SearchToolUsage
|
||||
|
||||
|
||||
class LlmStepResult(BaseModel):
|
||||
reasoning: str | None
|
||||
answer: str | None
|
||||
|
||||
@@ -3,6 +3,7 @@ IMPORTANT: familiarize yourself with the design concepts prior to contributing t
|
||||
An overview can be found in the README.md file in this directory.
|
||||
"""
|
||||
|
||||
import io
|
||||
import re
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
@@ -33,11 +34,11 @@ from onyx.chat.models import ChatBasicResponse
|
||||
from onyx.chat.models import ChatFullResponse
|
||||
from onyx.chat.models import ChatLoadedFile
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ContextFileMetadata
|
||||
from onyx.chat.models import CreateChatSessionID
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
from onyx.chat.models import ExtractedContextFiles
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import ProjectFileMetadata
|
||||
from onyx.chat.models import ProjectSearchConfig
|
||||
from onyx.chat.models import SearchParams
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import ToolCallResponse
|
||||
from onyx.chat.prompt_utils import calculate_reserved_tokens
|
||||
@@ -62,11 +63,12 @@ from onyx.db.models import ChatSession
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.projects import get_project_token_count
|
||||
from onyx.db.projects import get_user_files_from_project
|
||||
from onyx.db.tools import get_tools
|
||||
from onyx.deep_research.dr_loop import run_deep_research_llm_loop
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.file_store.utils import load_in_memory_chat_files
|
||||
from onyx.file_store.utils import verify_user_files
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
@@ -139,12 +141,12 @@ def _collect_available_file_ids(
|
||||
pass
|
||||
|
||||
if project_id:
|
||||
project_files = get_user_files_from_project(
|
||||
user_files = get_user_files_from_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
for uf in project_files:
|
||||
for uf in user_files:
|
||||
user_file_ids.add(uf.id)
|
||||
|
||||
return _AvailableFiles(
|
||||
@@ -192,9 +194,67 @@ def _convert_loaded_files_to_chat_files(
|
||||
return chat_files
|
||||
|
||||
|
||||
def _extract_project_file_texts_and_images(
|
||||
def resolve_context_user_files(
|
||||
persona: Persona,
|
||||
project_id: int | None,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> list[UserFile]:
|
||||
"""Apply the precedence rule to decide which user files to load.
|
||||
|
||||
A custom persona fully supersedes the project. When a chat uses a
|
||||
custom persona, the project is purely organisational — its files are
|
||||
never loaded and never made searchable.
|
||||
|
||||
Custom persona → persona's own user_files (may be empty).
|
||||
Default persona inside a project → project files.
|
||||
Otherwise → empty list.
|
||||
"""
|
||||
if persona.id != DEFAULT_PERSONA_ID:
|
||||
return list(persona.user_files) if persona.user_files else []
|
||||
if project_id:
|
||||
return get_user_files_from_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
def _empty_extracted_context_files() -> ExtractedContextFiles:
|
||||
return ExtractedContextFiles(
|
||||
file_texts=[],
|
||||
image_files=[],
|
||||
use_as_search_filter=False,
|
||||
total_token_count=0,
|
||||
file_metadata=[],
|
||||
uncapped_token_count=None,
|
||||
)
|
||||
|
||||
|
||||
def _extract_text_from_in_memory_file(f: InMemoryChatFile) -> str | None:
|
||||
"""Extract text content from an InMemoryChatFile.
|
||||
|
||||
PLAIN_TEXT: the content is pre-extracted UTF-8 plaintext stored during
|
||||
ingestion — decode directly.
|
||||
DOC / CSV / other text types: the content is the original file bytes —
|
||||
use extract_file_text which handles encoding detection and format parsing.
|
||||
"""
|
||||
try:
|
||||
if f.file_type == ChatFileType.PLAIN_TEXT:
|
||||
return f.content.decode("utf-8", errors="ignore").replace("\x00", "")
|
||||
return extract_file_text(
|
||||
file=io.BytesIO(f.content),
|
||||
file_name=f.filename or "",
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(f"Failed to extract text from file {f.file_id}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def extract_context_files(
|
||||
user_files: list[UserFile],
|
||||
llm_max_context_window: int,
|
||||
reserved_token_count: int,
|
||||
db_session: Session,
|
||||
@@ -203,8 +263,12 @@ def _extract_project_file_texts_and_images(
|
||||
# 60% of the LLM's max context window. The other benefit is that for projects with
|
||||
# more files, this makes it so that we don't throw away the history too quickly every time.
|
||||
max_llm_context_percentage: float = 0.6,
|
||||
) -> ExtractedProjectFiles:
|
||||
"""Extract text content from project files if they fit within the context window.
|
||||
) -> ExtractedContextFiles:
|
||||
"""Load user files into context if they fit; otherwise flag for search.
|
||||
|
||||
The caller is responsible for deciding *which* user files to pass in
|
||||
(project files, persona files, etc.). This function only cares about
|
||||
the all-or-nothing fit check and the actual content loading.
|
||||
|
||||
Args:
|
||||
project_id: The project ID to load files from
|
||||
@@ -213,160 +277,95 @@ def _extract_project_file_texts_and_images(
|
||||
reserved_token_count: Number of tokens to reserve for other content
|
||||
db_session: Database session
|
||||
max_llm_context_percentage: Maximum percentage of the LLM context window to use.
|
||||
|
||||
Returns:
|
||||
ExtractedProjectFiles containing:
|
||||
- List of text content strings from project files (text files only)
|
||||
- List of image files from project (ChatLoadedFile objects)
|
||||
- Project id if the the project should be provided as a filter in search or None if not.
|
||||
ExtractedContextFiles containing:
|
||||
- List of text content strings from context files (text files only)
|
||||
- List of image files from context (ChatLoadedFile objects)
|
||||
- Total token count of all extracted files
|
||||
- File metadata for context files
|
||||
- Uncapped token count of all extracted files
|
||||
- File metadata for files that don't fit in context and vector DB is disabled
|
||||
"""
|
||||
# TODO I believe this is not handling all file types correctly.
|
||||
project_as_filter = False
|
||||
if not project_id:
|
||||
return ExtractedProjectFiles(
|
||||
project_file_texts=[],
|
||||
project_image_files=[],
|
||||
project_as_filter=False,
|
||||
total_token_count=0,
|
||||
project_file_metadata=[],
|
||||
project_uncapped_token_count=None,
|
||||
)
|
||||
# TODO(yuhong): I believe this is not handling all file types correctly.
|
||||
|
||||
if not user_files:
|
||||
return _empty_extracted_context_files()
|
||||
|
||||
aggregate_tokens = sum(uf.token_count or 0 for uf in user_files)
|
||||
max_actual_tokens = (
|
||||
llm_max_context_window - reserved_token_count
|
||||
) * max_llm_context_percentage
|
||||
|
||||
# Calculate total token count for all user files in the project
|
||||
project_tokens = get_project_token_count(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
if aggregate_tokens >= max_actual_tokens:
|
||||
tool_metadata = []
|
||||
use_as_search_filter = not DISABLE_VECTOR_DB
|
||||
if DISABLE_VECTOR_DB:
|
||||
tool_metadata = _build_file_tool_metadata_for_user_files(user_files)
|
||||
return ExtractedContextFiles(
|
||||
file_texts=[],
|
||||
image_files=[],
|
||||
use_as_search_filter=use_as_search_filter,
|
||||
total_token_count=0,
|
||||
file_metadata=[],
|
||||
uncapped_token_count=aggregate_tokens,
|
||||
file_metadata_for_tool=tool_metadata,
|
||||
)
|
||||
|
||||
# Files fit — load them into context
|
||||
user_file_map = {str(uf.id): uf for uf in user_files}
|
||||
in_memory_files = load_in_memory_chat_files(
|
||||
user_file_ids=[uf.id for uf in user_files],
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
project_file_texts: list[str] = []
|
||||
project_image_files: list[ChatLoadedFile] = []
|
||||
project_file_metadata: list[ProjectFileMetadata] = []
|
||||
file_texts: list[str] = []
|
||||
image_files: list[ChatLoadedFile] = []
|
||||
file_metadata: list[ContextFileMetadata] = []
|
||||
total_token_count = 0
|
||||
if project_tokens < max_actual_tokens:
|
||||
# Load project files into memory using cached plaintext when available
|
||||
project_user_files = get_user_files_from_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
if project_user_files:
|
||||
# Create a mapping from file_id to UserFile for token count lookup
|
||||
user_file_map = {str(file.id): file for file in project_user_files}
|
||||
|
||||
project_file_ids = [file.id for file in project_user_files]
|
||||
in_memory_project_files = load_in_memory_chat_files(
|
||||
user_file_ids=project_file_ids,
|
||||
db_session=db_session,
|
||||
for f in in_memory_files:
|
||||
uf = user_file_map.get(str(f.file_id))
|
||||
if f.file_type.is_text_file():
|
||||
text_content = _extract_text_from_in_memory_file(f)
|
||||
if not text_content:
|
||||
continue
|
||||
file_texts.append(text_content)
|
||||
file_metadata.append(
|
||||
ContextFileMetadata(
|
||||
file_id=str(f.file_id),
|
||||
filename=f.filename or f"file_{f.file_id}",
|
||||
file_content=text_content,
|
||||
)
|
||||
)
|
||||
if uf and uf.token_count:
|
||||
total_token_count += uf.token_count
|
||||
elif f.file_type == ChatFileType.IMAGE:
|
||||
token_count = uf.token_count if uf and uf.token_count else 0
|
||||
total_token_count += token_count
|
||||
image_files.append(
|
||||
ChatLoadedFile(
|
||||
file_id=f.file_id,
|
||||
content=f.content,
|
||||
file_type=f.file_type,
|
||||
filename=f.filename,
|
||||
content_text=None,
|
||||
token_count=token_count,
|
||||
)
|
||||
)
|
||||
|
||||
# Extract text content from loaded files
|
||||
for file in in_memory_project_files:
|
||||
if file.file_type.is_text_file():
|
||||
try:
|
||||
text_content = file.content.decode("utf-8", errors="ignore")
|
||||
# Strip null bytes
|
||||
text_content = text_content.replace("\x00", "")
|
||||
if text_content:
|
||||
project_file_texts.append(text_content)
|
||||
# Add metadata for citation support
|
||||
project_file_metadata.append(
|
||||
ProjectFileMetadata(
|
||||
file_id=str(file.file_id),
|
||||
filename=file.filename or f"file_{file.file_id}",
|
||||
file_content=text_content,
|
||||
)
|
||||
)
|
||||
# Add token count for text file
|
||||
user_file = user_file_map.get(str(file.file_id))
|
||||
if user_file and user_file.token_count:
|
||||
total_token_count += user_file.token_count
|
||||
except Exception:
|
||||
# Skip files that can't be decoded
|
||||
pass
|
||||
elif file.file_type == ChatFileType.IMAGE:
|
||||
# Convert InMemoryChatFile to ChatLoadedFile
|
||||
user_file = user_file_map.get(str(file.file_id))
|
||||
token_count = (
|
||||
user_file.token_count
|
||||
if user_file and user_file.token_count
|
||||
else 0
|
||||
)
|
||||
total_token_count += token_count
|
||||
chat_loaded_file = ChatLoadedFile(
|
||||
file_id=file.file_id,
|
||||
content=file.content,
|
||||
file_type=file.file_type,
|
||||
filename=file.filename,
|
||||
content_text=None, # Images don't have text content
|
||||
token_count=token_count,
|
||||
)
|
||||
project_image_files.append(chat_loaded_file)
|
||||
else:
|
||||
if DISABLE_VECTOR_DB:
|
||||
# Without a vector DB we can't use project-as-filter search.
|
||||
# Instead, build lightweight metadata so the LLM can call the
|
||||
# FileReaderTool to inspect individual files on demand.
|
||||
file_metadata_for_tool = _build_file_tool_metadata_for_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
return ExtractedProjectFiles(
|
||||
project_file_texts=[],
|
||||
project_image_files=[],
|
||||
project_as_filter=False,
|
||||
total_token_count=0,
|
||||
project_file_metadata=[],
|
||||
project_uncapped_token_count=project_tokens,
|
||||
file_metadata_for_tool=file_metadata_for_tool,
|
||||
)
|
||||
project_as_filter = True
|
||||
|
||||
return ExtractedProjectFiles(
|
||||
project_file_texts=project_file_texts,
|
||||
project_image_files=project_image_files,
|
||||
project_as_filter=project_as_filter,
|
||||
return ExtractedContextFiles(
|
||||
file_texts=file_texts,
|
||||
image_files=image_files,
|
||||
use_as_search_filter=False,
|
||||
total_token_count=total_token_count,
|
||||
project_file_metadata=project_file_metadata,
|
||||
project_uncapped_token_count=project_tokens,
|
||||
file_metadata=file_metadata,
|
||||
uncapped_token_count=aggregate_tokens,
|
||||
)
|
||||
|
||||
|
||||
APPROX_CHARS_PER_TOKEN = 4
|
||||
|
||||
|
||||
def _build_file_tool_metadata_for_project(
|
||||
project_id: int,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> list[FileToolMetadata]:
|
||||
"""Build lightweight FileToolMetadata for every file in a project.
|
||||
|
||||
Used when files are too large to fit in context and the vector DB is
|
||||
disabled, so the LLM needs to know which files it can read via the
|
||||
FileReaderTool.
|
||||
"""
|
||||
project_user_files = get_user_files_from_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
return [
|
||||
FileToolMetadata(
|
||||
file_id=str(uf.id),
|
||||
filename=uf.name,
|
||||
approx_char_count=(uf.token_count or 0) * APPROX_CHARS_PER_TOKEN,
|
||||
)
|
||||
for uf in project_user_files
|
||||
]
|
||||
|
||||
|
||||
def _build_file_tool_metadata_for_user_files(
|
||||
user_files: list[UserFile],
|
||||
) -> list[FileToolMetadata]:
|
||||
@@ -381,55 +380,46 @@ def _build_file_tool_metadata_for_user_files(
|
||||
]
|
||||
|
||||
|
||||
def _get_project_search_availability(
|
||||
def determine_search_params(
|
||||
persona_id: int,
|
||||
project_id: int | None,
|
||||
persona_id: int | None,
|
||||
loaded_project_files: bool,
|
||||
project_has_files: bool,
|
||||
forced_tool_id: int | None,
|
||||
search_tool_id: int | None,
|
||||
) -> ProjectSearchConfig:
|
||||
"""Determine search tool availability based on project context.
|
||||
extracted_context_files: ExtractedContextFiles,
|
||||
) -> SearchParams:
|
||||
"""Decide which search filter IDs and search-tool usage apply for a chat turn.
|
||||
|
||||
Search is disabled when ALL of the following are true:
|
||||
- User is in a project
|
||||
- Using the default persona (not a custom agent)
|
||||
- Project files are already loaded in context
|
||||
A custom persona fully supersedes the project — project files are never
|
||||
searchable and the search tool config is entirely controlled by the
|
||||
persona. The project_id filter is only set for the default persona.
|
||||
|
||||
When search is disabled and the user tried to force the search tool,
|
||||
that forcing is also disabled.
|
||||
|
||||
Returns AUTO (follow persona config) in all other cases.
|
||||
For the default persona inside a project:
|
||||
- Files overflow → ENABLED (vector DB scopes to these files)
|
||||
- Files fit → DISABLED (content already in prompt)
|
||||
- No files at all → DISABLED (nothing to search)
|
||||
"""
|
||||
# Not in a project, this should have no impact on search tool availability
|
||||
if not project_id:
|
||||
return ProjectSearchConfig(
|
||||
search_usage=SearchToolUsage.AUTO, disable_forced_tool=False
|
||||
)
|
||||
is_custom_persona = persona_id != DEFAULT_PERSONA_ID
|
||||
|
||||
# Custom persona in project - let persona config decide
|
||||
# Even if there are no files in the project, it's still guided by the persona config.
|
||||
if persona_id != DEFAULT_PERSONA_ID:
|
||||
return ProjectSearchConfig(
|
||||
search_usage=SearchToolUsage.AUTO, disable_forced_tool=False
|
||||
)
|
||||
search_project_id: int | None = None
|
||||
search_persona_id: int | None = None
|
||||
if extracted_context_files.use_as_search_filter:
|
||||
if is_custom_persona:
|
||||
search_persona_id = persona_id
|
||||
else:
|
||||
search_project_id = project_id
|
||||
|
||||
# If in a project with the default persona and the files have been already loaded into the context or
|
||||
# there are no files in the project, disable search as there is nothing to search for.
|
||||
if loaded_project_files or not project_has_files:
|
||||
user_forced_search = (
|
||||
forced_tool_id is not None
|
||||
and search_tool_id is not None
|
||||
and forced_tool_id == search_tool_id
|
||||
)
|
||||
return ProjectSearchConfig(
|
||||
search_usage=SearchToolUsage.DISABLED,
|
||||
disable_forced_tool=user_forced_search,
|
||||
)
|
||||
search_usage = SearchToolUsage.AUTO
|
||||
if not is_custom_persona and project_id:
|
||||
has_context_files = bool(extracted_context_files.uncapped_token_count)
|
||||
files_loaded_in_context = bool(extracted_context_files.file_texts)
|
||||
|
||||
# Default persona in a project with files, but also the files have not been loaded into the context already.
|
||||
return ProjectSearchConfig(
|
||||
search_usage=SearchToolUsage.ENABLED, disable_forced_tool=False
|
||||
if extracted_context_files.use_as_search_filter:
|
||||
search_usage = SearchToolUsage.ENABLED
|
||||
elif files_loaded_in_context or not has_context_files:
|
||||
search_usage = SearchToolUsage.DISABLED
|
||||
|
||||
return SearchParams(
|
||||
search_project_id=search_project_id,
|
||||
search_persona_id=search_persona_id,
|
||||
search_usage=search_usage,
|
||||
)
|
||||
|
||||
|
||||
@@ -661,26 +651,37 @@ def handle_stream_message_objects(
|
||||
user_memory_context=prompt_memory_context,
|
||||
)
|
||||
|
||||
# Process projects, if all of the files fit in the context, it doesn't need to use RAG
|
||||
extracted_project_files = _extract_project_file_texts_and_images(
|
||||
# Determine which user files to use. A custom persona fully
|
||||
# supersedes the project — project files are never loaded or
|
||||
# searchable when a custom persona is in play. Only the default
|
||||
# persona inside a project uses the project's files.
|
||||
context_user_files = resolve_context_user_files(
|
||||
persona=persona,
|
||||
project_id=chat_session.project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
extracted_context_files = extract_context_files(
|
||||
user_files=context_user_files,
|
||||
llm_max_context_window=llm.config.max_input_tokens,
|
||||
reserved_token_count=reserved_token_count,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# When the vector DB is disabled, persona-attached user_files have no
|
||||
# search pipeline path. Inject them as file_metadata_for_tool so the
|
||||
# LLM can read them via the FileReaderTool.
|
||||
if DISABLE_VECTOR_DB and persona.user_files:
|
||||
persona_file_metadata = _build_file_tool_metadata_for_user_files(
|
||||
persona.user_files
|
||||
)
|
||||
# Merge persona file metadata into the extracted project files
|
||||
extracted_project_files.file_metadata_for_tool.extend(persona_file_metadata)
|
||||
search_params = determine_search_params(
|
||||
persona_id=persona.id,
|
||||
project_id=chat_session.project_id,
|
||||
extracted_context_files=extracted_context_files,
|
||||
)
|
||||
|
||||
# Also grant access to persona-attached user files for FileReaderTool
|
||||
if persona.user_files:
|
||||
existing = set(available_files.user_file_ids)
|
||||
for uf in persona.user_files:
|
||||
if uf.id not in existing:
|
||||
available_files.user_file_ids.append(uf.id)
|
||||
|
||||
# Build a mapping of tool_id to tool_name for history reconstruction
|
||||
all_tools = get_tools(db_session)
|
||||
tool_id_to_name_map = {tool.id: tool.name for tool in all_tools}
|
||||
|
||||
@@ -689,30 +690,17 @@ def handle_stream_message_objects(
|
||||
None,
|
||||
)
|
||||
|
||||
# Determine if search should be disabled for this project context
|
||||
forced_tool_id = new_msg_req.forced_tool_id
|
||||
project_search_config = _get_project_search_availability(
|
||||
project_id=chat_session.project_id,
|
||||
persona_id=persona.id,
|
||||
loaded_project_files=bool(extracted_project_files.project_file_texts),
|
||||
project_has_files=bool(
|
||||
extracted_project_files.project_uncapped_token_count
|
||||
),
|
||||
forced_tool_id=new_msg_req.forced_tool_id,
|
||||
search_tool_id=search_tool_id,
|
||||
)
|
||||
if project_search_config.disable_forced_tool:
|
||||
if (
|
||||
search_params.search_usage == SearchToolUsage.DISABLED
|
||||
and forced_tool_id is not None
|
||||
and search_tool_id is not None
|
||||
and forced_tool_id == search_tool_id
|
||||
):
|
||||
forced_tool_id = None
|
||||
|
||||
emitter = get_default_emitter()
|
||||
|
||||
# Also grant access to persona-attached user files
|
||||
if persona.user_files:
|
||||
existing = set(available_files.user_file_ids)
|
||||
for uf in persona.user_files:
|
||||
if uf.id not in existing:
|
||||
available_files.user_file_ids.append(uf.id)
|
||||
|
||||
# Construct tools based on the persona configurations
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
@@ -722,11 +710,8 @@ def handle_stream_message_objects(
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=new_msg_req.internal_search_filters,
|
||||
project_id=(
|
||||
chat_session.project_id
|
||||
if extracted_project_files.project_as_filter
|
||||
else None
|
||||
),
|
||||
project_id=search_params.search_project_id,
|
||||
persona_id=search_params.search_persona_id,
|
||||
bypass_acl=bypass_acl,
|
||||
slack_context=slack_context,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
@@ -744,7 +729,7 @@ def handle_stream_message_objects(
|
||||
chat_file_ids=available_files.chat_file_ids,
|
||||
),
|
||||
allowed_tool_ids=new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=project_search_config.search_usage,
|
||||
search_usage_forcing_setting=search_params.search_usage,
|
||||
)
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
@@ -783,7 +768,7 @@ def handle_stream_message_objects(
|
||||
chat_history_result = convert_chat_history(
|
||||
chat_history=chat_history,
|
||||
files=files,
|
||||
project_image_files=extracted_project_files.project_image_files,
|
||||
context_image_files=extracted_context_files.image_files,
|
||||
additional_context=additional_context,
|
||||
token_counter=token_counter,
|
||||
tool_id_to_name_map=tool_id_to_name_map,
|
||||
@@ -879,46 +864,54 @@ def handle_stream_message_objects(
|
||||
# (user has already responded to a clarification question)
|
||||
skip_clarification = is_last_assistant_message_clarification(chat_history)
|
||||
|
||||
# NOTE: we _could_ pass in a zero argument function since emitter and state_container
|
||||
# are just passed in immediately anyways, but the abstraction is cleaner this way.
|
||||
yield from run_chat_loop_with_state_containers(
|
||||
run_deep_research_llm_loop,
|
||||
lambda emitter, state_container: run_deep_research_llm_loop(
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
skip_clarification=skip_clarification,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
),
|
||||
llm_loop_completion_callback,
|
||||
is_connected=check_is_connected,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
skip_clarification=skip_clarification,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
else:
|
||||
yield from run_chat_loop_with_state_containers(
|
||||
run_llm_loop,
|
||||
lambda emitter, state_container: run_llm_loop(
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
context_files=extracted_context_files,
|
||||
persona=persona,
|
||||
user_memory_context=user_memory_context,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
forced_tool_id=forced_tool_id,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
chat_files=chat_files_for_tools,
|
||||
include_citations=new_msg_req.include_citations,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
),
|
||||
llm_loop_completion_callback,
|
||||
is_connected=check_is_connected, # Not passed through to run_llm_loop
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
project_files=extracted_project_files,
|
||||
persona=persona,
|
||||
user_memory_context=user_memory_context,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
forced_tool_id=forced_tool_id,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
chat_files=chat_files_for_tools,
|
||||
include_citations=new_msg_req.include_citations,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
|
||||
@@ -294,6 +294,12 @@ ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX = (
|
||||
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
and os.environ.get("ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX", "").lower() == "true"
|
||||
)
|
||||
# Whether we should check for and create an index if necessary every time we
|
||||
# instantiate an OpenSearchDocumentIndex on multitenant cloud. Defaults to True.
|
||||
VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT = (
|
||||
os.environ.get("VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT", "true").lower()
|
||||
== "true"
|
||||
)
|
||||
|
||||
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
|
||||
# NOTE: this is used if and only if the vespa config server is accessible via a
|
||||
|
||||
@@ -32,6 +32,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
BASE_URL = "https://api.gong.io"
|
||||
MAX_CALL_DETAILS_ATTEMPTS = 6
|
||||
CALL_DETAILS_DELAY = 30 # in seconds
|
||||
# Gong API limit is 3 calls/sec — stay safely under it
|
||||
MIN_REQUEST_INTERVAL = 0.5 # seconds between requests
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -45,9 +47,13 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
self.continue_on_fail = continue_on_fail
|
||||
self.auth_token_basic: str | None = None
|
||||
self.hide_user_info = hide_user_info
|
||||
self._last_request_time: float = 0.0
|
||||
|
||||
# urllib3 Retry already respects the Retry-After header by default
|
||||
# (respect_retry_after_header=True), so on 429 it will sleep for the
|
||||
# duration Gong specifies before retrying.
|
||||
retry_strategy = Retry(
|
||||
total=5,
|
||||
total=10,
|
||||
backoff_factor=2,
|
||||
status_forcelist=[429, 500, 502, 503, 504],
|
||||
)
|
||||
@@ -61,8 +67,24 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
url = f"{GongConnector.BASE_URL}{endpoint}"
|
||||
return url
|
||||
|
||||
def _throttled_request(
|
||||
self, method: str, url: str, **kwargs: Any
|
||||
) -> requests.Response:
|
||||
"""Rate-limited request wrapper. Enforces MIN_REQUEST_INTERVAL between
|
||||
calls to stay under Gong's 3 calls/sec limit and avoid triggering 429s."""
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._last_request_time
|
||||
if elapsed < self.MIN_REQUEST_INTERVAL:
|
||||
time.sleep(self.MIN_REQUEST_INTERVAL - elapsed)
|
||||
|
||||
response = self._session.request(method, url, **kwargs)
|
||||
self._last_request_time = time.monotonic()
|
||||
return response
|
||||
|
||||
def _get_workspace_id_map(self) -> dict[str, str]:
|
||||
response = self._session.get(GongConnector.make_url("/v2/workspaces"))
|
||||
response = self._throttled_request(
|
||||
"GET", GongConnector.make_url("/v2/workspaces")
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
workspaces_details = response.json().get("workspaces")
|
||||
@@ -106,8 +128,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
del body["filter"]["workspaceId"]
|
||||
|
||||
while True:
|
||||
response = self._session.post(
|
||||
GongConnector.make_url("/v2/calls/transcript"), json=body
|
||||
response = self._throttled_request(
|
||||
"POST", GongConnector.make_url("/v2/calls/transcript"), json=body
|
||||
)
|
||||
# If no calls in the range, just break out
|
||||
if response.status_code == 404:
|
||||
@@ -142,8 +164,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
"contentSelector": {"exposedFields": {"parties": True}},
|
||||
}
|
||||
|
||||
response = self._session.post(
|
||||
GongConnector.make_url("/v2/calls/extensive"), json=body
|
||||
response = self._throttled_request(
|
||||
"POST", GongConnector.make_url("/v2/calls/extensive"), json=body
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -194,7 +216,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
# There's a likely race condition in the API where a transcript will have a
|
||||
# call id but the call to v2/calls/extensive will not return all of the id's
|
||||
# retry with exponential backoff has been observed to mitigate this
|
||||
# in ~2 minutes
|
||||
# in ~2 minutes. After max attempts, proceed with whatever we have —
|
||||
# the per-call loop below will skip missing IDs gracefully.
|
||||
current_attempt = 0
|
||||
while True:
|
||||
current_attempt += 1
|
||||
@@ -213,11 +236,14 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
f"missing_call_ids={missing_call_ids}"
|
||||
)
|
||||
if current_attempt >= self.MAX_CALL_DETAILS_ATTEMPTS:
|
||||
raise RuntimeError(
|
||||
f"Attempt count exceeded for _get_call_details_by_ids: "
|
||||
f"missing_call_ids={missing_call_ids} "
|
||||
f"max_attempts={self.MAX_CALL_DETAILS_ATTEMPTS}"
|
||||
logger.error(
|
||||
f"Giving up on missing call id's after "
|
||||
f"{self.MAX_CALL_DETAILS_ATTEMPTS} attempts: "
|
||||
f"missing_call_ids={missing_call_ids} — "
|
||||
f"proceeding with {len(call_details_map)} of "
|
||||
f"{len(transcript_call_ids)} calls"
|
||||
)
|
||||
break
|
||||
|
||||
wait_seconds = self.CALL_DETAILS_DELAY * pow(2, current_attempt - 1)
|
||||
logger.warning(
|
||||
|
||||
@@ -23,7 +23,6 @@ from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.serialization import pkcs12
|
||||
from office365.graph_client import GraphClient # type: ignore[import-untyped]
|
||||
from office365.intune.organizations.organization import Organization # type: ignore[import-untyped]
|
||||
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
|
||||
from office365.onedrive.sites.site import Site # type: ignore[import-untyped]
|
||||
from office365.onedrive.sites.sites_with_root import SitesWithRoot # type: ignore[import-untyped]
|
||||
@@ -872,6 +871,56 @@ class SharepointConnector(
|
||||
"Site URLs must be full Sharepoint URLs (e.g. https://your-tenant.sharepoint.com/sites/your-site or https://your-tenant.sharepoint.com/teams/your-team)"
|
||||
)
|
||||
|
||||
def _extract_tenant_domain_from_sites(self) -> str | None:
|
||||
"""Extract the tenant domain from configured site URLs.
|
||||
|
||||
Site URLs look like https://{tenant}.sharepoint.com/sites/... so the
|
||||
tenant domain is the first label of the hostname.
|
||||
"""
|
||||
for site_url in self.sites:
|
||||
try:
|
||||
hostname = urlsplit(site_url.strip()).hostname
|
||||
except ValueError:
|
||||
continue
|
||||
if not hostname:
|
||||
continue
|
||||
tenant = hostname.split(".")[0]
|
||||
if tenant:
|
||||
return tenant
|
||||
logger.warning(f"No tenant domain found from {len(self.sites)} sites")
|
||||
return None
|
||||
|
||||
def _resolve_tenant_domain_from_root_site(self) -> str:
|
||||
"""Resolve tenant domain via GET /v1.0/sites/root which only requires
|
||||
Sites.Read.All (a permission the connector already needs)."""
|
||||
root_site = self.graph_client.sites.root.get().execute_query()
|
||||
hostname = root_site.site_collection.hostname
|
||||
if not hostname:
|
||||
raise ConnectorValidationError(
|
||||
"Could not determine tenant domain from root site"
|
||||
)
|
||||
tenant_domain = hostname.split(".")[0]
|
||||
logger.info(
|
||||
"Resolved tenant domain '%s' from root site hostname '%s'",
|
||||
tenant_domain,
|
||||
hostname,
|
||||
)
|
||||
return tenant_domain
|
||||
|
||||
def _resolve_tenant_domain(self) -> str:
|
||||
"""Determine the tenant domain, preferring site URLs over a Graph API
|
||||
call to avoid needing extra permissions."""
|
||||
from_sites = self._extract_tenant_domain_from_sites()
|
||||
if from_sites:
|
||||
logger.info(
|
||||
"Resolved tenant domain '%s' from site URLs",
|
||||
from_sites,
|
||||
)
|
||||
return from_sites
|
||||
|
||||
logger.info("No site URLs available; resolving tenant domain from root site")
|
||||
return self._resolve_tenant_domain_from_root_site()
|
||||
|
||||
@property
|
||||
def graph_client(self) -> GraphClient:
|
||||
if self._graph_client is None:
|
||||
@@ -1589,6 +1638,11 @@ class SharepointConnector(
|
||||
sp_private_key = credentials.get("sp_private_key")
|
||||
sp_certificate_password = credentials.get("sp_certificate_password")
|
||||
|
||||
if not sp_client_id:
|
||||
raise ConnectorValidationError("Client ID is required")
|
||||
if not sp_directory_id:
|
||||
raise ConnectorValidationError("Directory (tenant) ID is required")
|
||||
|
||||
authority_url = f"{self.authority_host}/{sp_directory_id}"
|
||||
|
||||
if auth_method == SharepointAuthMethod.CERTIFICATE.value:
|
||||
@@ -1641,21 +1695,7 @@ class SharepointConnector(
|
||||
_acquire_token_for_graph, environment=self._azure_environment
|
||||
)
|
||||
if auth_method == SharepointAuthMethod.CERTIFICATE.value:
|
||||
org = self.graph_client.organization.get().execute_query()
|
||||
if not org or len(org) == 0:
|
||||
raise ConnectorValidationError("No organization found")
|
||||
|
||||
tenant_info: Organization = org[
|
||||
0
|
||||
] # Access first item directly from collection
|
||||
if not tenant_info.verified_domains:
|
||||
raise ConnectorValidationError("No verified domains found for tenant")
|
||||
|
||||
sp_tenant_domain = tenant_info.verified_domains[0].name
|
||||
if not sp_tenant_domain:
|
||||
raise ConnectorValidationError("No verified domains found for tenant")
|
||||
# remove the .onmicrosoft.com part
|
||||
self.sp_tenant_domain = sp_tenant_domain.split(".")[0]
|
||||
self.sp_tenant_domain = self._resolve_tenant_domain()
|
||||
return None
|
||||
|
||||
def _get_drive_names_for_site(self, site_url: str) -> list[str]:
|
||||
|
||||
@@ -21,8 +21,8 @@ from onyx.configs.app_configs import POSTGRES_POOL_RECYCLE
|
||||
from onyx.configs.app_configs import POSTGRES_PORT
|
||||
from onyx.configs.app_configs import POSTGRES_USE_NULL_POOL
|
||||
from onyx.configs.app_configs import POSTGRES_USER
|
||||
from onyx.db.engine.iam_auth import create_ssl_context_if_iam
|
||||
from onyx.db.engine.iam_auth import get_iam_auth_token
|
||||
from onyx.db.engine.iam_auth import ssl_context
|
||||
from onyx.db.engine.sql_engine import ASYNC_DB_API
|
||||
from onyx.db.engine.sql_engine import build_connection_string
|
||||
from onyx.db.engine.sql_engine import is_valid_schema_name
|
||||
@@ -66,7 +66,7 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
if app_name:
|
||||
connect_args["server_settings"] = {"application_name": app_name}
|
||||
|
||||
connect_args["ssl"] = ssl_context
|
||||
connect_args["ssl"] = create_ssl_context_if_iam()
|
||||
|
||||
engine_kwargs = {
|
||||
"connect_args": connect_args,
|
||||
@@ -97,7 +97,7 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
user = POSTGRES_USER
|
||||
token = get_iam_auth_token(host, port, user, AWS_REGION_NAME)
|
||||
cparams["password"] = token
|
||||
cparams["ssl"] = ssl_context
|
||||
cparams["ssl"] = create_ssl_context_if_iam()
|
||||
|
||||
return _ASYNC_ENGINE
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import functools
|
||||
import os
|
||||
import ssl
|
||||
from typing import Any
|
||||
@@ -48,11 +49,9 @@ def provide_iam_token(
|
||||
configure_psycopg2_iam_auth(cparams, host, port, user, region)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def create_ssl_context_if_iam() -> ssl.SSLContext | None:
|
||||
"""Create an SSL context if IAM authentication is enabled, else return None."""
|
||||
if USE_IAM_AUTH:
|
||||
return ssl.create_default_context(cafile=SSL_CERT_FILE)
|
||||
return None
|
||||
|
||||
|
||||
ssl_context = create_ssl_context_if_iam()
|
||||
|
||||
@@ -202,7 +202,6 @@ def create_default_image_gen_config_from_api_key(
|
||||
api_key=api_key,
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
default_model_name=model_name,
|
||||
deployment_name=None,
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
@@ -213,11 +213,29 @@ def upsert_llm_provider(
|
||||
llm_provider_upsert_request: LLMProviderUpsertRequest,
|
||||
db_session: Session,
|
||||
) -> LLMProviderView:
|
||||
existing_llm_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
)
|
||||
existing_llm_provider: LLMProviderModel | None = None
|
||||
if llm_provider_upsert_request.id:
|
||||
existing_llm_provider = fetch_existing_llm_provider_by_id(
|
||||
id=llm_provider_upsert_request.id, db_session=db_session
|
||||
)
|
||||
if not existing_llm_provider:
|
||||
raise ValueError(
|
||||
f"LLM provider with id {llm_provider_upsert_request.id} not found"
|
||||
)
|
||||
|
||||
if not existing_llm_provider:
|
||||
if existing_llm_provider.name != llm_provider_upsert_request.name:
|
||||
raise ValueError(
|
||||
f"LLM provider with id {llm_provider_upsert_request.id} name change not allowed"
|
||||
)
|
||||
else:
|
||||
existing_llm_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
)
|
||||
if existing_llm_provider:
|
||||
raise ValueError(
|
||||
f"LLM provider with name '{llm_provider_upsert_request.name}'"
|
||||
" already exists"
|
||||
)
|
||||
existing_llm_provider = LLMProviderModel(name=llm_provider_upsert_request.name)
|
||||
db_session.add(existing_llm_provider)
|
||||
|
||||
@@ -238,11 +256,7 @@ def upsert_llm_provider(
|
||||
existing_llm_provider.api_base = api_base
|
||||
existing_llm_provider.api_version = llm_provider_upsert_request.api_version
|
||||
existing_llm_provider.custom_config = custom_config
|
||||
# TODO: Remove default model name on api change
|
||||
# Needed due to /provider/{id}/default endpoint not disclosing the default model name
|
||||
existing_llm_provider.default_model_name = (
|
||||
llm_provider_upsert_request.default_model_name
|
||||
)
|
||||
|
||||
existing_llm_provider.is_public = llm_provider_upsert_request.is_public
|
||||
existing_llm_provider.is_auto_mode = llm_provider_upsert_request.is_auto_mode
|
||||
existing_llm_provider.deployment_name = llm_provider_upsert_request.deployment_name
|
||||
@@ -306,15 +320,6 @@ def upsert_llm_provider(
|
||||
display_name=model_config.display_name,
|
||||
)
|
||||
|
||||
default_model = fetch_default_model(db_session, LLMModelFlowType.CHAT)
|
||||
if default_model and default_model.llm_provider_id == existing_llm_provider.id:
|
||||
_update_default_model(
|
||||
db_session=db_session,
|
||||
provider_id=existing_llm_provider.id,
|
||||
model=existing_llm_provider.default_model_name,
|
||||
flow_type=LLMModelFlowType.CHAT,
|
||||
)
|
||||
|
||||
# Make sure the relationship table stays up to date
|
||||
update_group_llm_provider_relationships__no_commit(
|
||||
llm_provider_id=existing_llm_provider.id,
|
||||
@@ -488,6 +493,22 @@ def fetch_existing_llm_provider(
|
||||
return provider_model
|
||||
|
||||
|
||||
def fetch_existing_llm_provider_by_id(
|
||||
id: int, db_session: Session
|
||||
) -> LLMProviderModel | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel)
|
||||
.where(LLMProviderModel.id == id)
|
||||
.options(
|
||||
selectinload(LLMProviderModel.model_configurations),
|
||||
selectinload(LLMProviderModel.groups),
|
||||
selectinload(LLMProviderModel.personas),
|
||||
)
|
||||
)
|
||||
|
||||
return provider_model
|
||||
|
||||
|
||||
def fetch_embedding_provider(
|
||||
db_session: Session, provider_type: EmbeddingProvider
|
||||
) -> CloudEmbeddingProviderModel | None:
|
||||
@@ -604,22 +625,13 @@ def remove_llm_provider__no_commit(db_session: Session, provider_id: int) -> Non
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def update_default_provider(provider_id: int, db_session: Session) -> None:
|
||||
# Attempt to get the default_model_name from the provider first
|
||||
# TODO: Remove default_model_name check
|
||||
provider = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.id == provider_id,
|
||||
)
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f"LLM Provider with id={provider_id} does not exist")
|
||||
|
||||
def update_default_provider(
|
||||
provider_id: int, model_name: str, db_session: Session
|
||||
) -> None:
|
||||
_update_default_model(
|
||||
db_session,
|
||||
provider_id,
|
||||
provider.default_model_name,
|
||||
model_name,
|
||||
LLMModelFlowType.CHAT,
|
||||
)
|
||||
|
||||
@@ -805,12 +817,6 @@ def sync_auto_mode_models(
|
||||
)
|
||||
changes += 1
|
||||
|
||||
# In Auto mode, default model is always set from GitHub config
|
||||
default_model = llm_recommendations.get_default_model(provider.provider)
|
||||
if default_model and provider.default_model_name != default_model.name:
|
||||
provider.default_model_name = default_model.name
|
||||
changes += 1
|
||||
|
||||
db_session.commit()
|
||||
return changes
|
||||
|
||||
|
||||
@@ -2822,13 +2822,17 @@ class LLMProvider(Base):
|
||||
custom_config: Mapped[dict[str, str] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
default_model_name: Mapped[str] = mapped_column(String)
|
||||
|
||||
# Deprecated: use LLMModelFlow with CHAT flow type instead
|
||||
default_model_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
# should only be set for a single provider
|
||||
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
|
||||
# Deprecated: use LLMModelFlow.is_default with CHAT flow type instead
|
||||
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
# Deprecated: use LLMModelFlow.is_default with VISION flow type instead
|
||||
is_default_vision_provider: Mapped[bool | None] = mapped_column(Boolean)
|
||||
# Deprecated: use LLMModelFlow with VISION flow type instead
|
||||
default_vision_model: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
# EE only
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
@@ -2879,6 +2883,7 @@ class ModelConfiguration(Base):
|
||||
# - The end-user is configuring a model and chooses not to set a max-input-tokens limit.
|
||||
max_input_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Deprecated: use LLMModelFlow with VISION flow type instead
|
||||
supports_image_input: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
|
||||
# Human-readable display name for the model.
|
||||
|
||||
@@ -256,9 +256,6 @@ def create_update_persona(
|
||||
try:
|
||||
# Default persona validation
|
||||
if create_persona_request.is_default_persona:
|
||||
if not create_persona_request.is_public:
|
||||
raise ValueError("Cannot make a default persona non public")
|
||||
|
||||
# Curators can edit default personas, but not make them
|
||||
if user.role == UserRole.CURATOR or user.role == UserRole.GLOBAL_CURATOR:
|
||||
pass
|
||||
@@ -335,6 +332,7 @@ def update_persona_shared(
|
||||
db_session: Session,
|
||||
group_ids: list[int] | None = None,
|
||||
is_public: bool | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
) -> None:
|
||||
"""Simplified version of `create_update_persona` which only touches the
|
||||
accessibility rather than any of the logic (e.g. prompt, connected data sources,
|
||||
@@ -344,9 +342,7 @@ def update_persona_shared(
|
||||
)
|
||||
|
||||
if user and user.role != UserRole.ADMIN and persona.user_id != user.id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have permission to modify this persona"
|
||||
)
|
||||
raise PermissionError("You don't have permission to modify this persona")
|
||||
|
||||
versioned_update_persona_access = fetch_versioned_implementation(
|
||||
"onyx.db.persona", "update_persona_access"
|
||||
@@ -360,6 +356,15 @@ def update_persona_shared(
|
||||
group_ids=group_ids,
|
||||
)
|
||||
|
||||
if label_ids is not None:
|
||||
labels = (
|
||||
db_session.query(PersonaLabel).filter(PersonaLabel.id.in_(label_ids)).all()
|
||||
)
|
||||
if len(labels) != len(label_ids):
|
||||
raise ValueError("Some label IDs were not found in the database")
|
||||
persona.labels.clear()
|
||||
persona.labels = labels
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -965,6 +970,8 @@ def upsert_persona(
|
||||
labels = (
|
||||
db_session.query(PersonaLabel).filter(PersonaLabel.id.in_(label_ids)).all()
|
||||
)
|
||||
if len(labels) != len(label_ids):
|
||||
raise ValueError("Some label IDs were not found in the database")
|
||||
|
||||
# Fetch and attach hierarchy_nodes by IDs
|
||||
hierarchy_nodes = None
|
||||
@@ -1161,9 +1168,6 @@ def update_persona_is_default(
|
||||
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
|
||||
)
|
||||
|
||||
if not persona.is_public:
|
||||
persona.is_public = True
|
||||
|
||||
persona.is_default_persona = is_default
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import Project__UserFile
|
||||
from onyx.db.models import UserFile
|
||||
|
||||
|
||||
@@ -57,12 +58,19 @@ def fetch_user_project_ids_for_user_files(
|
||||
db_session: Session,
|
||||
) -> dict[str, list[int]]:
|
||||
"""Fetch user project ids for specified user files"""
|
||||
stmt = select(UserFile).where(UserFile.id.in_(user_file_ids))
|
||||
results = db_session.execute(stmt).scalars().all()
|
||||
return {
|
||||
str(user_file.id): [project.id for project in user_file.projects]
|
||||
for user_file in results
|
||||
user_file_uuid_ids = [UUID(user_file_id) for user_file_id in user_file_ids]
|
||||
stmt = select(Project__UserFile.user_file_id, Project__UserFile.project_id).where(
|
||||
Project__UserFile.user_file_id.in_(user_file_uuid_ids)
|
||||
)
|
||||
rows = db_session.execute(stmt).all()
|
||||
|
||||
user_file_id_to_project_ids: dict[str, list[int]] = {
|
||||
user_file_id: [] for user_file_id in user_file_ids
|
||||
}
|
||||
for user_file_id, project_id in rows:
|
||||
user_file_id_to_project_ids[str(user_file_id)].append(project_id)
|
||||
|
||||
return user_file_id_to_project_ids
|
||||
|
||||
|
||||
def fetch_persona_ids_for_user_files(
|
||||
|
||||
@@ -139,7 +139,7 @@ def generate_final_report(
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=history,
|
||||
reminder_message=reminder_message,
|
||||
project_files=None,
|
||||
context_files=None,
|
||||
available_tokens=llm.config.max_input_tokens,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
@@ -257,7 +257,7 @@ def run_deep_research_llm_loop(
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=None,
|
||||
context_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
@@ -321,7 +321,7 @@ def run_deep_research_llm_loop(
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history + [reminder_message],
|
||||
reminder_message=None,
|
||||
project_files=None,
|
||||
context_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT + 1,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
@@ -485,7 +485,7 @@ def run_deep_research_llm_loop(
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=first_cycle_reminder_message,
|
||||
project_files=None,
|
||||
context_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
|
||||
@@ -11,6 +11,7 @@ from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchOldDocumentIndex,
|
||||
)
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
@@ -49,8 +50,11 @@ def get_default_document_index(
|
||||
|
||||
opensearch_retrieval_enabled = get_opensearch_retrieval_state(db_session)
|
||||
if opensearch_retrieval_enabled:
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
return OpenSearchOldDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
embedding_dim=indexing_setting.final_embedding_dim,
|
||||
embedding_precision=indexing_setting.embedding_precision,
|
||||
secondary_index_name=secondary_index_name,
|
||||
large_chunks_enabled=search_settings.large_chunks_enabled,
|
||||
secondary_large_chunks_enabled=secondary_large_chunks_enabled,
|
||||
@@ -118,8 +122,11 @@ def get_all_document_indices(
|
||||
)
|
||||
opensearch_document_index: OpenSearchOldDocumentIndex | None = None
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
opensearch_document_index = OpenSearchOldDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
embedding_dim=indexing_setting.final_embedding_dim,
|
||||
embedding_precision=indexing_setting.embedding_precision,
|
||||
secondary_index_name=None,
|
||||
large_chunks_enabled=False,
|
||||
secondary_large_chunks_enabled=None,
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import logging
|
||||
import time
|
||||
from contextlib import AbstractContextManager
|
||||
from contextlib import nullcontext
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TypeVar
|
||||
@@ -83,22 +85,26 @@ def get_new_body_without_vectors(body: dict[str, Any]) -> dict[str, Any]:
|
||||
return new_body
|
||||
|
||||
|
||||
class OpenSearchClient:
|
||||
"""Client for interacting with OpenSearch.
|
||||
class OpenSearchClient(AbstractContextManager):
|
||||
"""Client for interacting with OpenSearch for cluster-level operations.
|
||||
|
||||
OpenSearch's Python module has pretty bad typing support so this client
|
||||
attempts to protect the rest of the codebase from this. As a consequence,
|
||||
most methods here return the minimum data needed for the rest of Onyx, and
|
||||
tend to rely on Exceptions to handle errors.
|
||||
|
||||
TODO(andrei): This class currently assumes the structure of the database
|
||||
schema when it returns a DocumentChunk. Make the class, or at least the
|
||||
search method, templated on the structure the caller can expect.
|
||||
Args:
|
||||
host: The host of the OpenSearch cluster.
|
||||
port: The port of the OpenSearch cluster.
|
||||
auth: The authentication credentials for the OpenSearch cluster. A tuple
|
||||
of (username, password).
|
||||
use_ssl: Whether to use SSL for the OpenSearch cluster. Defaults to
|
||||
True.
|
||||
verify_certs: Whether to verify the SSL certificates for the OpenSearch
|
||||
cluster. Defaults to False.
|
||||
ssl_show_warn: Whether to show warnings for SSL certificates. Defaults
|
||||
to False.
|
||||
timeout: The timeout for the OpenSearch cluster. Defaults to
|
||||
DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
host: str = OPENSEARCH_HOST,
|
||||
port: int = OPENSEARCH_REST_API_PORT,
|
||||
auth: tuple[str, str] = (OPENSEARCH_ADMIN_USERNAME, OPENSEARCH_ADMIN_PASSWORD),
|
||||
@@ -107,9 +113,8 @@ class OpenSearchClient:
|
||||
ssl_show_warn: bool = False,
|
||||
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
|
||||
):
|
||||
self._index_name = index_name
|
||||
logger.debug(
|
||||
f"Creating OpenSearch client for index {index_name} with host {host} and port {port} and timeout {timeout} seconds."
|
||||
f"Creating OpenSearch client with host {host}, port {port} and timeout {timeout} seconds."
|
||||
)
|
||||
self._client = OpenSearch(
|
||||
hosts=[{"host": host, "port": port}],
|
||||
@@ -125,6 +130,142 @@ class OpenSearchClient:
|
||||
# your request body that is less than this value.
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
def __exit__(self, *_: Any) -> None:
|
||||
self.close()
|
||||
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
self.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def create_search_pipeline(
|
||||
self,
|
||||
pipeline_id: str,
|
||||
pipeline_body: dict[str, Any],
|
||||
) -> None:
|
||||
"""Creates a search pipeline.
|
||||
|
||||
See the OpenSearch documentation for more information on the search
|
||||
pipeline body.
|
||||
https://docs.opensearch.org/latest/search-plugins/search-pipelines/index/
|
||||
|
||||
Args:
|
||||
pipeline_id: The ID of the search pipeline to create.
|
||||
pipeline_body: The body of the search pipeline to create.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error creating the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
|
||||
if not result.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to create search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def delete_search_pipeline(self, pipeline_id: str) -> None:
|
||||
"""Deletes a search pipeline.
|
||||
|
||||
Args:
|
||||
pipeline_id: The ID of the search pipeline to delete.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error deleting the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.delete(id=pipeline_id)
|
||||
if not result.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to delete search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def put_cluster_settings(self, settings: dict[str, Any]) -> bool:
|
||||
"""Puts cluster settings.
|
||||
|
||||
Args:
|
||||
settings: The settings to put.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error putting the cluster settings.
|
||||
|
||||
Returns:
|
||||
True if the settings were put successfully, False otherwise.
|
||||
"""
|
||||
response = self._client.cluster.put_settings(body=settings)
|
||||
if response.get("acknowledged", False):
|
||||
logger.info("Successfully put cluster settings.")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to put cluster settings: {response}.")
|
||||
return False
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def ping(self) -> bool:
|
||||
"""Pings the OpenSearch cluster.
|
||||
|
||||
Returns:
|
||||
True if OpenSearch could be reached, False if it could not.
|
||||
"""
|
||||
return self._client.ping()
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def close(self) -> None:
|
||||
"""Closes the client.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error closing the client.
|
||||
"""
|
||||
self._client.close()
|
||||
|
||||
|
||||
class OpenSearchIndexClient(OpenSearchClient):
|
||||
"""Client for interacting with OpenSearch for index-level operations.
|
||||
|
||||
OpenSearch's Python module has pretty bad typing support so this client
|
||||
attempts to protect the rest of the codebase from this. As a consequence,
|
||||
most methods here return the minimum data needed for the rest of Onyx, and
|
||||
tend to rely on Exceptions to handle errors.
|
||||
|
||||
TODO(andrei): This class currently assumes the structure of the database
|
||||
schema when it returns a DocumentChunk. Make the class, or at least the
|
||||
search method, templated on the structure the caller can expect.
|
||||
|
||||
Args:
|
||||
index_name: The name of the index to interact with.
|
||||
host: The host of the OpenSearch cluster.
|
||||
port: The port of the OpenSearch cluster.
|
||||
auth: The authentication credentials for the OpenSearch cluster. A tuple
|
||||
of (username, password).
|
||||
use_ssl: Whether to use SSL for the OpenSearch cluster. Defaults to
|
||||
True.
|
||||
verify_certs: Whether to verify the SSL certificates for the OpenSearch
|
||||
cluster. Defaults to False.
|
||||
ssl_show_warn: Whether to show warnings for SSL certificates. Defaults
|
||||
to False.
|
||||
timeout: The timeout for the OpenSearch cluster. Defaults to
|
||||
DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
host: str = OPENSEARCH_HOST,
|
||||
port: int = OPENSEARCH_REST_API_PORT,
|
||||
auth: tuple[str, str] = (OPENSEARCH_ADMIN_USERNAME, OPENSEARCH_ADMIN_PASSWORD),
|
||||
use_ssl: bool = True,
|
||||
verify_certs: bool = False,
|
||||
ssl_show_warn: bool = False,
|
||||
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
|
||||
):
|
||||
super().__init__(
|
||||
host=host,
|
||||
port=port,
|
||||
auth=auth,
|
||||
use_ssl=use_ssl,
|
||||
verify_certs=verify_certs,
|
||||
ssl_show_warn=ssl_show_warn,
|
||||
timeout=timeout,
|
||||
)
|
||||
self._index_name = index_name
|
||||
logger.debug(
|
||||
f"OpenSearch client created successfully for index {self._index_name}."
|
||||
)
|
||||
@@ -192,6 +333,38 @@ class OpenSearchClient:
|
||||
"""
|
||||
return self._client.indices.exists(index=self._index_name)
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def put_mapping(self, mappings: dict[str, Any]) -> None:
|
||||
"""Updates the index mapping in an idempotent manner.
|
||||
|
||||
- Existing fields with the same definition: No-op (succeeds silently).
|
||||
- New fields: Added to the index.
|
||||
- Existing fields with different types: Raises exception (requires
|
||||
reindex).
|
||||
|
||||
See the OpenSearch documentation for more information:
|
||||
https://docs.opensearch.org/latest/api-reference/index-apis/put-mapping/
|
||||
|
||||
Args:
|
||||
mappings: The complete mapping definition to apply. This will be
|
||||
merged with existing mappings in the index.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error updating the mappings, such as
|
||||
attempting to change the type of an existing field.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Putting mappings for index {self._index_name} with mappings {mappings}."
|
||||
)
|
||||
response = self._client.indices.put_mapping(
|
||||
index=self._index_name, body=mappings
|
||||
)
|
||||
if not response.get("acknowledged", False):
|
||||
raise RuntimeError(
|
||||
f"Failed to put the mapping update for index {self._index_name}."
|
||||
)
|
||||
logger.debug(f"Successfully put mappings for index {self._index_name}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def validate_index(self, expected_mappings: dict[str, Any]) -> bool:
|
||||
"""Validates the index.
|
||||
@@ -610,43 +783,6 @@ class OpenSearchClient:
|
||||
)
|
||||
return DocumentChunk.model_validate(document_chunk_source)
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def create_search_pipeline(
|
||||
self,
|
||||
pipeline_id: str,
|
||||
pipeline_body: dict[str, Any],
|
||||
) -> None:
|
||||
"""Creates a search pipeline.
|
||||
|
||||
See the OpenSearch documentation for more information on the search
|
||||
pipeline body.
|
||||
https://docs.opensearch.org/latest/search-plugins/search-pipelines/index/
|
||||
|
||||
Args:
|
||||
pipeline_id: The ID of the search pipeline to create.
|
||||
pipeline_body: The body of the search pipeline to create.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error creating the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
|
||||
if not result.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to create search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def delete_search_pipeline(self, pipeline_id: str) -> None:
|
||||
"""Deletes a search pipeline.
|
||||
|
||||
Args:
|
||||
pipeline_id: The ID of the search pipeline to delete.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error deleting the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.delete(id=pipeline_id)
|
||||
if not result.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to delete search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def search(
|
||||
self, body: dict[str, Any], search_pipeline_id: str | None
|
||||
@@ -807,48 +943,6 @@ class OpenSearchClient:
|
||||
"""
|
||||
self._client.indices.refresh(index=self._index_name)
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def put_cluster_settings(self, settings: dict[str, Any]) -> bool:
|
||||
"""Puts cluster settings.
|
||||
|
||||
Args:
|
||||
settings: The settings to put.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error putting the cluster settings.
|
||||
|
||||
Returns:
|
||||
True if the settings were put successfully, False otherwise.
|
||||
"""
|
||||
response = self._client.cluster.put_settings(body=settings)
|
||||
if response.get("acknowledged", False):
|
||||
logger.info("Successfully put cluster settings.")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to put cluster settings: {response}.")
|
||||
return False
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def ping(self) -> bool:
|
||||
"""Pings the OpenSearch cluster.
|
||||
|
||||
Returns:
|
||||
True if OpenSearch could be reached, False if it could not.
|
||||
"""
|
||||
return self._client.ping()
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def close(self) -> None:
|
||||
"""Closes the client.
|
||||
|
||||
TODO(andrei): Can we have some way to auto close when the client no
|
||||
longer has any references?
|
||||
|
||||
Raises:
|
||||
Exception: There was an error closing the client.
|
||||
"""
|
||||
self._client.close()
|
||||
|
||||
def _get_hits_and_profile_from_search_result(
|
||||
self, result: dict[str, Any]
|
||||
) -> tuple[list[Any], int | None, bool | None, dict[str, Any], dict[str, Any]]:
|
||||
@@ -945,14 +1039,7 @@ def wait_for_opensearch_with_timeout(
|
||||
Returns:
|
||||
True if OpenSearch is ready, False otherwise.
|
||||
"""
|
||||
made_client = False
|
||||
try:
|
||||
if client is None:
|
||||
# NOTE: index_name does not matter because we are only using this object
|
||||
# to ping.
|
||||
# TODO(andrei): Make this better.
|
||||
client = OpenSearchClient(index_name="")
|
||||
made_client = True
|
||||
with nullcontext(client) if client else OpenSearchClient() as client:
|
||||
time_start = time.monotonic()
|
||||
while True:
|
||||
if client.ping():
|
||||
@@ -969,7 +1056,3 @@ def wait_for_opensearch_with_timeout(
|
||||
f"[OpenSearch] Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={wait_limit_s:.1f}"
|
||||
)
|
||||
time.sleep(wait_interval_s)
|
||||
finally:
|
||||
if made_client:
|
||||
assert client is not None
|
||||
client.close()
|
||||
|
||||
@@ -7,6 +7,7 @@ from opensearchpy import NotFoundError
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH
|
||||
from onyx.configs.app_configs import VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT
|
||||
from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
from onyx.configs.constants import PUBLIC_DOC_PAT
|
||||
@@ -40,6 +41,7 @@ from onyx.document_index.interfaces_new import IndexingMetadata
|
||||
from onyx.document_index.interfaces_new import MetadataUpdateRequest
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.client import OpenSearchIndexClient
|
||||
from onyx.document_index.opensearch.client import SearchHit
|
||||
from onyx.document_index.opensearch.cluster_settings import OPENSEARCH_CLUSTER_SETTINGS
|
||||
from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
|
||||
@@ -93,6 +95,25 @@ def generate_opensearch_filtered_access_control_list(
|
||||
return list(access_control_list)
|
||||
|
||||
|
||||
def set_cluster_state(client: OpenSearchClient) -> None:
|
||||
if not client.put_cluster_settings(settings=OPENSEARCH_CLUSTER_SETTINGS):
|
||||
logger.error(
|
||||
"Failed to put cluster settings. If the settings have never been set before, "
|
||||
"this may cause unexpected index creation when indexing documents into an "
|
||||
"index that does not exist, or may cause expected logs to not appear. If this "
|
||||
"is not the first time running Onyx against this instance of OpenSearch, these "
|
||||
"settings have likely already been set. Not taking any further action..."
|
||||
)
|
||||
client.create_search_pipeline(
|
||||
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
|
||||
pipeline_body=MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
|
||||
)
|
||||
client.create_search_pipeline(
|
||||
pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
|
||||
pipeline_body=ZSCORE_NORMALIZATION_PIPELINE_CONFIG,
|
||||
)
|
||||
|
||||
|
||||
def _convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
chunk: DocumentChunk,
|
||||
score: float | None,
|
||||
@@ -248,6 +269,8 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
embedding_dim: int,
|
||||
embedding_precision: EmbeddingPrecision,
|
||||
secondary_index_name: str | None,
|
||||
large_chunks_enabled: bool, # noqa: ARG002
|
||||
secondary_large_chunks_enabled: bool | None, # noqa: ARG002
|
||||
@@ -258,10 +281,6 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
index_name=index_name,
|
||||
secondary_index_name=secondary_index_name,
|
||||
)
|
||||
if multitenant:
|
||||
raise ValueError(
|
||||
"Bug: OpenSearch is not yet ready for multitenant environments but something tried to use it."
|
||||
)
|
||||
if multitenant != MULTI_TENANT:
|
||||
raise ValueError(
|
||||
"Bug: Multitenant mismatch when initializing an OpenSearchDocumentIndex. "
|
||||
@@ -269,8 +288,10 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
)
|
||||
tenant_id = get_current_tenant_id()
|
||||
self._real_index = OpenSearchDocumentIndex(
|
||||
index_name=index_name,
|
||||
tenant_state=TenantState(tenant_id=tenant_id, multitenant=multitenant),
|
||||
index_name=index_name,
|
||||
embedding_dim=embedding_dim,
|
||||
embedding_precision=embedding_precision,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -279,9 +300,8 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
embedding_dims: list[int],
|
||||
embedding_precisions: list[EmbeddingPrecision],
|
||||
) -> None:
|
||||
# TODO(andrei): Implement.
|
||||
raise NotImplementedError(
|
||||
"Multitenant index registration is not yet implemented for OpenSearch."
|
||||
"Bug: Multitenant index registration is not supported for OpenSearch."
|
||||
)
|
||||
|
||||
def ensure_indices_exist(
|
||||
@@ -471,19 +491,37 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
for an OpenSearch search engine instance. It handles the complete lifecycle
|
||||
of document chunks within a specific OpenSearch index/schema.
|
||||
|
||||
Although not yet used in this way in the codebase, each kind of embedding
|
||||
used should correspond to a different instance of this class, and therefore
|
||||
a different index in OpenSearch.
|
||||
Each kind of embedding used should correspond to a different instance of
|
||||
this class, and therefore a different index in OpenSearch.
|
||||
|
||||
If in a multitenant environment and
|
||||
VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT, will verify and create the index
|
||||
if necessary on initialization. This is because there is no logic which runs
|
||||
on cluster restart which scans through all search settings over all tenants
|
||||
and creates the relevant indices.
|
||||
|
||||
Args:
|
||||
tenant_state: The tenant state of the caller.
|
||||
index_name: The name of the index to interact with.
|
||||
embedding_dim: The dimensionality of the embeddings used for the index.
|
||||
embedding_precision: The precision of the embeddings used for the index.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
tenant_state: TenantState,
|
||||
index_name: str,
|
||||
embedding_dim: int,
|
||||
embedding_precision: EmbeddingPrecision,
|
||||
) -> None:
|
||||
self._index_name: str = index_name
|
||||
self._tenant_state: TenantState = tenant_state
|
||||
self._os_client = OpenSearchClient(index_name=self._index_name)
|
||||
self._client = OpenSearchIndexClient(index_name=self._index_name)
|
||||
|
||||
if self._tenant_state.multitenant and VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT:
|
||||
self.verify_and_create_index_if_necessary(
|
||||
embedding_dim=embedding_dim, embedding_precision=embedding_precision
|
||||
)
|
||||
|
||||
def verify_and_create_index_if_necessary(
|
||||
self,
|
||||
@@ -492,10 +530,15 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
) -> None:
|
||||
"""Verifies and creates the index if necessary.
|
||||
|
||||
Also puts the desired cluster settings.
|
||||
Also puts the desired cluster settings if not in a multitenant
|
||||
environment.
|
||||
|
||||
Also puts the desired search pipeline state, creating the pipelines if
|
||||
they do not exist and updating them otherwise.
|
||||
Also puts the desired search pipeline state if not in a multitenant
|
||||
environment, creating the pipelines if they do not exist and updating
|
||||
them otherwise.
|
||||
|
||||
In a multitenant environment, the above steps happen explicitly on
|
||||
setup.
|
||||
|
||||
Args:
|
||||
embedding_dim: Vector dimensionality for the vector similarity part
|
||||
@@ -508,47 +551,38 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
search pipelines.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Verifying and creating index {self._index_name} if necessary, "
|
||||
f"with embedding dimension {embedding_dim}."
|
||||
f"[OpenSearchDocumentIndex] Verifying and creating index {self._index_name} if "
|
||||
f"necessary, with embedding dimension {embedding_dim}."
|
||||
)
|
||||
|
||||
if not self._tenant_state.multitenant:
|
||||
set_cluster_state(self._client)
|
||||
|
||||
expected_mappings = DocumentSchema.get_document_schema(
|
||||
embedding_dim, self._tenant_state.multitenant
|
||||
)
|
||||
if not self._os_client.put_cluster_settings(
|
||||
settings=OPENSEARCH_CLUSTER_SETTINGS
|
||||
):
|
||||
logger.error(
|
||||
f"Failed to put cluster settings for index {self._index_name}. If the settings have never been set before this "
|
||||
"may cause unexpected index creation when indexing documents into an index that does not exist, or may cause "
|
||||
"expected logs to not appear. If this is not the first time running Onyx against this instance of OpenSearch, "
|
||||
"these settings have likely already been set. Not taking any further action..."
|
||||
)
|
||||
if not self._os_client.index_exists():
|
||||
|
||||
if not self._client.index_exists():
|
||||
if USING_AWS_MANAGED_OPENSEARCH:
|
||||
index_settings = (
|
||||
DocumentSchema.get_index_settings_for_aws_managed_opensearch()
|
||||
)
|
||||
else:
|
||||
index_settings = DocumentSchema.get_index_settings()
|
||||
self._os_client.create_index(
|
||||
self._client.create_index(
|
||||
mappings=expected_mappings,
|
||||
settings=index_settings,
|
||||
)
|
||||
if not self._os_client.validate_index(
|
||||
expected_mappings=expected_mappings,
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"The index {self._index_name} is not valid. The expected mappings do not match the actual mappings."
|
||||
)
|
||||
|
||||
self._os_client.create_search_pipeline(
|
||||
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
|
||||
pipeline_body=MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
|
||||
)
|
||||
self._os_client.create_search_pipeline(
|
||||
pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
|
||||
pipeline_body=ZSCORE_NORMALIZATION_PIPELINE_CONFIG,
|
||||
)
|
||||
else:
|
||||
# Ensure schema is up to date by applying the current mappings.
|
||||
try:
|
||||
self._client.put_mapping(expected_mappings)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to update mappings for index {self._index_name}. This likely means a "
|
||||
f"field type was changed which requires reindexing. Error: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
def index(
|
||||
self,
|
||||
@@ -620,7 +654,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
)
|
||||
# Now index. This will raise if a chunk of the same ID exists, which
|
||||
# we do not expect because we should have deleted all chunks.
|
||||
self._os_client.bulk_index_documents(
|
||||
self._client.bulk_index_documents(
|
||||
documents=chunk_batch,
|
||||
tenant_state=self._tenant_state,
|
||||
)
|
||||
@@ -660,7 +694,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
tenant_state=self._tenant_state,
|
||||
)
|
||||
|
||||
return self._os_client.delete_by_query(query_body)
|
||||
return self._client.delete_by_query(query_body)
|
||||
|
||||
def update(
|
||||
self,
|
||||
@@ -760,7 +794,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
document_id=doc_id,
|
||||
chunk_index=chunk_index,
|
||||
)
|
||||
self._os_client.update_document(
|
||||
self._client.update_document(
|
||||
document_chunk_id=document_chunk_id,
|
||||
properties_to_update=properties_to_update,
|
||||
)
|
||||
@@ -799,7 +833,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
min_chunk_index=chunk_request.min_chunk_ind,
|
||||
max_chunk_index=chunk_request.max_chunk_ind,
|
||||
)
|
||||
search_hits = self._os_client.search(
|
||||
search_hits = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
)
|
||||
@@ -849,7 +883,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
# NOTE: Using z-score normalization here because it's better for hybrid search from a theoretical standpoint.
|
||||
# Empirically on a small dataset of up to 10K docs, it's not very different. Likely more impactful at scale.
|
||||
# https://opensearch.org/blog/introducing-the-z-score-normalization-technique-for-hybrid-search/
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._os_client.search(
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
|
||||
)
|
||||
@@ -881,7 +915,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
index_filters=filters,
|
||||
num_to_retrieve=num_to_retrieve,
|
||||
)
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._os_client.search(
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
)
|
||||
@@ -909,6 +943,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
# Do not raise if the document already exists, just update. This is
|
||||
# because the document may already have been indexed during the
|
||||
# OpenSearch transition period.
|
||||
self._os_client.bulk_index_documents(
|
||||
self._client.bulk_index_documents(
|
||||
documents=chunks, tenant_state=self._tenant_state, update_if_exists=True
|
||||
)
|
||||
|
||||
@@ -405,6 +405,7 @@ class PersonaShareRequest(BaseModel):
|
||||
user_ids: list[UUID] | None = None
|
||||
group_ids: list[int] | None = None
|
||||
is_public: bool | None = None
|
||||
label_ids: list[int] | None = None
|
||||
|
||||
|
||||
# We notify each user when a user is shared with them
|
||||
@@ -415,14 +416,22 @@ def share_persona(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
update_persona_shared(
|
||||
persona_id=persona_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
user_ids=persona_share_request.user_ids,
|
||||
group_ids=persona_share_request.group_ids,
|
||||
is_public=persona_share_request.is_public,
|
||||
)
|
||||
try:
|
||||
update_persona_shared(
|
||||
persona_id=persona_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
user_ids=persona_share_request.user_ids,
|
||||
group_ids=persona_share_request.group_ids,
|
||||
is_public=persona_share_request.is_public,
|
||||
label_ids=persona_share_request.label_ids,
|
||||
)
|
||||
except PermissionError as e:
|
||||
logger.exception("Failed to share persona")
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to share persona")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@basic_router.delete("/{persona_id}", tags=PUBLIC_API_TAGS)
|
||||
|
||||
@@ -97,7 +97,6 @@ def _build_llm_provider_request(
|
||||
), # Only this from source
|
||||
api_base=api_base, # From request
|
||||
api_version=api_version, # From request
|
||||
default_model_name=model_name,
|
||||
deployment_name=deployment_name, # From request
|
||||
is_public=True,
|
||||
groups=[],
|
||||
@@ -136,7 +135,6 @@ def _build_llm_provider_request(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
default_model_name=model_name,
|
||||
deployment_name=deployment_name,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
@@ -168,7 +166,6 @@ def _create_image_gen_llm_provider__no_commit(
|
||||
api_key=provider_request.api_key,
|
||||
api_base=provider_request.api_base,
|
||||
api_version=provider_request.api_version,
|
||||
default_model_name=provider_request.default_model_name,
|
||||
deployment_name=provider_request.deployment_name,
|
||||
is_public=provider_request.is_public,
|
||||
custom_config=provider_request.custom_config,
|
||||
|
||||
@@ -22,7 +22,10 @@ from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.llm import can_user_access_llm_provider
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.db.llm import fetch_default_vision_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import fetch_existing_llm_provider_by_id
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_existing_models
|
||||
from onyx.db.llm import fetch_persona_with_groups
|
||||
@@ -52,11 +55,12 @@ from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
)
|
||||
from onyx.server.manage.llm.models import BedrockFinalModelResponse
|
||||
from onyx.server.manage.llm.models import BedrockModelsRequest
|
||||
from onyx.server.manage.llm.models import DefaultModel
|
||||
from onyx.server.manage.llm.models import LLMCost
|
||||
from onyx.server.manage.llm.models import LLMProviderDescriptor
|
||||
from onyx.server.manage.llm.models import LLMProviderResponse
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.manage.llm.models import OllamaFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OllamaModelDetails
|
||||
from onyx.server.manage.llm.models import OllamaModelsRequest
|
||||
@@ -233,12 +237,9 @@ def test_llm_configuration(
|
||||
|
||||
test_api_key = test_llm_request.api_key
|
||||
test_custom_config = test_llm_request.custom_config
|
||||
if test_llm_request.name:
|
||||
# NOTE: we are querying by name. we probably should be querying by an invariant id, but
|
||||
# as it turns out the name is not editable in the UI and other code also keys off name,
|
||||
# so we won't rock the boat just yet.
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=test_llm_request.name, db_session=db_session
|
||||
if test_llm_request.id:
|
||||
existing_provider = fetch_existing_llm_provider_by_id(
|
||||
id=test_llm_request.id, db_session=db_session
|
||||
)
|
||||
if existing_provider:
|
||||
test_custom_config = _restore_masked_custom_config_values(
|
||||
@@ -268,7 +269,7 @@ def test_llm_configuration(
|
||||
|
||||
llm = get_llm(
|
||||
provider=test_llm_request.provider,
|
||||
model=test_llm_request.default_model_name,
|
||||
model=test_llm_request.model,
|
||||
api_key=test_api_key,
|
||||
api_base=test_llm_request.api_base,
|
||||
api_version=test_llm_request.api_version,
|
||||
@@ -303,7 +304,7 @@ def list_llm_providers(
|
||||
include_image_gen: bool = Query(False),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LLMProviderView]:
|
||||
) -> LLMProviderResponse[LLMProviderView]:
|
||||
start_time = datetime.now(timezone.utc)
|
||||
logger.debug("Starting to fetch LLM providers")
|
||||
|
||||
@@ -328,7 +329,15 @@ def list_llm_providers(
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
logger.debug(f"Completed fetching LLM providers in {duration:.2f} seconds")
|
||||
|
||||
return llm_provider_list
|
||||
return LLMProviderResponse[LLMProviderView].from_models(
|
||||
providers=llm_provider_list,
|
||||
default_text=DefaultModel.from_model_config(
|
||||
fetch_default_llm_model(db_session)
|
||||
),
|
||||
default_vision=DefaultModel.from_model_config(
|
||||
fetch_default_vision_model(db_session)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@admin_router.put("/provider")
|
||||
@@ -344,18 +353,44 @@ def put_llm_provider(
|
||||
# validate request (e.g. if we're intending to create but the name already exists we should throw an error)
|
||||
# NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache
|
||||
# the result
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
existing_provider = None
|
||||
if llm_provider_upsert_request.id:
|
||||
existing_provider = fetch_existing_llm_provider_by_id(
|
||||
id=llm_provider_upsert_request.id, db_session=db_session
|
||||
)
|
||||
|
||||
# Check name constraints
|
||||
# TODO: Once port from name to id is complete, unique name will no longer be required
|
||||
if existing_provider and llm_provider_upsert_request.name != existing_provider.name:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Renaming providers is not currently supported",
|
||||
)
|
||||
|
||||
found_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
)
|
||||
if found_provider is not None and found_provider is not existing_provider:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Provider with name={llm_provider_upsert_request.name} already exists",
|
||||
)
|
||||
|
||||
if existing_provider and is_creation:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"LLM Provider with name {llm_provider_upsert_request.name} already exists",
|
||||
detail=(
|
||||
f"LLM Provider with name {llm_provider_upsert_request.name} and "
|
||||
f"id={llm_provider_upsert_request.id} already exists"
|
||||
),
|
||||
)
|
||||
elif not existing_provider and not is_creation:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"LLM Provider with name {llm_provider_upsert_request.name} does not exist",
|
||||
detail=(
|
||||
f"LLM Provider with name {llm_provider_upsert_request.name} and "
|
||||
f"id={llm_provider_upsert_request.id} does not exist"
|
||||
),
|
||||
)
|
||||
|
||||
# SSRF Protection: Validate api_base and custom_config match stored values
|
||||
@@ -393,22 +428,6 @@ def put_llm_provider(
|
||||
deduplicated_personas.append(persona_id)
|
||||
llm_provider_upsert_request.personas = deduplicated_personas
|
||||
|
||||
default_model_found = False
|
||||
|
||||
for model_configuration in llm_provider_upsert_request.model_configurations:
|
||||
if model_configuration.name == llm_provider_upsert_request.default_model_name:
|
||||
model_configuration.is_visible = True
|
||||
default_model_found = True
|
||||
|
||||
# TODO: Remove this logic on api change
|
||||
# Believed to be a dead pathway but we want to be safe for now
|
||||
if not default_model_found:
|
||||
llm_provider_upsert_request.model_configurations.append(
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=llm_provider_upsert_request.default_model_name, is_visible=True
|
||||
)
|
||||
)
|
||||
|
||||
# the llm api key is sanitized when returned to clients, so the only time we
|
||||
# should get a real key is when it is explicitly changed
|
||||
if existing_provider and not llm_provider_upsert_request.api_key_changed:
|
||||
@@ -438,8 +457,8 @@ def put_llm_provider(
|
||||
config = fetch_llm_recommendations_from_github()
|
||||
if config and llm_provider_upsert_request.provider in config.providers:
|
||||
# Refetch the provider to get the updated model
|
||||
updated_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
updated_provider = fetch_existing_llm_provider_by_id(
|
||||
id=result.id, db_session=db_session
|
||||
)
|
||||
if updated_provider:
|
||||
sync_auto_mode_models(
|
||||
@@ -469,28 +488,29 @@ def delete_llm_provider(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.post("/provider/{provider_id}/default")
|
||||
@admin_router.post("/default")
|
||||
def set_provider_as_default(
|
||||
provider_id: int,
|
||||
default_model_request: DefaultModel,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
update_default_provider(provider_id=provider_id, db_session=db_session)
|
||||
update_default_provider(
|
||||
provider_id=default_model_request.provider_id,
|
||||
model_name=default_model_request.model_name,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
@admin_router.post("/provider/{provider_id}/default-vision")
|
||||
@admin_router.post("/default-vision")
|
||||
def set_provider_as_default_vision(
|
||||
provider_id: int,
|
||||
vision_model: str | None = Query(
|
||||
None, description="The default vision model to use"
|
||||
),
|
||||
default_model: DefaultModel,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
if vision_model is None:
|
||||
raise HTTPException(status_code=404, detail="Vision model not provided")
|
||||
update_default_vision_provider(
|
||||
provider_id=provider_id, vision_model=vision_model, db_session=db_session
|
||||
provider_id=default_model.provider_id,
|
||||
vision_model=default_model.model_name,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
@@ -516,7 +536,7 @@ def get_auto_config(
|
||||
def get_vision_capable_providers(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[VisionProviderResponse]:
|
||||
) -> LLMProviderResponse[VisionProviderResponse]:
|
||||
"""Return a list of LLM providers and their models that support image input"""
|
||||
vision_models = fetch_existing_models(
|
||||
db_session=db_session, flow_types=[LLMModelFlowType.VISION]
|
||||
@@ -545,7 +565,13 @@ def get_vision_capable_providers(
|
||||
]
|
||||
|
||||
logger.debug(f"Found {len(vision_provider_response)} vision-capable providers")
|
||||
return vision_provider_response
|
||||
|
||||
return LLMProviderResponse[VisionProviderResponse].from_models(
|
||||
providers=vision_provider_response,
|
||||
default_vision=DefaultModel.from_model_config(
|
||||
fetch_default_vision_model(db_session)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
"""Endpoints for all"""
|
||||
@@ -555,7 +581,7 @@ def get_vision_capable_providers(
|
||||
def list_llm_provider_basics(
|
||||
user: User = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LLMProviderDescriptor]:
|
||||
) -> LLMProviderResponse[LLMProviderDescriptor]:
|
||||
"""Get LLM providers accessible to the current user.
|
||||
|
||||
Returns:
|
||||
@@ -592,7 +618,15 @@ def list_llm_provider_basics(
|
||||
f"Completed fetching {len(accessible_providers)} user-accessible providers in {duration:.2f} seconds"
|
||||
)
|
||||
|
||||
return accessible_providers
|
||||
return LLMProviderResponse[LLMProviderDescriptor].from_models(
|
||||
providers=accessible_providers,
|
||||
default_text=DefaultModel.from_model_config(
|
||||
fetch_default_llm_model(db_session)
|
||||
),
|
||||
default_vision=DefaultModel.from_model_config(
|
||||
fetch_default_vision_model(db_session)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_valid_model_names_for_persona(
|
||||
@@ -635,7 +669,7 @@ def list_llm_providers_for_persona(
|
||||
persona_id: int,
|
||||
user: User = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LLMProviderDescriptor]:
|
||||
) -> LLMProviderResponse[LLMProviderDescriptor]:
|
||||
"""Get LLM providers for a specific persona.
|
||||
|
||||
Returns providers that the user can access when using this persona:
|
||||
@@ -682,7 +716,51 @@ def list_llm_providers_for_persona(
|
||||
f"Completed fetching {len(llm_provider_list)} LLM providers for persona {persona_id} in {duration:.2f} seconds"
|
||||
)
|
||||
|
||||
return llm_provider_list
|
||||
# Get the default model and vision model for the persona
|
||||
# TODO: Port persona's over to use ID
|
||||
persona_default_provider = persona.llm_model_provider_override
|
||||
persona_default_model = persona.llm_model_version_override
|
||||
|
||||
default_text_model = fetch_default_llm_model(db_session)
|
||||
default_vision_model = fetch_default_vision_model(db_session)
|
||||
|
||||
# Build default_text and default_vision using persona overrides when available,
|
||||
# falling back to the global defaults.
|
||||
default_text = DefaultModel.from_model_config(default_text_model)
|
||||
default_vision = DefaultModel.from_model_config(default_vision_model)
|
||||
|
||||
if persona_default_provider:
|
||||
provider = fetch_existing_llm_provider(persona_default_provider, db_session)
|
||||
if provider and can_user_access_llm_provider(
|
||||
provider, user_group_ids, persona, is_admin=is_admin
|
||||
):
|
||||
if persona_default_model:
|
||||
# Persona specifies both provider and model — use them directly
|
||||
default_text = DefaultModel(
|
||||
provider_id=provider.id,
|
||||
model_name=persona_default_model,
|
||||
)
|
||||
else:
|
||||
# Persona specifies only the provider — pick a visible (public) model,
|
||||
# falling back to any model on this provider
|
||||
visible_model = next(
|
||||
(mc for mc in provider.model_configurations if mc.is_visible),
|
||||
None,
|
||||
)
|
||||
fallback_model = visible_model or next(
|
||||
iter(provider.model_configurations), None
|
||||
)
|
||||
if fallback_model:
|
||||
default_text = DefaultModel(
|
||||
provider_id=provider.id,
|
||||
model_name=fallback_model.name,
|
||||
)
|
||||
|
||||
return LLMProviderResponse[LLMProviderDescriptor].from_models(
|
||||
providers=llm_provider_list,
|
||||
default_text=default_text,
|
||||
default_vision=default_vision,
|
||||
)
|
||||
|
||||
|
||||
@admin_router.get("/provider-contextual-cost")
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
@@ -21,50 +25,22 @@ if TYPE_CHECKING:
|
||||
ModelConfiguration as ModelConfigurationModel,
|
||||
)
|
||||
|
||||
|
||||
# TODO: Clear this up on api refactor
|
||||
# There is still logic that requires sending each providers default model name
|
||||
# There is no logic that requires sending the providers default vision model name
|
||||
# We only send for the one that is actually the default
|
||||
def get_default_llm_model_name(llm_provider_model: "LLMProviderModel") -> str:
|
||||
"""Find the default conversation model name for a provider.
|
||||
|
||||
Returns the model name if found, otherwise returns empty string.
|
||||
"""
|
||||
for model_config in llm_provider_model.model_configurations:
|
||||
for flow in model_config.llm_model_flows:
|
||||
if flow.is_default and flow.llm_model_flow_type == LLMModelFlowType.CHAT:
|
||||
return model_config.name
|
||||
return ""
|
||||
|
||||
|
||||
def get_default_vision_model_name(llm_provider_model: "LLMProviderModel") -> str | None:
|
||||
"""Find the default vision model name for a provider.
|
||||
|
||||
Returns the model name if found, otherwise returns None.
|
||||
"""
|
||||
for model_config in llm_provider_model.model_configurations:
|
||||
for flow in model_config.llm_model_flows:
|
||||
if flow.is_default and flow.llm_model_flow_type == LLMModelFlowType.VISION:
|
||||
return model_config.name
|
||||
return None
|
||||
T = TypeVar("T", "LLMProviderDescriptor", "LLMProviderView", "VisionProviderResponse")
|
||||
|
||||
|
||||
class TestLLMRequest(BaseModel):
|
||||
# provider level
|
||||
name: str | None = None
|
||||
id: int | None = None
|
||||
provider: str
|
||||
model: str
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
custom_config: dict[str, str] | None = None
|
||||
|
||||
# model level
|
||||
default_model_name: str
|
||||
deployment_name: str | None = None
|
||||
|
||||
model_configurations: list["ModelConfigurationUpsertRequest"]
|
||||
|
||||
# if try and use the existing API/custom config key
|
||||
api_key_changed: bool
|
||||
custom_config_changed: bool
|
||||
@@ -80,13 +56,10 @@ class LLMProviderDescriptor(BaseModel):
|
||||
"""A descriptor for an LLM provider that can be safely viewed by
|
||||
non-admin users. Used when giving a list of available LLMs."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
provider: str
|
||||
provider_display_name: str # Human-friendly name like "Claude (Anthropic)"
|
||||
default_model_name: str
|
||||
is_default_provider: bool | None
|
||||
is_default_vision_provider: bool | None
|
||||
default_vision_model: str | None
|
||||
model_configurations: list["ModelConfigurationView"]
|
||||
|
||||
@classmethod
|
||||
@@ -99,22 +72,12 @@ class LLMProviderDescriptor(BaseModel):
|
||||
)
|
||||
|
||||
provider = llm_provider_model.provider
|
||||
default_model_name = get_default_llm_model_name(llm_provider_model)
|
||||
default_vision_model = get_default_vision_model_name(llm_provider_model)
|
||||
|
||||
is_default_provider = bool(default_model_name)
|
||||
is_default_vision_provider = default_vision_model is not None
|
||||
|
||||
default_model_name = default_model_name or llm_provider_model.default_model_name
|
||||
|
||||
return cls(
|
||||
id=llm_provider_model.id,
|
||||
name=llm_provider_model.name,
|
||||
provider=provider,
|
||||
provider_display_name=get_provider_display_name(provider),
|
||||
default_model_name=default_model_name,
|
||||
is_default_provider=is_default_provider,
|
||||
is_default_vision_provider=is_default_vision_provider,
|
||||
default_vision_model=default_vision_model,
|
||||
model_configurations=filter_model_configurations(
|
||||
llm_provider_model.model_configurations, provider
|
||||
),
|
||||
@@ -128,18 +91,17 @@ class LLMProvider(BaseModel):
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
custom_config: dict[str, str] | None = None
|
||||
default_model_name: str
|
||||
is_public: bool = True
|
||||
is_auto_mode: bool = False
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
personas: list[int] = Field(default_factory=list)
|
||||
deployment_name: str | None = None
|
||||
default_vision_model: str | None = None
|
||||
|
||||
|
||||
class LLMProviderUpsertRequest(LLMProvider):
|
||||
# should only be used for a "custom" provider
|
||||
# for default providers, the built-in model names are used
|
||||
id: int | None = None
|
||||
api_key_changed: bool = False
|
||||
custom_config_changed: bool = False
|
||||
model_configurations: list["ModelConfigurationUpsertRequest"] = []
|
||||
@@ -155,8 +117,6 @@ class LLMProviderView(LLMProvider):
|
||||
"""Stripped down representation of LLMProvider for display / limited access info only"""
|
||||
|
||||
id: int
|
||||
is_default_provider: bool | None = None
|
||||
is_default_vision_provider: bool | None = None
|
||||
model_configurations: list["ModelConfigurationView"]
|
||||
|
||||
@classmethod
|
||||
@@ -178,14 +138,6 @@ class LLMProviderView(LLMProvider):
|
||||
|
||||
provider = llm_provider_model.provider
|
||||
|
||||
default_model_name = get_default_llm_model_name(llm_provider_model)
|
||||
default_vision_model = get_default_vision_model_name(llm_provider_model)
|
||||
|
||||
is_default_provider = bool(default_model_name)
|
||||
is_default_vision_provider = default_vision_model is not None
|
||||
|
||||
default_model_name = default_model_name or llm_provider_model.default_model_name
|
||||
|
||||
return cls(
|
||||
id=llm_provider_model.id,
|
||||
name=llm_provider_model.name,
|
||||
@@ -198,10 +150,6 @@ class LLMProviderView(LLMProvider):
|
||||
api_base=llm_provider_model.api_base,
|
||||
api_version=llm_provider_model.api_version,
|
||||
custom_config=llm_provider_model.custom_config,
|
||||
default_model_name=default_model_name,
|
||||
is_default_provider=is_default_provider,
|
||||
is_default_vision_provider=is_default_vision_provider,
|
||||
default_vision_model=default_vision_model,
|
||||
is_public=llm_provider_model.is_public,
|
||||
is_auto_mode=llm_provider_model.is_auto_mode,
|
||||
groups=groups,
|
||||
@@ -421,3 +369,38 @@ class OpenRouterFinalModelResponse(BaseModel):
|
||||
int | None
|
||||
) # From OpenRouter API context_length (may be missing for some models)
|
||||
supports_image_input: bool
|
||||
|
||||
|
||||
class DefaultModel(BaseModel):
|
||||
provider_id: int
|
||||
model_name: str
|
||||
|
||||
@classmethod
|
||||
def from_model_config(
|
||||
cls, model_config: ModelConfigurationModel | None
|
||||
) -> DefaultModel | None:
|
||||
if not model_config:
|
||||
return None
|
||||
return cls(
|
||||
provider_id=model_config.llm_provider_id,
|
||||
model_name=model_config.name,
|
||||
)
|
||||
|
||||
|
||||
class LLMProviderResponse(BaseModel, Generic[T]):
|
||||
providers: list[T]
|
||||
default_text: DefaultModel | None = None
|
||||
default_vision: DefaultModel | None = None
|
||||
|
||||
@classmethod
|
||||
def from_models(
|
||||
cls,
|
||||
providers: list[T],
|
||||
default_text: DefaultModel | None = None,
|
||||
default_vision: DefaultModel | None = None,
|
||||
) -> LLMProviderResponse[T]:
|
||||
return cls(
|
||||
providers=providers,
|
||||
default_text=default_text,
|
||||
default_vision=default_vision,
|
||||
)
|
||||
|
||||
27
backend/onyx/server/metrics/per_tenant.py
Normal file
27
backend/onyx/server/metrics/per_tenant.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Per-tenant request counter metric.
|
||||
|
||||
Increments a counter on every request, labelled by tenant, so Grafana can
|
||||
answer "which tenant is generating the most traffic?"
|
||||
"""
|
||||
|
||||
from prometheus_client import Counter
|
||||
from prometheus_fastapi_instrumentator.metrics import Info
|
||||
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
_requests_by_tenant = Counter(
|
||||
"onyx_api_requests_by_tenant_total",
|
||||
"Total API requests by tenant",
|
||||
["tenant_id", "method", "handler", "status"],
|
||||
)
|
||||
|
||||
|
||||
def per_tenant_request_callback(info: Info) -> None:
|
||||
"""Increment per-tenant request counter for every request."""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or "unknown"
|
||||
_requests_by_tenant.labels(
|
||||
tenant_id=tenant_id,
|
||||
method=info.method,
|
||||
handler=info.modified_handler,
|
||||
status=info.modified_status,
|
||||
).inc()
|
||||
@@ -32,6 +32,7 @@ from sqlalchemy.pool import QueuePool
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_ENDPOINT_CONTEXTVAR
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -72,7 +73,7 @@ _checkout_timeout_total = Counter(
|
||||
_connections_held = Gauge(
|
||||
"onyx_db_connections_held_by_endpoint",
|
||||
"Number of DB connections currently held, by endpoint and engine",
|
||||
["handler", "engine"],
|
||||
["handler", "engine", "tenant_id"],
|
||||
)
|
||||
|
||||
_hold_seconds = Histogram(
|
||||
@@ -163,10 +164,14 @@ def _register_pool_events(engine: Engine, label: str) -> None:
|
||||
conn_proxy: PoolProxiedConnection, # noqa: ARG001
|
||||
) -> None:
|
||||
handler = CURRENT_ENDPOINT_CONTEXTVAR.get() or "unknown"
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or "unknown"
|
||||
conn_record.info["_metrics_endpoint"] = handler
|
||||
conn_record.info["_metrics_tenant_id"] = tenant_id
|
||||
conn_record.info["_metrics_checkout_time"] = time.monotonic()
|
||||
_checkout_total.labels(engine=label).inc()
|
||||
_connections_held.labels(handler=handler, engine=label).inc()
|
||||
_connections_held.labels(
|
||||
handler=handler, engine=label, tenant_id=tenant_id
|
||||
).inc()
|
||||
|
||||
@event.listens_for(engine, "checkin")
|
||||
def on_checkin(
|
||||
@@ -174,9 +179,12 @@ def _register_pool_events(engine: Engine, label: str) -> None:
|
||||
conn_record: ConnectionPoolEntry,
|
||||
) -> None:
|
||||
handler = conn_record.info.pop("_metrics_endpoint", "unknown")
|
||||
tenant_id = conn_record.info.pop("_metrics_tenant_id", "unknown")
|
||||
start = conn_record.info.pop("_metrics_checkout_time", None)
|
||||
_checkin_total.labels(engine=label).inc()
|
||||
_connections_held.labels(handler=handler, engine=label).dec()
|
||||
_connections_held.labels(
|
||||
handler=handler, engine=label, tenant_id=tenant_id
|
||||
).dec()
|
||||
if start is not None:
|
||||
_hold_seconds.labels(handler=handler, engine=label).observe(
|
||||
time.monotonic() - start
|
||||
@@ -199,9 +207,12 @@ def _register_pool_events(engine: Engine, label: str) -> None:
|
||||
# Defensively clean up the held-connections gauge in case checkin
|
||||
# doesn't fire after invalidation (e.g. hard pool shutdown).
|
||||
handler = conn_record.info.pop("_metrics_endpoint", None)
|
||||
tenant_id = conn_record.info.pop("_metrics_tenant_id", "unknown")
|
||||
start = conn_record.info.pop("_metrics_checkout_time", None)
|
||||
if handler:
|
||||
_connections_held.labels(handler=handler, engine=label).dec()
|
||||
_connections_held.labels(
|
||||
handler=handler, engine=label, tenant_id=tenant_id
|
||||
).dec()
|
||||
if start is not None:
|
||||
_hold_seconds.labels(handler=handler or "unknown", engine=label).observe(
|
||||
time.monotonic() - start
|
||||
|
||||
@@ -11,9 +11,11 @@ SQLAlchemy connection pool metrics are registered separately via
|
||||
"""
|
||||
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from prometheus_fastapi_instrumentator.metrics import default as default_metrics
|
||||
from sqlalchemy.exc import TimeoutError as SATimeoutError
|
||||
from starlette.applications import Starlette
|
||||
|
||||
from onyx.server.metrics.per_tenant import per_tenant_request_callback
|
||||
from onyx.server.metrics.postgres_connection_pool import pool_timeout_handler
|
||||
from onyx.server.metrics.slow_requests import slow_request_callback
|
||||
|
||||
@@ -59,6 +61,15 @@ def setup_prometheus_metrics(app: Starlette) -> None:
|
||||
excluded_handlers=_EXCLUDED_HANDLERS,
|
||||
)
|
||||
|
||||
# Explicitly create the default metrics (http_requests_total,
|
||||
# http_request_duration_seconds, etc.) and add them first. The library
|
||||
# skips creating defaults when ANY custom instrumentations are registered
|
||||
# via .add(), so we must include them ourselves.
|
||||
default_callback = default_metrics(latency_lowr_buckets=_LATENCY_BUCKETS)
|
||||
if default_callback:
|
||||
instrumentator.add(default_callback)
|
||||
|
||||
instrumentator.add(slow_request_callback)
|
||||
instrumentator.add(per_tenant_request_callback)
|
||||
|
||||
instrumentator.instrument(app, latency_lowr_buckets=_LATENCY_BUCKETS).expose(app)
|
||||
|
||||
@@ -19,6 +19,7 @@ class ApplicationStatus(str, Enum):
|
||||
PAYMENT_REMINDER = "payment_reminder"
|
||||
GRACE_PERIOD = "grace_period"
|
||||
GATED_ACCESS = "gated_access"
|
||||
SEAT_LIMIT_EXCEEDED = "seat_limit_exceeded"
|
||||
|
||||
|
||||
class Notification(BaseModel):
|
||||
@@ -82,6 +83,10 @@ class Settings(BaseModel):
|
||||
# Default Assistant settings
|
||||
disable_default_assistant: bool | None = False
|
||||
|
||||
# Seat usage - populated by license enforcement when seat limit is exceeded
|
||||
seat_count: int | None = None
|
||||
used_seats: int | None = None
|
||||
|
||||
# OpenSearch migration
|
||||
opensearch_indexing_enabled: bool = False
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_NUM_ATTEMPTS_ON_STARTUP
|
||||
@@ -24,6 +25,7 @@ from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from onyx.db.index_attempt import expire_index_attempts
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
@@ -32,6 +34,9 @@ from onyx.db.search_settings import update_current_search_settings
|
||||
from onyx.db.swap_index import check_and_perform_index_swap
|
||||
from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
|
||||
from onyx.document_index.opensearch.opensearch_document_index import set_cluster_state
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
@@ -250,14 +255,18 @@ def setup_postgres(db_session: Session) -> None:
|
||||
logger.notice("Setting up default OpenAI LLM for dev.")
|
||||
|
||||
llm_model = GEN_AI_MODEL_VERSION or "gpt-4o-mini"
|
||||
provider_name = "DevEnvPresetOpenAI"
|
||||
existing = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
model_req = LLMProviderUpsertRequest(
|
||||
name="DevEnvPresetOpenAI",
|
||||
id=existing.id if existing else None,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=GEN_AI_API_KEY,
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
custom_config=None,
|
||||
default_model_name=llm_model,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
model_configurations=[
|
||||
@@ -269,7 +278,9 @@ def setup_postgres(db_session: Session) -> None:
|
||||
new_llm_provider = upsert_llm_provider(
|
||||
llm_provider_upsert_request=model_req, db_session=db_session
|
||||
)
|
||||
update_default_provider(provider_id=new_llm_provider.id, db_session=db_session)
|
||||
update_default_provider(
|
||||
provider_id=new_llm_provider.id, model_name=llm_model, db_session=db_session
|
||||
)
|
||||
|
||||
|
||||
def update_default_multipass_indexing(db_session: Session) -> None:
|
||||
@@ -311,7 +322,14 @@ def setup_multitenant_onyx() -> None:
|
||||
logger.notice("DISABLE_VECTOR_DB is set — skipping multitenant Vespa setup.")
|
||||
return
|
||||
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
opensearch_client = OpenSearchClient()
|
||||
if not wait_for_opensearch_with_timeout(client=opensearch_client):
|
||||
raise RuntimeError("Failed to connect to OpenSearch.")
|
||||
set_cluster_state(opensearch_client)
|
||||
|
||||
# For Managed Vespa, the schema is sent over via the Vespa Console manually.
|
||||
# NOTE: Pretty sure this code is never hit in any production environment.
|
||||
if not MANAGED_VESPA:
|
||||
setup_vespa_multitenant(SUPPORTED_EMBEDDING_MODELS)
|
||||
|
||||
|
||||
@@ -120,7 +120,7 @@ def generate_intermediate_report(
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=history,
|
||||
reminder_message=reminder_message,
|
||||
project_files=None,
|
||||
context_files=None,
|
||||
available_tokens=llm.config.max_input_tokens,
|
||||
)
|
||||
|
||||
@@ -325,7 +325,7 @@ def run_research_agent_call(
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=msg_history,
|
||||
reminder_message=reminder_message,
|
||||
project_files=None,
|
||||
context_files=None,
|
||||
available_tokens=llm.config.max_input_tokens,
|
||||
)
|
||||
|
||||
|
||||
@@ -257,7 +257,7 @@ exceptiongroup==1.3.0
|
||||
# via
|
||||
# braintrust
|
||||
# fastmcp
|
||||
fastapi==0.128.0
|
||||
fastapi==0.133.1
|
||||
# via
|
||||
# fastapi-limiter
|
||||
# fastapi-users
|
||||
@@ -809,7 +809,7 @@ pypandoc-binary==1.16.2
|
||||
# via onyx
|
||||
pyparsing==3.2.5
|
||||
# via httplib2
|
||||
pypdf==6.6.2
|
||||
pypdf==6.7.3
|
||||
# via
|
||||
# onyx
|
||||
# unstructured-client
|
||||
@@ -1155,6 +1155,7 @@ typing-inspect==0.9.0
|
||||
# via dataclasses-json
|
||||
typing-inspection==0.4.2
|
||||
# via
|
||||
# fastapi
|
||||
# mcp
|
||||
# pydantic
|
||||
# pydantic-settings
|
||||
@@ -1216,7 +1217,7 @@ websockets==15.0.1
|
||||
# via
|
||||
# fastmcp
|
||||
# google-genai
|
||||
werkzeug==3.1.5
|
||||
werkzeug==3.1.6
|
||||
# via sendgrid
|
||||
wrapt==1.17.3
|
||||
# via
|
||||
|
||||
@@ -125,7 +125,7 @@ executing==2.2.1
|
||||
# via stack-data
|
||||
faker==40.1.2
|
||||
# via onyx
|
||||
fastapi==0.128.0
|
||||
fastapi==0.133.1
|
||||
# via
|
||||
# onyx
|
||||
# onyx-devtools
|
||||
@@ -619,6 +619,7 @@ typing-extensions==4.15.0
|
||||
# typing-inspection
|
||||
typing-inspection==0.4.2
|
||||
# via
|
||||
# fastapi
|
||||
# mcp
|
||||
# pydantic
|
||||
# pydantic-settings
|
||||
|
||||
@@ -90,7 +90,7 @@ docstring-parser==0.17.0
|
||||
# via google-cloud-aiplatform
|
||||
durationpy==0.10
|
||||
# via kubernetes
|
||||
fastapi==0.128.0
|
||||
fastapi==0.133.1
|
||||
# via onyx
|
||||
fastavro==1.12.1
|
||||
# via cohere
|
||||
@@ -398,6 +398,7 @@ typing-extensions==4.15.0
|
||||
# typing-inspection
|
||||
typing-inspection==0.4.2
|
||||
# via
|
||||
# fastapi
|
||||
# mcp
|
||||
# pydantic
|
||||
# pydantic-settings
|
||||
|
||||
@@ -108,7 +108,7 @@ durationpy==0.10
|
||||
# via kubernetes
|
||||
einops==0.8.1
|
||||
# via onyx
|
||||
fastapi==0.128.0
|
||||
fastapi==0.133.1
|
||||
# via
|
||||
# onyx
|
||||
# sentry-sdk
|
||||
@@ -525,6 +525,7 @@ typing-extensions==4.15.0
|
||||
# typing-inspection
|
||||
typing-inspection==0.4.2
|
||||
# via
|
||||
# fastapi
|
||||
# mcp
|
||||
# pydantic
|
||||
# pydantic-settings
|
||||
|
||||
@@ -12,6 +12,7 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.sharepoint.connector import SharepointAuthMethod
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
from onyx.db.enums import HierarchyNodeType
|
||||
from tests.daily.connectors.utils import load_all_from_connector
|
||||
@@ -521,3 +522,46 @@ def test_sharepoint_connector_hierarchy_nodes(
|
||||
f"Document {doc.semantic_identifier} should have "
|
||||
"parent_hierarchy_raw_node_id set"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sharepoint_cert_credentials() -> dict[str, str]:
|
||||
return {
|
||||
"authentication_method": SharepointAuthMethod.CERTIFICATE.value,
|
||||
"sp_client_id": os.environ["PERM_SYNC_SHAREPOINT_CLIENT_ID"],
|
||||
"sp_private_key": os.environ["PERM_SYNC_SHAREPOINT_PRIVATE_KEY"],
|
||||
"sp_certificate_password": os.environ[
|
||||
"PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD"
|
||||
],
|
||||
"sp_directory_id": os.environ["PERM_SYNC_SHAREPOINT_DIRECTORY_ID"],
|
||||
}
|
||||
|
||||
|
||||
def test_resolve_tenant_domain_from_site_urls(
|
||||
sharepoint_cert_credentials: dict[str, str],
|
||||
) -> None:
|
||||
"""Verify that certificate auth resolves the tenant domain from site URLs
|
||||
without calling the /organization endpoint."""
|
||||
site_url = os.environ["SHAREPOINT_SITE"]
|
||||
connector = SharepointConnector(sites=[site_url])
|
||||
connector.load_credentials(sharepoint_cert_credentials)
|
||||
|
||||
assert connector.sp_tenant_domain is not None
|
||||
assert len(connector.sp_tenant_domain) > 0
|
||||
# The tenant domain should match the first label of the site URL hostname
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
expected = urlsplit(site_url).hostname.split(".")[0] # type: ignore
|
||||
assert connector.sp_tenant_domain == expected
|
||||
|
||||
|
||||
def test_resolve_tenant_domain_from_root_site(
|
||||
sharepoint_cert_credentials: dict[str, str],
|
||||
) -> None:
|
||||
"""Verify that certificate auth resolves the tenant domain via the root
|
||||
site endpoint when no site URLs are configured."""
|
||||
connector = SharepointConnector(sites=[])
|
||||
connector.load_credentials(sharepoint_cert_credentials)
|
||||
|
||||
assert connector.sp_tenant_domain is not None
|
||||
assert len(connector.sp_tenant_domain) > 0
|
||||
|
||||
@@ -17,7 +17,7 @@ def test_bedrock_llm_configuration(client: TestClient) -> None:
|
||||
# Prepare the test request payload
|
||||
test_request: dict[str, Any] = {
|
||||
"provider": LlmProviderNames.BEDROCK,
|
||||
"default_model_name": _DEFAULT_BEDROCK_MODEL,
|
||||
"model": _DEFAULT_BEDROCK_MODEL,
|
||||
"api_key": None,
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
@@ -44,7 +44,7 @@ def test_bedrock_llm_configuration_invalid_key(client: TestClient) -> None:
|
||||
# Prepare the test request payload with invalid credentials
|
||||
test_request: dict[str, Any] = {
|
||||
"provider": LlmProviderNames.BEDROCK,
|
||||
"default_model_name": _DEFAULT_BEDROCK_MODEL,
|
||||
"model": _DEFAULT_BEDROCK_MODEL,
|
||||
"api_key": None,
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
|
||||
@@ -28,7 +28,6 @@ def ensure_default_llm_provider(db_session: Session) -> None:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=os.environ.get("OPENAI_API_KEY", "test"),
|
||||
is_public=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini",
|
||||
@@ -41,7 +40,7 @@ def ensure_default_llm_provider(db_session: Session) -> None:
|
||||
llm_provider_upsert_request=llm_provider_request,
|
||||
db_session=db_session,
|
||||
)
|
||||
update_default_provider(provider.id, db_session)
|
||||
update_default_provider(provider.id, "gpt-4o-mini", db_session)
|
||||
except Exception as exc: # pragma: no cover - only hits on duplicate setup issues
|
||||
# Rollback to clear the pending transaction state
|
||||
db_session.rollback()
|
||||
|
||||
@@ -47,7 +47,6 @@ def test_answer_with_only_anthropic_provider(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.ANTHROPIC,
|
||||
api_key=anthropic_api_key,
|
||||
default_model_name=anthropic_model,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
model_configurations=[
|
||||
@@ -59,7 +58,7 @@ def test_answer_with_only_anthropic_provider(
|
||||
)
|
||||
|
||||
try:
|
||||
update_default_provider(anthropic_provider.id, db_session)
|
||||
update_default_provider(anthropic_provider.id, anthropic_model, db_session)
|
||||
|
||||
test_user = create_test_user(db_session, email_prefix="anthropic_only")
|
||||
chat_session = create_chat_session(
|
||||
|
||||
@@ -0,0 +1,544 @@
|
||||
"""
|
||||
External dependency unit tests for persona file sync.
|
||||
|
||||
Validates that:
|
||||
|
||||
1. The check_for_user_file_project_sync beat task picks up UserFiles with
|
||||
needs_persona_sync=True (not just needs_project_sync).
|
||||
|
||||
2. The process_single_user_file_project_sync worker task reads persona
|
||||
associations from the DB, passes persona_ids to the document index via
|
||||
VespaDocumentUserFields, and clears needs_persona_sync afterwards.
|
||||
|
||||
3. upsert_persona correctly marks affected UserFiles with
|
||||
needs_persona_sync=True when file associations change.
|
||||
|
||||
Uses real Redis and PostgreSQL. Document index (Vespa) calls are mocked
|
||||
since we only need to verify the arguments passed to update_single.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import PropertyMock
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
check_for_user_file_project_sync,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
process_single_user_file_project_sync,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
user_file_project_sync_lock_key,
|
||||
)
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserFile
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _create_completed_user_file(
|
||||
db_session: Session,
|
||||
user: User,
|
||||
needs_persona_sync: bool = False,
|
||||
needs_project_sync: bool = False,
|
||||
) -> UserFile:
|
||||
"""Insert a UserFile in COMPLETED status."""
|
||||
uf = UserFile(
|
||||
id=uuid4(),
|
||||
user_id=user.id,
|
||||
file_id=f"test_file_{uuid4().hex[:8]}",
|
||||
name=f"test_{uuid4().hex[:8]}.txt",
|
||||
file_type="text/plain",
|
||||
status=UserFileStatus.COMPLETED,
|
||||
needs_persona_sync=needs_persona_sync,
|
||||
needs_project_sync=needs_project_sync,
|
||||
chunk_count=5,
|
||||
)
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
db_session.refresh(uf)
|
||||
return uf
|
||||
|
||||
|
||||
def _create_test_persona(
|
||||
db_session: Session,
|
||||
user: User,
|
||||
user_files: list[UserFile] | None = None,
|
||||
) -> Persona:
|
||||
"""Create a minimal Persona via direct model insert."""
|
||||
persona = Persona(
|
||||
name=f"Test Persona {uuid4().hex[:8]}",
|
||||
description="Test persona",
|
||||
num_chunks=10.0,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
system_prompt="You are a test assistant",
|
||||
task_prompt="Answer the question",
|
||||
tools=[],
|
||||
document_sets=[],
|
||||
users=[user],
|
||||
groups=[],
|
||||
is_visible=True,
|
||||
is_public=True,
|
||||
display_priority=None,
|
||||
starter_messages=None,
|
||||
deleted=False,
|
||||
user_files=user_files or [],
|
||||
user_id=user.id,
|
||||
)
|
||||
db_session.add(persona)
|
||||
db_session.commit()
|
||||
db_session.refresh(persona)
|
||||
return persona
|
||||
|
||||
|
||||
def _link_file_to_persona(
|
||||
db_session: Session, persona: Persona, user_file: UserFile
|
||||
) -> None:
|
||||
"""Create the join table row between a persona and a user file."""
|
||||
link = Persona__UserFile(persona_id=persona.id, user_file_id=user_file.id)
|
||||
db_session.add(link)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
_PATCH_QUEUE_DEPTH = (
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks"
|
||||
".get_user_file_project_sync_queue_depth"
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, None]:
|
||||
"""Patch the ``app`` property on a bound Celery task."""
|
||||
task_instance = task.run.__self__
|
||||
with (
|
||||
patch.object(
|
||||
type(task_instance),
|
||||
"app",
|
||||
new_callable=PropertyMock,
|
||||
return_value=mock_app,
|
||||
),
|
||||
patch(_PATCH_QUEUE_DEPTH, return_value=0),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: check_for_user_file_project_sync picks up persona sync
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckSweepIncludesPersonaSync:
|
||||
"""The beat task must pick up files needing persona sync, not just project sync."""
|
||||
|
||||
def test_persona_sync_flag_enqueues_task(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file with needs_persona_sync=True (and COMPLETED) gets enqueued."""
|
||||
user = create_test_user(db_session, "persona_sweep")
|
||||
uf = _create_completed_user_file(db_session, user, needs_persona_sync=True)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
with _patch_task_app(check_for_user_file_project_sync, mock_app):
|
||||
check_for_user_file_project_sync.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
enqueued_ids = {
|
||||
call.kwargs["kwargs"]["user_file_id"]
|
||||
for call in mock_app.send_task.call_args_list
|
||||
}
|
||||
assert str(uf.id) in enqueued_ids
|
||||
|
||||
def test_neither_flag_does_not_enqueue(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file with both flags False is not enqueued."""
|
||||
user = create_test_user(db_session, "no_sync")
|
||||
uf = _create_completed_user_file(db_session, user)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
with _patch_task_app(check_for_user_file_project_sync, mock_app):
|
||||
check_for_user_file_project_sync.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
enqueued_ids = {
|
||||
call.kwargs["kwargs"]["user_file_id"]
|
||||
for call in mock_app.send_task.call_args_list
|
||||
}
|
||||
assert str(uf.id) not in enqueued_ids
|
||||
|
||||
def test_both_flags_enqueues_once(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file with BOTH flags True is enqueued exactly once."""
|
||||
user = create_test_user(db_session, "both_flags")
|
||||
uf = _create_completed_user_file(
|
||||
db_session, user, needs_persona_sync=True, needs_project_sync=True
|
||||
)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
with _patch_task_app(check_for_user_file_project_sync, mock_app):
|
||||
check_for_user_file_project_sync.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
matching_calls = [
|
||||
call
|
||||
for call in mock_app.send_task.call_args_list
|
||||
if call.kwargs["kwargs"]["user_file_id"] == str(uf.id)
|
||||
]
|
||||
assert len(matching_calls) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: process_single_user_file_project_sync passes persona_ids to index
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PATCH_GET_SETTINGS = (
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.get_active_search_settings"
|
||||
)
|
||||
_PATCH_GET_INDICES = (
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.get_all_document_indices"
|
||||
)
|
||||
_PATCH_HTTPX_INIT = (
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.httpx_init_vespa_pool"
|
||||
)
|
||||
_PATCH_DISABLE_VDB = (
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.DISABLE_VECTOR_DB"
|
||||
)
|
||||
|
||||
|
||||
class TestSyncTaskWritesPersonaIds:
|
||||
"""The sync task reads persona associations and sends them to the index."""
|
||||
|
||||
def test_passes_persona_ids_to_update_single(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""After linking a file to a persona, sync sends the persona ID."""
|
||||
user = create_test_user(db_session, "sync_persona")
|
||||
uf = _create_completed_user_file(db_session, user, needs_persona_sync=True)
|
||||
persona = _create_test_persona(db_session, user)
|
||||
_link_file_to_persona(db_session, persona, uf)
|
||||
|
||||
mock_doc_index = MagicMock()
|
||||
mock_search_settings = MagicMock()
|
||||
mock_search_settings.primary = MagicMock()
|
||||
mock_search_settings.secondary = None
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
lock_key = user_file_project_sync_lock_key(str(uf.id))
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
with (
|
||||
patch(_PATCH_DISABLE_VDB, False),
|
||||
patch(_PATCH_HTTPX_INIT),
|
||||
patch(_PATCH_GET_SETTINGS, return_value=mock_search_settings),
|
||||
patch(_PATCH_GET_INDICES, return_value=[mock_doc_index]),
|
||||
):
|
||||
process_single_user_file_project_sync.run(
|
||||
user_file_id=str(uf.id), tenant_id=TEST_TENANT_ID
|
||||
)
|
||||
|
||||
mock_doc_index.update_single.assert_called_once()
|
||||
call_args = mock_doc_index.update_single.call_args
|
||||
user_fields: VespaDocumentUserFields = call_args.kwargs["user_fields"]
|
||||
assert user_fields.personas is not None
|
||||
assert persona.id in user_fields.personas
|
||||
assert call_args.args[0] == str(uf.id)
|
||||
|
||||
def test_clears_persona_sync_flag(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""After a successful sync the needs_persona_sync flag is cleared."""
|
||||
user = create_test_user(db_session, "sync_clear")
|
||||
uf = _create_completed_user_file(db_session, user, needs_persona_sync=True)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
lock_key = user_file_project_sync_lock_key(str(uf.id))
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
with patch(_PATCH_DISABLE_VDB, True):
|
||||
process_single_user_file_project_sync.run(
|
||||
user_file_id=str(uf.id), tenant_id=TEST_TENANT_ID
|
||||
)
|
||||
|
||||
db_session.refresh(uf)
|
||||
assert uf.needs_persona_sync is False
|
||||
|
||||
def test_passes_both_project_and_persona_ids(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file linked to both a project and a persona gets both IDs."""
|
||||
from onyx.db.models import Project__UserFile
|
||||
from onyx.db.models import UserProject
|
||||
|
||||
user = create_test_user(db_session, "sync_both")
|
||||
uf = _create_completed_user_file(
|
||||
db_session, user, needs_persona_sync=True, needs_project_sync=True
|
||||
)
|
||||
persona = _create_test_persona(db_session, user)
|
||||
_link_file_to_persona(db_session, persona, uf)
|
||||
|
||||
project = UserProject(user_id=user.id, name="test-project", instructions="")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
db_session.refresh(project)
|
||||
|
||||
link = Project__UserFile(project_id=project.id, user_file_id=uf.id)
|
||||
db_session.add(link)
|
||||
db_session.commit()
|
||||
|
||||
mock_doc_index = MagicMock()
|
||||
mock_search_settings = MagicMock()
|
||||
mock_search_settings.primary = MagicMock()
|
||||
mock_search_settings.secondary = None
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
lock_key = user_file_project_sync_lock_key(str(uf.id))
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
with (
|
||||
patch(_PATCH_DISABLE_VDB, False),
|
||||
patch(_PATCH_HTTPX_INIT),
|
||||
patch(_PATCH_GET_SETTINGS, return_value=mock_search_settings),
|
||||
patch(_PATCH_GET_INDICES, return_value=[mock_doc_index]),
|
||||
):
|
||||
process_single_user_file_project_sync.run(
|
||||
user_file_id=str(uf.id), tenant_id=TEST_TENANT_ID
|
||||
)
|
||||
|
||||
call_kwargs = mock_doc_index.update_single.call_args.kwargs
|
||||
user_fields: VespaDocumentUserFields = call_kwargs["user_fields"]
|
||||
assert user_fields.personas is not None
|
||||
assert user_fields.user_projects is not None
|
||||
assert persona.id in user_fields.personas
|
||||
assert project.id in user_fields.user_projects
|
||||
|
||||
# Both flags should be cleared
|
||||
db_session.refresh(uf)
|
||||
assert uf.needs_persona_sync is False
|
||||
assert uf.needs_project_sync is False
|
||||
|
||||
def test_deleted_persona_excluded_from_ids(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A soft-deleted persona should NOT appear in the persona_ids sent to Vespa."""
|
||||
user = create_test_user(db_session, "sync_deleted")
|
||||
uf = _create_completed_user_file(db_session, user, needs_persona_sync=True)
|
||||
persona = _create_test_persona(db_session, user)
|
||||
_link_file_to_persona(db_session, persona, uf)
|
||||
|
||||
persona.deleted = True
|
||||
db_session.commit()
|
||||
|
||||
mock_doc_index = MagicMock()
|
||||
mock_search_settings = MagicMock()
|
||||
mock_search_settings.primary = MagicMock()
|
||||
mock_search_settings.secondary = None
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
lock_key = user_file_project_sync_lock_key(str(uf.id))
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
with (
|
||||
patch(_PATCH_DISABLE_VDB, False),
|
||||
patch(_PATCH_HTTPX_INIT),
|
||||
patch(_PATCH_GET_SETTINGS, return_value=mock_search_settings),
|
||||
patch(_PATCH_GET_INDICES, return_value=[mock_doc_index]),
|
||||
):
|
||||
process_single_user_file_project_sync.run(
|
||||
user_file_id=str(uf.id), tenant_id=TEST_TENANT_ID
|
||||
)
|
||||
|
||||
call_kwargs = mock_doc_index.update_single.call_args.kwargs
|
||||
user_fields: VespaDocumentUserFields = call_kwargs["user_fields"]
|
||||
assert user_fields.personas is not None
|
||||
assert persona.id not in user_fields.personas
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: upsert_persona marks files for persona sync
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUpsertPersonaMarksSyncFlag:
|
||||
"""upsert_persona must set needs_persona_sync on affected UserFiles."""
|
||||
|
||||
def test_creating_persona_with_files_marks_sync(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "upsert_create")
|
||||
uf = _create_completed_user_file(db_session, user)
|
||||
assert uf.needs_persona_sync is False
|
||||
|
||||
upsert_persona(
|
||||
user=user,
|
||||
name=f"persona-{uuid4().hex[:8]}",
|
||||
description="test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
system_prompt="test",
|
||||
task_prompt="test",
|
||||
datetime_aware=None,
|
||||
is_public=True,
|
||||
db_session=db_session,
|
||||
user_file_ids=[uf.id],
|
||||
)
|
||||
|
||||
db_session.refresh(uf)
|
||||
assert uf.needs_persona_sync is True
|
||||
|
||||
def test_updating_persona_files_marks_both_old_and_new(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""When file associations change, both the removed and added files are flagged."""
|
||||
user = create_test_user(db_session, "upsert_update")
|
||||
uf_old = _create_completed_user_file(db_session, user)
|
||||
uf_new = _create_completed_user_file(db_session, user)
|
||||
|
||||
persona = upsert_persona(
|
||||
user=user,
|
||||
name=f"persona-{uuid4().hex[:8]}",
|
||||
description="test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
system_prompt="test",
|
||||
task_prompt="test",
|
||||
datetime_aware=None,
|
||||
is_public=True,
|
||||
db_session=db_session,
|
||||
user_file_ids=[uf_old.id],
|
||||
)
|
||||
|
||||
# Clear the flag from creation so we can observe the update
|
||||
uf_old.needs_persona_sync = False
|
||||
db_session.commit()
|
||||
|
||||
assert persona.num_chunks is not None
|
||||
# Now update the persona to swap files
|
||||
upsert_persona(
|
||||
user=user,
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=persona.num_chunks,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=persona.recency_bias,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
system_prompt=persona.system_prompt,
|
||||
task_prompt=persona.task_prompt,
|
||||
datetime_aware=None,
|
||||
is_public=persona.is_public,
|
||||
db_session=db_session,
|
||||
persona_id=persona.id,
|
||||
user_file_ids=[uf_new.id],
|
||||
)
|
||||
|
||||
db_session.refresh(uf_old)
|
||||
db_session.refresh(uf_new)
|
||||
assert uf_old.needs_persona_sync is True, "Removed file should be flagged"
|
||||
assert uf_new.needs_persona_sync is True, "Added file should be flagged"
|
||||
|
||||
def test_removing_all_files_marks_old_files(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""Removing all files from a persona flags the previously associated files."""
|
||||
user = create_test_user(db_session, "upsert_remove")
|
||||
uf = _create_completed_user_file(db_session, user)
|
||||
|
||||
persona = upsert_persona(
|
||||
user=user,
|
||||
name=f"persona-{uuid4().hex[:8]}",
|
||||
description="test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
system_prompt="test",
|
||||
task_prompt="test",
|
||||
datetime_aware=None,
|
||||
is_public=True,
|
||||
db_session=db_session,
|
||||
user_file_ids=[uf.id],
|
||||
)
|
||||
|
||||
uf.needs_persona_sync = False
|
||||
db_session.commit()
|
||||
|
||||
assert persona.num_chunks is not None
|
||||
upsert_persona(
|
||||
user=user,
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=persona.num_chunks,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=persona.recency_bias,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
system_prompt=persona.system_prompt,
|
||||
task_prompt=persona.task_prompt,
|
||||
datetime_aware=None,
|
||||
is_public=persona.is_public,
|
||||
db_session=db_session,
|
||||
persona_id=persona.id,
|
||||
user_file_ids=[],
|
||||
)
|
||||
|
||||
db_session.refresh(uf)
|
||||
assert uf.needs_persona_sync is True
|
||||
@@ -0,0 +1,318 @@
|
||||
"""
|
||||
External dependency unit tests for UserFileIndexingAdapter metadata writing.
|
||||
|
||||
Validates that build_metadata_aware_chunks produces DocMetadataAwareIndexChunk
|
||||
objects with both `user_project` and `personas` fields populated correctly
|
||||
based on actual DB associations.
|
||||
|
||||
Uses real PostgreSQL for UserFile/Persona/UserProject rows.
|
||||
Mocks the LLM tokenizer and file store since they are not relevant here.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserFile
|
||||
from onyx.db.models import Project__UserFile
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.models import UserProject
|
||||
from onyx.indexing.adapters.user_file_indexing_adapter import UserFileIndexingAdapter
|
||||
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
from onyx.indexing.models import IndexChunk
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _create_user_file(db_session: Session, user: User) -> UserFile:
|
||||
uf = UserFile(
|
||||
id=uuid4(),
|
||||
user_id=user.id,
|
||||
file_id=f"test_file_{uuid4().hex[:8]}",
|
||||
name=f"test_{uuid4().hex[:8]}.txt",
|
||||
file_type="text/plain",
|
||||
status=UserFileStatus.COMPLETED,
|
||||
chunk_count=1,
|
||||
)
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
db_session.refresh(uf)
|
||||
return uf
|
||||
|
||||
|
||||
def _create_persona(db_session: Session, user: User) -> Persona:
|
||||
persona = Persona(
|
||||
name=f"Test Persona {uuid4().hex[:8]}",
|
||||
description="Test persona",
|
||||
num_chunks=10.0,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
system_prompt="test",
|
||||
task_prompt="test",
|
||||
tools=[],
|
||||
document_sets=[],
|
||||
users=[user],
|
||||
groups=[],
|
||||
is_visible=True,
|
||||
is_public=True,
|
||||
display_priority=None,
|
||||
starter_messages=None,
|
||||
deleted=False,
|
||||
user_id=user.id,
|
||||
)
|
||||
db_session.add(persona)
|
||||
db_session.commit()
|
||||
db_session.refresh(persona)
|
||||
return persona
|
||||
|
||||
|
||||
def _create_project(db_session: Session, user: User) -> UserProject:
|
||||
project = UserProject(
|
||||
user_id=user.id,
|
||||
name=f"project-{uuid4().hex[:8]}",
|
||||
instructions="",
|
||||
)
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
db_session.refresh(project)
|
||||
return project
|
||||
|
||||
|
||||
def _make_index_chunk(user_file: UserFile) -> IndexChunk:
|
||||
"""Build a minimal IndexChunk whose source document ID matches the UserFile."""
|
||||
doc = Document(
|
||||
id=str(user_file.id),
|
||||
source=DocumentSource.USER_FILE,
|
||||
semantic_identifier=user_file.name,
|
||||
sections=[TextSection(text="test chunk content", link=None)],
|
||||
metadata={},
|
||||
)
|
||||
return IndexChunk(
|
||||
source_document=doc,
|
||||
chunk_id=0,
|
||||
blurb="test chunk",
|
||||
content="test chunk content",
|
||||
source_links={0: ""},
|
||||
image_file_id=None,
|
||||
section_continuation=False,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
embeddings=ChunkEmbedding(
|
||||
full_embedding=[0.0] * 768,
|
||||
mini_chunk_embeddings=[],
|
||||
),
|
||||
title_embedding=None,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAdapterWritesBothMetadataFields:
|
||||
"""build_metadata_aware_chunks must populate user_project AND personas."""
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
|
||||
side_effect=Exception("no LLM in test"),
|
||||
)
|
||||
def test_file_linked_to_persona_gets_persona_id(
|
||||
self,
|
||||
_mock_llm: MagicMock,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "adapter_persona")
|
||||
uf = _create_user_file(db_session, user)
|
||||
persona = _create_persona(db_session, user)
|
||||
|
||||
db_session.add(Persona__UserFile(persona_id=persona.id, user_file_id=uf.id))
|
||||
db_session.commit()
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=TEST_TENANT_ID, db_session=db_session
|
||||
)
|
||||
chunk = _make_index_chunk(uf)
|
||||
doc = chunk.source_document
|
||||
context = DocumentBatchPrepareContext(updatable_docs=[doc], id_to_boost_map={})
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert len(result.chunks) == 1
|
||||
aware_chunk = result.chunks[0]
|
||||
assert persona.id in aware_chunk.personas
|
||||
assert aware_chunk.user_project == []
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
|
||||
side_effect=Exception("no LLM in test"),
|
||||
)
|
||||
def test_file_linked_to_project_gets_project_id(
|
||||
self,
|
||||
_mock_llm: MagicMock,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "adapter_project")
|
||||
uf = _create_user_file(db_session, user)
|
||||
project = _create_project(db_session, user)
|
||||
|
||||
db_session.add(Project__UserFile(project_id=project.id, user_file_id=uf.id))
|
||||
db_session.commit()
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=TEST_TENANT_ID, db_session=db_session
|
||||
)
|
||||
chunk = _make_index_chunk(uf)
|
||||
context = DocumentBatchPrepareContext(
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert len(result.chunks) == 1
|
||||
aware_chunk = result.chunks[0]
|
||||
assert project.id in aware_chunk.user_project
|
||||
assert aware_chunk.personas == []
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
|
||||
side_effect=Exception("no LLM in test"),
|
||||
)
|
||||
def test_file_linked_to_both_gets_both_ids(
|
||||
self,
|
||||
_mock_llm: MagicMock,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "adapter_both")
|
||||
uf = _create_user_file(db_session, user)
|
||||
persona = _create_persona(db_session, user)
|
||||
project = _create_project(db_session, user)
|
||||
|
||||
db_session.add(Persona__UserFile(persona_id=persona.id, user_file_id=uf.id))
|
||||
db_session.add(Project__UserFile(project_id=project.id, user_file_id=uf.id))
|
||||
db_session.commit()
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=TEST_TENANT_ID, db_session=db_session
|
||||
)
|
||||
chunk = _make_index_chunk(uf)
|
||||
context = DocumentBatchPrepareContext(
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
context=context,
|
||||
)
|
||||
|
||||
aware_chunk = result.chunks[0]
|
||||
assert persona.id in aware_chunk.personas
|
||||
assert project.id in aware_chunk.user_project
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
|
||||
side_effect=Exception("no LLM in test"),
|
||||
)
|
||||
def test_file_with_no_associations_gets_empty_lists(
|
||||
self,
|
||||
_mock_llm: MagicMock,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "adapter_empty")
|
||||
uf = _create_user_file(db_session, user)
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=TEST_TENANT_ID, db_session=db_session
|
||||
)
|
||||
chunk = _make_index_chunk(uf)
|
||||
context = DocumentBatchPrepareContext(
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
context=context,
|
||||
)
|
||||
|
||||
aware_chunk = result.chunks[0]
|
||||
assert aware_chunk.personas == []
|
||||
assert aware_chunk.user_project == []
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
|
||||
side_effect=Exception("no LLM in test"),
|
||||
)
|
||||
def test_multiple_personas_all_appear(
|
||||
self,
|
||||
_mock_llm: MagicMock,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file linked to multiple personas should have all their IDs."""
|
||||
user = create_test_user(db_session, "adapter_multi")
|
||||
uf = _create_user_file(db_session, user)
|
||||
persona_a = _create_persona(db_session, user)
|
||||
persona_b = _create_persona(db_session, user)
|
||||
|
||||
db_session.add(Persona__UserFile(persona_id=persona_a.id, user_file_id=uf.id))
|
||||
db_session.add(Persona__UserFile(persona_id=persona_b.id, user_file_id=uf.id))
|
||||
db_session.commit()
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=TEST_TENANT_ID, db_session=db_session
|
||||
)
|
||||
chunk = _make_index_chunk(uf)
|
||||
context = DocumentBatchPrepareContext(
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
context=context,
|
||||
)
|
||||
|
||||
aware_chunk = result.chunks[0]
|
||||
assert set(aware_chunk.personas) == {persona_a.id, persona_b.id}
|
||||
@@ -29,6 +29,7 @@ from onyx.server.manage.llm.api import (
|
||||
test_llm_configuration as run_test_llm_configuration,
|
||||
)
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.manage.llm.models import TestLLMRequest as LLMTestRequest
|
||||
|
||||
@@ -44,15 +45,14 @@ def _create_test_provider(
|
||||
db_session: Session,
|
||||
name: str,
|
||||
api_key: str = "sk-test-key-00000000000000000000000000000000000",
|
||||
) -> None:
|
||||
) -> LLMProviderView:
|
||||
"""Helper to create a test LLM provider in the database."""
|
||||
upsert_llm_provider(
|
||||
return upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=api_key,
|
||||
api_key_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o-mini", is_visible=True)
|
||||
],
|
||||
@@ -102,17 +102,11 @@ class TestLLMConfigurationEndpoint:
|
||||
# This should complete without exception
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
name=None, # New provider (not in DB)
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-new-test-key-0000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=False,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
model="gpt-4o-mini",
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -152,17 +146,11 @@ class TestLLMConfigurationEndpoint:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
name=None,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-invalid-key-00000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=False,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
model="gpt-4o-mini",
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -194,7 +182,9 @@ class TestLLMConfigurationEndpoint:
|
||||
|
||||
try:
|
||||
# First, create the provider in the database
|
||||
_create_test_provider(db_session, provider_name, api_key=original_api_key)
|
||||
provider = _create_test_provider(
|
||||
db_session, provider_name, api_key=original_api_key
|
||||
)
|
||||
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
|
||||
@@ -202,17 +192,12 @@ class TestLLMConfigurationEndpoint:
|
||||
# Test with api_key_changed=False - should use stored key
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
name=provider_name, # Existing provider
|
||||
id=provider.id,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None, # Not providing a new key
|
||||
api_key_changed=False, # Using existing key
|
||||
custom_config_changed=False,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
model="gpt-4o-mini",
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -246,7 +231,9 @@ class TestLLMConfigurationEndpoint:
|
||||
|
||||
try:
|
||||
# First, create the provider in the database
|
||||
_create_test_provider(db_session, provider_name, api_key=original_api_key)
|
||||
provider = _create_test_provider(
|
||||
db_session, provider_name, api_key=original_api_key
|
||||
)
|
||||
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
|
||||
@@ -254,17 +241,12 @@ class TestLLMConfigurationEndpoint:
|
||||
# Test with api_key_changed=True - should use new key
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
name=provider_name, # Existing provider
|
||||
id=provider.id,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=new_api_key, # Providing a new key
|
||||
api_key_changed=True, # Key is being changed
|
||||
custom_config_changed=False,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
model="gpt-4o-mini",
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -297,7 +279,7 @@ class TestLLMConfigurationEndpoint:
|
||||
|
||||
try:
|
||||
# First, create the provider in the database with custom_config
|
||||
upsert_llm_provider(
|
||||
provider = upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
@@ -305,7 +287,6 @@ class TestLLMConfigurationEndpoint:
|
||||
api_key_changed=True,
|
||||
custom_config=original_custom_config,
|
||||
custom_config_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
@@ -321,18 +302,13 @@ class TestLLMConfigurationEndpoint:
|
||||
# Test with custom_config_changed=False - should use stored config
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
name=provider_name,
|
||||
id=provider.id,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None,
|
||||
api_key_changed=False,
|
||||
custom_config=None, # Not providing new config
|
||||
custom_config_changed=False, # Using existing config
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
model="gpt-4o-mini",
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -368,17 +344,11 @@ class TestLLMConfigurationEndpoint:
|
||||
for model_name in test_models:
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
name=None,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=False,
|
||||
default_model_name=model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
model=model_name,
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -442,7 +412,6 @@ class TestDefaultProviderEndpoint:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=provider_1_api_key,
|
||||
api_key_changed=True,
|
||||
default_model_name=provider_1_initial_model,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True),
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
|
||||
@@ -452,7 +421,7 @@ class TestDefaultProviderEndpoint:
|
||||
)
|
||||
|
||||
# Set provider 1 as the default provider explicitly
|
||||
update_default_provider(provider_1.id, db_session)
|
||||
update_default_provider(provider_1.id, provider_1_initial_model, db_session)
|
||||
|
||||
# Step 2: Call run_test_default_provider - should use provider 1's default model
|
||||
with patch(
|
||||
@@ -472,7 +441,6 @@ class TestDefaultProviderEndpoint:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=provider_2_api_key,
|
||||
api_key_changed=True,
|
||||
default_model_name=provider_2_default_model,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
@@ -499,11 +467,11 @@ class TestDefaultProviderEndpoint:
|
||||
# Step 5: Update provider 1's default model
|
||||
upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
id=provider_1.id,
|
||||
name=provider_1_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=provider_1_api_key,
|
||||
api_key_changed=True,
|
||||
default_model_name=provider_1_updated_model, # Changed
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True),
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
|
||||
@@ -512,6 +480,9 @@ class TestDefaultProviderEndpoint:
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Set provider 1's default model to the updated model
|
||||
update_default_provider(provider_1.id, provider_1_updated_model, db_session)
|
||||
|
||||
# Step 6: Call run_test_default_provider - should use new model on provider 1
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
|
||||
@@ -524,7 +495,7 @@ class TestDefaultProviderEndpoint:
|
||||
captured_llms.clear()
|
||||
|
||||
# Step 7: Change the default provider to provider 2
|
||||
update_default_provider(provider_2.id, db_session)
|
||||
update_default_provider(provider_2.id, provider_2_default_model, db_session)
|
||||
|
||||
# Step 8: Call run_test_default_provider - should use provider 2
|
||||
with patch(
|
||||
@@ -596,7 +567,6 @@ class TestDefaultProviderEndpoint:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
@@ -605,7 +575,7 @@ class TestDefaultProviderEndpoint:
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
update_default_provider(provider.id, db_session)
|
||||
update_default_provider(provider.id, "gpt-4o-mini", db_session)
|
||||
|
||||
# Test should fail
|
||||
with patch(
|
||||
|
||||
@@ -49,7 +49,6 @@ def _create_test_provider(
|
||||
api_key_changed=True,
|
||||
api_base=api_base,
|
||||
custom_config=custom_config,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o-mini", is_visible=True)
|
||||
],
|
||||
@@ -91,14 +90,14 @@ class TestLLMProviderChanges:
|
||||
the API key should be blocked.
|
||||
"""
|
||||
try:
|
||||
_create_test_provider(db_session, provider_name)
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base="https://attacker.example.com",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -125,16 +124,16 @@ class TestLLMProviderChanges:
|
||||
Changing api_base IS allowed when the API key is also being changed.
|
||||
"""
|
||||
try:
|
||||
_create_test_provider(db_session, provider_name)
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-new-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
api_base="https://custom-endpoint.example.com/v1",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -159,14 +158,16 @@ class TestLLMProviderChanges:
|
||||
original_api_base = "https://original.example.com/v1"
|
||||
|
||||
try:
|
||||
_create_test_provider(db_session, provider_name, api_base=original_api_base)
|
||||
provider = _create_test_provider(
|
||||
db_session, provider_name, api_base=original_api_base
|
||||
)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base=original_api_base,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -190,14 +191,14 @@ class TestLLMProviderChanges:
|
||||
changes. This allows model-only updates when provider has no custom base URL.
|
||||
"""
|
||||
try:
|
||||
_create_test_provider(db_session, provider_name, api_base=None)
|
||||
view = _create_test_provider(db_session, provider_name, api_base=None)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=view.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base="",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -223,14 +224,16 @@ class TestLLMProviderChanges:
|
||||
original_api_base = "https://original.example.com/v1"
|
||||
|
||||
try:
|
||||
_create_test_provider(db_session, provider_name, api_base=original_api_base)
|
||||
provider = _create_test_provider(
|
||||
db_session, provider_name, api_base=original_api_base
|
||||
)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base=None,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -259,14 +262,14 @@ class TestLLMProviderChanges:
|
||||
users have full control over their deployment.
|
||||
"""
|
||||
try:
|
||||
_create_test_provider(db_session, provider_name)
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base="https://custom.example.com/v1",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -297,7 +300,6 @@ class TestLLMProviderChanges:
|
||||
api_key="sk-new-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
api_base="https://custom.example.com/v1",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -322,7 +324,7 @@ class TestLLMProviderChanges:
|
||||
redirect LLM API requests).
|
||||
"""
|
||||
try:
|
||||
_create_test_provider(
|
||||
provider = _create_test_provider(
|
||||
db_session,
|
||||
provider_name,
|
||||
custom_config={"SOME_CONFIG": "original_value"},
|
||||
@@ -330,11 +332,11 @@ class TestLLMProviderChanges:
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
custom_config={"OPENAI_API_BASE": "https://attacker.example.com"},
|
||||
custom_config_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -362,15 +364,15 @@ class TestLLMProviderChanges:
|
||||
without changing the API key.
|
||||
"""
|
||||
try:
|
||||
_create_test_provider(db_session, provider_name)
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
custom_config={"OPENAI_API_BASE": "https://attacker.example.com"},
|
||||
custom_config_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -399,7 +401,7 @@ class TestLLMProviderChanges:
|
||||
new_config = {"AWS_REGION_NAME": "us-west-2"}
|
||||
|
||||
try:
|
||||
_create_test_provider(
|
||||
provider = _create_test_provider(
|
||||
db_session,
|
||||
provider_name,
|
||||
custom_config={"AWS_REGION_NAME": "us-east-1"},
|
||||
@@ -407,13 +409,13 @@ class TestLLMProviderChanges:
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-new-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=True,
|
||||
custom_config=new_config,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -438,17 +440,17 @@ class TestLLMProviderChanges:
|
||||
original_config = {"AWS_REGION_NAME": "us-east-1"}
|
||||
|
||||
try:
|
||||
_create_test_provider(
|
||||
provider = _create_test_provider(
|
||||
db_session, provider_name, custom_config=original_config
|
||||
)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
custom_config=original_config,
|
||||
custom_config_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -474,7 +476,7 @@ class TestLLMProviderChanges:
|
||||
new_config = {"AWS_REGION_NAME": "eu-west-1"}
|
||||
|
||||
try:
|
||||
_create_test_provider(
|
||||
provider = _create_test_provider(
|
||||
db_session,
|
||||
provider_name,
|
||||
custom_config={"AWS_REGION_NAME": "us-east-1"},
|
||||
@@ -482,10 +484,10 @@ class TestLLMProviderChanges:
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
custom_config=new_config,
|
||||
default_model_name="gpt-4o-mini",
|
||||
custom_config_changed=True,
|
||||
)
|
||||
|
||||
@@ -530,14 +532,8 @@ def test_upload_with_custom_config_then_change(
|
||||
with patch("onyx.server.manage.llm.api.test_llm", side_effect=capture_test_llm):
|
||||
run_llm_config_test(
|
||||
LLMTestRequest(
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
model=default_model_name,
|
||||
api_key_changed=False,
|
||||
custom_config_changed=True,
|
||||
custom_config=custom_config,
|
||||
@@ -546,11 +542,10 @@ def test_upload_with_custom_config_then_change(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
put_llm_provider(
|
||||
provider = put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
default_model_name=default_model_name,
|
||||
custom_config=custom_config,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -569,14 +564,9 @@ def test_upload_with_custom_config_then_change(
|
||||
# Turn auto mode off
|
||||
run_llm_config_test(
|
||||
LLMTestRequest(
|
||||
name=name,
|
||||
id=provider.id,
|
||||
provider=provider_name,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
model=default_model_name,
|
||||
api_key_changed=False,
|
||||
custom_config_changed=False,
|
||||
),
|
||||
@@ -586,9 +576,9 @@ def test_upload_with_custom_config_then_change(
|
||||
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
@@ -616,13 +606,13 @@ def test_upload_with_custom_config_then_change(
|
||||
)
|
||||
|
||||
# Check inside the database and check that custom_config is the same as the original
|
||||
provider = fetch_existing_llm_provider(name=name, db_session=db_session)
|
||||
if not provider:
|
||||
db_provider = fetch_existing_llm_provider(name=name, db_session=db_session)
|
||||
if not db_provider:
|
||||
assert False, "Provider not found in the database"
|
||||
|
||||
assert provider.custom_config == custom_config, (
|
||||
assert db_provider.custom_config == custom_config, (
|
||||
f"Expected custom_config {custom_config}, "
|
||||
f"but got {provider.custom_config}"
|
||||
f"but got {db_provider.custom_config}"
|
||||
)
|
||||
finally:
|
||||
db_session.rollback()
|
||||
@@ -642,11 +632,10 @@ def test_preserves_masked_sensitive_custom_config_on_provider_update(
|
||||
}
|
||||
|
||||
try:
|
||||
put_llm_provider(
|
||||
view = put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=provider,
|
||||
default_model_name=default_model_name,
|
||||
custom_config=original_custom_config,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -665,9 +654,9 @@ def test_preserves_masked_sensitive_custom_config_on_provider_update(
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
id=view.id,
|
||||
name=name,
|
||||
provider=provider,
|
||||
default_model_name=default_model_name,
|
||||
custom_config={
|
||||
"vertex_credentials": _mask_string(
|
||||
original_custom_config["vertex_credentials"]
|
||||
@@ -706,7 +695,7 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
|
||||
) -> None:
|
||||
"""LLM test should restore masked sensitive custom config values before invocation."""
|
||||
name = f"test-provider-vertex-test-{uuid4().hex[:8]}"
|
||||
provider = LlmProviderNames.VERTEX_AI.value
|
||||
provider_name = LlmProviderNames.VERTEX_AI.value
|
||||
default_model_name = "gemini-2.5-pro"
|
||||
original_custom_config = {
|
||||
"vertex_credentials": '{"type":"service_account","private_key":"REAL_PRIVATE_KEY"}',
|
||||
@@ -719,11 +708,10 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
|
||||
return ""
|
||||
|
||||
try:
|
||||
put_llm_provider(
|
||||
provider = put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=provider,
|
||||
default_model_name=default_model_name,
|
||||
provider=provider_name,
|
||||
custom_config=original_custom_config,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -742,14 +730,9 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
|
||||
with patch("onyx.server.manage.llm.api.test_llm", side_effect=capture_test_llm):
|
||||
run_llm_config_test(
|
||||
LLMTestRequest(
|
||||
name=name,
|
||||
provider=provider,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
id=provider.id,
|
||||
provider=provider_name,
|
||||
model=default_model_name,
|
||||
api_key_changed=False,
|
||||
custom_config_changed=True,
|
||||
custom_config={
|
||||
|
||||
@@ -15,9 +15,11 @@ import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.llm import fetch_auto_mode_providers
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_llm_provider_view
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import sync_auto_mode_models
|
||||
from onyx.db.llm import update_default_provider
|
||||
@@ -135,7 +137,6 @@ class TestAutoModeSyncFeature:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name=expected_default_model,
|
||||
model_configurations=[], # No model configs provided
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -163,13 +164,8 @@ class TestAutoModeSyncFeature:
|
||||
if mc.name in all_expected_models:
|
||||
assert mc.is_visible is True, f"Model '{mc.name}' should be visible"
|
||||
|
||||
# Verify the default model was set correctly
|
||||
assert (
|
||||
provider.default_model_name == expected_default_model
|
||||
), f"Default model should be '{expected_default_model}'"
|
||||
|
||||
# Step 4: Set the provider as default
|
||||
update_default_provider(provider.id, db_session)
|
||||
update_default_provider(provider.id, expected_default_model, db_session)
|
||||
|
||||
# Step 5: Fetch the default provider and verify
|
||||
default_model = fetch_default_llm_model(db_session)
|
||||
@@ -238,7 +234,6 @@ class TestAutoModeSyncFeature:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name="gpt-4o",
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -317,7 +312,6 @@ class TestAutoModeSyncFeature:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=False, # Not in auto mode initially
|
||||
default_model_name="gpt-4",
|
||||
model_configurations=initial_models,
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -326,13 +320,13 @@ class TestAutoModeSyncFeature:
|
||||
)
|
||||
|
||||
# Verify initial state: all models are visible
|
||||
provider = fetch_existing_llm_provider(
|
||||
initial_provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
assert provider.is_auto_mode is False
|
||||
assert initial_provider is not None
|
||||
assert initial_provider.is_auto_mode is False
|
||||
|
||||
for mc in provider.model_configurations:
|
||||
for mc in initial_provider.model_configurations:
|
||||
assert (
|
||||
mc.is_visible is True
|
||||
), f"Initial model '{mc.name}' should be visible"
|
||||
@@ -344,12 +338,12 @@ class TestAutoModeSyncFeature:
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
id=initial_provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None, # Not changing API key
|
||||
api_key_changed=False,
|
||||
is_auto_mode=True, # Now enabling auto mode
|
||||
default_model_name=auto_mode_default,
|
||||
model_configurations=[], # Auto mode will sync from config
|
||||
),
|
||||
is_creation=False, # This is an update
|
||||
@@ -360,15 +354,15 @@ class TestAutoModeSyncFeature:
|
||||
# Step 3: Verify model visibility after auto mode transition
|
||||
# Expire session cache to force fresh fetch after sync_auto_mode_models committed
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
provider_view = fetch_llm_provider_view(
|
||||
provider_name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
assert provider.is_auto_mode is True
|
||||
assert provider_view is not None
|
||||
assert provider_view.is_auto_mode is True
|
||||
|
||||
# Build a map of model name -> visibility
|
||||
model_visibility = {
|
||||
mc.name: mc.is_visible for mc in provider.model_configurations
|
||||
mc.name: mc.is_visible for mc in provider_view.model_configurations
|
||||
}
|
||||
|
||||
# Models in auto mode config should be visible
|
||||
@@ -388,9 +382,6 @@ class TestAutoModeSyncFeature:
|
||||
model_visibility[model_name] is False
|
||||
), f"Model '{model_name}' not in auto config should NOT be visible"
|
||||
|
||||
# Verify the default model was updated
|
||||
assert provider.default_model_name == auto_mode_default
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
@@ -432,8 +423,12 @@ class TestAutoModeSyncFeature:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name="gpt-4o",
|
||||
model_configurations=[],
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o",
|
||||
is_visible=True,
|
||||
)
|
||||
],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
@@ -535,7 +530,6 @@ class TestAutoModeSyncFeature:
|
||||
api_key=provider_1_api_key,
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name=provider_1_default_model,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -549,7 +543,7 @@ class TestAutoModeSyncFeature:
|
||||
name=provider_1_name, db_session=db_session
|
||||
)
|
||||
assert provider_1 is not None
|
||||
update_default_provider(provider_1.id, db_session)
|
||||
update_default_provider(provider_1.id, provider_1_default_model, db_session)
|
||||
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
@@ -563,7 +557,6 @@ class TestAutoModeSyncFeature:
|
||||
api_key=provider_2_api_key,
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name=provider_2_default_model,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -584,7 +577,7 @@ class TestAutoModeSyncFeature:
|
||||
name=provider_2_name, db_session=db_session
|
||||
)
|
||||
assert provider_2 is not None
|
||||
update_default_provider(provider_2.id, db_session)
|
||||
update_default_provider(provider_2.id, provider_2_default_model, db_session)
|
||||
|
||||
# Step 5: Verify provider 2 is now the default
|
||||
db_session.expire_all()
|
||||
@@ -644,7 +637,6 @@ class TestAutoModeMissingFlows:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name="gpt-4o",
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -701,3 +693,364 @@ class TestAutoModeMissingFlows:
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
|
||||
class TestAutoModeTransitionsAndResync:
|
||||
"""Tests for auto/manual transitions, config evolution, and sync idempotency."""
|
||||
|
||||
def test_auto_to_manual_mode_preserves_models_and_stops_syncing(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""Disabling auto mode should preserve the current model list and
|
||||
prevent future syncs from altering visibility.
|
||||
|
||||
Steps:
|
||||
1. Create provider in auto mode — models synced from config.
|
||||
2. Update provider to manual mode (is_auto_mode=False).
|
||||
3. Verify all models remain with unchanged visibility.
|
||||
4. Call sync_auto_mode_models with a *different* config.
|
||||
5. Verify fetch_auto_mode_providers excludes this provider, so the
|
||||
periodic task would never call sync on it.
|
||||
"""
|
||||
initial_config = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o",
|
||||
additional_models=["gpt-4o-mini"],
|
||||
)
|
||||
|
||||
try:
|
||||
# Step 1: Create in auto mode
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
return_value=initial_config,
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
visibility_before = {
|
||||
mc.name: mc.is_visible for mc in provider.model_configurations
|
||||
}
|
||||
assert visibility_before == {"gpt-4o": True, "gpt-4o-mini": True}
|
||||
|
||||
# Step 2: Switch to manual mode
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None,
|
||||
api_key_changed=False,
|
||||
is_auto_mode=False,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
),
|
||||
],
|
||||
),
|
||||
is_creation=False,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Step 3: Models unchanged
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
assert provider.is_auto_mode is False
|
||||
visibility_after = {
|
||||
mc.name: mc.is_visible for mc in provider.model_configurations
|
||||
}
|
||||
assert visibility_after == visibility_before
|
||||
|
||||
# Step 4-5: Provider excluded from auto mode queries
|
||||
auto_providers = fetch_auto_mode_providers(db_session)
|
||||
auto_provider_ids = {p.id for p in auto_providers}
|
||||
assert provider.id not in auto_provider_ids
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
def test_resync_adds_new_and_hides_removed_models(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""When the GitHub config changes between syncs, a subsequent sync
|
||||
should add newly listed models and hide models that were removed.
|
||||
|
||||
Steps:
|
||||
1. Create provider in auto mode with config v1: [gpt-4o, gpt-4o-mini].
|
||||
2. Sync with config v2: [gpt-4o, gpt-4-turbo] (gpt-4o-mini removed,
|
||||
gpt-4-turbo added).
|
||||
3. Verify gpt-4o still visible, gpt-4o-mini hidden, gpt-4-turbo added
|
||||
and visible.
|
||||
"""
|
||||
config_v1 = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o",
|
||||
additional_models=["gpt-4o-mini"],
|
||||
)
|
||||
config_v2 = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o",
|
||||
additional_models=["gpt-4-turbo"],
|
||||
)
|
||||
|
||||
try:
|
||||
# Step 1: Create with config v1
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
return_value=config_v1,
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Step 2: Re-sync with config v2
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
|
||||
changes = sync_auto_mode_models(
|
||||
db_session=db_session,
|
||||
provider=provider,
|
||||
llm_recommendations=config_v2,
|
||||
)
|
||||
assert changes > 0
|
||||
|
||||
# Step 3: Verify
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
|
||||
visibility = {
|
||||
mc.name: mc.is_visible for mc in provider.model_configurations
|
||||
}
|
||||
|
||||
# gpt-4o: still in config -> visible
|
||||
assert visibility["gpt-4o"] is True
|
||||
# gpt-4o-mini: removed from config -> hidden (not deleted)
|
||||
assert "gpt-4o-mini" in visibility, "Removed model should still exist in DB"
|
||||
assert visibility["gpt-4o-mini"] is False
|
||||
# gpt-4-turbo: newly added -> visible
|
||||
assert visibility["gpt-4-turbo"] is True
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
def test_sync_is_idempotent(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""Running sync twice with the same config should produce zero
|
||||
changes on the second call."""
|
||||
config = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o",
|
||||
additional_models=["gpt-4o-mini", "gpt-4-turbo"],
|
||||
)
|
||||
|
||||
try:
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
return_value=config,
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
|
||||
# First explicit sync (may report changes if creation already synced)
|
||||
sync_auto_mode_models(
|
||||
db_session=db_session,
|
||||
provider=provider,
|
||||
llm_recommendations=config,
|
||||
)
|
||||
|
||||
# Snapshot state after first sync
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
snapshot = {
|
||||
mc.name: (mc.is_visible, mc.display_name)
|
||||
for mc in provider.model_configurations
|
||||
}
|
||||
|
||||
# Second sync — should be a no-op
|
||||
changes = sync_auto_mode_models(
|
||||
db_session=db_session,
|
||||
provider=provider,
|
||||
llm_recommendations=config,
|
||||
)
|
||||
assert (
|
||||
changes == 0
|
||||
), f"Expected 0 changes on idempotent re-sync, got {changes}"
|
||||
|
||||
# State should be identical
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
current = {
|
||||
mc.name: (mc.is_visible, mc.display_name)
|
||||
for mc in provider.model_configurations
|
||||
}
|
||||
assert current == snapshot
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
def test_default_model_hidden_when_removed_from_config(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""When the current default model is removed from the config, sync
|
||||
should hide it. The default model flow row should still exist (it
|
||||
points at the ModelConfiguration), but the model is no longer visible.
|
||||
|
||||
Steps:
|
||||
1. Create provider with config: default=gpt-4o, additional=[gpt-4o-mini].
|
||||
2. Set gpt-4o as the global default.
|
||||
3. Re-sync with config: default=gpt-4o-mini (gpt-4o removed entirely).
|
||||
4. Verify gpt-4o is hidden, gpt-4o-mini is visible, and
|
||||
fetch_default_llm_model still returns a result (the flow row persists).
|
||||
"""
|
||||
config_v1 = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o",
|
||||
additional_models=["gpt-4o-mini"],
|
||||
)
|
||||
config_v2 = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o-mini",
|
||||
additional_models=[],
|
||||
)
|
||||
|
||||
try:
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
return_value=config_v1,
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Step 2: Set gpt-4o as global default
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
default_before = fetch_default_llm_model(db_session)
|
||||
assert default_before is not None
|
||||
assert default_before.name == "gpt-4o"
|
||||
|
||||
# Step 3: Re-sync with config v2 (gpt-4o removed)
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
|
||||
changes = sync_auto_mode_models(
|
||||
db_session=db_session,
|
||||
provider=provider,
|
||||
llm_recommendations=config_v2,
|
||||
)
|
||||
assert changes > 0
|
||||
|
||||
# Step 4: Verify visibility
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
|
||||
visibility = {
|
||||
mc.name: mc.is_visible for mc in provider.model_configurations
|
||||
}
|
||||
assert visibility["gpt-4o"] is False, "Removed default should be hidden"
|
||||
assert visibility["gpt-4o-mini"] is True, "New default should be visible"
|
||||
|
||||
# The LLMModelFlow row for gpt-4o still exists (is_default=True),
|
||||
# but the model is hidden. fetch_default_llm_model filters on
|
||||
# is_visible=True, so it should NOT return gpt-4o.
|
||||
db_session.expire_all()
|
||||
default_after = fetch_default_llm_model(db_session)
|
||||
assert (
|
||||
default_after is None or default_after.name != "gpt-4o"
|
||||
), "Hidden model should not be returned as the default"
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
@@ -64,7 +64,6 @@ def _create_provider(
|
||||
name=name,
|
||||
provider=provider,
|
||||
api_key="sk-ant-api03-...",
|
||||
default_model_name="claude-3-5-sonnet-20240620",
|
||||
is_public=is_public,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -154,7 +153,9 @@ def test_user_sends_message_to_private_provider(
|
||||
)
|
||||
_create_provider(db_session, LlmProviderNames.GOOGLE, "private-provider", False)
|
||||
|
||||
update_default_provider(public_provider_id, db_session)
|
||||
update_default_provider(
|
||||
public_provider_id, "claude-3-5-sonnet-20240620", db_session
|
||||
)
|
||||
|
||||
try:
|
||||
# Create chat session
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""External dependency unit tests for OpenSearchClient.
|
||||
"""External dependency unit tests for OpenSearchIndexClient.
|
||||
|
||||
These tests assume OpenSearch is running and test all implemented methods
|
||||
using real schemas, pipelines, and search queries from the codebase.
|
||||
@@ -19,7 +19,7 @@ from onyx.access.utils import prefix_user_email
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.client import OpenSearchIndexClient
|
||||
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
|
||||
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
@@ -125,10 +125,10 @@ def opensearch_available() -> None:
|
||||
@pytest.fixture(scope="function")
|
||||
def test_client(
|
||||
opensearch_available: None, # noqa: ARG001
|
||||
) -> Generator[OpenSearchClient, None, None]:
|
||||
) -> Generator[OpenSearchIndexClient, None, None]:
|
||||
"""Creates an OpenSearch client for testing with automatic cleanup."""
|
||||
test_index_name = f"test_index_{uuid.uuid4().hex[:8]}"
|
||||
client = OpenSearchClient(index_name=test_index_name)
|
||||
client = OpenSearchIndexClient(index_name=test_index_name)
|
||||
|
||||
yield client # Test runs here.
|
||||
|
||||
@@ -142,7 +142,7 @@ def test_client(
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def search_pipeline(test_client: OpenSearchClient) -> Generator[None, None, None]:
|
||||
def search_pipeline(test_client: OpenSearchIndexClient) -> Generator[None, None, None]:
|
||||
"""Creates a search pipeline for testing with automatic cleanup."""
|
||||
test_client.create_search_pipeline(
|
||||
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
|
||||
@@ -158,9 +158,9 @@ def search_pipeline(test_client: OpenSearchClient) -> Generator[None, None, None
|
||||
|
||||
|
||||
class TestOpenSearchClient:
|
||||
"""Tests for OpenSearchClient."""
|
||||
"""Tests for OpenSearchIndexClient."""
|
||||
|
||||
def test_create_index(self, test_client: OpenSearchClient) -> None:
|
||||
def test_create_index(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests creating an index with a real schema."""
|
||||
# Precondition.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
@@ -176,7 +176,7 @@ class TestOpenSearchClient:
|
||||
# Verify index exists.
|
||||
assert test_client.validate_index(expected_mappings=mappings) is True
|
||||
|
||||
def test_delete_existing_index(self, test_client: OpenSearchClient) -> None:
|
||||
def test_delete_existing_index(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests deleting an existing index returns True."""
|
||||
# Precondition.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
@@ -193,7 +193,7 @@ class TestOpenSearchClient:
|
||||
assert result is True
|
||||
assert test_client.validate_index(expected_mappings=mappings) is False
|
||||
|
||||
def test_delete_nonexistent_index(self, test_client: OpenSearchClient) -> None:
|
||||
def test_delete_nonexistent_index(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests deleting a nonexistent index returns False."""
|
||||
# Under test.
|
||||
# Don't create index, just try to delete.
|
||||
@@ -202,7 +202,7 @@ class TestOpenSearchClient:
|
||||
# Postcondition.
|
||||
assert result is False
|
||||
|
||||
def test_index_exists(self, test_client: OpenSearchClient) -> None:
|
||||
def test_index_exists(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests checking if an index exists."""
|
||||
# Precondition.
|
||||
# Index should not exist before creation.
|
||||
@@ -219,7 +219,7 @@ class TestOpenSearchClient:
|
||||
# Index should exist after creation.
|
||||
assert test_client.index_exists() is True
|
||||
|
||||
def test_validate_index(self, test_client: OpenSearchClient) -> None:
|
||||
def test_validate_index(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests validating an index."""
|
||||
# Precondition.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
@@ -239,7 +239,120 @@ class TestOpenSearchClient:
|
||||
# Should return True after creation.
|
||||
assert test_client.validate_index(expected_mappings=mappings) is True
|
||||
|
||||
def test_create_duplicate_index(self, test_client: OpenSearchClient) -> None:
|
||||
def test_put_mapping_idempotent(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests put_mapping with same schema is idempotent."""
|
||||
# Precondition.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=True
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Under test.
|
||||
# Applying the same mappings again should succeed.
|
||||
test_client.put_mapping(mappings)
|
||||
|
||||
# Postcondition.
|
||||
# Index should still be valid.
|
||||
assert test_client.validate_index(expected_mappings=mappings)
|
||||
|
||||
def test_put_mapping_adds_new_field(
|
||||
self, test_client: OpenSearchIndexClient
|
||||
) -> None:
|
||||
"""Tests put_mapping successfully adds new fields to existing index."""
|
||||
# Precondition.
|
||||
# Create index with minimal schema (just required fields).
|
||||
initial_mappings = {
|
||||
"dynamic": "strict",
|
||||
"properties": {
|
||||
"document_id": {"type": "keyword"},
|
||||
"chunk_index": {"type": "integer"},
|
||||
"content": {"type": "text"},
|
||||
"content_vector": {
|
||||
"type": "knn_vector",
|
||||
"dimension": 128,
|
||||
"method": {
|
||||
"name": "hnsw",
|
||||
"space_type": "cosinesimil",
|
||||
"engine": "lucene",
|
||||
"parameters": {"ef_construction": 512, "m": 16},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=initial_mappings, settings=settings)
|
||||
|
||||
# Under test.
|
||||
# Add a new field using put_mapping.
|
||||
updated_mappings = {
|
||||
"properties": {
|
||||
"document_id": {"type": "keyword"},
|
||||
"chunk_index": {"type": "integer"},
|
||||
"content": {"type": "text"},
|
||||
"content_vector": {
|
||||
"type": "knn_vector",
|
||||
"dimension": 128,
|
||||
"method": {
|
||||
"name": "hnsw",
|
||||
"space_type": "cosinesimil",
|
||||
"engine": "lucene",
|
||||
"parameters": {"ef_construction": 512, "m": 16},
|
||||
},
|
||||
},
|
||||
# New field
|
||||
"new_test_field": {"type": "keyword"},
|
||||
},
|
||||
}
|
||||
# Should not raise.
|
||||
test_client.put_mapping(updated_mappings)
|
||||
|
||||
# Postcondition.
|
||||
# Validate the new schema includes the new field.
|
||||
assert test_client.validate_index(expected_mappings=updated_mappings)
|
||||
|
||||
def test_put_mapping_fails_on_type_change(
|
||||
self, test_client: OpenSearchIndexClient
|
||||
) -> None:
|
||||
"""Tests put_mapping fails when trying to change existing field type."""
|
||||
# Precondition.
|
||||
initial_mappings = {
|
||||
"dynamic": "strict",
|
||||
"properties": {
|
||||
"document_id": {"type": "keyword"},
|
||||
"test_field": {"type": "keyword"},
|
||||
},
|
||||
}
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=initial_mappings, settings=settings)
|
||||
|
||||
# Under test and postcondition.
|
||||
# Try to change test_field type from keyword to text.
|
||||
conflicting_mappings = {
|
||||
"properties": {
|
||||
"document_id": {"type": "keyword"},
|
||||
"test_field": {"type": "text"}, # Changed from keyword to text
|
||||
},
|
||||
}
|
||||
# Should raise because field type cannot be changed.
|
||||
with pytest.raises(Exception, match="mapper|illegal_argument_exception"):
|
||||
test_client.put_mapping(conflicting_mappings)
|
||||
|
||||
def test_put_mapping_on_nonexistent_index(
|
||||
self, test_client: OpenSearchIndexClient
|
||||
) -> None:
|
||||
"""Tests put_mapping on non-existent index raises an error."""
|
||||
# Precondition.
|
||||
# Index does not exist yet.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=True
|
||||
)
|
||||
|
||||
# Under test and postcondition.
|
||||
with pytest.raises(Exception, match="index_not_found_exception|404"):
|
||||
test_client.put_mapping(mappings)
|
||||
|
||||
def test_create_duplicate_index(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests creating an index twice raises an error."""
|
||||
# Precondition.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
@@ -254,14 +367,14 @@ class TestOpenSearchClient:
|
||||
with pytest.raises(Exception, match="already exists"):
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
def test_update_settings(self, test_client: OpenSearchClient) -> None:
|
||||
def test_update_settings(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests that update_settings raises NotImplementedError."""
|
||||
# Under test and postcondition.
|
||||
with pytest.raises(NotImplementedError):
|
||||
test_client.update_settings(settings={})
|
||||
|
||||
def test_create_and_delete_search_pipeline(
|
||||
self, test_client: OpenSearchClient
|
||||
self, test_client: OpenSearchIndexClient
|
||||
) -> None:
|
||||
"""Tests creating and deleting a search pipeline."""
|
||||
# Under test and postcondition.
|
||||
@@ -278,7 +391,7 @@ class TestOpenSearchClient:
|
||||
)
|
||||
|
||||
def test_index_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests indexing a document."""
|
||||
# Precondition.
|
||||
@@ -306,7 +419,7 @@ class TestOpenSearchClient:
|
||||
)
|
||||
|
||||
def test_bulk_index_documents(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests bulk indexing documents."""
|
||||
# Precondition.
|
||||
@@ -337,7 +450,7 @@ class TestOpenSearchClient:
|
||||
)
|
||||
|
||||
def test_index_duplicate_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests indexing a duplicate document raises an error."""
|
||||
# Precondition.
|
||||
@@ -365,7 +478,7 @@ class TestOpenSearchClient:
|
||||
test_client.index_document(document=doc, tenant_state=tenant_state)
|
||||
|
||||
def test_get_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests getting a document."""
|
||||
# Precondition.
|
||||
@@ -401,7 +514,7 @@ class TestOpenSearchClient:
|
||||
assert retrieved_doc == original_doc
|
||||
|
||||
def test_get_nonexistent_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests getting a nonexistent document raises an error."""
|
||||
# Precondition.
|
||||
@@ -419,7 +532,7 @@ class TestOpenSearchClient:
|
||||
)
|
||||
|
||||
def test_delete_existing_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests deleting an existing document returns True."""
|
||||
# Precondition.
|
||||
@@ -455,7 +568,7 @@ class TestOpenSearchClient:
|
||||
test_client.get_document(document_chunk_id=doc_chunk_id)
|
||||
|
||||
def test_delete_nonexistent_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests deleting a nonexistent document returns False."""
|
||||
# Precondition.
|
||||
@@ -476,7 +589,7 @@ class TestOpenSearchClient:
|
||||
assert result is False
|
||||
|
||||
def test_delete_by_query(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests deleting documents by query."""
|
||||
# Precondition.
|
||||
@@ -552,7 +665,7 @@ class TestOpenSearchClient:
|
||||
assert len(keep_ids) == 1
|
||||
|
||||
def test_update_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests updating a document's properties."""
|
||||
# Precondition.
|
||||
@@ -601,7 +714,7 @@ class TestOpenSearchClient:
|
||||
assert updated_doc.public == doc.public
|
||||
|
||||
def test_update_nonexistent_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests updating a nonexistent document raises an error."""
|
||||
# Precondition.
|
||||
@@ -623,7 +736,7 @@ class TestOpenSearchClient:
|
||||
|
||||
def test_hybrid_search_with_pipeline(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
test_client: OpenSearchIndexClient,
|
||||
search_pipeline: None, # noqa: ARG002
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
@@ -704,7 +817,7 @@ class TestOpenSearchClient:
|
||||
|
||||
def test_search_empty_index(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
test_client: OpenSearchIndexClient,
|
||||
search_pipeline: None, # noqa: ARG002
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
@@ -743,7 +856,7 @@ class TestOpenSearchClient:
|
||||
|
||||
def test_hybrid_search_with_pipeline_and_filters(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
test_client: OpenSearchIndexClient,
|
||||
search_pipeline: None, # noqa: ARG002
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
@@ -863,7 +976,7 @@ class TestOpenSearchClient:
|
||||
|
||||
def test_hybrid_search_with_pipeline_and_filters_returns_chunks_with_related_content_first(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
test_client: OpenSearchIndexClient,
|
||||
search_pipeline: None, # noqa: ARG002
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
@@ -993,7 +1106,7 @@ class TestOpenSearchClient:
|
||||
previous_score = current_score
|
||||
|
||||
def test_delete_by_query_multitenant_isolation(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""
|
||||
Tests delete_by_query respects tenant boundaries in multi-tenant mode.
|
||||
@@ -1087,7 +1200,7 @@ class TestOpenSearchClient:
|
||||
assert set(remaining_y_ids) == expected_y_ids
|
||||
|
||||
def test_delete_by_query_nonexistent_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""
|
||||
Tests delete_by_query for non-existent document returns 0 deleted.
|
||||
@@ -1116,7 +1229,7 @@ class TestOpenSearchClient:
|
||||
assert num_deleted == 0
|
||||
|
||||
def test_search_for_document_ids(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests search_for_document_ids method returns correct chunk IDs."""
|
||||
# Precondition.
|
||||
@@ -1181,7 +1294,7 @@ class TestOpenSearchClient:
|
||||
assert set(chunk_ids) == expected_ids
|
||||
|
||||
def test_search_with_no_document_access_can_retrieve_all_documents(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""
|
||||
Tests search with no document access can retrieve all documents, even
|
||||
@@ -1259,7 +1372,7 @@ class TestOpenSearchClient:
|
||||
|
||||
def test_time_cutoff_filter(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
test_client: OpenSearchIndexClient,
|
||||
search_pipeline: None, # noqa: ARG002
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
@@ -1352,7 +1465,7 @@ class TestOpenSearchClient:
|
||||
)
|
||||
|
||||
def test_random_search(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests the random search query works."""
|
||||
# Precondition.
|
||||
|
||||
@@ -37,6 +37,7 @@ from onyx.db.opensearch_migration import build_sanitized_to_original_doc_id_mapp
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.client import OpenSearchIndexClient
|
||||
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
|
||||
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
|
||||
from onyx.document_index.opensearch.schema import DocumentChunk
|
||||
@@ -74,7 +75,7 @@ CHUNK_COUNT = 5
|
||||
|
||||
|
||||
def _get_document_chunks_from_opensearch(
|
||||
opensearch_client: OpenSearchClient, document_id: str, current_tenant_id: str
|
||||
opensearch_client: OpenSearchIndexClient, document_id: str, current_tenant_id: str
|
||||
) -> list[DocumentChunk]:
|
||||
opensearch_client.refresh_index()
|
||||
filters = IndexFilters(access_control_list=None, tenant_id=current_tenant_id)
|
||||
@@ -95,7 +96,7 @@ def _get_document_chunks_from_opensearch(
|
||||
|
||||
|
||||
def _delete_document_chunks_from_opensearch(
|
||||
opensearch_client: OpenSearchClient, document_id: str, current_tenant_id: str
|
||||
opensearch_client: OpenSearchIndexClient, document_id: str, current_tenant_id: str
|
||||
) -> None:
|
||||
opensearch_client.refresh_index()
|
||||
query_body = DocumentQuery.delete_from_document_id_query(
|
||||
@@ -283,10 +284,10 @@ def vespa_document_index(
|
||||
def opensearch_client(
|
||||
db_session: Session,
|
||||
full_deployment_setup: None, # noqa: ARG001
|
||||
) -> Generator[OpenSearchClient, None, None]:
|
||||
) -> Generator[OpenSearchIndexClient, None, None]:
|
||||
"""Creates an OpenSearch client for the test tenant."""
|
||||
active = get_active_search_settings(db_session)
|
||||
yield OpenSearchClient(index_name=active.primary.index_name) # Test runs here.
|
||||
yield OpenSearchIndexClient(index_name=active.primary.index_name) # Test runs here.
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -330,7 +331,7 @@ def patch_get_vespa_chunks_page_size() -> Generator[int, None, None]:
|
||||
def test_documents(
|
||||
db_session: Session,
|
||||
vespa_document_index: VespaDocumentIndex,
|
||||
opensearch_client: OpenSearchClient,
|
||||
opensearch_client: OpenSearchIndexClient,
|
||||
patch_get_vespa_chunks_page_size: int,
|
||||
) -> Generator[list[Document], None, None]:
|
||||
"""
|
||||
@@ -411,7 +412,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
db_session: Session,
|
||||
test_documents: list[Document],
|
||||
vespa_document_index: VespaDocumentIndex,
|
||||
opensearch_client: OpenSearchClient,
|
||||
opensearch_client: OpenSearchIndexClient,
|
||||
test_embedding_dimension: int,
|
||||
clean_migration_tables: None, # noqa: ARG002
|
||||
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
|
||||
@@ -480,7 +481,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
db_session: Session,
|
||||
test_documents: list[Document],
|
||||
vespa_document_index: VespaDocumentIndex,
|
||||
opensearch_client: OpenSearchClient,
|
||||
opensearch_client: OpenSearchIndexClient,
|
||||
test_embedding_dimension: int,
|
||||
clean_migration_tables: None, # noqa: ARG002
|
||||
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
|
||||
@@ -618,7 +619,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
db_session: Session,
|
||||
test_documents: list[Document],
|
||||
vespa_document_index: VespaDocumentIndex,
|
||||
opensearch_client: OpenSearchClient,
|
||||
opensearch_client: OpenSearchIndexClient,
|
||||
test_embedding_dimension: int,
|
||||
clean_migration_tables: None, # noqa: ARG002
|
||||
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
|
||||
@@ -712,7 +713,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
db_session: Session,
|
||||
test_documents: list[Document],
|
||||
vespa_document_index: VespaDocumentIndex,
|
||||
opensearch_client: OpenSearchClient,
|
||||
opensearch_client: OpenSearchIndexClient,
|
||||
test_embedding_dimension: int,
|
||||
clean_migration_tables: None, # noqa: ARG002
|
||||
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
|
||||
|
||||
@@ -42,7 +42,6 @@ def _create_llm_provider_and_model(
|
||||
name=provider_name,
|
||||
provider="openai",
|
||||
api_key="test-api-key",
|
||||
default_model_name=model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=model_name,
|
||||
|
||||
@@ -434,7 +434,6 @@ class TestSlackBotFederatedSearch:
|
||||
name=f"test-llm-provider-{uuid4().hex[:8]}",
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=api_key,
|
||||
default_model_name="gpt-4o",
|
||||
is_public=True,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -448,7 +447,7 @@ class TestSlackBotFederatedSearch:
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
update_default_provider(provider_view.id, db_session)
|
||||
update_default_provider(provider_view.id, "gpt-4o", db_session)
|
||||
|
||||
def _teardown_common_mocks(self, patches: list) -> None:
|
||||
"""Stop all patches"""
|
||||
|
||||
@@ -20,6 +20,7 @@ from onyx.auth.oauth_token_manager import OAuthTokenManager
|
||||
from onyx.db.models import OAuthConfig
|
||||
from onyx.db.oauth_config import create_oauth_config
|
||||
from onyx.db.oauth_config import upsert_user_oauth_token
|
||||
from onyx.utils.sensitive import SensitiveValue
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
|
||||
|
||||
@@ -491,3 +492,19 @@ class TestOAuthTokenManagerURLBuilding:
|
||||
# Should use & instead of ? since URL already has query params
|
||||
assert "foo=bar&" in url or "?foo=bar" in url
|
||||
assert "client_id=custom_client_id" in url
|
||||
|
||||
|
||||
class TestUnwrapSensitiveStr:
|
||||
"""Tests for _unwrap_sensitive_str static method"""
|
||||
|
||||
def test_unwrap_sensitive_str(self) -> None:
|
||||
"""Test that both SensitiveValue and plain str inputs are handled"""
|
||||
# SensitiveValue input
|
||||
sensitive = SensitiveValue[str](
|
||||
encrypted_bytes=b"test_client_id",
|
||||
decrypt_fn=lambda b: b.decode(),
|
||||
)
|
||||
assert OAuthTokenManager._unwrap_sensitive_str(sensitive) == "test_client_id"
|
||||
|
||||
# Plain str input
|
||||
assert OAuthTokenManager._unwrap_sensitive_str("plain_string") == "plain_string"
|
||||
|
||||
@@ -76,9 +76,12 @@ class ChatSessionManager:
|
||||
user_performing_action: DATestUser,
|
||||
persona_id: int = 0,
|
||||
description: str = "Test chat session",
|
||||
project_id: int | None = None,
|
||||
) -> DATestChatSession:
|
||||
chat_session_creation_req = ChatSessionCreationRequest(
|
||||
persona_id=persona_id, description=description
|
||||
persona_id=persona_id,
|
||||
description=description,
|
||||
project_id=project_id,
|
||||
)
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/create-chat-session",
|
||||
|
||||
@@ -4,10 +4,12 @@ from uuid import uuid4
|
||||
import requests
|
||||
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.server.manage.llm.models import DefaultModel
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -32,7 +34,6 @@ class LLMProviderManager:
|
||||
llm_provider = LLMProviderUpsertRequest(
|
||||
name=name or f"test-provider-{uuid4()}",
|
||||
provider=provider or LlmProviderNames.OPENAI,
|
||||
default_model_name=default_model_name or "gpt-4o-mini",
|
||||
api_key=api_key or os.environ["OPENAI_API_KEY"],
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
@@ -65,7 +66,7 @@ class LLMProviderManager:
|
||||
name=response_data["name"],
|
||||
provider=response_data["provider"],
|
||||
api_key=response_data["api_key"],
|
||||
default_model_name=response_data["default_model_name"],
|
||||
default_model_name=default_model_name or "gpt-4o-mini",
|
||||
is_public=response_data["is_public"],
|
||||
is_auto_mode=response_data.get("is_auto_mode", False),
|
||||
groups=response_data["groups"],
|
||||
@@ -75,9 +76,19 @@ class LLMProviderManager:
|
||||
)
|
||||
|
||||
if set_as_default:
|
||||
if default_model_name is None:
|
||||
default_model_name = "gpt-4o-mini"
|
||||
set_default_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{llm_response.json()['id']}/default",
|
||||
headers=user_performing_action.headers,
|
||||
f"{API_SERVER_URL}/admin/llm/default",
|
||||
json={
|
||||
"provider_id": response_data["id"],
|
||||
"model_name": default_model_name,
|
||||
},
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
)
|
||||
set_default_response.raise_for_status()
|
||||
|
||||
@@ -104,7 +115,7 @@ class LLMProviderManager:
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [LLMProviderView(**ug) for ug in response.json()]
|
||||
return [LLMProviderView(**p) for p in response.json()["providers"]]
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
@@ -113,7 +124,11 @@ class LLMProviderManager:
|
||||
verify_deleted: bool = False,
|
||||
) -> None:
|
||||
all_llm_providers = LLMProviderManager.get_all(user_performing_action)
|
||||
default_model = LLMProviderManager.get_default_model(user_performing_action)
|
||||
for fetched_llm_provider in all_llm_providers:
|
||||
model_names = [
|
||||
model.name for model in fetched_llm_provider.model_configurations
|
||||
]
|
||||
if llm_provider.id == fetched_llm_provider.id:
|
||||
if verify_deleted:
|
||||
raise ValueError(
|
||||
@@ -126,11 +141,30 @@ class LLMProviderManager:
|
||||
if (
|
||||
fetched_llm_groups == llm_provider_groups
|
||||
and llm_provider.provider == fetched_llm_provider.provider
|
||||
and llm_provider.default_model_name
|
||||
== fetched_llm_provider.default_model_name
|
||||
and (
|
||||
default_model is None or default_model.model_name in model_names
|
||||
)
|
||||
and llm_provider.is_public == fetched_llm_provider.is_public
|
||||
and set(fetched_llm_provider.personas) == set(llm_provider.personas)
|
||||
):
|
||||
return
|
||||
if not verify_deleted:
|
||||
raise ValueError(f"LLM Provider {llm_provider.id} not found")
|
||||
|
||||
@staticmethod
|
||||
def get_default_model(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DefaultModel | None:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
default_text = response.json().get("default_text")
|
||||
if default_text is None:
|
||||
return None
|
||||
return DefaultModel(**default_text)
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestScimToken
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
class ScimTokenManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
name: str,
|
||||
user_performing_action: DATestUser,
|
||||
) -> DATestScimToken:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/enterprise-settings/scim/token",
|
||||
json={"name": name},
|
||||
headers=user_performing_action.headers,
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return DATestScimToken(
|
||||
id=data["id"],
|
||||
name=data["name"],
|
||||
token_display=data["token_display"],
|
||||
is_active=data["is_active"],
|
||||
created_at=data["created_at"],
|
||||
last_used_at=data.get("last_used_at"),
|
||||
raw_token=data["raw_token"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_active(
|
||||
user_performing_action: DATestUser,
|
||||
) -> DATestScimToken | None:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/enterprise-settings/scim/token",
|
||||
headers=user_performing_action.headers,
|
||||
timeout=60,
|
||||
)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return DATestScimToken(
|
||||
id=data["id"],
|
||||
name=data["name"],
|
||||
token_display=data["token_display"],
|
||||
is_active=data["is_active"],
|
||||
created_at=data["created_at"],
|
||||
last_used_at=data.get("last_used_at"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_scim_headers(raw_token: str) -> dict[str, str]:
|
||||
return {
|
||||
**GENERAL_HEADERS,
|
||||
"Authorization": f"Bearer {raw_token}",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def scim_get(
|
||||
path: str,
|
||||
raw_token: str,
|
||||
) -> requests.Response:
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=ScimTokenManager.get_scim_headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def scim_get_no_auth(path: str) -> requests.Response:
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=GENERAL_HEADERS,
|
||||
timeout=60,
|
||||
)
|
||||
@@ -42,6 +42,18 @@ class DATestPAT(BaseModel):
|
||||
last_used_at: str | None = None
|
||||
|
||||
|
||||
class DATestScimToken(BaseModel):
|
||||
"""SCIM bearer token model for testing."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
raw_token: str | None = None # Only present on initial creation
|
||||
token_display: str
|
||||
is_active: bool
|
||||
created_at: str
|
||||
last_used_at: str | None = None
|
||||
|
||||
|
||||
class DATestAPIKey(BaseModel):
|
||||
api_key_id: int
|
||||
api_key_display: str
|
||||
@@ -116,7 +128,7 @@ class DATestLLMProvider(BaseModel):
|
||||
name: str
|
||||
provider: str
|
||||
api_key: str
|
||||
default_model_name: str
|
||||
default_model_name: str | None = None
|
||||
is_public: bool
|
||||
is_auto_mode: bool = False
|
||||
groups: list[int]
|
||||
|
||||
@@ -42,12 +42,10 @@ def _create_provider_with_api(
|
||||
llm_provider_data = {
|
||||
"name": name,
|
||||
"provider": provider_type,
|
||||
"default_model_name": default_model,
|
||||
"api_key": "test-api-key-for-auto-mode-testing",
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
"custom_config": None,
|
||||
"fast_default_model_name": default_model,
|
||||
"is_public": True,
|
||||
"is_auto_mode": is_auto_mode,
|
||||
"groups": [],
|
||||
@@ -72,7 +70,7 @@ def _get_provider_by_id(admin_user: DATestUser, provider_id: int) -> dict:
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
for provider in response.json():
|
||||
for provider in response.json()["providers"]:
|
||||
if provider["id"] == provider_id:
|
||||
return provider
|
||||
raise ValueError(f"Provider with id {provider_id} not found")
|
||||
@@ -219,15 +217,6 @@ def test_auto_mode_provider_gets_synced_from_github_config(
|
||||
"is_visible"
|
||||
], "Outdated model should not be visible after sync"
|
||||
|
||||
# Verify default model was set from GitHub config
|
||||
expected_default = (
|
||||
default_model["name"] if isinstance(default_model, dict) else default_model
|
||||
)
|
||||
assert synced_provider["default_model_name"] == expected_default, (
|
||||
f"Default model should be {expected_default}, "
|
||||
f"got {synced_provider['default_model_name']}"
|
||||
)
|
||||
|
||||
|
||||
def test_manual_mode_provider_not_affected_by_auto_sync(
|
||||
reset: None, # noqa: ARG001
|
||||
@@ -273,7 +262,3 @@ def test_manual_mode_provider_not_affected_by_auto_sync(
|
||||
f"Manual mode provider models should not change. "
|
||||
f"Initial: {initial_models}, Current: {current_models}"
|
||||
)
|
||||
|
||||
assert (
|
||||
updated_provider["default_model_name"] == custom_model
|
||||
), f"Manual mode default model should remain {custom_model}"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -6,20 +6,21 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.llm import can_user_access_llm_provider
|
||||
from onyx.db.llm import fetch_user_group_ids
|
||||
from onyx.db.models import LLMModelFlow
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import LLMProvider as LLMProviderModel
|
||||
from onyx.db.models import LLMProvider__Persona
|
||||
from onyx.db.models import LLMProvider__UserGroup
|
||||
from onyx.db.models import ModelConfiguration
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.persona import PersonaManager
|
||||
@@ -41,24 +42,30 @@ def _create_llm_provider(
|
||||
is_public: bool,
|
||||
is_default: bool,
|
||||
) -> LLMProviderModel:
|
||||
provider = LLMProviderModel(
|
||||
name=name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None,
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
custom_config=None,
|
||||
default_model_name=default_model_name,
|
||||
deployment_name=None,
|
||||
is_public=is_public,
|
||||
# Use None instead of False to avoid unique constraint violation
|
||||
# The is_default_provider column has unique=True, so only one True and one False allowed
|
||||
is_default_provider=is_default if is_default else None,
|
||||
is_default_vision_provider=False,
|
||||
default_vision_model=None,
|
||||
_provider = upsert_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None,
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
custom_config=None,
|
||||
is_public=is_public,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name,
|
||||
is_visible=True,
|
||||
)
|
||||
],
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.add(provider)
|
||||
db_session.flush()
|
||||
if is_default:
|
||||
update_default_provider(_provider.id, default_model_name, db_session)
|
||||
|
||||
provider = db_session.get(LLMProviderModel, _provider.id)
|
||||
if not provider:
|
||||
raise ValueError(f"Provider {name} not found")
|
||||
return provider
|
||||
|
||||
|
||||
@@ -270,24 +277,6 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
|
||||
provider_name=restricted_provider.name,
|
||||
)
|
||||
|
||||
# Set up ModelConfiguration + LLMModelFlow so get_default_llm() can
|
||||
# resolve the default provider when the fallback path is triggered.
|
||||
default_model_config = ModelConfiguration(
|
||||
llm_provider_id=default_provider.id,
|
||||
name=default_provider.default_model_name,
|
||||
is_visible=True,
|
||||
)
|
||||
db_session.add(default_model_config)
|
||||
db_session.flush()
|
||||
db_session.add(
|
||||
LLMModelFlow(
|
||||
model_configuration_id=default_model_config.id,
|
||||
llm_model_flow_type=LLMModelFlowType.CHAT,
|
||||
is_default=True,
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
access_group = UserGroup(name="persona-group")
|
||||
db_session.add(access_group)
|
||||
db_session.flush()
|
||||
@@ -321,13 +310,19 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
|
||||
persona=persona,
|
||||
user=admin_model,
|
||||
)
|
||||
assert allowed_llm.config.model_name == restricted_provider.default_model_name
|
||||
assert (
|
||||
allowed_llm.config.model_name
|
||||
== restricted_provider.model_configurations[0].name
|
||||
)
|
||||
|
||||
fallback_llm = get_llm_for_persona(
|
||||
persona=persona,
|
||||
user=basic_model,
|
||||
)
|
||||
assert fallback_llm.config.model_name == default_provider.default_model_name
|
||||
assert (
|
||||
fallback_llm.config.model_name
|
||||
== default_provider.model_configurations[0].name
|
||||
)
|
||||
|
||||
|
||||
def test_list_llm_provider_basics_excludes_non_public_unrestricted(
|
||||
@@ -346,6 +341,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
|
||||
name="public-provider",
|
||||
is_public=True,
|
||||
set_as_default=True,
|
||||
default_model_name="gpt-4o",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
@@ -365,7 +361,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
|
||||
headers=basic_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
providers = response.json()
|
||||
providers = response.json()["providers"]
|
||||
provider_names = [p["name"] for p in providers]
|
||||
|
||||
# Public provider should be visible
|
||||
@@ -380,7 +376,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert admin_response.status_code == 200
|
||||
admin_providers = admin_response.json()
|
||||
admin_providers = admin_response.json()["providers"]
|
||||
admin_provider_names = [p["name"] for p in admin_providers]
|
||||
|
||||
assert public_provider.name in admin_provider_names
|
||||
@@ -396,6 +392,7 @@ def test_provider_delete_clears_persona_references(reset: None) -> None: # noqa
|
||||
name="default-provider",
|
||||
is_public=True,
|
||||
set_as_default=True,
|
||||
default_model_name="gpt-4o",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
|
||||
@@ -107,7 +107,7 @@ def test_authorized_persona_access_returns_filtered_providers(
|
||||
|
||||
# Should succeed
|
||||
assert response.status_code == 200
|
||||
providers = response.json()
|
||||
providers = response.json()["providers"]
|
||||
|
||||
# Should include the restricted provider since basic_user can access the persona
|
||||
provider_names = [p["name"] for p in providers]
|
||||
@@ -140,7 +140,7 @@ def test_persona_id_zero_applies_rbac(
|
||||
|
||||
# Should succeed (persona_id=0 refers to default persona, which is public)
|
||||
assert response.status_code == 200
|
||||
providers = response.json()
|
||||
providers = response.json()["providers"]
|
||||
|
||||
# Should NOT include the restricted provider since basic_user is not in group2
|
||||
provider_names = [p["name"] for p in providers]
|
||||
@@ -182,7 +182,7 @@ def test_admin_can_query_any_persona(
|
||||
|
||||
# Should succeed - admins can access any persona
|
||||
assert response.status_code == 200
|
||||
providers = response.json()
|
||||
providers = response.json()["providers"]
|
||||
|
||||
# Should include the restricted provider
|
||||
provider_names = [p["name"] for p in providers]
|
||||
@@ -223,7 +223,7 @@ def test_public_persona_accessible_to_all(
|
||||
|
||||
# Should succeed
|
||||
assert response.status_code == 200
|
||||
providers = response.json()
|
||||
providers = response.json()["providers"]
|
||||
|
||||
# Should return the public provider
|
||||
assert len(providers) > 0
|
||||
|
||||
@@ -23,6 +23,8 @@ _ENV_PROVIDER = "NIGHTLY_LLM_PROVIDER"
|
||||
_ENV_MODELS = "NIGHTLY_LLM_MODELS"
|
||||
_ENV_API_KEY = "NIGHTLY_LLM_API_KEY"
|
||||
_ENV_API_BASE = "NIGHTLY_LLM_API_BASE"
|
||||
_ENV_API_VERSION = "NIGHTLY_LLM_API_VERSION"
|
||||
_ENV_DEPLOYMENT_NAME = "NIGHTLY_LLM_DEPLOYMENT_NAME"
|
||||
_ENV_CUSTOM_CONFIG_JSON = "NIGHTLY_LLM_CUSTOM_CONFIG_JSON"
|
||||
_ENV_STRICT = "NIGHTLY_LLM_STRICT"
|
||||
|
||||
@@ -34,6 +36,8 @@ class NightlyProviderConfig(BaseModel):
|
||||
model_names: list[str]
|
||||
api_key: str | None
|
||||
api_base: str | None
|
||||
api_version: str | None
|
||||
deployment_name: str | None
|
||||
custom_config: dict[str, str] | None
|
||||
strict: bool
|
||||
|
||||
@@ -45,17 +49,29 @@ def _env_true(env_var: str, default: bool = False) -> bool:
|
||||
return value.strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _split_csv_env(env_var: str) -> list[str]:
|
||||
return [
|
||||
part.strip() for part in os.environ.get(env_var, "").split(",") if part.strip()
|
||||
]
|
||||
def _parse_models_env(env_var: str) -> list[str]:
|
||||
raw_value = os.environ.get(env_var, "").strip()
|
||||
if not raw_value:
|
||||
return []
|
||||
|
||||
try:
|
||||
parsed_json = json.loads(raw_value)
|
||||
except json.JSONDecodeError:
|
||||
parsed_json = None
|
||||
|
||||
if isinstance(parsed_json, list):
|
||||
return [str(model).strip() for model in parsed_json if str(model).strip()]
|
||||
|
||||
return [part.strip() for part in raw_value.split(",") if part.strip()]
|
||||
|
||||
|
||||
def _load_provider_config() -> NightlyProviderConfig:
|
||||
provider = os.environ.get(_ENV_PROVIDER, "").strip().lower()
|
||||
model_names = _split_csv_env(_ENV_MODELS)
|
||||
model_names = _parse_models_env(_ENV_MODELS)
|
||||
api_key = os.environ.get(_ENV_API_KEY) or None
|
||||
api_base = os.environ.get(_ENV_API_BASE) or None
|
||||
api_version = os.environ.get(_ENV_API_VERSION) or None
|
||||
deployment_name = os.environ.get(_ENV_DEPLOYMENT_NAME) or None
|
||||
strict = _env_true(_ENV_STRICT, default=False)
|
||||
|
||||
custom_config: dict[str, str] | None = None
|
||||
@@ -74,6 +90,8 @@ def _load_provider_config() -> NightlyProviderConfig:
|
||||
model_names=model_names,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
deployment_name=deployment_name,
|
||||
custom_config=custom_config,
|
||||
strict=strict,
|
||||
)
|
||||
@@ -95,10 +113,15 @@ def _validate_provider_config(config: NightlyProviderConfig) -> None:
|
||||
message=f"{_ENV_MODELS} must include at least one model",
|
||||
)
|
||||
|
||||
if config.provider != "ollama_chat" and not config.api_key:
|
||||
if config.provider != "ollama_chat" and not (
|
||||
config.api_key or config.custom_config
|
||||
):
|
||||
_skip_or_fail(
|
||||
strict=config.strict,
|
||||
message=(f"{_ENV_API_KEY} is required for provider '{config.provider}'"),
|
||||
message=(
|
||||
f"{_ENV_API_KEY} or {_ENV_CUSTOM_CONFIG_JSON} is required for "
|
||||
f"provider '{config.provider}'"
|
||||
),
|
||||
)
|
||||
|
||||
if config.provider == "ollama_chat" and not (
|
||||
@@ -109,6 +132,22 @@ def _validate_provider_config(config: NightlyProviderConfig) -> None:
|
||||
message=(f"{_ENV_API_BASE} is required for provider '{config.provider}'"),
|
||||
)
|
||||
|
||||
if config.provider == "azure":
|
||||
if not config.api_base:
|
||||
_skip_or_fail(
|
||||
strict=config.strict,
|
||||
message=(
|
||||
f"{_ENV_API_BASE} is required for provider '{config.provider}'"
|
||||
),
|
||||
)
|
||||
if not config.api_version:
|
||||
_skip_or_fail(
|
||||
strict=config.strict,
|
||||
message=(
|
||||
f"{_ENV_API_VERSION} is required for provider '{config.provider}'"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _assert_integration_mode_enabled() -> None:
|
||||
assert (
|
||||
@@ -147,6 +186,8 @@ def _create_provider_payload(
|
||||
model_name: str,
|
||||
api_key: str | None,
|
||||
api_base: str | None,
|
||||
api_version: str | None,
|
||||
deployment_name: str | None,
|
||||
custom_config: dict[str, str] | None,
|
||||
) -> dict:
|
||||
return {
|
||||
@@ -154,6 +195,8 @@ def _create_provider_payload(
|
||||
"provider": provider,
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
"api_version": api_version,
|
||||
"deployment_name": deployment_name,
|
||||
"custom_config": custom_config,
|
||||
"default_model_name": model_name,
|
||||
"is_public": True,
|
||||
@@ -255,6 +298,8 @@ def _create_and_test_provider_for_model(
|
||||
model_name=model_name,
|
||||
api_key=config.api_key,
|
||||
api_base=resolved_api_base,
|
||||
api_version=config.api_version,
|
||||
deployment_name=config.deployment_name,
|
||||
custom_config=config.custom_config,
|
||||
)
|
||||
|
||||
@@ -313,10 +358,21 @@ def test_nightly_provider_chat_workflow(admin_user: DATestUser) -> None:
|
||||
_seed_connector_for_search_tool(admin_user)
|
||||
search_tool_id = _get_internal_search_tool_id(admin_user)
|
||||
|
||||
failures: list[str] = []
|
||||
for model_name in config.model_names:
|
||||
_create_and_test_provider_for_model(
|
||||
admin_user=admin_user,
|
||||
config=config,
|
||||
model_name=model_name,
|
||||
search_tool_id=search_tool_id,
|
||||
)
|
||||
try:
|
||||
_create_and_test_provider_for_model(
|
||||
admin_user=admin_user,
|
||||
config=config,
|
||||
model_name=model_name,
|
||||
search_tool_id=search_tool_id,
|
||||
)
|
||||
except BaseException as exc:
|
||||
if isinstance(exc, (KeyboardInterrupt, SystemExit)):
|
||||
raise
|
||||
failures.append(
|
||||
f"provider={config.provider} model={model_name} error={type(exc).__name__}: {exc}"
|
||||
)
|
||||
|
||||
if failures:
|
||||
pytest.fail("Nightly provider chat failures:\n" + "\n".join(failures))
|
||||
|
||||
@@ -72,6 +72,9 @@ def test_cold_startup_default_assistant() -> None:
|
||||
assert (
|
||||
"read_file" in tool_names
|
||||
), "Default assistant should have FileReaderTool attached"
|
||||
assert (
|
||||
"python" in tool_names
|
||||
), "Default assistant should have PythonTool attached"
|
||||
|
||||
# Also verify by display names for clarity
|
||||
assert (
|
||||
@@ -86,8 +89,11 @@ def test_cold_startup_default_assistant() -> None:
|
||||
assert (
|
||||
"File Reader" in tool_display_names
|
||||
), "Default assistant should have File Reader tool"
|
||||
|
||||
# Should have exactly 5 tools
|
||||
assert (
|
||||
len(tool_associations) == 5
|
||||
), f"Default assistant should have exactly 5 tools attached, got {len(tool_associations)}"
|
||||
"Code Interpreter" in tool_display_names
|
||||
), "Default assistant should have Code Interpreter tool"
|
||||
|
||||
# Should have exactly 6 tools
|
||||
assert (
|
||||
len(tool_associations) == 6
|
||||
), f"Default assistant should have exactly 6 tools attached, got {len(tool_associations)}"
|
||||
|
||||
@@ -0,0 +1,318 @@
|
||||
"""
|
||||
Integration tests for the unified persona file context flow.
|
||||
|
||||
End-to-end tests that verify:
|
||||
1. Files can be uploaded and attached to a persona via API.
|
||||
2. The persona correctly reports its attached files.
|
||||
3. A chat session with a file-bearing persona processes without error.
|
||||
4. Precedence: custom persona files take priority over project files when
|
||||
the chat session is inside a project.
|
||||
|
||||
These tests run against a real Onyx deployment (all services running).
|
||||
File processing is asynchronous, so we poll the file status endpoint
|
||||
until files reach COMPLETED before chatting.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import MAX_DELAY
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.file import FileManager
|
||||
from tests.integration.common_utils.managers.persona import PersonaManager
|
||||
from tests.integration.common_utils.managers.project import ProjectManager
|
||||
from tests.integration.common_utils.test_file_utils import create_test_text_file
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
FILE_PROCESSING_POLL_INTERVAL = 2
|
||||
|
||||
|
||||
def _poll_file_statuses(
|
||||
user_file_ids: list[str],
|
||||
user: DATestUser,
|
||||
target_status: UserFileStatus = UserFileStatus.COMPLETED,
|
||||
timeout: int = MAX_DELAY,
|
||||
) -> None:
|
||||
"""Block until all files reach the target status or timeout expires."""
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/user/projects/file/statuses",
|
||||
json={"file_ids": user_file_ids},
|
||||
headers=user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
statuses = response.json()
|
||||
if all(f["status"] == target_status.value for f in statuses):
|
||||
return
|
||||
time.sleep(FILE_PROCESSING_POLL_INTERVAL)
|
||||
raise TimeoutError(
|
||||
f"Files {user_file_ids} did not reach {target_status.value} "
|
||||
f"within {timeout}s"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_persona_with_files_chat_no_error(
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Upload files, attach them to a persona, wait for processing,
|
||||
then send a chat message. Verify no error is returned."""
|
||||
|
||||
# Upload files (creates UserFile records)
|
||||
text_file = create_test_text_file(
|
||||
"The secret project codename is NIGHTINGALE. "
|
||||
"It was started in 2024 by the Advanced Research division."
|
||||
)
|
||||
file_descriptors, error = FileManager.upload_files(
|
||||
files=[("nightingale_brief.txt", text_file)],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert not error, f"File upload failed: {error}"
|
||||
assert len(file_descriptors) == 1
|
||||
|
||||
user_file_id = file_descriptors[0]["user_file_id"]
|
||||
assert user_file_id is not None
|
||||
|
||||
# Wait for file processing
|
||||
_poll_file_statuses([user_file_id], admin_user, timeout=120)
|
||||
|
||||
# Create persona with the file attached
|
||||
persona = PersonaManager.create(
|
||||
user_performing_action=admin_user,
|
||||
name="Nightingale Agent",
|
||||
description="Agent with secret file",
|
||||
system_prompt="You are a helpful assistant with access to uploaded files.",
|
||||
user_file_ids=[user_file_id],
|
||||
)
|
||||
|
||||
# Verify persona has the file
|
||||
persona_snapshots = PersonaManager.get_one(persona.id, admin_user)
|
||||
assert len(persona_snapshots) == 1
|
||||
assert user_file_id in persona_snapshots[0].user_file_ids
|
||||
|
||||
# Chat with the persona
|
||||
chat_session = ChatSessionManager.create(
|
||||
persona_id=persona.id,
|
||||
description="Test persona file context",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session.id,
|
||||
message="What is the secret project codename?",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert response.error is None, f"Chat should succeed, got error: {response.error}"
|
||||
assert len(response.full_message) > 0, "Response should not be empty"
|
||||
|
||||
|
||||
def test_persona_without_files_still_works(
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""A persona with no attached files should still chat normally."""
|
||||
persona = PersonaManager.create(
|
||||
user_performing_action=admin_user,
|
||||
name="Blank Agent",
|
||||
description="No files attached",
|
||||
system_prompt="You are a helpful assistant.",
|
||||
)
|
||||
|
||||
chat_session = ChatSessionManager.create(
|
||||
persona_id=persona.id,
|
||||
description="Test blank persona",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session.id,
|
||||
message="Hello, how are you?",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert response.error is None
|
||||
assert len(response.full_message) > 0
|
||||
|
||||
|
||||
def test_persona_files_override_project_files(
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""When a custom persona (with its own files) is used inside a project,
|
||||
the persona's files take precedence — the project's files are invisible.
|
||||
|
||||
We verify this by putting different content in project vs persona files
|
||||
and checking which content the model responds with."""
|
||||
|
||||
# Upload persona file
|
||||
persona_file = create_test_text_file("The persona's secret word is ALBATROSS.")
|
||||
persona_fds, err1 = FileManager.upload_files(
|
||||
files=[("persona_secret.txt", persona_file)],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert not err1
|
||||
persona_user_file_id = persona_fds[0]["user_file_id"]
|
||||
assert persona_user_file_id is not None
|
||||
# Create a project and upload project files
|
||||
project = ProjectManager.create(
|
||||
name="Precedence Test Project",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
project_files = [
|
||||
("project_secret.txt", b"The project's secret word is FLAMINGO."),
|
||||
]
|
||||
project_upload_result = ProjectManager.upload_files(
|
||||
project_id=project.id,
|
||||
files=project_files,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert len(project_upload_result.user_files) == 1
|
||||
project_user_file_id = str(project_upload_result.user_files[0].id)
|
||||
|
||||
# Wait for both persona and project file processing
|
||||
_poll_file_statuses([persona_user_file_id], admin_user, timeout=120)
|
||||
_poll_file_statuses([project_user_file_id], admin_user, timeout=120)
|
||||
|
||||
# Create persona with persona file
|
||||
persona = PersonaManager.create(
|
||||
user_performing_action=admin_user,
|
||||
name="Override Agent",
|
||||
description="Persona with its own files",
|
||||
system_prompt="You are a helpful assistant. Answer using the files.",
|
||||
user_file_ids=[persona_user_file_id],
|
||||
)
|
||||
|
||||
# Create chat session inside the project but using the custom persona
|
||||
chat_session = ChatSessionManager.create(
|
||||
persona_id=persona.id,
|
||||
project_id=project.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session.id,
|
||||
message="What is the secret word?",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert response.error is None, f"Chat should succeed, got error: {response.error}"
|
||||
# The persona's file should be what the model sees, not the project's
|
||||
message_lower = response.full_message.lower()
|
||||
assert "albatross" in message_lower, (
|
||||
"Response should reference the persona file's secret word (ALBATROSS), "
|
||||
f"but got: {response.full_message}"
|
||||
)
|
||||
|
||||
|
||||
def test_default_persona_in_project_uses_project_files(
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""When the default persona (id=0) is used inside a project,
|
||||
the project's files should be used for context."""
|
||||
project = ProjectManager.create(
|
||||
name="Default Persona Project",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
project_files = [
|
||||
("project_info.txt", b"The project mascot is a PANGOLIN."),
|
||||
]
|
||||
upload_result = ProjectManager.upload_files(
|
||||
project_id=project.id,
|
||||
files=project_files,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert len(upload_result.user_files) == 1
|
||||
|
||||
# Wait for project file processing
|
||||
project_file_id = str(upload_result.user_files[0].id)
|
||||
_poll_file_statuses([project_file_id], admin_user, timeout=120)
|
||||
|
||||
# Create chat session inside project using default persona (id=0)
|
||||
chat_session = ChatSessionManager.create(
|
||||
persona_id=0,
|
||||
project_id=project.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session.id,
|
||||
message="What is the project mascot?",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert response.error is None
|
||||
assert "pangolin" in response.full_message.lower(), (
|
||||
"Response should reference the project file content (PANGOLIN), "
|
||||
f"but got: {response.full_message}"
|
||||
)
|
||||
|
||||
|
||||
def test_custom_persona_no_files_in_project_ignores_project(
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""A custom persona with NO files, used inside a project with files,
|
||||
should NOT see the project's files. The project is purely organizational.
|
||||
|
||||
We verify by asking about content only in the project file and checking
|
||||
the model does NOT reference it."""
|
||||
|
||||
project = ProjectManager.create(
|
||||
name="Ignored Project",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
project_upload_result = ProjectManager.upload_files(
|
||||
project_id=project.id,
|
||||
files=[("project_only.txt", b"The project secret is CAPYBARA.")],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert len(project_upload_result.user_files) == 1
|
||||
project_user_file_id = str(project_upload_result.user_files[0].id)
|
||||
|
||||
# Wait for project file processing
|
||||
_poll_file_statuses([project_user_file_id], admin_user, timeout=120)
|
||||
|
||||
# Custom persona with no files
|
||||
persona = PersonaManager.create(
|
||||
user_performing_action=admin_user,
|
||||
name="No Files Agent",
|
||||
description="No files, project is irrelevant",
|
||||
system_prompt=(
|
||||
"You are a helpful assistant. If you do not have information "
|
||||
"to answer a question, say 'I do not have that information.'"
|
||||
),
|
||||
)
|
||||
|
||||
chat_session = ChatSessionManager.create(
|
||||
persona_id=persona.id,
|
||||
project_id=project.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session.id,
|
||||
message="What is the project secret?",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert response.error is None
|
||||
assert len(response.full_message) > 0
|
||||
assert "capybara" not in response.full_message.lower(), (
|
||||
"Response should NOT reference the project file content (CAPYBARA) "
|
||||
"because the custom persona has no files and should not inherit "
|
||||
f"project files, but got: {response.full_message}"
|
||||
)
|
||||
166
backend/tests/integration/tests/scim/test_scim_tokens.py
Normal file
166
backend/tests/integration/tests/scim/test_scim_tokens.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""Integration tests for SCIM token management.
|
||||
|
||||
Covers the admin token API and SCIM bearer-token authentication:
|
||||
1. Token lifecycle: create, retrieve metadata, use for SCIM requests
|
||||
2. Token rotation: creating a new token revokes previous tokens
|
||||
3. Revoked tokens are rejected by SCIM endpoints
|
||||
4. Non-admin users cannot manage SCIM tokens
|
||||
5. SCIM requests without a token are rejected
|
||||
6. Service discovery endpoints work without authentication
|
||||
7. last_used_at is updated after a SCIM request
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.scim_token import ScimTokenManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_scim_token_lifecycle(admin_user: DATestUser) -> None:
|
||||
"""Create token → retrieve metadata → use for SCIM request."""
|
||||
token = ScimTokenManager.create(
|
||||
name="Test SCIM Token",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert token.raw_token is not None
|
||||
assert token.raw_token.startswith("onyx_scim_")
|
||||
assert token.is_active is True
|
||||
assert "****" in token.token_display
|
||||
|
||||
# GET returns the same metadata but raw_token is None because the
|
||||
# server only reveals the raw token once at creation time (it stores
|
||||
# only the SHA-256 hash).
|
||||
active = ScimTokenManager.get_active(user_performing_action=admin_user)
|
||||
assert active == token.model_copy(update={"raw_token": None})
|
||||
|
||||
# Token works for SCIM requests
|
||||
response = ScimTokenManager.scim_get("/Users", token.raw_token)
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert "Resources" in body
|
||||
assert body["totalResults"] >= 0
|
||||
|
||||
|
||||
def test_scim_token_rotation_revokes_previous(admin_user: DATestUser) -> None:
|
||||
"""Creating a new token automatically revokes the previous one."""
|
||||
first = ScimTokenManager.create(
|
||||
name="First Token",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert first.raw_token is not None
|
||||
|
||||
response = ScimTokenManager.scim_get("/Users", first.raw_token)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Create second token — should revoke first
|
||||
second = ScimTokenManager.create(
|
||||
name="Second Token",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert second.raw_token is not None
|
||||
|
||||
# Active token should now be the second one
|
||||
active = ScimTokenManager.get_active(user_performing_action=admin_user)
|
||||
assert active == second.model_copy(update={"raw_token": None})
|
||||
|
||||
# First token rejected, second works
|
||||
assert ScimTokenManager.scim_get("/Users", first.raw_token).status_code == 401
|
||||
assert ScimTokenManager.scim_get("/Users", second.raw_token).status_code == 200
|
||||
|
||||
|
||||
def test_scim_request_without_token_rejected(
|
||||
admin_user: DATestUser, # noqa: ARG001
|
||||
) -> None:
|
||||
"""SCIM endpoints reject requests with no Authorization header."""
|
||||
assert ScimTokenManager.scim_get_no_auth("/Users").status_code == 401
|
||||
|
||||
|
||||
def test_scim_request_with_bad_token_rejected(
|
||||
admin_user: DATestUser, # noqa: ARG001
|
||||
) -> None:
|
||||
"""SCIM endpoints reject requests with an invalid token."""
|
||||
assert (
|
||||
ScimTokenManager.scim_get("/Users", "onyx_scim_bogus_token_value").status_code
|
||||
== 401
|
||||
)
|
||||
|
||||
|
||||
def test_non_admin_cannot_create_token(
|
||||
admin_user: DATestUser, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Non-admin users get 403 when trying to create a SCIM token."""
|
||||
basic_user = UserManager.create(name="scim_basic_user")
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/enterprise-settings/scim/token",
|
||||
json={"name": "Should Fail"},
|
||||
headers=basic_user.headers,
|
||||
timeout=60,
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_non_admin_cannot_get_token(
|
||||
admin_user: DATestUser, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Non-admin users get 403 when trying to retrieve SCIM token metadata."""
|
||||
basic_user = UserManager.create(name="scim_basic_user2")
|
||||
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/enterprise-settings/scim/token",
|
||||
headers=basic_user.headers,
|
||||
timeout=60,
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_no_active_token_returns_404(new_admin_user: DATestUser) -> None:
|
||||
"""GET active token returns 404 when no token exists."""
|
||||
# new_admin_user depends on the reset fixture, ensuring a clean DB
|
||||
# with no active SCIM tokens.
|
||||
active = ScimTokenManager.get_active(user_performing_action=new_admin_user)
|
||||
assert active is None
|
||||
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/enterprise-settings/scim/token",
|
||||
headers=new_admin_user.headers,
|
||||
timeout=60,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_service_discovery_no_auth_required(
|
||||
admin_user: DATestUser, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Service discovery endpoints work without any authentication."""
|
||||
for path in ["/ServiceProviderConfig", "/ResourceTypes", "/Schemas"]:
|
||||
response = ScimTokenManager.scim_get_no_auth(path)
|
||||
assert response.status_code == 200, f"{path} returned {response.status_code}"
|
||||
|
||||
|
||||
def test_last_used_at_updated_after_scim_request(
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""last_used_at timestamp is updated after using the token."""
|
||||
token = ScimTokenManager.create(
|
||||
name="Last Used Token",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert token.raw_token is not None
|
||||
|
||||
active = ScimTokenManager.get_active(user_performing_action=admin_user)
|
||||
assert active is not None
|
||||
assert active.last_used_at is None
|
||||
|
||||
# Make a SCIM request, then verify last_used_at is set
|
||||
assert ScimTokenManager.scim_get("/Users", token.raw_token).status_code == 200
|
||||
time.sleep(0.5)
|
||||
|
||||
active_after = ScimTokenManager.get_active(user_performing_action=admin_user)
|
||||
assert active_after is not None
|
||||
assert active_after.last_used_at is not None
|
||||
@@ -9,6 +9,19 @@ from redis.exceptions import RedisError
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from onyx.server.settings.models import Settings
|
||||
|
||||
# Fields we assert on across all tests
|
||||
_ASSERT_FIELDS = {
|
||||
"application_status",
|
||||
"ee_features_enabled",
|
||||
"seat_count",
|
||||
"used_seats",
|
||||
}
|
||||
|
||||
|
||||
def _pick(settings: Settings) -> dict:
|
||||
"""Extract only the fields under test from a Settings object."""
|
||||
return settings.model_dump(include=_ASSERT_FIELDS)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_settings() -> Settings:
|
||||
@@ -27,17 +40,17 @@ class TestApplyLicenseStatusToSettings:
|
||||
def test_enforcement_disabled_enables_ee_features(
|
||||
self, base_settings: Settings
|
||||
) -> None:
|
||||
"""When LICENSE_ENFORCEMENT_ENABLED=False, EE features are enabled.
|
||||
|
||||
If we're running the EE apply function, EE code was loaded via
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES, so features should be on.
|
||||
"""
|
||||
"""When LICENSE_ENFORCEMENT_ENABLED=False, EE features are enabled."""
|
||||
from ee.onyx.server.settings.api import apply_license_status_to_settings
|
||||
|
||||
assert base_settings.ee_features_enabled is False
|
||||
result = apply_license_status_to_settings(base_settings)
|
||||
assert result.application_status == ApplicationStatus.ACTIVE
|
||||
assert result.ee_features_enabled is True
|
||||
assert _pick(result) == {
|
||||
"application_status": ApplicationStatus.ACTIVE,
|
||||
"ee_features_enabled": True,
|
||||
"seat_count": None,
|
||||
"used_seats": None,
|
||||
}
|
||||
|
||||
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
|
||||
@patch("ee.onyx.server.settings.api.MULTI_TENANT", True)
|
||||
@@ -46,13 +59,60 @@ class TestApplyLicenseStatusToSettings:
|
||||
from ee.onyx.server.settings.api import apply_license_status_to_settings
|
||||
|
||||
result = apply_license_status_to_settings(base_settings)
|
||||
assert result.ee_features_enabled is True
|
||||
assert _pick(result) == {
|
||||
"application_status": ApplicationStatus.ACTIVE,
|
||||
"ee_features_enabled": True,
|
||||
"seat_count": None,
|
||||
"used_seats": None,
|
||||
}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"license_status,expected_app_status,expected_ee_enabled",
|
||||
"license_status,used_seats,seats,expected",
|
||||
[
|
||||
(ApplicationStatus.GATED_ACCESS, ApplicationStatus.GATED_ACCESS, False),
|
||||
(ApplicationStatus.ACTIVE, ApplicationStatus.ACTIVE, True),
|
||||
(
|
||||
ApplicationStatus.GATED_ACCESS,
|
||||
3,
|
||||
10,
|
||||
{
|
||||
"application_status": ApplicationStatus.GATED_ACCESS,
|
||||
"ee_features_enabled": False,
|
||||
"seat_count": None,
|
||||
"used_seats": None,
|
||||
},
|
||||
),
|
||||
(
|
||||
ApplicationStatus.ACTIVE,
|
||||
3,
|
||||
10,
|
||||
{
|
||||
"application_status": ApplicationStatus.ACTIVE,
|
||||
"ee_features_enabled": True,
|
||||
"seat_count": None,
|
||||
"used_seats": None,
|
||||
},
|
||||
),
|
||||
(
|
||||
ApplicationStatus.ACTIVE,
|
||||
10,
|
||||
10,
|
||||
{
|
||||
"application_status": ApplicationStatus.ACTIVE,
|
||||
"ee_features_enabled": True,
|
||||
"seat_count": None,
|
||||
"used_seats": None,
|
||||
},
|
||||
),
|
||||
(
|
||||
ApplicationStatus.GRACE_PERIOD,
|
||||
3,
|
||||
10,
|
||||
{
|
||||
"application_status": ApplicationStatus.ACTIVE,
|
||||
"ee_features_enabled": True,
|
||||
"seat_count": None,
|
||||
"used_seats": None,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
|
||||
@@ -63,25 +123,80 @@ class TestApplyLicenseStatusToSettings:
|
||||
self,
|
||||
mock_get_metadata: MagicMock,
|
||||
mock_get_tenant: MagicMock,
|
||||
license_status: ApplicationStatus | None,
|
||||
expected_app_status: ApplicationStatus,
|
||||
expected_ee_enabled: bool,
|
||||
license_status: ApplicationStatus,
|
||||
used_seats: int,
|
||||
seats: int,
|
||||
expected: dict,
|
||||
base_settings: Settings,
|
||||
) -> None:
|
||||
"""Self-hosted: license status controls both application_status and ee_features_enabled."""
|
||||
from ee.onyx.server.settings.api import apply_license_status_to_settings
|
||||
|
||||
mock_get_tenant.return_value = "test_tenant"
|
||||
if license_status is None:
|
||||
mock_get_metadata.return_value = None
|
||||
else:
|
||||
mock_metadata = MagicMock()
|
||||
mock_metadata.status = license_status
|
||||
mock_get_metadata.return_value = mock_metadata
|
||||
mock_metadata = MagicMock()
|
||||
mock_metadata.status = license_status
|
||||
mock_metadata.used_seats = used_seats
|
||||
mock_metadata.seats = seats
|
||||
mock_get_metadata.return_value = mock_metadata
|
||||
|
||||
result = apply_license_status_to_settings(base_settings)
|
||||
assert result.application_status == expected_app_status
|
||||
assert result.ee_features_enabled is expected_ee_enabled
|
||||
assert _pick(result) == expected
|
||||
|
||||
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
|
||||
@patch("ee.onyx.server.settings.api.MULTI_TENANT", False)
|
||||
@patch("ee.onyx.server.settings.api.get_current_tenant_id")
|
||||
@patch("ee.onyx.server.settings.api.get_cached_license_metadata")
|
||||
def test_seat_limit_exceeded_sets_status_and_counts(
|
||||
self,
|
||||
mock_get_metadata: MagicMock,
|
||||
mock_get_tenant: MagicMock,
|
||||
base_settings: Settings,
|
||||
) -> None:
|
||||
"""Seat limit exceeded sets SEAT_LIMIT_EXCEEDED with counts, keeps EE enabled."""
|
||||
from ee.onyx.server.settings.api import apply_license_status_to_settings
|
||||
|
||||
mock_get_tenant.return_value = "test_tenant"
|
||||
mock_metadata = MagicMock()
|
||||
mock_metadata.status = ApplicationStatus.ACTIVE
|
||||
mock_metadata.used_seats = 15
|
||||
mock_metadata.seats = 10
|
||||
mock_get_metadata.return_value = mock_metadata
|
||||
|
||||
result = apply_license_status_to_settings(base_settings)
|
||||
assert _pick(result) == {
|
||||
"application_status": ApplicationStatus.SEAT_LIMIT_EXCEEDED,
|
||||
"ee_features_enabled": True,
|
||||
"seat_count": 10,
|
||||
"used_seats": 15,
|
||||
}
|
||||
|
||||
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
|
||||
@patch("ee.onyx.server.settings.api.MULTI_TENANT", False)
|
||||
@patch("ee.onyx.server.settings.api.get_current_tenant_id")
|
||||
@patch("ee.onyx.server.settings.api.get_cached_license_metadata")
|
||||
def test_expired_license_takes_precedence_over_seat_limit(
|
||||
self,
|
||||
mock_get_metadata: MagicMock,
|
||||
mock_get_tenant: MagicMock,
|
||||
base_settings: Settings,
|
||||
) -> None:
|
||||
"""Expired license (GATED_ACCESS) takes precedence over seat limit exceeded."""
|
||||
from ee.onyx.server.settings.api import apply_license_status_to_settings
|
||||
|
||||
mock_get_tenant.return_value = "test_tenant"
|
||||
mock_metadata = MagicMock()
|
||||
mock_metadata.status = ApplicationStatus.GATED_ACCESS
|
||||
mock_metadata.used_seats = 15
|
||||
mock_metadata.seats = 10
|
||||
mock_get_metadata.return_value = mock_metadata
|
||||
|
||||
result = apply_license_status_to_settings(base_settings)
|
||||
assert _pick(result) == {
|
||||
"application_status": ApplicationStatus.GATED_ACCESS,
|
||||
"ee_features_enabled": False,
|
||||
"seat_count": None,
|
||||
"used_seats": None,
|
||||
}
|
||||
|
||||
@patch("ee.onyx.server.settings.api.ENTERPRISE_EDITION_ENABLED", True)
|
||||
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
|
||||
@@ -105,8 +220,12 @@ class TestApplyLicenseStatusToSettings:
|
||||
mock_get_metadata.return_value = None
|
||||
|
||||
result = apply_license_status_to_settings(base_settings)
|
||||
assert result.application_status == ApplicationStatus.GATED_ACCESS
|
||||
assert result.ee_features_enabled is False
|
||||
assert _pick(result) == {
|
||||
"application_status": ApplicationStatus.GATED_ACCESS,
|
||||
"ee_features_enabled": False,
|
||||
"seat_count": None,
|
||||
"used_seats": None,
|
||||
}
|
||||
|
||||
@patch("ee.onyx.server.settings.api.ENTERPRISE_EDITION_ENABLED", False)
|
||||
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
|
||||
@@ -130,8 +249,12 @@ class TestApplyLicenseStatusToSettings:
|
||||
mock_get_metadata.return_value = None
|
||||
|
||||
result = apply_license_status_to_settings(base_settings)
|
||||
assert result.application_status == ApplicationStatus.ACTIVE
|
||||
assert result.ee_features_enabled is False
|
||||
assert _pick(result) == {
|
||||
"application_status": ApplicationStatus.ACTIVE,
|
||||
"ee_features_enabled": False,
|
||||
"seat_count": None,
|
||||
"used_seats": None,
|
||||
}
|
||||
|
||||
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
|
||||
@patch("ee.onyx.server.settings.api.MULTI_TENANT", False)
|
||||
@@ -150,8 +273,12 @@ class TestApplyLicenseStatusToSettings:
|
||||
mock_get_metadata.side_effect = RedisError("Connection failed")
|
||||
|
||||
result = apply_license_status_to_settings(base_settings)
|
||||
assert result.application_status == ApplicationStatus.ACTIVE
|
||||
assert result.ee_features_enabled is False
|
||||
assert _pick(result) == {
|
||||
"application_status": ApplicationStatus.ACTIVE,
|
||||
"ee_features_enabled": False,
|
||||
"seat_count": None,
|
||||
"used_seats": None,
|
||||
}
|
||||
|
||||
|
||||
class TestSettingsDefaultEEDisabled:
|
||||
|
||||
426
backend/tests/unit/onyx/chat/test_context_files.py
Normal file
426
backend/tests/unit/onyx/chat/test_context_files.py
Normal file
@@ -0,0 +1,426 @@
|
||||
"""Tests for the unified context file extraction logic (Phase 5).
|
||||
|
||||
Covers:
|
||||
- resolve_context_user_files: precedence rule (custom persona supersedes project)
|
||||
- extract_context_files: all-or-nothing context window fit check
|
||||
- Search filter / search_usage determination in the caller
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.chat.models import ExtractedContextFiles
|
||||
from onyx.chat.process_message import determine_search_params
|
||||
from onyx.chat.process_message import extract_context_files
|
||||
from onyx.chat.process_message import resolve_context_user_files
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.tools.models import SearchToolUsage
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_user_file(
|
||||
token_count: int = 100,
|
||||
name: str = "file.txt",
|
||||
file_id: str | None = None,
|
||||
) -> UserFile:
|
||||
file_uuid = UUID(file_id) if file_id else uuid4()
|
||||
return UserFile(
|
||||
id=file_uuid,
|
||||
file_id=str(file_uuid),
|
||||
name=name,
|
||||
token_count=token_count,
|
||||
)
|
||||
|
||||
|
||||
def _make_persona(
|
||||
persona_id: int,
|
||||
user_files: list | None = None,
|
||||
) -> MagicMock:
|
||||
persona = MagicMock()
|
||||
persona.id = persona_id
|
||||
persona.user_files = user_files or []
|
||||
return persona
|
||||
|
||||
|
||||
def _make_in_memory_file(
|
||||
file_id: str,
|
||||
content: str = "hello world",
|
||||
file_type: ChatFileType = ChatFileType.PLAIN_TEXT,
|
||||
filename: str = "file.txt",
|
||||
) -> InMemoryChatFile:
|
||||
return InMemoryChatFile(
|
||||
file_id=file_id,
|
||||
content=content.encode("utf-8"),
|
||||
file_type=file_type,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# resolve_context_user_files
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestResolveContextUserFiles:
|
||||
"""Precedence rule: custom persona fully supersedes project."""
|
||||
|
||||
def test_custom_persona_with_files_returns_persona_files(self) -> None:
|
||||
persona_files = [_make_user_file(), _make_user_file()]
|
||||
persona = _make_persona(persona_id=42, user_files=persona_files)
|
||||
db_session = MagicMock()
|
||||
|
||||
result = resolve_context_user_files(
|
||||
persona=persona, project_id=99, user_id=uuid4(), db_session=db_session
|
||||
)
|
||||
|
||||
assert result == persona_files
|
||||
|
||||
def test_custom_persona_without_files_returns_empty(self) -> None:
|
||||
"""Custom persona with no files should NOT fall through to project."""
|
||||
persona = _make_persona(persona_id=42, user_files=[])
|
||||
db_session = MagicMock()
|
||||
|
||||
result = resolve_context_user_files(
|
||||
persona=persona, project_id=99, user_id=uuid4(), db_session=db_session
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_custom_persona_none_files_returns_empty(self) -> None:
|
||||
"""Custom persona with user_files=None should NOT fall through."""
|
||||
persona = _make_persona(persona_id=42, user_files=None)
|
||||
db_session = MagicMock()
|
||||
|
||||
result = resolve_context_user_files(
|
||||
persona=persona, project_id=99, user_id=uuid4(), db_session=db_session
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
@patch("onyx.chat.process_message.get_user_files_from_project")
|
||||
def test_default_persona_in_project_returns_project_files(
|
||||
self, mock_get_files: MagicMock
|
||||
) -> None:
|
||||
project_files = [_make_user_file(), _make_user_file()]
|
||||
mock_get_files.return_value = project_files
|
||||
persona = _make_persona(persona_id=DEFAULT_PERSONA_ID)
|
||||
user_id = uuid4()
|
||||
db_session = MagicMock()
|
||||
|
||||
result = resolve_context_user_files(
|
||||
persona=persona, project_id=99, user_id=user_id, db_session=db_session
|
||||
)
|
||||
|
||||
assert result == project_files
|
||||
mock_get_files.assert_called_once_with(
|
||||
project_id=99, user_id=user_id, db_session=db_session
|
||||
)
|
||||
|
||||
def test_default_persona_no_project_returns_empty(self) -> None:
|
||||
persona = _make_persona(persona_id=DEFAULT_PERSONA_ID)
|
||||
db_session = MagicMock()
|
||||
|
||||
result = resolve_context_user_files(
|
||||
persona=persona, project_id=None, user_id=uuid4(), db_session=db_session
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
@patch("onyx.chat.process_message.get_user_files_from_project")
|
||||
def test_custom_persona_without_files_ignores_project(
|
||||
self, mock_get_files: MagicMock
|
||||
) -> None:
|
||||
"""Even with a project_id, custom persona means project is invisible."""
|
||||
persona = _make_persona(persona_id=7, user_files=[])
|
||||
db_session = MagicMock()
|
||||
|
||||
result = resolve_context_user_files(
|
||||
persona=persona, project_id=99, user_id=uuid4(), db_session=db_session
|
||||
)
|
||||
|
||||
assert result == []
|
||||
mock_get_files.assert_not_called()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# extract_context_files
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestExtractContextFiles:
|
||||
"""All-or-nothing context window fit check."""
|
||||
|
||||
def test_empty_user_files_returns_empty(self) -> None:
|
||||
db_session = MagicMock()
|
||||
result = extract_context_files(
|
||||
user_files=[],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=db_session,
|
||||
)
|
||||
assert result.file_texts == []
|
||||
assert result.image_files == []
|
||||
assert result.use_as_search_filter is False
|
||||
assert result.uncapped_token_count is None
|
||||
|
||||
@patch("onyx.chat.process_message.load_in_memory_chat_files")
|
||||
def test_files_fit_in_context_are_loaded(self, mock_load: MagicMock) -> None:
|
||||
file_id = str(uuid4())
|
||||
uf = _make_user_file(token_count=100, file_id=file_id)
|
||||
mock_load.return_value = [
|
||||
_make_in_memory_file(file_id=file_id, content="file content")
|
||||
]
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.file_texts == ["file content"]
|
||||
assert result.use_as_search_filter is False
|
||||
assert result.total_token_count == 100
|
||||
assert len(result.file_metadata) == 1
|
||||
assert result.file_metadata[0].file_id == file_id
|
||||
|
||||
def test_files_overflow_context_not_loaded(self) -> None:
|
||||
"""When aggregate tokens exceed 60% of available window, nothing is loaded."""
|
||||
uf = _make_user_file(token_count=7000)
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.file_texts == []
|
||||
assert result.image_files == []
|
||||
assert result.use_as_search_filter is True
|
||||
assert result.uncapped_token_count == 7000
|
||||
assert result.total_token_count == 0
|
||||
|
||||
def test_overflow_boundary_exact(self) -> None:
|
||||
"""Token count exactly at the 60% boundary should trigger overflow."""
|
||||
# Available = (10000 - 0) * 0.6 = 6000. Tokens = 6000 → >= threshold.
|
||||
uf = _make_user_file(token_count=6000)
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.use_as_search_filter is True
|
||||
|
||||
@patch("onyx.chat.process_message.load_in_memory_chat_files")
|
||||
def test_just_under_boundary_loads(self, mock_load: MagicMock) -> None:
|
||||
"""Token count just under the 60% boundary should load files."""
|
||||
file_id = str(uuid4())
|
||||
uf = _make_user_file(token_count=5999, file_id=file_id)
|
||||
mock_load.return_value = [_make_in_memory_file(file_id=file_id, content="data")]
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.use_as_search_filter is False
|
||||
assert result.file_texts == ["data"]
|
||||
|
||||
@patch("onyx.chat.process_message.load_in_memory_chat_files")
|
||||
def test_multiple_files_aggregate_check(self, mock_load: MagicMock) -> None:
|
||||
"""Multiple small files that individually fit but collectively overflow."""
|
||||
files = [_make_user_file(token_count=2500) for _ in range(3)]
|
||||
# 3 * 2500 = 7500 > 6000 threshold
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=files,
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.use_as_search_filter is True
|
||||
assert result.file_texts == []
|
||||
mock_load.assert_not_called()
|
||||
|
||||
@patch("onyx.chat.process_message.load_in_memory_chat_files")
|
||||
def test_reserved_tokens_reduce_available_space(self, mock_load: MagicMock) -> None:
|
||||
"""Reserved tokens shrink the available window."""
|
||||
file_id = str(uuid4())
|
||||
uf = _make_user_file(token_count=3000, file_id=file_id)
|
||||
# Available = (10000 - 5000) * 0.6 = 3000. Tokens = 3000 → overflow.
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=5000,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.use_as_search_filter is True
|
||||
mock_load.assert_not_called()
|
||||
|
||||
@patch("onyx.chat.process_message.load_in_memory_chat_files")
|
||||
def test_image_files_are_extracted(self, mock_load: MagicMock) -> None:
|
||||
file_id = str(uuid4())
|
||||
uf = _make_user_file(token_count=50, file_id=file_id)
|
||||
mock_load.return_value = [
|
||||
InMemoryChatFile(
|
||||
file_id=file_id,
|
||||
content=b"\x89PNG",
|
||||
file_type=ChatFileType.IMAGE,
|
||||
filename="photo.png",
|
||||
)
|
||||
]
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert len(result.image_files) == 1
|
||||
assert result.image_files[0].file_id == file_id
|
||||
assert result.file_texts == []
|
||||
assert result.total_token_count == 50
|
||||
|
||||
@patch("onyx.chat.process_message.DISABLE_VECTOR_DB", True)
|
||||
def test_overflow_with_vector_db_disabled_provides_tool_metadata(self) -> None:
|
||||
"""When vector DB is disabled, overflow produces FileToolMetadata."""
|
||||
uf = _make_user_file(token_count=7000, name="bigfile.txt")
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.use_as_search_filter is False
|
||||
assert len(result.file_metadata_for_tool) == 1
|
||||
assert result.file_metadata_for_tool[0].filename == "bigfile.txt"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Search filter + search_usage determination
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestSearchFilterDetermination:
|
||||
"""Verify that determine_search_params correctly resolves
|
||||
search_project_id, search_persona_id, and search_usage based on
|
||||
the extraction result and the precedence rule.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_context(
|
||||
use_as_search_filter: bool = False,
|
||||
file_texts: list[str] | None = None,
|
||||
uncapped_token_count: int | None = None,
|
||||
) -> ExtractedContextFiles:
|
||||
return ExtractedContextFiles(
|
||||
file_texts=file_texts or [],
|
||||
image_files=[],
|
||||
use_as_search_filter=use_as_search_filter,
|
||||
total_token_count=0,
|
||||
file_metadata=[],
|
||||
uncapped_token_count=uncapped_token_count,
|
||||
)
|
||||
|
||||
def test_custom_persona_files_fit_no_filter(self) -> None:
|
||||
"""Custom persona, files fit → no search filter, AUTO."""
|
||||
result = determine_search_params(
|
||||
persona_id=42,
|
||||
project_id=99,
|
||||
extracted_context_files=self._make_context(
|
||||
file_texts=["content"],
|
||||
uncapped_token_count=100,
|
||||
),
|
||||
)
|
||||
assert result.search_project_id is None
|
||||
assert result.search_persona_id is None
|
||||
assert result.search_usage == SearchToolUsage.AUTO
|
||||
|
||||
def test_custom_persona_files_overflow_persona_filter(self) -> None:
|
||||
"""Custom persona, files overflow → persona_id filter, AUTO."""
|
||||
result = determine_search_params(
|
||||
persona_id=42,
|
||||
project_id=99,
|
||||
extracted_context_files=self._make_context(use_as_search_filter=True),
|
||||
)
|
||||
assert result.search_persona_id == 42
|
||||
assert result.search_project_id is None
|
||||
assert result.search_usage == SearchToolUsage.AUTO
|
||||
|
||||
def test_custom_persona_no_files_no_project_leak(self) -> None:
|
||||
"""Custom persona (no files) in project → nothing leaks from project."""
|
||||
result = determine_search_params(
|
||||
persona_id=42,
|
||||
project_id=99,
|
||||
extracted_context_files=self._make_context(),
|
||||
)
|
||||
assert result.search_project_id is None
|
||||
assert result.search_persona_id is None
|
||||
assert result.search_usage == SearchToolUsage.AUTO
|
||||
|
||||
def test_default_persona_project_files_fit_disables_search(self) -> None:
|
||||
"""Default persona, project files fit → DISABLED."""
|
||||
result = determine_search_params(
|
||||
persona_id=DEFAULT_PERSONA_ID,
|
||||
project_id=99,
|
||||
extracted_context_files=self._make_context(
|
||||
file_texts=["content"],
|
||||
uncapped_token_count=100,
|
||||
),
|
||||
)
|
||||
assert result.search_project_id is None
|
||||
assert result.search_usage == SearchToolUsage.DISABLED
|
||||
|
||||
def test_default_persona_project_files_overflow_enables_search(self) -> None:
|
||||
"""Default persona, project files overflow → ENABLED + project_id filter."""
|
||||
result = determine_search_params(
|
||||
persona_id=DEFAULT_PERSONA_ID,
|
||||
project_id=99,
|
||||
extracted_context_files=self._make_context(
|
||||
use_as_search_filter=True,
|
||||
uncapped_token_count=7000,
|
||||
),
|
||||
)
|
||||
assert result.search_project_id == 99
|
||||
assert result.search_persona_id is None
|
||||
assert result.search_usage == SearchToolUsage.ENABLED
|
||||
|
||||
def test_default_persona_no_project_auto(self) -> None:
|
||||
"""Default persona, no project → AUTO."""
|
||||
result = determine_search_params(
|
||||
persona_id=DEFAULT_PERSONA_ID,
|
||||
project_id=None,
|
||||
extracted_context_files=self._make_context(),
|
||||
)
|
||||
assert result.search_project_id is None
|
||||
assert result.search_usage == SearchToolUsage.AUTO
|
||||
|
||||
def test_default_persona_project_no_files_disables_search(self) -> None:
|
||||
"""Default persona in project with no files → DISABLED."""
|
||||
result = determine_search_params(
|
||||
persona_id=DEFAULT_PERSONA_ID,
|
||||
project_id=99,
|
||||
extracted_context_files=self._make_context(),
|
||||
)
|
||||
assert result.search_usage == SearchToolUsage.DISABLED
|
||||
@@ -7,10 +7,10 @@ from onyx.chat.llm_loop import _try_fallback_tool_extraction
|
||||
from onyx.chat.llm_loop import construct_message_history
|
||||
from onyx.chat.models import ChatLoadedFile
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
from onyx.chat.models import ContextFileMetadata
|
||||
from onyx.chat.models import ExtractedContextFiles
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.models import ProjectFileMetadata
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.file_store.models import ChatFileType
|
||||
@@ -74,20 +74,20 @@ def create_tool_response(
|
||||
)
|
||||
|
||||
|
||||
def create_project_files(
|
||||
def create_context_files(
|
||||
num_files: int = 0, num_images: int = 0, tokens_per_file: int = 100
|
||||
) -> ExtractedProjectFiles:
|
||||
"""Helper to create ExtractedProjectFiles for testing."""
|
||||
project_file_texts = [f"Project file {i} content" for i in range(num_files)]
|
||||
project_file_metadata = [
|
||||
ProjectFileMetadata(
|
||||
) -> ExtractedContextFiles:
|
||||
"""Helper to create ExtractedContextFiles for testing."""
|
||||
file_texts = [f"Project file {i} content" for i in range(num_files)]
|
||||
file_metadata = [
|
||||
ContextFileMetadata(
|
||||
file_id=f"file_{i}",
|
||||
filename=f"file_{i}.txt",
|
||||
file_content=f"Project file {i} content",
|
||||
)
|
||||
for i in range(num_files)
|
||||
]
|
||||
project_image_files = [
|
||||
image_files = [
|
||||
ChatLoadedFile(
|
||||
file_id=f"image_{i}",
|
||||
content=b"",
|
||||
@@ -98,13 +98,13 @@ def create_project_files(
|
||||
)
|
||||
for i in range(num_images)
|
||||
]
|
||||
return ExtractedProjectFiles(
|
||||
project_file_texts=project_file_texts,
|
||||
project_image_files=project_image_files,
|
||||
project_as_filter=False,
|
||||
return ExtractedContextFiles(
|
||||
file_texts=file_texts,
|
||||
image_files=image_files,
|
||||
use_as_search_filter=False,
|
||||
total_token_count=num_files * tokens_per_file,
|
||||
project_file_metadata=project_file_metadata,
|
||||
project_uncapped_token_count=num_files * tokens_per_file,
|
||||
file_metadata=file_metadata,
|
||||
uncapped_token_count=num_files * tokens_per_file,
|
||||
)
|
||||
|
||||
|
||||
@@ -121,14 +121,14 @@ class TestConstructMessageHistory:
|
||||
user_msg2 = create_message("How are you?", MessageType.USER, 5)
|
||||
|
||||
simple_chat_history = [user_msg1, assistant_msg1, user_msg2]
|
||||
project_files = create_project_files()
|
||||
context_files = create_context_files()
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -148,14 +148,14 @@ class TestConstructMessageHistory:
|
||||
custom_agent = create_message("Custom instructions", MessageType.USER, 10)
|
||||
|
||||
simple_chat_history = [user_msg1, assistant_msg1, user_msg2]
|
||||
project_files = create_project_files()
|
||||
context_files = create_context_files()
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=custom_agent,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -167,25 +167,25 @@ class TestConstructMessageHistory:
|
||||
assert result[3] == custom_agent # Before last user message
|
||||
assert result[4] == user_msg2
|
||||
|
||||
def test_with_project_files(self) -> None:
|
||||
def test_with_context_files(self) -> None:
|
||||
"""Test that project files are inserted before the last user message."""
|
||||
system_prompt = create_message("System", MessageType.SYSTEM, 10)
|
||||
user_msg1 = create_message("First message", MessageType.USER, 5)
|
||||
user_msg2 = create_message("Second message", MessageType.USER, 5)
|
||||
|
||||
simple_chat_history = [user_msg1, user_msg2]
|
||||
project_files = create_project_files(num_files=2, tokens_per_file=50)
|
||||
context_files = create_context_files(num_files=2, tokens_per_file=50)
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
# Should have: system, user1, project_files_message, user2
|
||||
# Should have: system, user1, context_files_message, user2
|
||||
assert len(result) == 4
|
||||
assert result[0] == system_prompt
|
||||
assert result[1] == user_msg1
|
||||
@@ -202,14 +202,14 @@ class TestConstructMessageHistory:
|
||||
reminder = create_message("Remember to cite sources", MessageType.USER, 10)
|
||||
|
||||
simple_chat_history = [user_msg]
|
||||
project_files = create_project_files()
|
||||
context_files = create_context_files()
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=reminder,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -235,14 +235,14 @@ class TestConstructMessageHistory:
|
||||
assistant_with_tool,
|
||||
tool_response,
|
||||
]
|
||||
project_files = create_project_files()
|
||||
context_files = create_context_files()
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -264,18 +264,18 @@ class TestConstructMessageHistory:
|
||||
custom_agent = create_message("Custom", MessageType.USER, 10)
|
||||
|
||||
simple_chat_history = [user_msg1, user_msg2, assistant_with_tool]
|
||||
project_files = create_project_files(num_files=1, tokens_per_file=50)
|
||||
context_files = create_context_files(num_files=1, tokens_per_file=50)
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=custom_agent,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
# Should have: system, user1, custom_agent, project_files, user2, assistant_with_tool
|
||||
# Should have: system, user1, custom_agent, context_files, user2, assistant_with_tool
|
||||
assert len(result) == 6
|
||||
assert result[0] == system_prompt
|
||||
assert result[1] == user_msg1
|
||||
@@ -292,14 +292,14 @@ class TestConstructMessageHistory:
|
||||
user_msg2 = create_message("Second", MessageType.USER, 5)
|
||||
|
||||
simple_chat_history = [user_msg1, user_msg2]
|
||||
project_files = create_project_files(num_files=0, num_images=2)
|
||||
context_files = create_context_files(num_files=0, num_images=2)
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -332,14 +332,14 @@ class TestConstructMessageHistory:
|
||||
)
|
||||
|
||||
simple_chat_history = [user_msg]
|
||||
project_files = create_project_files(num_files=0, num_images=1)
|
||||
context_files = create_context_files(num_files=0, num_images=1)
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -366,7 +366,7 @@ class TestConstructMessageHistory:
|
||||
assistant_msg2,
|
||||
user_msg3,
|
||||
]
|
||||
project_files = create_project_files()
|
||||
context_files = create_context_files()
|
||||
|
||||
# Budget only allows last 3 messages + system (10 + 20 + 20 + 20 = 70 tokens)
|
||||
result = construct_message_history(
|
||||
@@ -374,7 +374,7 @@ class TestConstructMessageHistory:
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=80,
|
||||
)
|
||||
|
||||
@@ -395,7 +395,7 @@ class TestConstructMessageHistory:
|
||||
tool_response = create_tool_response("tc_1", "tool_response", 20)
|
||||
|
||||
simple_chat_history = [user_msg1, user_msg2, assistant_with_tool, tool_response]
|
||||
project_files = create_project_files()
|
||||
context_files = create_context_files()
|
||||
|
||||
# Budget only allows last user message and messages after + system
|
||||
# (10 + 20 + 20 + 20 = 70 tokens)
|
||||
@@ -404,7 +404,7 @@ class TestConstructMessageHistory:
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=80,
|
||||
)
|
||||
|
||||
@@ -432,7 +432,7 @@ class TestConstructMessageHistory:
|
||||
assistant_msg1,
|
||||
user_msg2,
|
||||
]
|
||||
project_files = create_project_files()
|
||||
context_files = create_context_files()
|
||||
|
||||
# Remaining history budget is 10 tokens (30 total - 10 system - 10 last user):
|
||||
# keeps [tool_response, assistant_msg1] from history_before_last_user,
|
||||
@@ -442,7 +442,7 @@ class TestConstructMessageHistory:
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=30,
|
||||
)
|
||||
|
||||
@@ -461,7 +461,7 @@ class TestConstructMessageHistory:
|
||||
user_msg2 = create_message("Latest question", MessageType.USER, 10)
|
||||
|
||||
simple_chat_history = [user_msg1, assistant_with_tool, tool_response, user_msg2]
|
||||
project_files = create_project_files()
|
||||
context_files = create_context_files()
|
||||
|
||||
# Remaining history budget is 25 tokens (45 total - 10 system - 10 last user):
|
||||
# keeps both assistant_with_tool and tool_response in history_before_last_user.
|
||||
@@ -470,7 +470,7 @@ class TestConstructMessageHistory:
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=45,
|
||||
)
|
||||
|
||||
@@ -487,18 +487,18 @@ class TestConstructMessageHistory:
|
||||
reminder = create_message("Reminder", MessageType.USER, 10)
|
||||
|
||||
simple_chat_history: list[ChatMessageSimple] = []
|
||||
project_files = create_project_files(num_files=1, tokens_per_file=50)
|
||||
context_files = create_context_files(num_files=1, tokens_per_file=50)
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=custom_agent,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=reminder,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
# Should have: system, custom_agent, project_files, reminder
|
||||
# Should have: system, custom_agent, context_files, reminder
|
||||
assert len(result) == 4
|
||||
assert result[0] == system_prompt
|
||||
assert result[1] == custom_agent
|
||||
@@ -512,7 +512,7 @@ class TestConstructMessageHistory:
|
||||
assistant_with_tool = create_assistant_with_tool_call("tc_1", "tool", 5)
|
||||
|
||||
simple_chat_history = [assistant_msg, assistant_with_tool]
|
||||
project_files = create_project_files()
|
||||
context_files = create_context_files()
|
||||
|
||||
with pytest.raises(ValueError, match="No user message found"):
|
||||
construct_message_history(
|
||||
@@ -520,7 +520,7 @@ class TestConstructMessageHistory:
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -531,7 +531,7 @@ class TestConstructMessageHistory:
|
||||
custom_agent = create_message("Custom", MessageType.USER, 50)
|
||||
|
||||
simple_chat_history = [user_msg]
|
||||
project_files = create_project_files(num_files=1, tokens_per_file=100)
|
||||
context_files = create_context_files(num_files=1, tokens_per_file=100)
|
||||
|
||||
# Total required: 50 (system) + 50 (custom) + 100 (project) + 50 (user) = 250
|
||||
# But only 200 available
|
||||
@@ -541,7 +541,7 @@ class TestConstructMessageHistory:
|
||||
custom_agent_prompt=custom_agent,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=200,
|
||||
)
|
||||
|
||||
@@ -553,7 +553,7 @@ class TestConstructMessageHistory:
|
||||
assistant_with_tool = create_assistant_with_tool_call("tc_1", "tool", 30)
|
||||
|
||||
simple_chat_history = [user_msg1, user_msg2, assistant_with_tool]
|
||||
project_files = create_project_files()
|
||||
context_files = create_context_files()
|
||||
|
||||
# Budget: 50 tokens
|
||||
# Required: 10 (system) + 30 (user2) + 30 (assistant_with_tool) = 70 tokens
|
||||
@@ -566,7 +566,7 @@ class TestConstructMessageHistory:
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=50,
|
||||
)
|
||||
|
||||
@@ -592,20 +592,20 @@ class TestConstructMessageHistory:
|
||||
assistant_with_tool,
|
||||
tool_response,
|
||||
]
|
||||
project_files = create_project_files(num_files=2, tokens_per_file=20)
|
||||
context_files = create_context_files(num_files=2, tokens_per_file=20)
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=custom_agent,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=reminder,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
# Expected order:
|
||||
# system, user1, assistant1, user2, assistant2,
|
||||
# custom_agent, project_files, user3, assistant_with_tool, tool_response, reminder
|
||||
# custom_agent, context_files, user3, assistant_with_tool, tool_response, reminder
|
||||
assert len(result) == 11
|
||||
assert result[0] == system_prompt
|
||||
assert result[1] == user_msg1
|
||||
@@ -622,20 +622,20 @@ class TestConstructMessageHistory:
|
||||
assert result[9] == tool_response # After last user
|
||||
assert result[10] == reminder # At the very end
|
||||
|
||||
def test_project_files_json_format(self) -> None:
|
||||
def test_context_files_json_format(self) -> None:
|
||||
"""Test that project files are formatted correctly as JSON."""
|
||||
system_prompt = create_message("System", MessageType.SYSTEM, 10)
|
||||
user_msg = create_message("Hello", MessageType.USER, 5)
|
||||
|
||||
simple_chat_history = [user_msg]
|
||||
project_files = create_project_files(num_files=2, tokens_per_file=50)
|
||||
context_files = create_context_files(num_files=2, tokens_per_file=50)
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -692,7 +692,7 @@ class TestForgottenFileMetadata:
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=create_project_files(),
|
||||
context_files=create_context_files(),
|
||||
available_tokens=available_tokens,
|
||||
token_counter=_simple_token_counter,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
|
||||
@@ -44,7 +44,6 @@ def _build_provider_view(
|
||||
id=1,
|
||||
name="test-provider",
|
||||
provider=provider,
|
||||
default_model_name="test-model",
|
||||
model_configurations=[
|
||||
ModelConfigurationView(
|
||||
name="test-model",
|
||||
@@ -62,7 +61,6 @@ def _build_provider_view(
|
||||
groups=[],
|
||||
personas=[],
|
||||
deployment_name=None,
|
||||
default_vision_model=None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -106,6 +106,9 @@ def test_checkout_event_stores_endpoint_and_increments_gauge() -> None:
|
||||
patch(
|
||||
"onyx.server.metrics.postgres_connection_pool.CURRENT_ENDPOINT_CONTEXTVAR"
|
||||
) as mock_ctx,
|
||||
patch(
|
||||
"onyx.server.metrics.postgres_connection_pool.CURRENT_TENANT_ID_CONTEXTVAR"
|
||||
) as mock_tenant_ctx,
|
||||
patch(
|
||||
"onyx.server.metrics.postgres_connection_pool._connections_held"
|
||||
) as mock_gauge,
|
||||
@@ -114,12 +117,14 @@ def test_checkout_event_stores_endpoint_and_increments_gauge() -> None:
|
||||
mock_labels = MagicMock()
|
||||
mock_gauge.labels.return_value = mock_labels
|
||||
mock_ctx.get.return_value = "/api/chat/send-message"
|
||||
mock_tenant_ctx.get.return_value = "tenant_xyz"
|
||||
listeners["checkout"](None, conn_record, None)
|
||||
|
||||
assert conn_record.info["_metrics_endpoint"] == "/api/chat/send-message"
|
||||
assert conn_record.info["_metrics_tenant_id"] == "tenant_xyz"
|
||||
assert "_metrics_checkout_time" in conn_record.info
|
||||
mock_gauge.labels.assert_called_with(
|
||||
handler="/api/chat/send-message", engine="sync"
|
||||
handler="/api/chat/send-message", engine="sync", tenant_id="tenant_xyz"
|
||||
)
|
||||
mock_labels.inc.assert_called_once()
|
||||
|
||||
@@ -144,6 +149,7 @@ def test_checkin_event_observes_hold_duration() -> None:
|
||||
|
||||
conn_record = _make_conn_record()
|
||||
conn_record.info["_metrics_endpoint"] = "/api/search"
|
||||
conn_record.info["_metrics_tenant_id"] = "tenant_abc"
|
||||
conn_record.info["_metrics_checkout_time"] = time.monotonic() - 0.5
|
||||
|
||||
with (
|
||||
@@ -162,7 +168,9 @@ def test_checkin_event_observes_hold_duration() -> None:
|
||||
|
||||
listeners["checkin"](None, conn_record)
|
||||
|
||||
mock_gauge.labels.assert_called_with(handler="/api/search", engine="sync")
|
||||
mock_gauge.labels.assert_called_with(
|
||||
handler="/api/search", engine="sync", tenant_id="tenant_abc"
|
||||
)
|
||||
mock_labels.dec.assert_called_once()
|
||||
mock_hist.labels.assert_called_with(handler="/api/search", engine="sync")
|
||||
mock_hist_labels.observe.assert_called_once()
|
||||
@@ -172,11 +180,12 @@ def test_checkin_event_observes_hold_duration() -> None:
|
||||
|
||||
# conn_record.info should be cleaned up
|
||||
assert "_metrics_endpoint" not in conn_record.info
|
||||
assert "_metrics_tenant_id" not in conn_record.info
|
||||
assert "_metrics_checkout_time" not in conn_record.info
|
||||
|
||||
|
||||
def test_checkin_with_missing_endpoint_uses_unknown() -> None:
|
||||
"""Verify checkin gracefully handles missing endpoint info."""
|
||||
"""Verify checkin gracefully handles missing endpoint and tenant info."""
|
||||
engine = MagicMock()
|
||||
engine.pool = MagicMock()
|
||||
listeners: dict[str, Any] = {}
|
||||
@@ -207,7 +216,9 @@ def test_checkin_with_missing_endpoint_uses_unknown() -> None:
|
||||
|
||||
listeners["checkin"](None, conn_record)
|
||||
|
||||
mock_gauge.labels.assert_called_with(handler="unknown", engine="sync")
|
||||
mock_gauge.labels.assert_called_with(
|
||||
handler="unknown", engine="sync", tenant_id="unknown"
|
||||
)
|
||||
|
||||
|
||||
# --- setup_postgres_connection_pool_metrics tests ---
|
||||
|
||||
@@ -10,6 +10,7 @@ from fastapi.testclient import TestClient
|
||||
from prometheus_client import CollectorRegistry
|
||||
from prometheus_client import Gauge
|
||||
|
||||
from onyx.server.metrics.per_tenant import per_tenant_request_callback
|
||||
from onyx.server.metrics.prometheus_setup import setup_prometheus_metrics
|
||||
from onyx.server.metrics.slow_requests import slow_request_callback
|
||||
|
||||
@@ -81,7 +82,7 @@ def test_setup_attaches_instrumentator_to_app() -> None:
|
||||
inprogress_labels=True,
|
||||
excluded_handlers=["/health", "/metrics", "/openapi.json"],
|
||||
)
|
||||
mock_instance.add.assert_called_once()
|
||||
assert mock_instance.add.call_count == 3
|
||||
mock_instance.instrument.assert_called_once_with(
|
||||
app,
|
||||
latency_lowr_buckets=(
|
||||
@@ -100,6 +101,56 @@ def test_setup_attaches_instrumentator_to_app() -> None:
|
||||
mock_instance.expose.assert_called_once_with(app)
|
||||
|
||||
|
||||
def test_per_tenant_callback_increments_with_tenant_id() -> None:
|
||||
"""Verify per-tenant callback reads tenant from contextvar and increments."""
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.metrics.per_tenant.CURRENT_TENANT_ID_CONTEXTVAR"
|
||||
) as mock_ctx,
|
||||
patch("onyx.server.metrics.per_tenant._requests_by_tenant") as mock_counter,
|
||||
):
|
||||
mock_labels = MagicMock()
|
||||
mock_counter.labels.return_value = mock_labels
|
||||
mock_ctx.get.return_value = "tenant_abc"
|
||||
|
||||
info = _make_info(
|
||||
duration=0.1, method="POST", handler="/api/chat", status="200"
|
||||
)
|
||||
per_tenant_request_callback(info)
|
||||
|
||||
mock_counter.labels.assert_called_once_with(
|
||||
tenant_id="tenant_abc",
|
||||
method="POST",
|
||||
handler="/api/chat",
|
||||
status="200",
|
||||
)
|
||||
mock_labels.inc.assert_called_once()
|
||||
|
||||
|
||||
def test_per_tenant_callback_falls_back_to_unknown() -> None:
|
||||
"""Verify per-tenant callback uses 'unknown' when contextvar is None."""
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.metrics.per_tenant.CURRENT_TENANT_ID_CONTEXTVAR"
|
||||
) as mock_ctx,
|
||||
patch("onyx.server.metrics.per_tenant._requests_by_tenant") as mock_counter,
|
||||
):
|
||||
mock_labels = MagicMock()
|
||||
mock_counter.labels.return_value = mock_labels
|
||||
mock_ctx.get.return_value = None
|
||||
|
||||
info = _make_info(duration=0.1)
|
||||
per_tenant_request_callback(info)
|
||||
|
||||
mock_counter.labels.assert_called_once_with(
|
||||
tenant_id="unknown",
|
||||
method="GET",
|
||||
handler="/api/test",
|
||||
status="200",
|
||||
)
|
||||
mock_labels.inc.assert_called_once()
|
||||
|
||||
|
||||
def test_inprogress_gauge_increments_during_request() -> None:
|
||||
"""Verify the in-progress gauge goes up while a request is in flight."""
|
||||
registry = CollectorRegistry()
|
||||
|
||||
@@ -163,3 +163,16 @@ Add clear comments:
|
||||
- Any TODOs you add in the code must be accompanied by either the name/username
|
||||
of the owner of that TODO, or an issue number for an issue referencing that
|
||||
piece of work.
|
||||
- Avoid module-level logic that runs on import, which leads to import-time side
|
||||
effects. Essentially every piece of meaningful logic should exist within some
|
||||
function that has to be explicitly invoked. Acceptable exceptions to this may
|
||||
include loading environment variables or setting up loggers.
|
||||
- If you find yourself needing something like this, you may want that logic to
|
||||
exist in a file dedicated for manual execution (contains `if __name__ ==
|
||||
"__main__":`) which should not be imported by anything else.
|
||||
- Related to the above, do not conflate Python scripts you intend to run from
|
||||
the command line (contains `if __name__ == "__main__":`) with modules you
|
||||
intend to import from elsewhere. If for some unlikely reason they have to be
|
||||
the same file, any logic specific to executing the file (including imports)
|
||||
should be contained in the `if __name__ == "__main__":` block.
|
||||
- Generally these executable files exist in `backend/scripts/`.
|
||||
|
||||
@@ -468,7 +468,7 @@ services:
|
||||
- minio_data:/data
|
||||
command: server /data --console-address ":9001"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
|
||||
test: ["CMD", "mc", "ready", "local"]
|
||||
interval: 30s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
|
||||
@@ -293,7 +293,7 @@ services:
|
||||
- minio_data:/data
|
||||
command: server /data --console-address ":9001"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
|
||||
test: ["CMD", "mc", "ready", "local"]
|
||||
interval: 30s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
|
||||
@@ -298,7 +298,7 @@ services:
|
||||
- minio_data:/data
|
||||
command: server /data --console-address ":9001"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
|
||||
test: ["CMD", "mc", "ready", "local"]
|
||||
interval: 30s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
|
||||
@@ -335,7 +335,7 @@ services:
|
||||
- minio_data:/data
|
||||
command: server /data --console-address ":9001"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
|
||||
test: ["CMD", "mc", "ready", "local"]
|
||||
interval: 30s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
|
||||
@@ -232,7 +232,7 @@ services:
|
||||
- minio_data:/data
|
||||
command: server /data --console-address ":9001"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
|
||||
test: ["CMD", "mc", "ready", "local"]
|
||||
interval: 30s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
|
||||
@@ -520,7 +520,7 @@ services:
|
||||
- minio_data:/data
|
||||
command: server /data --console-address ":9001"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
|
||||
test: ["CMD", "mc", "ready", "local"]
|
||||
interval: 30s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
@@ -534,9 +534,10 @@ services:
|
||||
required: false
|
||||
|
||||
# Below is needed for the `docker-out-of-docker` execution mode
|
||||
# For Linux rootless Docker, set DOCKER_SOCK_PATH=${XDG_RUNTIME_DIR}/docker.sock
|
||||
user: root
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
- ${DOCKER_SOCK_PATH:-/var/run/docker.sock}:/var/run/docker.sock
|
||||
|
||||
# uncomment below + comment out the above to use the `docker-in-docker` execution mode
|
||||
# privileged: true
|
||||
|
||||
@@ -10,7 +10,7 @@ requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"aioboto3==15.1.0",
|
||||
"cohere==5.6.1",
|
||||
"fastapi==0.128.0",
|
||||
"fastapi==0.133.1",
|
||||
"google-cloud-aiplatform==1.121.0",
|
||||
"google-genai==1.52.0",
|
||||
"litellm==1.81.6",
|
||||
@@ -92,7 +92,7 @@ backend = [
|
||||
"python-gitlab==5.6.0",
|
||||
"python-pptx==0.6.23",
|
||||
"pypandoc_binary==1.16.2",
|
||||
"pypdf==6.6.2",
|
||||
"pypdf==6.7.3",
|
||||
"pytest-mock==3.12.0",
|
||||
"pytest-playwright==0.7.0",
|
||||
"python-docx==1.1.2",
|
||||
|
||||
@@ -51,6 +51,7 @@ func NewRootCommand() *cobra.Command {
|
||||
cmd.AddCommand(NewRunCICommand())
|
||||
cmd.AddCommand(NewScreenshotDiffCommand())
|
||||
cmd.AddCommand(NewWebCommand())
|
||||
cmd.AddCommand(NewWhoisCommand())
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
159
tools/ods/cmd/whois.go
Normal file
159
tools/ods/cmd/whois.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"text/tabwriter"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/tools/ods/internal/kube"
|
||||
)
|
||||
|
||||
var safeIdentifier = regexp.MustCompile(`^[a-zA-Z0-9_\-]+$`)
|
||||
|
||||
// NewWhoisCommand creates the whois command for looking up users/tenants.
|
||||
func NewWhoisCommand() *cobra.Command {
|
||||
var ctx string
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "whois <email-fragment or tenant-id>",
|
||||
Short: "Look up users and admins by email or tenant ID",
|
||||
Long: `Look up tenant and user information from the data plane PostgreSQL database.
|
||||
|
||||
Requires: AWS SSO login, kubectl access to the EKS cluster.
|
||||
|
||||
Two modes (auto-detected):
|
||||
|
||||
Email fragment:
|
||||
ods whois chris
|
||||
→ Searches user_tenant_mapping for emails matching '%chris%'
|
||||
|
||||
Tenant ID:
|
||||
ods whois tenant_abcd1234-...
|
||||
→ Lists all admin emails in that tenant
|
||||
|
||||
Cluster connection is configured via KUBE_CTX_* environment variables.
|
||||
Each variable is a space-separated tuple: "cluster region namespace"
|
||||
|
||||
export KUBE_CTX_DATA_PLANE="<cluster> <region> <namespace>"
|
||||
export KUBE_CTX_CONTROL_PLANE="<cluster> <region> <namespace>"
|
||||
etc...
|
||||
|
||||
Use -c to select which context (default: data_plane).`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
runWhois(args[0], ctx)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVarP(&ctx, "context", "c", "data_plane", "cluster context name (maps to KUBE_CTX_<NAME> env var)")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func clusterFromEnv(name string) *kube.Cluster {
|
||||
envKey := "KUBE_CTX_" + strings.ToUpper(name)
|
||||
val := os.Getenv(envKey)
|
||||
if val == "" {
|
||||
log.Fatalf("Environment variable %s is not set.\n\nSet it as a space-separated tuple:\n export %s=\"<cluster> <region> <namespace>\"", envKey, envKey)
|
||||
}
|
||||
|
||||
parts := strings.Fields(val)
|
||||
if len(parts) != 3 {
|
||||
log.Fatalf("%s must be a space-separated tuple of 3 values (cluster region namespace), got: %q", envKey, val)
|
||||
}
|
||||
|
||||
return &kube.Cluster{Name: parts[0], Region: parts[1], Namespace: parts[2]}
|
||||
}
|
||||
|
||||
// queryPod runs a SQL query via pginto on the given pod and returns cleaned output lines.
|
||||
func queryPod(c *kube.Cluster, pod, sql string) []string {
|
||||
raw, err := c.ExecOnPod(pod, "pginto", "-A", "-t", "-F", "\t", "-c", sql)
|
||||
if err != nil {
|
||||
log.Fatalf("Query failed: %v", err)
|
||||
}
|
||||
|
||||
var lines []string
|
||||
for _, line := range strings.Split(strings.TrimSpace(raw), "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line != "" && !strings.HasPrefix(line, "Connecting to ") {
|
||||
lines = append(lines, line)
|
||||
}
|
||||
}
|
||||
return lines
|
||||
}
|
||||
|
||||
func runWhois(query string, ctx string) {
|
||||
c := clusterFromEnv(ctx)
|
||||
|
||||
if err := c.EnsureContext(); err != nil {
|
||||
log.Fatalf("Failed to ensure cluster context: %v", err)
|
||||
}
|
||||
|
||||
log.Info("Finding api-server pod...")
|
||||
pod, err := c.FindPod("api-server")
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to find api-server pod: %v", err)
|
||||
}
|
||||
log.Debugf("Using pod: %s", pod)
|
||||
|
||||
if strings.HasPrefix(query, "tenant_") {
|
||||
findAdminsByTenant(c, pod, query)
|
||||
} else {
|
||||
findByEmail(c, pod, query)
|
||||
}
|
||||
}
|
||||
|
||||
func findByEmail(c *kube.Cluster, pod, fragment string) {
|
||||
fragment = strings.NewReplacer("'", "", `"`, "", `;`, "", `\`, `\\`, `%`, `\%`, `_`, `\_`).Replace(fragment)
|
||||
|
||||
sql := fmt.Sprintf(
|
||||
`SELECT email, tenant_id, active FROM public.user_tenant_mapping WHERE email LIKE '%%%s%%' ORDER BY email;`,
|
||||
fragment,
|
||||
)
|
||||
|
||||
log.Infof("Searching for emails matching '%%%s%%'...", fragment)
|
||||
lines := queryPod(c, pod, sql)
|
||||
if len(lines) == 0 {
|
||||
fmt.Println("No results found.")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
|
||||
_, _ = fmt.Fprintln(w, "EMAIL\tTENANT ID\tACTIVE")
|
||||
_, _ = fmt.Fprintln(w, "-----\t---------\t------")
|
||||
for _, line := range lines {
|
||||
_, _ = fmt.Fprintln(w, line)
|
||||
}
|
||||
_ = w.Flush()
|
||||
}
|
||||
|
||||
func findAdminsByTenant(c *kube.Cluster, pod, tenantID string) {
|
||||
if !safeIdentifier.MatchString(tenantID) {
|
||||
log.Fatalf("Invalid tenant ID: %q (must be alphanumeric, hyphens, underscores only)", tenantID)
|
||||
}
|
||||
|
||||
sql := fmt.Sprintf(
|
||||
`SELECT email FROM "%s"."user" WHERE role = 'ADMIN' AND is_active = true AND email NOT LIKE 'api_key__%%' ORDER BY email;`,
|
||||
tenantID,
|
||||
)
|
||||
|
||||
log.Infof("Fetching admin emails for %s...", tenantID)
|
||||
lines := queryPod(c, pod, sql)
|
||||
if len(lines) == 0 {
|
||||
fmt.Println("No admin users found for this tenant.")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
fmt.Println("EMAIL")
|
||||
fmt.Println("-----")
|
||||
for _, line := range lines {
|
||||
fmt.Println(line)
|
||||
}
|
||||
}
|
||||
90
tools/ods/internal/kube/kube.go
Normal file
90
tools/ods/internal/kube/kube.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package kube
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Cluster holds the connection info for a Kubernetes cluster.
|
||||
type Cluster struct {
|
||||
Name string
|
||||
Region string
|
||||
Namespace string
|
||||
}
|
||||
|
||||
// EnsureContext makes sure the cluster exists in kubeconfig, calling
|
||||
// aws eks update-kubeconfig only if the context is missing.
|
||||
func (c *Cluster) EnsureContext() error {
|
||||
// Check if context already exists in kubeconfig
|
||||
cmd := exec.Command("kubectl", "config", "get-contexts", c.Name, "--no-headers")
|
||||
if err := cmd.Run(); err == nil {
|
||||
log.Debugf("Context %s already exists, skipping aws eks update-kubeconfig", c.Name)
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Infof("Context %s not found, fetching kubeconfig from AWS...", c.Name)
|
||||
cmd = exec.Command("aws", "eks", "update-kubeconfig", "--region", c.Region, "--name", c.Name, "--alias", c.Name)
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("aws eks update-kubeconfig failed: %w\n%s", err, string(out))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// kubectlArgs returns common kubectl flags to target this cluster without mutating global context.
|
||||
func (c *Cluster) kubectlArgs() []string {
|
||||
return []string{"--context", c.Name, "--namespace", c.Namespace}
|
||||
}
|
||||
|
||||
// FindPod returns the name of the first Running/Ready pod matching the given substring.
|
||||
func (c *Cluster) FindPod(substring string) (string, error) {
|
||||
args := append(c.kubectlArgs(), "get", "po",
|
||||
"--field-selector", "status.phase=Running",
|
||||
"--no-headers",
|
||||
"-o", "custom-columns=NAME:.metadata.name,READY:.status.conditions[?(@.type=='Ready')].status",
|
||||
)
|
||||
cmd := exec.Command("kubectl", args...)
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
return "", fmt.Errorf("kubectl get po failed: %w\n%s", err, string(exitErr.Stderr))
|
||||
}
|
||||
return "", fmt.Errorf("kubectl get po failed: %w", err)
|
||||
}
|
||||
|
||||
for _, line := range strings.Split(strings.TrimSpace(string(out)), "\n") {
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 2 {
|
||||
continue
|
||||
}
|
||||
name, ready := fields[0], fields[1]
|
||||
if strings.Contains(name, substring) && ready == "True" {
|
||||
log.Debugf("Found pod: %s", name)
|
||||
return name, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("no ready pod found matching %q", substring)
|
||||
}
|
||||
|
||||
// ExecOnPod runs a command on a pod and returns its stdout.
|
||||
func (c *Cluster) ExecOnPod(pod string, command ...string) (string, error) {
|
||||
args := append(c.kubectlArgs(), "exec", pod, "--")
|
||||
args = append(args, command...)
|
||||
log.Debugf("Running: kubectl %s", strings.Join(args, " "))
|
||||
|
||||
cmd := exec.Command("kubectl", args...)
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return "", fmt.Errorf("kubectl exec failed: %w\n%s", err, stderr.String())
|
||||
}
|
||||
|
||||
return stdout.String(), nil
|
||||
}
|
||||
23
uv.lock
generated
23
uv.lock
generated
@@ -1688,17 +1688,18 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "fastapi"
|
||||
version = "0.128.0"
|
||||
version = "0.133.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "annotated-doc" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "starlette" },
|
||||
{ name = "typing-extensions" },
|
||||
{ name = "typing-inspection" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/52/08/8c8508db6c7b9aae8f7175046af41baad690771c9bcde676419965e338c7/fastapi-0.128.0.tar.gz", hash = "sha256:1cc179e1cef10a6be60ffe429f79b829dce99d8de32d7acb7e6c8dfdf7f2645a", size = 365682, upload-time = "2025-12-27T15:21:13.714Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/22/6f/0eafed8349eea1fa462238b54a624c8b408cd1ba2795c8e64aa6c34f8ab7/fastapi-0.133.1.tar.gz", hash = "sha256:ed152a45912f102592976fde6cbce7dae1a8a1053da94202e51dd35d184fadd6", size = 378741, upload-time = "2026-02-25T18:18:17.398Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/5c/05/5cbb59154b093548acd0f4c7c474a118eda06da25aa75c616b72d8fcd92a/fastapi-0.128.0-py3-none-any.whl", hash = "sha256:aebd93f9716ee3b4f4fcfe13ffb7cf308d99c9f3ab5622d8877441072561582d", size = 103094, upload-time = "2025-12-27T15:21:12.154Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d2/c9/a175a7779f3599dfa4adfc97a6ce0e157237b3d7941538604aadaf97bfb6/fastapi-0.133.1-py3-none-any.whl", hash = "sha256:658f34ba334605b1617a65adf2ea6461901bdb9af3a3080d63ff791ecf7dc2e2", size = 109029, upload-time = "2026-02-25T18:18:18.578Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4612,7 +4613,7 @@ requires-dist = [
|
||||
{ name = "einops", marker = "extra == 'model-server'", specifier = "==0.8.1" },
|
||||
{ name = "exa-py", marker = "extra == 'backend'", specifier = "==1.15.4" },
|
||||
{ name = "faker", marker = "extra == 'dev'", specifier = "==40.1.2" },
|
||||
{ name = "fastapi", specifier = "==0.128.0" },
|
||||
{ name = "fastapi", specifier = "==0.133.1" },
|
||||
{ name = "fastapi-limiter", marker = "extra == 'backend'", specifier = "==0.1.6" },
|
||||
{ name = "fastapi-users", marker = "extra == 'backend'", specifier = "==15.0.4" },
|
||||
{ name = "fastapi-users-db-sqlalchemy", marker = "extra == 'backend'", specifier = "==7.0.0" },
|
||||
@@ -4677,7 +4678,7 @@ requires-dist = [
|
||||
{ name = "pygithub", marker = "extra == 'backend'", specifier = "==2.5.0" },
|
||||
{ name = "pympler", marker = "extra == 'backend'", specifier = "==1.1" },
|
||||
{ name = "pypandoc-binary", marker = "extra == 'backend'", specifier = "==1.16.2" },
|
||||
{ name = "pypdf", marker = "extra == 'backend'", specifier = "==6.6.2" },
|
||||
{ name = "pypdf", marker = "extra == 'backend'", specifier = "==6.7.3" },
|
||||
{ name = "pytest", marker = "extra == 'dev'", specifier = "==8.3.5" },
|
||||
{ name = "pytest-alembic", marker = "extra == 'dev'", specifier = "==0.12.1" },
|
||||
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = "==1.3.0" },
|
||||
@@ -5924,11 +5925,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "pypdf"
|
||||
version = "6.6.2"
|
||||
version = "6.7.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b8/bb/a44bab1ac3c54dbcf653d7b8bcdee93dddb2d3bf025a3912cacb8149a2f2/pypdf-6.6.2.tar.gz", hash = "sha256:0a3ea3b3303982333404e22d8f75d7b3144f9cf4b2970b96856391a516f9f016", size = 5281850, upload-time = "2026-01-26T11:57:55.964Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/53/9b/63e767042fc852384dc71e5ff6f990ee4e1b165b1526cf3f9c23a4eebb47/pypdf-6.7.3.tar.gz", hash = "sha256:eca55c78d0ec7baa06f9288e2be5c4e8242d5cbb62c7a4b94f2716f8e50076d2", size = 5303304, upload-time = "2026-02-24T17:23:11.42Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7d/be/549aaf1dfa4ab4aed29b09703d2fb02c4366fc1f05e880948c296c5764b9/pypdf-6.6.2-py3-none-any.whl", hash = "sha256:44c0c9811cfb3b83b28f1c3d054531d5b8b81abaedee0d8cb403650d023832ba", size = 329132, upload-time = "2026-01-26T11:57:54.099Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b0/90/3308a9b8b46c1424181fdf3f4580d2b423c5471425799e7fc62f92d183f4/pypdf-6.7.3-py3-none-any.whl", hash = "sha256:cd25ac508f20b554a9fafd825186e3ba29591a69b78c156783c5d8a2d63a1c0a", size = 331263, upload-time = "2026-02-24T17:23:09.932Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -8079,14 +8080,14 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "werkzeug"
|
||||
version = "3.1.5"
|
||||
version = "3.1.6"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "markupsafe" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/5a/70/1469ef1d3542ae7c2c7b72bd5e3a4e6ee69d7978fa8a3af05a38eca5becf/werkzeug-3.1.5.tar.gz", hash = "sha256:6a548b0e88955dd07ccb25539d7d0cc97417ee9e179677d22c7041c8f078ce67", size = 864754, upload-time = "2026-01-08T17:49:23.247Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/61/f1/ee81806690a87dab5f5653c1f146c92bc066d7f4cebc603ef88eb9e13957/werkzeug-3.1.6.tar.gz", hash = "sha256:210c6bede5a420a913956b4791a7f4d6843a43b6fcee4dfa08a65e93007d0d25", size = 864736, upload-time = "2026-02-19T15:17:18.884Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ad/e4/8d97cca767bcc1be76d16fb76951608305561c6e056811587f36cb1316a8/werkzeug-3.1.5-py3-none-any.whl", hash = "sha256:5111e36e91086ece91f93268bb39b4a35c1e6f1feac762c9c822ded0a4e322dc", size = 225025, upload-time = "2026-01-08T17:49:21.859Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4d/ec/d58832f89ede95652fd01f4f24236af7d32b70cab2196dfcc2d2fd13c5c2/werkzeug-3.1.6-py3-none-any.whl", hash = "sha256:7ddf3357bb9564e407607f988f683d72038551200c704012bb9a4c523d42f131", size = 225166, upload-time = "2026-02-19T15:17:17.475Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -42,7 +42,7 @@ import SvgStar from "@opal/icons/star";
|
||||
|
||||
## Usage inside Content
|
||||
|
||||
Tag can be rendered as an accessory inside `Content`'s LabelLayout via the `tag` prop:
|
||||
Tag can be rendered as an accessory inside `Content`'s ContentMd via the `tag` prop:
|
||||
|
||||
```tsx
|
||||
import { Content } from "@opal/layouts";
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
import "@opal/components/buttons/Button/styles.css";
|
||||
import "@opal/components/tooltip.css";
|
||||
import {
|
||||
Interactive,
|
||||
type InteractiveBaseProps,
|
||||
type InteractiveContainerWidthVariant,
|
||||
} from "@opal/core";
|
||||
import type { SizeVariant } from "@opal/shared";
|
||||
import { Interactive, type InteractiveBaseProps } from "@opal/core";
|
||||
import type { SizeVariant, WidthVariant } from "@opal/shared";
|
||||
import type { TooltipSide } from "@opal/components";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
|
||||
@@ -91,7 +87,7 @@ type ButtonProps = InteractiveBaseProps &
|
||||
tooltip?: string;
|
||||
|
||||
/** Width preset. `"auto"` shrink-wraps, `"full"` stretches to parent width. */
|
||||
width?: InteractiveContainerWidthVariant;
|
||||
width?: WidthVariant;
|
||||
|
||||
/** Which side the tooltip appears on. */
|
||||
tooltipSide?: TooltipSide;
|
||||
|
||||
233
web/lib/opal/src/core/hoverable/components.tsx
Normal file
233
web/lib/opal/src/core/hoverable/components.tsx
Normal file
@@ -0,0 +1,233 @@
|
||||
import "@opal/core/hoverable/styles.css";
|
||||
import React, { createContext, useContext, useState, useCallback } from "react";
|
||||
import { cn } from "@opal/utils";
|
||||
import type { WithoutStyles } from "@opal/types";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Context-per-group registry
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Lazily-created map of group names to React contexts.
|
||||
*
|
||||
* Each group gets its own `React.Context<boolean | null>` so that a
|
||||
* `Hoverable.Item` only re-renders when its *own* group's hover state
|
||||
* changes — not when any unrelated group changes.
|
||||
*
|
||||
* The default value is `null` (no provider found), which lets
|
||||
* `Hoverable.Item` distinguish "no Root ancestor" from "Root says
|
||||
* not hovered" and throw when `group` was explicitly specified.
|
||||
*/
|
||||
const contextMap = new Map<string, React.Context<boolean | null>>();
|
||||
|
||||
function getOrCreateContext(group: string): React.Context<boolean | null> {
|
||||
let ctx = contextMap.get(group);
|
||||
if (!ctx) {
|
||||
ctx = createContext<boolean | null>(null);
|
||||
ctx.displayName = `HoverableContext(${group})`;
|
||||
contextMap.set(group, ctx);
|
||||
}
|
||||
return ctx;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface HoverableRootProps
|
||||
extends WithoutStyles<React.HTMLAttributes<HTMLDivElement>> {
|
||||
children: React.ReactNode;
|
||||
group: string;
|
||||
}
|
||||
|
||||
type HoverableItemVariant = "opacity-on-hover";
|
||||
|
||||
interface HoverableItemProps
|
||||
extends WithoutStyles<React.HTMLAttributes<HTMLDivElement>> {
|
||||
children: React.ReactNode;
|
||||
group?: string;
|
||||
variant?: HoverableItemVariant;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HoverableRoot
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Hover-tracking container for a named group.
|
||||
*
|
||||
* Wraps children in a `<div>` that tracks mouse-enter / mouse-leave and
|
||||
* provides the hover state via a per-group React context.
|
||||
*
|
||||
* Nesting works because each `Hoverable.Root` creates a **new** context
|
||||
* provider that shadows the parent — so an inner `Hoverable.Item group="b"`
|
||||
* reads from the inner provider, not the outer `group="a"` provider.
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* <Hoverable.Root group="card">
|
||||
* <Card>
|
||||
* <Hoverable.Item group="card" variant="opacity-on-hover">
|
||||
* <TrashIcon />
|
||||
* </Hoverable.Item>
|
||||
* </Card>
|
||||
* </Hoverable.Root>
|
||||
* ```
|
||||
*/
|
||||
function HoverableRoot({
|
||||
group,
|
||||
children,
|
||||
onMouseEnter: consumerMouseEnter,
|
||||
onMouseLeave: consumerMouseLeave,
|
||||
...props
|
||||
}: HoverableRootProps) {
|
||||
const [hovered, setHovered] = useState(false);
|
||||
|
||||
const onMouseEnter = useCallback(
|
||||
(e: React.MouseEvent<HTMLDivElement>) => {
|
||||
setHovered(true);
|
||||
consumerMouseEnter?.(e);
|
||||
},
|
||||
[consumerMouseEnter]
|
||||
);
|
||||
|
||||
const onMouseLeave = useCallback(
|
||||
(e: React.MouseEvent<HTMLDivElement>) => {
|
||||
setHovered(false);
|
||||
consumerMouseLeave?.(e);
|
||||
},
|
||||
[consumerMouseLeave]
|
||||
);
|
||||
|
||||
const GroupContext = getOrCreateContext(group);
|
||||
|
||||
return (
|
||||
<GroupContext.Provider value={hovered}>
|
||||
<div {...props} onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}>
|
||||
{children}
|
||||
</div>
|
||||
</GroupContext.Provider>
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HoverableItem
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* An element whose visibility is controlled by hover state.
|
||||
*
|
||||
* **Local mode** (`group` omitted): the item handles hover on its own
|
||||
* element via CSS `:hover`. This is the core abstraction.
|
||||
*
|
||||
* **Group mode** (`group` provided): visibility is driven by a matching
|
||||
* `Hoverable.Root` ancestor's hover state via React context. If no
|
||||
* matching Root is found, an error is thrown.
|
||||
*
|
||||
* Uses data-attributes for variant styling (see `styles.css`).
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* // Local mode — hover on the item itself
|
||||
* <Hoverable.Item variant="opacity-on-hover">
|
||||
* <TrashIcon />
|
||||
* </Hoverable.Item>
|
||||
*
|
||||
* // Group mode — hover on the Root reveals the item
|
||||
* <Hoverable.Root group="card">
|
||||
* <Hoverable.Item group="card" variant="opacity-on-hover">
|
||||
* <TrashIcon />
|
||||
* </Hoverable.Item>
|
||||
* </Hoverable.Root>
|
||||
* ```
|
||||
*
|
||||
* @throws If `group` is specified but no matching `Hoverable.Root` ancestor exists.
|
||||
*/
|
||||
function HoverableItem({
|
||||
group,
|
||||
variant = "opacity-on-hover",
|
||||
children,
|
||||
...props
|
||||
}: HoverableItemProps) {
|
||||
const contextValue = useContext(
|
||||
group ? getOrCreateContext(group) : NOOP_CONTEXT
|
||||
);
|
||||
|
||||
if (group && contextValue === null) {
|
||||
throw new Error(
|
||||
`Hoverable.Item group="${group}" has no matching Hoverable.Root ancestor. ` +
|
||||
`Either wrap it in <Hoverable.Root group="${group}"> or remove the group prop for local hover.`
|
||||
);
|
||||
}
|
||||
|
||||
const isLocal = group === undefined;
|
||||
|
||||
return (
|
||||
<div
|
||||
{...props}
|
||||
className={cn("hoverable-item")}
|
||||
data-hoverable-variant={variant}
|
||||
data-hoverable-active={
|
||||
isLocal ? undefined : contextValue ? "true" : undefined
|
||||
}
|
||||
data-hoverable-local={isLocal ? "true" : undefined}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
/** Stable context used when no group is specified (local mode). */
|
||||
const NOOP_CONTEXT = createContext<boolean | null>(null);
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Compound export
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Hoverable compound component for hover-to-reveal patterns.
|
||||
*
|
||||
* Provides two sub-components:
|
||||
*
|
||||
* - `Hoverable.Root` — A container that tracks hover state for a named group
|
||||
* and provides it via React context.
|
||||
*
|
||||
* - `Hoverable.Item` — The core abstraction. On its own (no `group`), it
|
||||
* applies local CSS `:hover` for the variant effect. When `group` is
|
||||
* specified, it reads hover state from the nearest matching
|
||||
* `Hoverable.Root` — and throws if no matching Root is found.
|
||||
*
|
||||
* Supports nesting: a child `Hoverable.Root` shadows the parent's context,
|
||||
* so each group's items only respond to their own root's hover.
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* import { Hoverable } from "@opal/core";
|
||||
*
|
||||
* // Group mode — hovering the card reveals the trash icon
|
||||
* <Hoverable.Root group="card">
|
||||
* <Card>
|
||||
* <span>Card content</span>
|
||||
* <Hoverable.Item group="card" variant="opacity-on-hover">
|
||||
* <TrashIcon />
|
||||
* </Hoverable.Item>
|
||||
* </Card>
|
||||
* </Hoverable.Root>
|
||||
*
|
||||
* // Local mode — hovering the item itself reveals it
|
||||
* <Hoverable.Item variant="opacity-on-hover">
|
||||
* <TrashIcon />
|
||||
* </Hoverable.Item>
|
||||
* ```
|
||||
*/
|
||||
const Hoverable = {
|
||||
Root: HoverableRoot,
|
||||
Item: HoverableItem,
|
||||
};
|
||||
|
||||
export {
|
||||
Hoverable,
|
||||
type HoverableRootProps,
|
||||
type HoverableItemProps,
|
||||
type HoverableItemVariant,
|
||||
};
|
||||
18
web/lib/opal/src/core/hoverable/styles.css
Normal file
18
web/lib/opal/src/core/hoverable/styles.css
Normal file
@@ -0,0 +1,18 @@
|
||||
/* Hoverable — item transitions */
|
||||
.hoverable-item {
|
||||
transition: opacity 150ms ease-in-out;
|
||||
}
|
||||
|
||||
.hoverable-item[data-hoverable-variant="opacity-on-hover"] {
|
||||
opacity: 0;
|
||||
}
|
||||
|
||||
/* Group mode — Root controls visibility via React context */
|
||||
.hoverable-item[data-hoverable-variant="opacity-on-hover"][data-hoverable-active="true"] {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
/* Local mode — item handles its own :hover */
|
||||
.hoverable-item[data-hoverable-variant="opacity-on-hover"][data-hoverable-local="true"]:hover {
|
||||
opacity: 1;
|
||||
}
|
||||
@@ -1,9 +1,16 @@
|
||||
/* Hoverable */
|
||||
export {
|
||||
Hoverable,
|
||||
type HoverableRootProps,
|
||||
type HoverableItemProps,
|
||||
type HoverableItemVariant,
|
||||
} from "@opal/core/hoverable/components";
|
||||
|
||||
/* Interactive */
|
||||
export {
|
||||
Interactive,
|
||||
type InteractiveBaseProps,
|
||||
type InteractiveBaseVariantProps,
|
||||
type InteractiveContainerProps,
|
||||
type InteractiveContainerWidthVariant,
|
||||
type InteractiveContainerRoundingVariant,
|
||||
} from "@opal/core/interactive/components";
|
||||
|
||||
@@ -3,7 +3,12 @@ import React from "react";
|
||||
import { Slot } from "@radix-ui/react-slot";
|
||||
import { cn } from "@opal/utils";
|
||||
import type { WithoutStyles } from "@opal/types";
|
||||
import { sizeVariants, type SizeVariant } from "@opal/shared";
|
||||
import {
|
||||
sizeVariants,
|
||||
type SizeVariant,
|
||||
widthVariants,
|
||||
type WidthVariant,
|
||||
} from "@opal/shared";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
@@ -39,18 +44,6 @@ type InteractiveBaseVariantProps =
|
||||
selected?: never;
|
||||
};
|
||||
|
||||
/**
|
||||
* Width presets for `Interactive.Container`.
|
||||
*
|
||||
* - `"auto"` — Shrink-wraps to content width (default)
|
||||
* - `"full"` — Stretches to fill the parent's width (`w-full`)
|
||||
*/
|
||||
type InteractiveContainerWidthVariant = "auto" | "full";
|
||||
const interactiveContainerWidthVariants = {
|
||||
auto: "w-auto",
|
||||
full: "w-full",
|
||||
} as const;
|
||||
|
||||
/**
|
||||
* Border-radius presets for `Interactive.Container`.
|
||||
*
|
||||
@@ -345,7 +338,7 @@ interface InteractiveContainerProps
|
||||
*
|
||||
* @default "auto"
|
||||
*/
|
||||
widthVariant?: InteractiveContainerWidthVariant;
|
||||
widthVariant?: WidthVariant;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -413,7 +406,7 @@ function InteractiveContainer({
|
||||
height,
|
||||
minWidth,
|
||||
padding,
|
||||
interactiveContainerWidthVariants[widthVariant],
|
||||
widthVariants[widthVariant],
|
||||
slotClassName
|
||||
),
|
||||
"data-border": border ? ("true" as const) : undefined,
|
||||
@@ -490,6 +483,5 @@ export {
|
||||
type InteractiveBaseVariantProps,
|
||||
type InteractiveBaseSelectVariantProps,
|
||||
type InteractiveContainerProps,
|
||||
type InteractiveContainerWidthVariant,
|
||||
type InteractiveContainerRoundingVariant,
|
||||
};
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user