mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-02 22:25:47 +00:00
Compare commits
56 Commits
refactor/l
...
worktree-o
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
992ad3b8d4 | ||
|
|
a6404f8b3e | ||
|
|
efc49c9f6b | ||
|
|
5e447440ea | ||
|
|
78c6ca39b8 | ||
|
|
71a7cf09b3 | ||
|
|
91d30a0156 | ||
|
|
7b30752767 | ||
|
|
4450ecf07c | ||
|
|
0e6b766996 | ||
|
|
12c8cd338b | ||
|
|
ad5688bf65 | ||
|
|
d2deefd1f1 | ||
|
|
18b90d405d | ||
|
|
8394e8837b | ||
|
|
f06df891c4 | ||
|
|
d6d5e72c18 | ||
|
|
449f5d62f9 | ||
|
|
4d256c5666 | ||
|
|
2e53496f46 | ||
|
|
63a206706a | ||
|
|
28427b3e5f | ||
|
|
3cafcd8a5e | ||
|
|
f2c50b7bb5 | ||
|
|
6b28c6bbfc | ||
|
|
226e801665 | ||
|
|
be13aa1310 | ||
|
|
45d38c4906 | ||
|
|
8aab518532 | ||
|
|
da6ce10e86 | ||
|
|
aaf8253520 | ||
|
|
7c7f81b164 | ||
|
|
2d4a3c72e9 | ||
|
|
7c51712018 | ||
|
|
aa5614695d | ||
|
|
8d7255d3c4 | ||
|
|
d403498f48 | ||
|
|
9ef3095c17 | ||
|
|
a39e93a0cb | ||
|
|
46d73cdfee | ||
|
|
1e04ce78e0 | ||
|
|
f9b81c1725 | ||
|
|
3bc1b89fee | ||
|
|
01743d99d4 | ||
|
|
092c1db7e0 | ||
|
|
40ac0d859a | ||
|
|
929e58361f | ||
|
|
6d472df7c5 | ||
|
|
cfa7acd904 | ||
|
|
5c5a6f943b | ||
|
|
d04128b8b1 | ||
|
|
bbebdf8f78 | ||
|
|
161279a2d5 | ||
|
|
e5ebb45a20 | ||
|
|
320ba9cb1b | ||
|
|
f2e8cb3114 |
@@ -9,7 +9,8 @@ inputs:
|
||||
required: true
|
||||
provider-api-key:
|
||||
description: "API key for NIGHTLY_LLM_API_KEY"
|
||||
required: true
|
||||
required: false
|
||||
default: ""
|
||||
strict:
|
||||
description: "String true/false for NIGHTLY_LLM_STRICT"
|
||||
required: true
|
||||
@@ -17,6 +18,14 @@ inputs:
|
||||
description: "Optional NIGHTLY_LLM_API_BASE"
|
||||
required: false
|
||||
default: ""
|
||||
api-version:
|
||||
description: "Optional NIGHTLY_LLM_API_VERSION"
|
||||
required: false
|
||||
default: ""
|
||||
deployment-name:
|
||||
description: "Optional NIGHTLY_LLM_DEPLOYMENT_NAME"
|
||||
required: false
|
||||
default: ""
|
||||
custom-config-json:
|
||||
description: "Optional NIGHTLY_LLM_CUSTOM_CONFIG_JSON"
|
||||
required: false
|
||||
@@ -59,6 +68,7 @@ runs:
|
||||
DISABLE_TELEMETRY=true
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
|
||||
AWS_REGION_NAME=us-west-2
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:nightly-llm-it-backend-${RUN_ID}
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:nightly-llm-it-model-server-${RUN_ID}
|
||||
EOF2
|
||||
@@ -82,6 +92,8 @@ runs:
|
||||
NIGHTLY_LLM_PROVIDER: ${{ inputs.provider }}
|
||||
NIGHTLY_LLM_API_KEY: ${{ inputs.provider-api-key }}
|
||||
NIGHTLY_LLM_API_BASE: ${{ inputs.api-base }}
|
||||
NIGHTLY_LLM_API_VERSION: ${{ inputs.api-version }}
|
||||
NIGHTLY_LLM_DEPLOYMENT_NAME: ${{ inputs.deployment-name }}
|
||||
NIGHTLY_LLM_CUSTOM_CONFIG_JSON: ${{ inputs.custom-config-json }}
|
||||
NIGHTLY_LLM_STRICT: ${{ inputs.strict }}
|
||||
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
|
||||
@@ -91,11 +103,6 @@ runs:
|
||||
max_attempts: 2
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
if [ -z "${MODELS}" ]; then
|
||||
echo "Input 'models' must be non-empty for provider '${NIGHTLY_LLM_PROVIDER}'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
@@ -110,10 +117,13 @@ runs:
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e AWS_REGION_NAME=us-west-2 \
|
||||
-e NIGHTLY_LLM_PROVIDER="${NIGHTLY_LLM_PROVIDER}" \
|
||||
-e NIGHTLY_LLM_MODELS="${MODELS}" \
|
||||
-e NIGHTLY_LLM_API_KEY="${NIGHTLY_LLM_API_KEY}" \
|
||||
-e NIGHTLY_LLM_API_BASE="${NIGHTLY_LLM_API_BASE}" \
|
||||
-e NIGHTLY_LLM_API_VERSION="${NIGHTLY_LLM_API_VERSION}" \
|
||||
-e NIGHTLY_LLM_DEPLOYMENT_NAME="${NIGHTLY_LLM_DEPLOYMENT_NAME}" \
|
||||
-e NIGHTLY_LLM_CUSTOM_CONFIG_JSON="${NIGHTLY_LLM_CUSTOM_CONFIG_JSON}" \
|
||||
-e NIGHTLY_LLM_STRICT="${NIGHTLY_LLM_STRICT}" \
|
||||
${RUNS_ON_ECR_CACHE}:nightly-llm-it-${RUN_ID} \
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
name: Nightly LLM Provider Chat Tests (OpenAI)
|
||||
concurrency:
|
||||
group: Nightly-LLM-Provider-Chat-OpenAI-${{ github.workflow }}-${{ github.ref_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
schedule:
|
||||
# Runs daily at 10:30 UTC (2:30 AM PST / 3:30 AM PDT)
|
||||
- cron: "30 10 * * *"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
openai-provider-chat-test:
|
||||
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
|
||||
with:
|
||||
provider: openai
|
||||
models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
|
||||
strict: true
|
||||
secrets:
|
||||
provider_api_key: ${{ secrets.OPENAI_API_KEY }}
|
||||
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs: [openai-provider-chat-test]
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 5
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.SLACK_WEBHOOK }}
|
||||
failed-jobs: openai-provider-chat-test
|
||||
title: "🚨 Scheduled OpenAI Provider Chat Tests failed!"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
56
.github/workflows/nightly-llm-provider-chat.yml
vendored
Normal file
56
.github/workflows/nightly-llm-provider-chat.yml
vendored
Normal file
@@ -0,0 +1,56 @@
|
||||
name: Nightly LLM Provider Chat Tests
|
||||
concurrency:
|
||||
group: Nightly-LLM-Provider-Chat-${{ github.workflow }}-${{ github.ref_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
schedule:
|
||||
# Runs daily at 10:30 UTC (2:30 AM PST / 3:30 AM PDT)
|
||||
- cron: "30 10 * * *"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
provider-chat-test:
|
||||
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
|
||||
with:
|
||||
openai_models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
|
||||
anthropic_models: ${{ vars.NIGHTLY_LLM_ANTHROPIC_MODELS }}
|
||||
bedrock_models: ${{ vars.NIGHTLY_LLM_BEDROCK_MODELS }}
|
||||
vertex_ai_models: ${{ vars.NIGHTLY_LLM_VERTEX_AI_MODELS }}
|
||||
azure_models: ${{ vars.NIGHTLY_LLM_AZURE_MODELS }}
|
||||
azure_api_base: ${{ vars.NIGHTLY_LLM_AZURE_API_BASE }}
|
||||
ollama_models: ${{ vars.NIGHTLY_LLM_OLLAMA_MODELS }}
|
||||
openrouter_models: ${{ vars.NIGHTLY_LLM_OPENROUTER_MODELS }}
|
||||
strict: true
|
||||
secrets:
|
||||
openai_api_key: ${{ secrets.OPENAI_API_KEY }}
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
bedrock_api_key: ${{ secrets.BEDROCK_API_KEY }}
|
||||
vertex_ai_custom_config_json: ${{ secrets.NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON }}
|
||||
azure_api_key: ${{ secrets.AZURE_API_KEY }}
|
||||
ollama_api_key: ${{ secrets.OLLAMA_API_KEY }}
|
||||
openrouter_api_key: ${{ secrets.OPENROUTER_API_KEY }}
|
||||
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs: [provider-chat-test]
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 5
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.SLACK_WEBHOOK }}
|
||||
failed-jobs: provider-chat-test
|
||||
title: "🚨 Scheduled LLM Provider Chat Tests failed!"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
@@ -114,10 +114,8 @@ jobs:
|
||||
|
||||
- name: Mark workflow as failed if cherry-pick failed
|
||||
if: steps.gate.outputs.should_cherrypick == 'true' && steps.run_cherry_pick.outputs.status == 'failure'
|
||||
env:
|
||||
CHERRY_PICK_REASON: ${{ steps.run_cherry_pick.outputs.reason }}
|
||||
run: |
|
||||
echo "::error::Automated cherry-pick failed (${CHERRY_PICK_REASON})."
|
||||
echo "::error::Automated cherry-pick failed (${{ steps.run_cherry_pick.outputs.reason }})."
|
||||
exit 1
|
||||
|
||||
notify-slack-on-cherry-pick-failure:
|
||||
|
||||
2
.github/workflows/pr-playwright-tests.yml
vendored
2
.github/workflows/pr-playwright-tests.yml
vendored
@@ -603,7 +603,7 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Download visual diff summaries
|
||||
uses: actions/download-artifact@95815c38cf2ff2164869cbab79da8d1f422bc89e # ratchet:actions/download-artifact@v4
|
||||
uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131
|
||||
with:
|
||||
pattern: screenshot-diff-summary-*
|
||||
path: summaries/
|
||||
|
||||
@@ -89,6 +89,10 @@ env:
|
||||
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
|
||||
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ vars.SHAREPOINT_CLIENT_DIRECTORY_ID }}
|
||||
SHAREPOINT_SITE: ${{ vars.SHAREPOINT_SITE }}
|
||||
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
|
||||
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
|
||||
|
||||
# Github
|
||||
ACCESS_TOKEN_GITHUB: ${{ secrets.ACCESS_TOKEN_GITHUB }}
|
||||
|
||||
@@ -3,33 +3,66 @@ name: Reusable Nightly LLM Provider Chat Tests
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
provider:
|
||||
description: "Provider slug passed to NIGHTLY_LLM_PROVIDER (e.g. openai, anthropic)"
|
||||
required: true
|
||||
openai_models:
|
||||
description: "Comma-separated models for openai"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
models:
|
||||
description: "Comma-separated model list passed to NIGHTLY_LLM_MODELS"
|
||||
required: true
|
||||
anthropic_models:
|
||||
description: "Comma-separated models for anthropic"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
bedrock_models:
|
||||
description: "Comma-separated models for bedrock"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
vertex_ai_models:
|
||||
description: "Comma-separated models for vertex_ai"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
azure_models:
|
||||
description: "Comma-separated models for azure"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
ollama_models:
|
||||
description: "Comma-separated models for ollama_chat"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
openrouter_models:
|
||||
description: "Comma-separated models for openrouter"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
azure_api_base:
|
||||
description: "API base for azure provider"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
strict:
|
||||
description: "Pass-through value for NIGHTLY_LLM_STRICT"
|
||||
description: "Default NIGHTLY_LLM_STRICT passed to tests"
|
||||
required: false
|
||||
default: true
|
||||
type: boolean
|
||||
api_base:
|
||||
description: "Optional NIGHTLY_LLM_API_BASE override"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
custom_config_json:
|
||||
description: "Optional NIGHTLY_LLM_CUSTOM_CONFIG_JSON override"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
secrets:
|
||||
provider_api_key:
|
||||
description: "Provider API key passed to NIGHTLY_LLM_API_KEY"
|
||||
required: true
|
||||
openai_api_key:
|
||||
required: false
|
||||
anthropic_api_key:
|
||||
required: false
|
||||
bedrock_api_key:
|
||||
required: false
|
||||
vertex_ai_custom_config_json:
|
||||
required: false
|
||||
azure_api_key:
|
||||
required: false
|
||||
ollama_api_key:
|
||||
required: false
|
||||
openrouter_api_key:
|
||||
required: false
|
||||
DOCKER_USERNAME:
|
||||
required: true
|
||||
DOCKER_TOKEN:
|
||||
@@ -38,29 +71,8 @@ on:
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
NIGHTLY_LLM_PROVIDER: ${{ inputs.provider }}
|
||||
NIGHTLY_LLM_MODELS: ${{ inputs.models }}
|
||||
NIGHTLY_LLM_API_KEY: ${{ secrets.provider_api_key }}
|
||||
NIGHTLY_LLM_API_BASE: ${{ inputs.api_base }}
|
||||
NIGHTLY_LLM_CUSTOM_CONFIG_JSON: ${{ inputs.custom_config_json }}
|
||||
NIGHTLY_LLM_STRICT: ${{ inputs.strict && 'true' || 'false' }}
|
||||
|
||||
jobs:
|
||||
validate-inputs:
|
||||
# NOTE: Keep this cheap and fail before image builds if required inputs are missing.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 5
|
||||
steps:
|
||||
- name: Validate required nightly provider inputs
|
||||
run: |
|
||||
if [ -z "${NIGHTLY_LLM_MODELS}" ]; then
|
||||
echo "Input 'models' must be non-empty for provider '${NIGHTLY_LLM_PROVIDER}'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
build-backend-image:
|
||||
needs: [validate-inputs]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
@@ -90,7 +102,6 @@ jobs:
|
||||
docker-no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' && 'true' || 'false' }}
|
||||
|
||||
build-model-server-image:
|
||||
needs: [validate-inputs]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
@@ -119,7 +130,6 @@ jobs:
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
build-integration-image:
|
||||
needs: [validate-inputs]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
@@ -149,11 +159,75 @@ jobs:
|
||||
|
||||
provider-chat-test:
|
||||
needs:
|
||||
[build-backend-image, build-model-server-image, build-integration-image]
|
||||
[
|
||||
build-backend-image,
|
||||
build-model-server-image,
|
||||
build-integration-image,
|
||||
]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- provider: openai
|
||||
models: ${{ inputs.openai_models }}
|
||||
api_key_secret: openai_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: true
|
||||
- provider: anthropic
|
||||
models: ${{ inputs.anthropic_models }}
|
||||
api_key_secret: anthropic_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: true
|
||||
- provider: bedrock
|
||||
models: ${{ inputs.bedrock_models }}
|
||||
api_key_secret: bedrock_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: vertex_ai
|
||||
models: ${{ inputs.vertex_ai_models }}
|
||||
api_key_secret: ""
|
||||
custom_config_secret: vertex_ai_custom_config_json
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: azure
|
||||
models: ${{ inputs.azure_models }}
|
||||
api_key_secret: azure_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: ${{ inputs.azure_api_base }}
|
||||
api_version: "2025-04-01-preview"
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: ollama_chat
|
||||
models: ${{ inputs.ollama_models }}
|
||||
api_key_secret: ollama_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: "https://ollama.com"
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: openrouter
|
||||
models: ${{ inputs.openrouter_models }}
|
||||
api_key_secret: openrouter_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: "https://openrouter.ai/api/v1"
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- "run-id=${{ github.run_id }}-nightly-${{ inputs.provider }}-provider-chat-test"
|
||||
- "run-id=${{ github.run_id }}-nightly-${{ matrix.provider }}-provider-chat-test"
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
@@ -167,12 +241,14 @@ jobs:
|
||||
- name: Run nightly provider chat test
|
||||
uses: ./.github/actions/run-nightly-provider-chat-test
|
||||
with:
|
||||
provider: ${{ env.NIGHTLY_LLM_PROVIDER }}
|
||||
models: ${{ env.NIGHTLY_LLM_MODELS }}
|
||||
provider-api-key: ${{ secrets.provider_api_key }}
|
||||
strict: ${{ env.NIGHTLY_LLM_STRICT }}
|
||||
api-base: ${{ env.NIGHTLY_LLM_API_BASE }}
|
||||
custom-config-json: ${{ env.NIGHTLY_LLM_CUSTOM_CONFIG_JSON }}
|
||||
provider: ${{ matrix.provider }}
|
||||
models: ${{ matrix.models }}
|
||||
provider-api-key: ${{ matrix.api_key_secret && secrets[matrix.api_key_secret] || '' }}
|
||||
strict: ${{ inputs.strict && 'true' || 'false' }}
|
||||
api-base: ${{ matrix.api_base }}
|
||||
api-version: ${{ matrix.api_version }}
|
||||
deployment-name: ${{ matrix.deployment_name }}
|
||||
custom-config-json: ${{ matrix.custom_config_secret && secrets[matrix.custom_config_secret] || '' }}
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
@@ -194,7 +270,7 @@ jobs:
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-nightly-${{ inputs.provider }}-llm-provider
|
||||
name: docker-all-logs-nightly-${{ matrix.provider }}-llm-provider
|
||||
path: |
|
||||
${{ github.workspace }}/api_server.log
|
||||
${{ github.workspace }}/docker-compose.log
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
"""add python tool on default
|
||||
|
||||
Revision ID: 57122d037335
|
||||
Revises: c0c937d5c9e5
|
||||
Create Date: 2026-02-27 10:10:40.124925
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "57122d037335"
|
||||
down_revision = "c0c937d5c9e5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
PYTHON_TOOL_NAME = "python"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Look up the PythonTool id
|
||||
result = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE name = :name"),
|
||||
{"name": PYTHON_TOOL_NAME},
|
||||
).fetchone()
|
||||
|
||||
if not result:
|
||||
return
|
||||
|
||||
tool_id = result[0]
|
||||
|
||||
# Attach to the default persona (id=0) if not already attached
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": tool_id},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
result = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE name = :name"),
|
||||
{"name": PYTHON_TOOL_NAME},
|
||||
).fetchone()
|
||||
|
||||
if not result:
|
||||
return
|
||||
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM persona__tool
|
||||
WHERE persona_id = 0 AND tool_id = :tool_id
|
||||
"""
|
||||
),
|
||||
{"tool_id": result[0]},
|
||||
)
|
||||
@@ -1,8 +1,8 @@
|
||||
"""LLMProvider deprecated fields are nullable
|
||||
"""llm provider deprecate fields
|
||||
|
||||
Revision ID: 001984c88745
|
||||
Revises: 7616121f6e97
|
||||
Create Date: 2026-02-01 22:24:34.171100
|
||||
Revision ID: c0c937d5c9e5
|
||||
Revises: 8ffcc2bcfc11
|
||||
Create Date: 2026-02-25 17:35:46.125102
|
||||
|
||||
"""
|
||||
|
||||
@@ -11,8 +11,8 @@ import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "001984c88745"
|
||||
down_revision = "7616121f6e97"
|
||||
revision = "c0c937d5c9e5"
|
||||
down_revision = "8ffcc2bcfc11"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
@@ -26,6 +26,13 @@ def upgrade() -> None:
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Drop unique constraint on is_default_provider (defaults now tracked via LLMModelFlow)
|
||||
op.drop_constraint(
|
||||
"llm_provider_is_default_provider_key",
|
||||
"llm_provider",
|
||||
type_="unique",
|
||||
)
|
||||
|
||||
# Remove server_default from is_default_vision_provider (was server_default=false())
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
@@ -34,8 +41,6 @@ def upgrade() -> None:
|
||||
server_default=None,
|
||||
)
|
||||
|
||||
# is_default_provider and default_vision_model are already nullable with no server_default
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Restore default_model_name to NOT NULL (set empty string for any NULLs first)
|
||||
@@ -49,6 +54,13 @@ def downgrade() -> None:
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Restore unique constraint on is_default_provider
|
||||
op.create_unique_constraint(
|
||||
"llm_provider_is_default_provider_key",
|
||||
"llm_provider",
|
||||
["is_default_provider"],
|
||||
)
|
||||
|
||||
# Restore server_default for is_default_vision_provider
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
@@ -34,6 +34,7 @@ from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from ee.onyx.server.scim.filtering import ScimFilter
|
||||
from ee.onyx.server.scim.filtering import ScimFilterOperator
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from onyx.db.dal import DAL
|
||||
from onyx.db.models import ScimGroupMapping
|
||||
from onyx.db.models import ScimToken
|
||||
@@ -128,12 +129,19 @@ class ScimDAL(DAL):
|
||||
external_id: str,
|
||||
user_id: UUID,
|
||||
scim_username: str | None = None,
|
||||
fields: ScimMappingFields | None = None,
|
||||
) -> ScimUserMapping:
|
||||
"""Create a mapping between a SCIM externalId and an Onyx user."""
|
||||
f = fields or ScimMappingFields()
|
||||
mapping = ScimUserMapping(
|
||||
external_id=external_id,
|
||||
user_id=user_id,
|
||||
scim_username=scim_username,
|
||||
department=f.department,
|
||||
manager=f.manager,
|
||||
given_name=f.given_name,
|
||||
family_name=f.family_name,
|
||||
scim_emails_json=f.scim_emails_json,
|
||||
)
|
||||
self._session.add(mapping)
|
||||
self._session.flush()
|
||||
@@ -311,8 +319,14 @@ class ScimDAL(DAL):
|
||||
user_id: UUID,
|
||||
new_external_id: str | None,
|
||||
scim_username: str | None = None,
|
||||
fields: ScimMappingFields | None = None,
|
||||
) -> None:
|
||||
"""Create, update, or delete the external ID mapping for a user."""
|
||||
"""Create, update, or delete the external ID mapping for a user.
|
||||
|
||||
When *fields* is provided, all mapping fields are written
|
||||
unconditionally — including ``None`` values — so that a caller can
|
||||
clear a previously-set field (e.g. removing a department).
|
||||
"""
|
||||
mapping = self.get_user_mapping_by_user_id(user_id)
|
||||
if new_external_id:
|
||||
if mapping:
|
||||
@@ -320,11 +334,18 @@ class ScimDAL(DAL):
|
||||
mapping.external_id = new_external_id
|
||||
if scim_username is not None:
|
||||
mapping.scim_username = scim_username
|
||||
if fields is not None:
|
||||
mapping.department = fields.department
|
||||
mapping.manager = fields.manager
|
||||
mapping.given_name = fields.given_name
|
||||
mapping.family_name = fields.family_name
|
||||
mapping.scim_emails_json = fields.scim_emails_json
|
||||
else:
|
||||
self.create_user_mapping(
|
||||
external_id=new_external_id,
|
||||
user_id=user_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
elif mapping:
|
||||
self.delete_user_mapping(mapping.id)
|
||||
|
||||
@@ -4,7 +4,6 @@ from collections import deque
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from urllib.parse import unquote
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests as _requests
|
||||
@@ -598,8 +597,12 @@ def get_external_access_from_sharepoint(
|
||||
)
|
||||
elif site_page:
|
||||
site_url = site_page.get("webUrl")
|
||||
# Prefer server-relative URL to avoid OData filters that break on apostrophes
|
||||
server_relative_url = unquote(urlparse(site_url).path)
|
||||
# Keep percent-encoding intact so the path matches the encoding
|
||||
# used by the Office365 library's SPResPath.create_relative(),
|
||||
# which compares against urlparse(context.base_url).path.
|
||||
# Decoding (e.g. %27 → ') causes a mismatch that duplicates
|
||||
# the site prefix in the constructed URL.
|
||||
server_relative_url = urlparse(site_url).path
|
||||
file_obj = client_context.web.get_file_by_server_relative_url(
|
||||
server_relative_url
|
||||
)
|
||||
|
||||
@@ -31,6 +31,7 @@ from ee.onyx.server.scim.models import ScimError
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimListResponse
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimPatchRequest
|
||||
from ee.onyx.server.scim.models import ScimServiceProviderConfig
|
||||
@@ -40,6 +41,8 @@ from ee.onyx.server.scim.patch import apply_user_patch
|
||||
from ee.onyx.server.scim.patch import ScimPatchError
|
||||
from ee.onyx.server.scim.providers.base import get_default_provider
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
from ee.onyx.server.scim.providers.base import serialize_emails
|
||||
from ee.onyx.server.scim.schema_definitions import ENTERPRISE_USER_SCHEMA_DEF
|
||||
from ee.onyx.server.scim.schema_definitions import GROUP_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import GROUP_SCHEMA_DEF
|
||||
from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
|
||||
@@ -47,6 +50,7 @@ from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import ScimUserMapping
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
@@ -122,7 +126,7 @@ def get_schemas() -> ScimJSONResponse:
|
||||
Wrapped in a ListResponse envelope (RFC 7644 §3.4.2) because IdPs
|
||||
like Entra ID expect a JSON object, not a bare array.
|
||||
"""
|
||||
schemas = [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF]
|
||||
schemas = [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF, ENTERPRISE_USER_SCHEMA_DEF]
|
||||
return ScimJSONResponse(
|
||||
content={
|
||||
"schemas": [SCIM_LIST_RESPONSE_SCHEMA],
|
||||
@@ -261,6 +265,45 @@ def _build_list_response(
|
||||
)
|
||||
|
||||
|
||||
def _extract_enterprise_fields(
|
||||
resource: ScimUserResource,
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Extract department and manager from enterprise extension."""
|
||||
ext = resource.enterprise_extension
|
||||
if not ext:
|
||||
return None, None
|
||||
department = ext.department
|
||||
manager = ext.manager.value if ext.manager else None
|
||||
return department, manager
|
||||
|
||||
|
||||
def _mapping_to_fields(
|
||||
mapping: ScimUserMapping | None,
|
||||
) -> ScimMappingFields | None:
|
||||
"""Extract round-trip fields from a SCIM user mapping."""
|
||||
if not mapping:
|
||||
return None
|
||||
return ScimMappingFields(
|
||||
department=mapping.department,
|
||||
manager=mapping.manager,
|
||||
given_name=mapping.given_name,
|
||||
family_name=mapping.family_name,
|
||||
scim_emails_json=mapping.scim_emails_json,
|
||||
)
|
||||
|
||||
|
||||
def _fields_from_resource(resource: ScimUserResource) -> ScimMappingFields:
|
||||
"""Build mapping fields from an incoming SCIM user resource."""
|
||||
department, manager = _extract_enterprise_fields(resource)
|
||||
return ScimMappingFields(
|
||||
department=department,
|
||||
manager=manager,
|
||||
given_name=resource.name.givenName if resource.name else None,
|
||||
family_name=resource.name.familyName if resource.name else None,
|
||||
scim_emails_json=serialize_emails(resource.emails),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User CRUD (RFC 7644 §3)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -279,6 +322,7 @@ def list_users(
|
||||
"""List users with optional SCIM filter and pagination."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
dal.commit()
|
||||
|
||||
try:
|
||||
scim_filter = parse_scim_filter(filter)
|
||||
@@ -297,6 +341,7 @@ def list_users(
|
||||
mapping.external_id if mapping else None,
|
||||
groups=user_groups_map.get(user.id, []),
|
||||
scim_username=mapping.scim_username if mapping else None,
|
||||
fields=_mapping_to_fields(mapping),
|
||||
)
|
||||
for user, mapping in users_with_mappings
|
||||
]
|
||||
@@ -321,6 +366,7 @@ def get_user(
|
||||
"""Get a single user by ID."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
dal.commit()
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
@@ -334,6 +380,7 @@ def get_user(
|
||||
mapping.external_id if mapping else None,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=mapping.scim_username if mapping else None,
|
||||
fields=_mapping_to_fields(mapping),
|
||||
)
|
||||
|
||||
# RFC 7644 §3.4.2.5 — IdP may request certain attributes be omitted
|
||||
@@ -392,14 +439,23 @@ def create_user(
|
||||
# Create SCIM mapping (externalId is validated above, always present)
|
||||
external_id = user_resource.externalId
|
||||
scim_username = user_resource.userName.strip()
|
||||
fields = _fields_from_resource(user_resource)
|
||||
dal.create_user_mapping(
|
||||
external_id=external_id, user_id=user.id, scim_username=scim_username
|
||||
external_id=external_id,
|
||||
user_id=user.id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(user, external_id, scim_username=scim_username),
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
external_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
),
|
||||
status_code=201,
|
||||
)
|
||||
|
||||
@@ -438,7 +494,13 @@ def replace_user(
|
||||
|
||||
new_external_id = user_resource.externalId
|
||||
scim_username = user_resource.userName.strip()
|
||||
dal.sync_user_external_id(user.id, new_external_id, scim_username=scim_username)
|
||||
fields = _fields_from_resource(user_resource)
|
||||
dal.sync_user_external_id(
|
||||
user.id,
|
||||
new_external_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
@@ -448,6 +510,7 @@ def replace_user(
|
||||
new_external_id,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -476,16 +539,18 @@ def patch_user(
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
external_id = mapping.external_id if mapping else None
|
||||
current_scim_username = mapping.scim_username if mapping else None
|
||||
current_fields = _mapping_to_fields(mapping)
|
||||
|
||||
current = provider.build_user_resource(
|
||||
user,
|
||||
external_id,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=current_scim_username,
|
||||
fields=current_fields,
|
||||
)
|
||||
|
||||
try:
|
||||
patched = apply_user_patch(
|
||||
patched, ent_data = apply_user_patch(
|
||||
patch_request.Operations, current, provider.ignored_patch_paths
|
||||
)
|
||||
except ScimPatchError as e:
|
||||
@@ -520,8 +585,25 @@ def patch_user(
|
||||
personal_name=personal_name,
|
||||
)
|
||||
|
||||
# Build updated fields by merging PATCH enterprise data with current values
|
||||
cf = current_fields or ScimMappingFields()
|
||||
fields = ScimMappingFields(
|
||||
department=ent_data.get("department", cf.department),
|
||||
manager=ent_data.get("manager", cf.manager),
|
||||
given_name=patched.name.givenName if patched.name else cf.given_name,
|
||||
family_name=patched.name.familyName if patched.name else cf.family_name,
|
||||
scim_emails_json=(
|
||||
serialize_emails(patched.emails)
|
||||
if patched.emails is not None
|
||||
else cf.scim_emails_json
|
||||
),
|
||||
)
|
||||
|
||||
dal.sync_user_external_id(
|
||||
user.id, patched.externalId, scim_username=new_scim_username
|
||||
user.id,
|
||||
patched.externalId,
|
||||
scim_username=new_scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
@@ -532,6 +614,7 @@ def patch_user(
|
||||
patched.externalId,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=new_scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -640,6 +723,7 @@ def list_groups(
|
||||
"""List groups with optional SCIM filter and pagination."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
dal.commit()
|
||||
|
||||
try:
|
||||
scim_filter = parse_scim_filter(filter)
|
||||
@@ -676,6 +760,7 @@ def get_group(
|
||||
"""Get a single group by ID."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
dal.commit()
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
|
||||
@@ -7,6 +7,7 @@ SCIM protocol schemas follow the wire format defined in:
|
||||
Admin API schemas are internal to Onyx and used for SCIM token management.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
@@ -32,6 +33,9 @@ SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA = (
|
||||
)
|
||||
SCIM_RESOURCE_TYPE_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:ResourceType"
|
||||
SCIM_SCHEMA_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Schema"
|
||||
SCIM_ENTERPRISE_USER_SCHEMA = (
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -71,6 +75,36 @@ class ScimUserGroupRef(BaseModel):
|
||||
display: str | None = None
|
||||
|
||||
|
||||
class ScimManagerRef(BaseModel):
|
||||
"""Manager sub-attribute for the enterprise extension (RFC 7643 §4.3)."""
|
||||
|
||||
value: str | None = None
|
||||
|
||||
|
||||
class ScimEnterpriseExtension(BaseModel):
|
||||
"""Enterprise User extension attributes (RFC 7643 §4.3)."""
|
||||
|
||||
department: str | None = None
|
||||
manager: ScimManagerRef | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScimMappingFields:
|
||||
"""Stored SCIM mapping fields that need to round-trip through the IdP.
|
||||
|
||||
Entra ID sends structured name components, email metadata, and enterprise
|
||||
extension attributes that must be returned verbatim in subsequent GET
|
||||
responses. These fields are persisted on ScimUserMapping and threaded
|
||||
through the DAL, provider, and endpoint layers.
|
||||
"""
|
||||
|
||||
department: str | None = None
|
||||
manager: str | None = None
|
||||
given_name: str | None = None
|
||||
family_name: str | None = None
|
||||
scim_emails_json: str | None = None
|
||||
|
||||
|
||||
class ScimUserResource(BaseModel):
|
||||
"""SCIM User resource representation (RFC 7643 §4.1).
|
||||
|
||||
@@ -79,6 +113,8 @@ class ScimUserResource(BaseModel):
|
||||
to match the SCIM wire format (not Python convention).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_USER_SCHEMA])
|
||||
id: str | None = None # Onyx's internal user ID, set on responses
|
||||
externalId: str | None = None # IdP's identifier for this user
|
||||
@@ -89,6 +125,10 @@ class ScimUserResource(BaseModel):
|
||||
active: bool = True
|
||||
groups: list[ScimUserGroupRef] = Field(default_factory=list)
|
||||
meta: ScimMeta | None = None
|
||||
enterprise_extension: ScimEnterpriseExtension | None = Field(
|
||||
default=None,
|
||||
alias="urn:ietf:params:scim:schemas:extension:enterprise:2.0:User",
|
||||
)
|
||||
|
||||
|
||||
class ScimGroupMember(BaseModel):
|
||||
|
||||
@@ -14,10 +14,13 @@ responsible for persisting changes.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field
|
||||
from typing import Any
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimPatchOperation
|
||||
@@ -26,6 +29,11 @@ from ee.onyx.server.scim.models import ScimPatchResourceValue
|
||||
from ee.onyx.server.scim.models import ScimPatchValue
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Lowercased enterprise extension URN for case-insensitive matching
|
||||
_ENTERPRISE_URN_LOWER = SCIM_ENTERPRISE_USER_SCHEMA.lower()
|
||||
|
||||
# Pattern for email filter paths, e.g.:
|
||||
# emails[primary eq true].value (Okta)
|
||||
# emails[type eq "work"].value (Azure AD / Entra ID)
|
||||
@@ -86,6 +94,7 @@ class _UserPatchCtx:
|
||||
|
||||
data: dict[str, Any]
|
||||
name_data: dict[str, Any]
|
||||
ent_data: dict[str, str | None] = field(default_factory=dict)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -97,7 +106,7 @@ def apply_user_patch(
|
||||
operations: list[ScimPatchOperation],
|
||||
current: ScimUserResource,
|
||||
ignored_paths: frozenset[str] = frozenset(),
|
||||
) -> ScimUserResource:
|
||||
) -> tuple[ScimUserResource, dict[str, str | None]]:
|
||||
"""Apply SCIM PATCH operations to a user resource.
|
||||
|
||||
Args:
|
||||
@@ -105,8 +114,10 @@ def apply_user_patch(
|
||||
current: The current user resource state.
|
||||
ignored_paths: SCIM attribute paths to silently skip (from provider).
|
||||
|
||||
Returns a new ``ScimUserResource`` with the modifications applied.
|
||||
The original object is not mutated.
|
||||
Returns:
|
||||
A tuple of (modified user resource, enterprise extension data dict).
|
||||
The enterprise dict has keys ``"department"`` and ``"manager"``
|
||||
with values set only when a PATCH operation touched them.
|
||||
|
||||
Raises:
|
||||
ScimPatchError: If an operation targets an unsupported path.
|
||||
@@ -125,7 +136,7 @@ def apply_user_patch(
|
||||
)
|
||||
|
||||
ctx.data["name"] = ctx.name_data
|
||||
return ScimUserResource.model_validate(ctx.data)
|
||||
return ScimUserResource.model_validate(ctx.data), ctx.ent_data
|
||||
|
||||
|
||||
def _apply_user_replace(
|
||||
@@ -209,6 +220,8 @@ def _set_user_field(
|
||||
ctx.data["emails"] = value
|
||||
elif _EMAIL_FILTER_RE.match(path):
|
||||
_update_primary_email(ctx.data, value)
|
||||
elif path.startswith(_ENTERPRISE_URN_LOWER):
|
||||
_set_enterprise_field(path, value, ctx.ent_data)
|
||||
elif not strict:
|
||||
return
|
||||
else:
|
||||
@@ -227,6 +240,54 @@ def _update_primary_email(data: dict[str, Any], value: ScimPatchValue) -> None:
|
||||
data["emails"] = emails
|
||||
|
||||
|
||||
def _to_dict(value: ScimPatchValue) -> dict | None:
|
||||
"""Coerce a SCIM patch value to a plain dict if possible.
|
||||
|
||||
Pydantic may parse raw dicts as ``ScimPatchResourceValue`` (which uses
|
||||
``extra="allow"``), so we also dump those back to a dict.
|
||||
"""
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
if isinstance(value, ScimPatchResourceValue):
|
||||
return value.model_dump(exclude_unset=True)
|
||||
return None
|
||||
|
||||
|
||||
def _set_enterprise_field(
|
||||
path: str,
|
||||
value: ScimPatchValue,
|
||||
ent_data: dict[str, str | None],
|
||||
) -> None:
|
||||
"""Handle enterprise extension URN paths or value dicts."""
|
||||
# Full URN as key with dict value (path-less PATCH)
|
||||
# e.g. key="urn:...:user", value={"department": "Eng", "manager": {...}}
|
||||
if path == _ENTERPRISE_URN_LOWER:
|
||||
d = _to_dict(value)
|
||||
if d is not None:
|
||||
if "department" in d:
|
||||
ent_data["department"] = d["department"]
|
||||
if "manager" in d:
|
||||
mgr = d["manager"]
|
||||
if isinstance(mgr, dict):
|
||||
ent_data["manager"] = mgr.get("value")
|
||||
return
|
||||
|
||||
# Dotted URN path, e.g. "urn:...:user:department"
|
||||
suffix = path[len(_ENTERPRISE_URN_LOWER) :].lstrip(":").lower()
|
||||
if suffix == "department":
|
||||
ent_data["department"] = str(value) if value is not None else None
|
||||
elif suffix == "manager":
|
||||
d = _to_dict(value)
|
||||
if d is not None:
|
||||
ent_data["manager"] = d.get("value")
|
||||
elif isinstance(value, str):
|
||||
ent_data["manager"] = value
|
||||
else:
|
||||
# Unknown enterprise attributes are silently ignored rather than
|
||||
# rejected — IdPs may send attributes we don't model yet.
|
||||
logger.warning("Ignoring unknown enterprise extension attribute '%s'", suffix)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Group PATCH
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -2,13 +2,22 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimEmail
|
||||
from ee.onyx.server.scim.models import ScimEnterpriseExtension
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimManagerRef
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from ee.onyx.server.scim.models import ScimMeta
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimUserGroupRef
|
||||
@@ -16,6 +25,9 @@ from ee.onyx.server.scim.models import ScimUserResource
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
COMMON_IGNORED_PATCH_PATHS: frozenset[str] = frozenset(
|
||||
{
|
||||
"id",
|
||||
@@ -49,12 +61,22 @@ class ScimProvider(ABC):
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def user_schemas(self) -> list[str]:
|
||||
"""Schema URIs to include in User resource responses.
|
||||
|
||||
Override in subclasses to advertise additional schemas (e.g. the
|
||||
enterprise extension for Entra ID).
|
||||
"""
|
||||
return [SCIM_USER_SCHEMA]
|
||||
|
||||
def build_user_resource(
|
||||
self,
|
||||
user: User,
|
||||
external_id: str | None = None,
|
||||
groups: list[tuple[int, str]] | None = None,
|
||||
scim_username: str | None = None,
|
||||
fields: ScimMappingFields | None = None,
|
||||
) -> ScimUserResource:
|
||||
"""Build a SCIM User response from an Onyx User.
|
||||
|
||||
@@ -66,27 +88,48 @@ class ScimProvider(ABC):
|
||||
for newly-created users.
|
||||
scim_username: The original-case userName from the IdP. Falls
|
||||
back to ``user.email`` (lowercase) when not available.
|
||||
fields: Stored mapping fields that the IdP expects round-tripped.
|
||||
"""
|
||||
f = fields or ScimMappingFields()
|
||||
group_refs = [
|
||||
ScimUserGroupRef(value=str(gid), display=gname)
|
||||
for gid, gname in (groups or [])
|
||||
]
|
||||
|
||||
# Use original-case userName if stored, otherwise fall back to the
|
||||
# lowercased email from the User model.
|
||||
username = scim_username or user.email
|
||||
|
||||
return ScimUserResource(
|
||||
# Build enterprise extension when at least one value is present.
|
||||
# Dynamically add the enterprise URN to schemas per RFC 7643 §3.0.
|
||||
enterprise_ext: ScimEnterpriseExtension | None = None
|
||||
schemas = list(self.user_schemas)
|
||||
if f.department is not None or f.manager is not None:
|
||||
manager_ref = (
|
||||
ScimManagerRef(value=f.manager) if f.manager is not None else None
|
||||
)
|
||||
enterprise_ext = ScimEnterpriseExtension(
|
||||
department=f.department,
|
||||
manager=manager_ref,
|
||||
)
|
||||
if SCIM_ENTERPRISE_USER_SCHEMA not in schemas:
|
||||
schemas.append(SCIM_ENTERPRISE_USER_SCHEMA)
|
||||
|
||||
name = self.build_scim_name(user, f)
|
||||
emails = _deserialize_emails(f.scim_emails_json, username)
|
||||
|
||||
resource = ScimUserResource(
|
||||
schemas=schemas,
|
||||
id=str(user.id),
|
||||
externalId=external_id,
|
||||
userName=username,
|
||||
name=self._build_scim_name(user),
|
||||
name=name,
|
||||
displayName=user.personal_name,
|
||||
emails=[ScimEmail(value=username, type="work", primary=True)],
|
||||
emails=emails,
|
||||
active=user.is_active,
|
||||
groups=group_refs,
|
||||
meta=ScimMeta(resourceType="User"),
|
||||
)
|
||||
resource.enterprise_extension = enterprise_ext
|
||||
return resource
|
||||
|
||||
def build_group_resource(
|
||||
self,
|
||||
@@ -106,9 +149,24 @@ class ScimProvider(ABC):
|
||||
meta=ScimMeta(resourceType="Group"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_scim_name(user: User) -> ScimName | None:
|
||||
"""Extract SCIM name components from a user's personal name."""
|
||||
def build_scim_name(
|
||||
self,
|
||||
user: User,
|
||||
fields: ScimMappingFields,
|
||||
) -> ScimName | None:
|
||||
"""Build SCIM name components for the response.
|
||||
|
||||
Round-trips stored ``given_name``/``family_name`` when available (so
|
||||
the IdP gets back what it sent). Falls back to splitting
|
||||
``personal_name`` for users provisioned before we stored components.
|
||||
Providers may override for custom behavior.
|
||||
"""
|
||||
if fields.given_name is not None or fields.family_name is not None:
|
||||
return ScimName(
|
||||
givenName=fields.given_name,
|
||||
familyName=fields.family_name,
|
||||
formatted=user.personal_name,
|
||||
)
|
||||
if not user.personal_name:
|
||||
return None
|
||||
parts = user.personal_name.split(" ", 1)
|
||||
@@ -119,6 +177,27 @@ class ScimProvider(ABC):
|
||||
)
|
||||
|
||||
|
||||
def _deserialize_emails(stored_json: str | None, username: str) -> list[ScimEmail]:
|
||||
"""Deserialize stored email entries or build a default work email."""
|
||||
if stored_json:
|
||||
try:
|
||||
entries = json.loads(stored_json)
|
||||
if isinstance(entries, list) and entries:
|
||||
return [ScimEmail(**e) for e in entries]
|
||||
except (json.JSONDecodeError, TypeError, ValidationError):
|
||||
logger.warning(
|
||||
"Corrupt scim_emails_json, falling back to default: %s", stored_json
|
||||
)
|
||||
return [ScimEmail(value=username, type="work", primary=True)]
|
||||
|
||||
|
||||
def serialize_emails(emails: list[ScimEmail]) -> str | None:
|
||||
"""Serialize SCIM email entries to JSON for storage."""
|
||||
if not emails:
|
||||
return None
|
||||
return json.dumps([e.model_dump(exclude_none=True) for e in emails])
|
||||
|
||||
|
||||
def get_default_provider() -> ScimProvider:
|
||||
"""Return the default SCIM provider.
|
||||
|
||||
|
||||
36
backend/ee/onyx/server/scim/providers/entra.py
Normal file
36
backend/ee/onyx/server/scim/providers/entra.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Entra ID (Azure AD) SCIM provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
|
||||
_ENTRA_IGNORED_PATCH_PATHS = COMMON_IGNORED_PATCH_PATHS
|
||||
|
||||
|
||||
class EntraProvider(ScimProvider):
|
||||
"""Entra ID (Azure AD) SCIM provider.
|
||||
|
||||
Entra behavioral notes:
|
||||
- Sends capitalized PATCH ops (``"Add"``, ``"Replace"``, ``"Remove"``)
|
||||
— handled by ``ScimPatchOperation.normalize_op`` validator.
|
||||
- Sends the enterprise extension URN as a key in path-less PATCH value
|
||||
dicts — handled by ``_set_enterprise_field`` in ``patch.py`` to
|
||||
store department/manager values.
|
||||
- Expects the enterprise extension schema in ``schemas`` arrays and
|
||||
``/Schemas`` + ``/ResourceTypes`` discovery endpoints.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "entra"
|
||||
|
||||
@property
|
||||
def ignored_patch_paths(self) -> frozenset[str]:
|
||||
return _ENTRA_IGNORED_PATCH_PATHS
|
||||
|
||||
@property
|
||||
def user_schemas(self) -> list[str]:
|
||||
return [SCIM_USER_SCHEMA, SCIM_ENTERPRISE_USER_SCHEMA]
|
||||
@@ -4,6 +4,7 @@ Pre-built at import time — these never change at runtime. Separated from
|
||||
api.py to keep the endpoint module focused on request handling.
|
||||
"""
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_GROUP_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimResourceType
|
||||
@@ -20,6 +21,9 @@ USER_RESOURCE_TYPE = ScimResourceType.model_validate(
|
||||
"endpoint": "/scim/v2/Users",
|
||||
"description": "SCIM User resource",
|
||||
"schema": SCIM_USER_SCHEMA,
|
||||
"schemaExtensions": [
|
||||
{"schema": SCIM_ENTERPRISE_USER_SCHEMA, "required": False}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
@@ -104,6 +108,31 @@ USER_SCHEMA_DEF = ScimSchemaDefinition(
|
||||
],
|
||||
)
|
||||
|
||||
ENTERPRISE_USER_SCHEMA_DEF = ScimSchemaDefinition(
|
||||
id=SCIM_ENTERPRISE_USER_SCHEMA,
|
||||
name="EnterpriseUser",
|
||||
description="Enterprise User extension (RFC 7643 §4.3)",
|
||||
attributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="department",
|
||||
type="string",
|
||||
description="Department.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="manager",
|
||||
type="complex",
|
||||
description="The user's manager.",
|
||||
subAttributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="value",
|
||||
type="string",
|
||||
description="Manager user ID.",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
GROUP_SCHEMA_DEF = ScimSchemaDefinition(
|
||||
id=SCIM_GROUP_SCHEMA,
|
||||
name="Group",
|
||||
|
||||
@@ -123,21 +123,9 @@ def _seed_llms(
|
||||
upsert_llm_provider(llm_upsert_request, db_session)
|
||||
for llm_upsert_request in llm_upsert_requests
|
||||
]
|
||||
|
||||
if len(seeded_providers[0].model_configurations) > 0:
|
||||
default_model = next(
|
||||
(
|
||||
mc
|
||||
for mc in seeded_providers[0].model_configurations
|
||||
if mc.is_visible
|
||||
),
|
||||
seeded_providers[0].model_configurations[0],
|
||||
).name
|
||||
update_default_provider(
|
||||
provider_id=seeded_providers[0].id,
|
||||
model_name=default_model,
|
||||
db_session=db_session,
|
||||
)
|
||||
update_default_provider(
|
||||
provider_id=seeded_providers[0].id, db_session=db_session
|
||||
)
|
||||
|
||||
|
||||
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:
|
||||
|
||||
@@ -302,12 +302,12 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
|
||||
has_set_default_provider = False
|
||||
|
||||
def _upsert(request: LLMProviderUpsertRequest, default_model: str) -> None:
|
||||
def _upsert(request: LLMProviderUpsertRequest) -> None:
|
||||
nonlocal has_set_default_provider
|
||||
try:
|
||||
provider = upsert_llm_provider(request, db_session)
|
||||
if not has_set_default_provider:
|
||||
update_default_provider(provider.id, default_model, db_session)
|
||||
update_default_provider(provider.id, db_session)
|
||||
has_set_default_provider = True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure {request.provider} provider: {e}")
|
||||
@@ -325,13 +325,14 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
OPENAI_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(openai_provider, default_model_name)
|
||||
_upsert(openai_provider)
|
||||
|
||||
# Create default image generation config using the OpenAI API key
|
||||
try:
|
||||
@@ -360,13 +361,14 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="Anthropic",
|
||||
provider=ANTHROPIC_PROVIDER_NAME,
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
ANTHROPIC_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(anthropic_provider, default_model_name)
|
||||
_upsert(anthropic_provider)
|
||||
else:
|
||||
logger.info(
|
||||
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
|
||||
@@ -391,13 +393,14 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="Google Vertex AI",
|
||||
provider=VERTEXAI_PROVIDER_NAME,
|
||||
custom_config=custom_config,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
VERTEXAI_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(vertexai_provider, default_model_name)
|
||||
_upsert(vertexai_provider)
|
||||
else:
|
||||
logger.info(
|
||||
"VERTEXAI_DEFAULT_CREDENTIALS not set, skipping Vertex AI provider configuration"
|
||||
@@ -429,11 +432,12 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="OpenRouter",
|
||||
provider=OPENROUTER_PROVIDER_NAME,
|
||||
api_key=OPENROUTER_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=model_configurations,
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(openrouter_provider, default_model_name)
|
||||
_upsert(openrouter_provider)
|
||||
else:
|
||||
logger.info(
|
||||
"OPENROUTER_DEFAULT_API_KEY not set, skipping OpenRouter provider configuration"
|
||||
|
||||
@@ -58,16 +58,27 @@ class OAuthTokenManager:
|
||||
if not user_token.token_data:
|
||||
raise ValueError("No token data available for refresh")
|
||||
|
||||
if (
|
||||
self.oauth_config.client_id is None
|
||||
or self.oauth_config.client_secret is None
|
||||
):
|
||||
raise ValueError(
|
||||
"OAuth client_id and client_secret are required for token refresh"
|
||||
)
|
||||
|
||||
token_data = self._unwrap_token_data(user_token.token_data)
|
||||
|
||||
data: dict[str, str] = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": token_data["refresh_token"],
|
||||
"client_id": self._unwrap_sensitive_str(self.oauth_config.client_id),
|
||||
"client_secret": self._unwrap_sensitive_str(
|
||||
self.oauth_config.client_secret
|
||||
),
|
||||
}
|
||||
response = requests.post(
|
||||
self.oauth_config.token_url,
|
||||
data={
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": token_data["refresh_token"],
|
||||
"client_id": self.oauth_config.client_id,
|
||||
"client_secret": self.oauth_config.client_secret,
|
||||
},
|
||||
data=data,
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -115,15 +126,26 @@ class OAuthTokenManager:
|
||||
|
||||
def exchange_code_for_token(self, code: str, redirect_uri: str) -> dict[str, Any]:
|
||||
"""Exchange authorization code for access token"""
|
||||
if (
|
||||
self.oauth_config.client_id is None
|
||||
or self.oauth_config.client_secret is None
|
||||
):
|
||||
raise ValueError(
|
||||
"OAuth client_id and client_secret are required for code exchange"
|
||||
)
|
||||
|
||||
data: dict[str, str] = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"client_id": self._unwrap_sensitive_str(self.oauth_config.client_id),
|
||||
"client_secret": self._unwrap_sensitive_str(
|
||||
self.oauth_config.client_secret
|
||||
),
|
||||
"redirect_uri": redirect_uri,
|
||||
}
|
||||
response = requests.post(
|
||||
self.oauth_config.token_url,
|
||||
data={
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"client_id": self.oauth_config.client_id,
|
||||
"client_secret": self.oauth_config.client_secret,
|
||||
"redirect_uri": redirect_uri,
|
||||
},
|
||||
data=data,
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -141,8 +163,13 @@ class OAuthTokenManager:
|
||||
oauth_config: OAuthConfig, redirect_uri: str, state: str
|
||||
) -> str:
|
||||
"""Build OAuth authorization URL"""
|
||||
if oauth_config.client_id is None:
|
||||
raise ValueError("OAuth client_id is required to build authorization URL")
|
||||
|
||||
params: dict[str, Any] = {
|
||||
"client_id": oauth_config.client_id,
|
||||
"client_id": OAuthTokenManager._unwrap_sensitive_str(
|
||||
oauth_config.client_id
|
||||
),
|
||||
"redirect_uri": redirect_uri,
|
||||
"response_type": "code",
|
||||
"state": state,
|
||||
@@ -161,6 +188,12 @@ class OAuthTokenManager:
|
||||
|
||||
return f"{oauth_config.authorization_url}{separator}{urlencode(params)}"
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_sensitive_str(value: SensitiveValue[str] | str) -> str:
|
||||
if isinstance(value, SensitiveValue):
|
||||
return value.get_value(apply_mask=False)
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_token_data(
|
||||
token_data: SensitiveValue[dict[str, Any]] | dict[str, Any],
|
||||
|
||||
@@ -48,6 +48,7 @@ from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchDocumentIndex,
|
||||
)
|
||||
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
@@ -149,8 +150,12 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT)
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
opensearch_document_index = OpenSearchDocumentIndex(
|
||||
index_name=search_settings.index_name, tenant_state=tenant_state
|
||||
tenant_state=tenant_state,
|
||||
index_name=search_settings.index_name,
|
||||
embedding_dim=indexing_setting.final_embedding_dim,
|
||||
embedding_precision=indexing_setting.embedding_precision,
|
||||
)
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
|
||||
@@ -76,7 +76,7 @@ def _user_file_queued_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
|
||||
def user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
@@ -764,7 +764,7 @@ def process_single_user_file_project_sync(
|
||||
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_project_sync_lock_key(user_file_id),
|
||||
user_file_project_sync_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from queue import Empty
|
||||
from typing import Any
|
||||
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.emitter import Emitter
|
||||
@@ -163,13 +162,11 @@ class ChatStateContainer:
|
||||
|
||||
|
||||
def run_chat_loop_with_state_containers(
|
||||
func: Callable[..., None],
|
||||
chat_loop_func: Callable[[Emitter, ChatStateContainer], None],
|
||||
completion_callback: Callable[[ChatStateContainer], None],
|
||||
is_connected: Callable[[], bool],
|
||||
emitter: Emitter,
|
||||
state_container: ChatStateContainer,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Generator[Packet, None]:
|
||||
"""
|
||||
Explicit wrapper function that runs a function in a background thread
|
||||
@@ -180,19 +177,18 @@ def run_chat_loop_with_state_containers(
|
||||
|
||||
Args:
|
||||
func: The function to wrap (should accept emitter and state_container as first and second args)
|
||||
completion_callback: Callback function to call when the function completes
|
||||
emitter: Emitter instance for sending packets
|
||||
state_container: ChatStateContainer instance for accumulating state
|
||||
is_connected: Callable that returns False when stop signal is set
|
||||
*args: Additional positional arguments for func
|
||||
**kwargs: Additional keyword arguments for func
|
||||
|
||||
Usage:
|
||||
packets = run_chat_loop_with_state_containers(
|
||||
my_func,
|
||||
completion_callback=completion_callback,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
is_connected=check_func,
|
||||
arg1, arg2, kwarg1=value1
|
||||
)
|
||||
for packet in packets:
|
||||
# Process packets
|
||||
@@ -201,9 +197,7 @@ def run_chat_loop_with_state_containers(
|
||||
|
||||
def run_with_exception_capture() -> None:
|
||||
try:
|
||||
# Ensure state_container is passed explicitly, removing it from kwargs if present
|
||||
kwargs_with_state = {**kwargs, "state_container": state_container}
|
||||
func(emitter, *args, **kwargs_with_state)
|
||||
chat_loop_func(emitter, state_container)
|
||||
except Exception as e:
|
||||
# If execution fails, emit an exception packet
|
||||
emitter.emit(
|
||||
|
||||
@@ -461,7 +461,7 @@ def _build_tool_call_response_history_message(
|
||||
def convert_chat_history(
|
||||
chat_history: list[ChatMessage],
|
||||
files: list[ChatLoadedFile],
|
||||
project_image_files: list[ChatLoadedFile],
|
||||
context_image_files: list[ChatLoadedFile],
|
||||
additional_context: str | None,
|
||||
token_counter: Callable[[str], int],
|
||||
tool_id_to_name_map: dict[int, str],
|
||||
@@ -541,11 +541,11 @@ def convert_chat_history(
|
||||
)
|
||||
|
||||
# Add the user message with image files attached
|
||||
# If this is the last USER message, also include project_image_files
|
||||
# Note: project image file tokens are NOT counted in the token count
|
||||
# If this is the last USER message, also include context_image_files
|
||||
# Note: context image file tokens are NOT counted in the token count
|
||||
if idx == last_user_message_idx:
|
||||
if project_image_files:
|
||||
image_files.extend(project_image_files)
|
||||
if context_image_files:
|
||||
image_files.extend(context_image_files)
|
||||
|
||||
if additional_context:
|
||||
simple_messages.append(
|
||||
|
||||
@@ -15,10 +15,10 @@ from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.llm_step import extract_tool_calls_from_response_text
|
||||
from onyx.chat.llm_step import run_llm_step
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
from onyx.chat.models import ContextFileMetadata
|
||||
from onyx.chat.models import ExtractedContextFiles
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.models import ProjectFileMetadata
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.chat.prompt_utils import build_reminder_message
|
||||
from onyx.chat.prompt_utils import build_system_prompt
|
||||
@@ -203,17 +203,17 @@ def _try_fallback_tool_extraction(
|
||||
MAX_LLM_CYCLES = 6
|
||||
|
||||
|
||||
def _build_project_file_citation_mapping(
|
||||
project_file_metadata: list[ProjectFileMetadata],
|
||||
def _build_context_file_citation_mapping(
|
||||
file_metadata: list[ContextFileMetadata],
|
||||
starting_citation_num: int = 1,
|
||||
) -> CitationMapping:
|
||||
"""Build citation mapping for project files.
|
||||
"""Build citation mapping for context files.
|
||||
|
||||
Converts project file metadata into SearchDoc objects that can be cited.
|
||||
Converts context file metadata into SearchDoc objects that can be cited.
|
||||
Citation numbers start from the provided starting number.
|
||||
|
||||
Args:
|
||||
project_file_metadata: List of project file metadata
|
||||
file_metadata: List of context file metadata
|
||||
starting_citation_num: Starting citation number (default: 1)
|
||||
|
||||
Returns:
|
||||
@@ -221,8 +221,7 @@ def _build_project_file_citation_mapping(
|
||||
"""
|
||||
citation_mapping: CitationMapping = {}
|
||||
|
||||
for idx, file_meta in enumerate(project_file_metadata, start=starting_citation_num):
|
||||
# Create a SearchDoc for each project file
|
||||
for idx, file_meta in enumerate(file_metadata, start=starting_citation_num):
|
||||
search_doc = SearchDoc(
|
||||
document_id=file_meta.file_id,
|
||||
chunk_ind=0,
|
||||
@@ -242,29 +241,28 @@ def _build_project_file_citation_mapping(
|
||||
|
||||
|
||||
def _build_project_message(
|
||||
project_files: ExtractedProjectFiles | None,
|
||||
context_files: ExtractedContextFiles | None,
|
||||
token_counter: Callable[[str], int] | None,
|
||||
) -> list[ChatMessageSimple]:
|
||||
"""Build messages for project / tool-backed files.
|
||||
"""Build messages for context-injected / tool-backed files.
|
||||
|
||||
Returns up to two messages:
|
||||
1. The full-text project files message (if project_file_texts is populated).
|
||||
1. The full-text files message (if file_texts is populated).
|
||||
2. A lightweight metadata message for files the LLM should access via the
|
||||
FileReaderTool (e.g. oversized chat-attached files or project files that
|
||||
don't fit in context).
|
||||
FileReaderTool (e.g. oversized files that don't fit in context).
|
||||
"""
|
||||
if not project_files:
|
||||
if not context_files:
|
||||
return []
|
||||
|
||||
messages: list[ChatMessageSimple] = []
|
||||
if project_files.project_file_texts:
|
||||
if context_files.file_texts:
|
||||
messages.append(
|
||||
_create_project_files_message(project_files, token_counter=None)
|
||||
_create_context_files_message(context_files, token_counter=None)
|
||||
)
|
||||
if project_files.file_metadata_for_tool and token_counter:
|
||||
if context_files.file_metadata_for_tool and token_counter:
|
||||
messages.append(
|
||||
_create_file_tool_metadata_message(
|
||||
project_files.file_metadata_for_tool, token_counter
|
||||
context_files.file_metadata_for_tool, token_counter
|
||||
)
|
||||
)
|
||||
return messages
|
||||
@@ -275,7 +273,7 @@ def construct_message_history(
|
||||
custom_agent_prompt: ChatMessageSimple | None,
|
||||
simple_chat_history: list[ChatMessageSimple],
|
||||
reminder_message: ChatMessageSimple | None,
|
||||
project_files: ExtractedProjectFiles | None,
|
||||
context_files: ExtractedContextFiles | None,
|
||||
available_tokens: int,
|
||||
last_n_user_messages: int | None = None,
|
||||
token_counter: Callable[[str], int] | None = None,
|
||||
@@ -289,7 +287,7 @@ def construct_message_history(
|
||||
|
||||
# Build the project / file-metadata messages up front so we can use their
|
||||
# actual token counts for the budget.
|
||||
project_messages = _build_project_message(project_files, token_counter)
|
||||
project_messages = _build_project_message(context_files, token_counter)
|
||||
project_messages_tokens = sum(m.token_count for m in project_messages)
|
||||
|
||||
history_token_budget = available_tokens
|
||||
@@ -445,17 +443,17 @@ def construct_message_history(
|
||||
)
|
||||
|
||||
# Attach project images to the last user message
|
||||
if project_files and project_files.project_image_files:
|
||||
if context_files and context_files.image_files:
|
||||
existing_images = last_user_message.image_files or []
|
||||
last_user_message = ChatMessageSimple(
|
||||
message=last_user_message.message,
|
||||
token_count=last_user_message.token_count,
|
||||
message_type=last_user_message.message_type,
|
||||
image_files=existing_images + project_files.project_image_files,
|
||||
image_files=existing_images + context_files.image_files,
|
||||
)
|
||||
|
||||
# Build the final message list according to README ordering:
|
||||
# [system], [history_before_last_user], [custom_agent], [project_files],
|
||||
# [system], [history_before_last_user], [custom_agent], [context_files],
|
||||
# [forgotten_files], [last_user_message], [messages_after_last_user], [reminder]
|
||||
result = [system_prompt] if system_prompt else []
|
||||
|
||||
@@ -466,14 +464,14 @@ def construct_message_history(
|
||||
if custom_agent_prompt:
|
||||
result.append(custom_agent_prompt)
|
||||
|
||||
# 3. Add project files / file-metadata messages (inserted before last user message)
|
||||
# 3. Add context files / file-metadata messages (inserted before last user message)
|
||||
result.extend(project_messages)
|
||||
|
||||
# 4. Add forgotten-files metadata (right before the user's question)
|
||||
if forgotten_files_message:
|
||||
result.append(forgotten_files_message)
|
||||
|
||||
# 5. Add last user message (with project images attached)
|
||||
# 5. Add last user message (with context images attached)
|
||||
result.append(last_user_message)
|
||||
|
||||
# 6. Add messages after last user message (tool calls, responses, etc.)
|
||||
@@ -547,11 +545,11 @@ def _create_file_tool_metadata_message(
|
||||
)
|
||||
|
||||
|
||||
def _create_project_files_message(
|
||||
project_files: ExtractedProjectFiles,
|
||||
def _create_context_files_message(
|
||||
context_files: ExtractedContextFiles,
|
||||
token_counter: Callable[[str], int] | None, # noqa: ARG001
|
||||
) -> ChatMessageSimple:
|
||||
"""Convert project files to a ChatMessageSimple message.
|
||||
"""Convert context files to a ChatMessageSimple message.
|
||||
|
||||
Format follows the README specification for document representation.
|
||||
"""
|
||||
@@ -559,7 +557,7 @@ def _create_project_files_message(
|
||||
|
||||
# Format as documents JSON as described in README
|
||||
documents_list = []
|
||||
for idx, file_text in enumerate(project_files.project_file_texts, start=1):
|
||||
for idx, file_text in enumerate(context_files.file_texts, start=1):
|
||||
documents_list.append(
|
||||
{
|
||||
"document": idx,
|
||||
@@ -570,10 +568,10 @@ def _create_project_files_message(
|
||||
documents_json = json.dumps({"documents": documents_list}, indent=2)
|
||||
message_content = f"Here are some documents provided for context, they may not all be relevant:\n{documents_json}"
|
||||
|
||||
# Use pre-calculated token count from project_files
|
||||
# Use pre-calculated token count from context_files
|
||||
return ChatMessageSimple(
|
||||
message=message_content,
|
||||
token_count=project_files.total_token_count,
|
||||
token_count=context_files.total_token_count,
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
|
||||
@@ -584,7 +582,7 @@ def run_llm_loop(
|
||||
simple_chat_history: list[ChatMessageSimple],
|
||||
tools: list[Tool],
|
||||
custom_agent_prompt: str | None,
|
||||
project_files: ExtractedProjectFiles,
|
||||
context_files: ExtractedContextFiles,
|
||||
persona: Persona | None,
|
||||
user_memory_context: UserMemoryContext | None,
|
||||
llm: LLM,
|
||||
@@ -627,9 +625,9 @@ def run_llm_loop(
|
||||
|
||||
# Add project file citation mappings if project files are present
|
||||
project_citation_mapping: CitationMapping = {}
|
||||
if project_files.project_file_metadata:
|
||||
project_citation_mapping = _build_project_file_citation_mapping(
|
||||
project_files.project_file_metadata
|
||||
if context_files.file_metadata:
|
||||
project_citation_mapping = _build_context_file_citation_mapping(
|
||||
context_files.file_metadata
|
||||
)
|
||||
citation_processor.update_citation_mapping(project_citation_mapping)
|
||||
|
||||
@@ -647,7 +645,7 @@ def run_llm_loop(
|
||||
# TODO allow citing of images in Projects. Since attached to the last user message, it has no text associated with it.
|
||||
# One future workaround is to include the images as separate user messages with citation information and process those.
|
||||
always_cite_documents: bool = bool(
|
||||
project_files.project_as_filter or project_files.project_file_texts
|
||||
context_files.use_as_search_filter or context_files.file_texts
|
||||
)
|
||||
should_cite_documents: bool = False
|
||||
ran_image_gen: bool = False
|
||||
@@ -788,7 +786,7 @@ def run_llm_loop(
|
||||
custom_agent_prompt=custom_agent_prompt_msg,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=reminder_msg,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=available_tokens,
|
||||
token_counter=token_counter,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
|
||||
@@ -31,13 +31,6 @@ class CustomToolResponse(BaseModel):
|
||||
tool_name: str
|
||||
|
||||
|
||||
class ProjectSearchConfig(BaseModel):
|
||||
"""Configuration for search tool availability in project context."""
|
||||
|
||||
search_usage: SearchToolUsage
|
||||
disable_forced_tool: bool
|
||||
|
||||
|
||||
class CreateChatSessionID(BaseModel):
|
||||
chat_session_id: UUID
|
||||
|
||||
@@ -132,8 +125,8 @@ class ChatMessageSimple(BaseModel):
|
||||
file_id: str | None = None
|
||||
|
||||
|
||||
class ProjectFileMetadata(BaseModel):
|
||||
"""Metadata for a project file to enable citation support."""
|
||||
class ContextFileMetadata(BaseModel):
|
||||
"""Metadata for a context-injected file to enable citation support."""
|
||||
|
||||
file_id: str
|
||||
filename: str
|
||||
@@ -167,20 +160,28 @@ class ChatHistoryResult(BaseModel):
|
||||
all_injected_file_metadata: dict[str, FileToolMetadata]
|
||||
|
||||
|
||||
class ExtractedProjectFiles(BaseModel):
|
||||
project_file_texts: list[str]
|
||||
project_image_files: list[ChatLoadedFile]
|
||||
project_as_filter: bool
|
||||
class ExtractedContextFiles(BaseModel):
|
||||
"""Result of attempting to load user files (from a project or persona) into context."""
|
||||
|
||||
file_texts: list[str]
|
||||
image_files: list[ChatLoadedFile]
|
||||
use_as_search_filter: bool
|
||||
total_token_count: int
|
||||
# Metadata for project files to enable citations
|
||||
project_file_metadata: list[ProjectFileMetadata]
|
||||
# None if not a project
|
||||
project_uncapped_token_count: int | None
|
||||
# Lightweight metadata for files exposed via FileReaderTool
|
||||
# (populated when files don't fit in context and vector DB is disabled)
|
||||
# (populated when files don't fit in context and vector DB is disabled).
|
||||
file_metadata: list[ContextFileMetadata]
|
||||
uncapped_token_count: int | None
|
||||
file_metadata_for_tool: list[FileToolMetadata] = []
|
||||
|
||||
|
||||
class SearchParams(BaseModel):
|
||||
"""Resolved search filter IDs and search-tool usage for a chat turn."""
|
||||
|
||||
search_project_id: int | None
|
||||
search_persona_id: int | None
|
||||
search_usage: SearchToolUsage
|
||||
|
||||
|
||||
class LlmStepResult(BaseModel):
|
||||
reasoning: str | None
|
||||
answer: str | None
|
||||
|
||||
@@ -3,6 +3,7 @@ IMPORTANT: familiarize yourself with the design concepts prior to contributing t
|
||||
An overview can be found in the README.md file in this directory.
|
||||
"""
|
||||
|
||||
import io
|
||||
import re
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
@@ -33,11 +34,11 @@ from onyx.chat.models import ChatBasicResponse
|
||||
from onyx.chat.models import ChatFullResponse
|
||||
from onyx.chat.models import ChatLoadedFile
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ContextFileMetadata
|
||||
from onyx.chat.models import CreateChatSessionID
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
from onyx.chat.models import ExtractedContextFiles
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import ProjectFileMetadata
|
||||
from onyx.chat.models import ProjectSearchConfig
|
||||
from onyx.chat.models import SearchParams
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import ToolCallResponse
|
||||
from onyx.chat.prompt_utils import calculate_reserved_tokens
|
||||
@@ -62,11 +63,12 @@ from onyx.db.models import ChatSession
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.projects import get_project_token_count
|
||||
from onyx.db.projects import get_user_files_from_project
|
||||
from onyx.db.tools import get_tools
|
||||
from onyx.deep_research.dr_loop import run_deep_research_llm_loop
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.file_store.utils import load_in_memory_chat_files
|
||||
from onyx.file_store.utils import verify_user_files
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
@@ -139,12 +141,12 @@ def _collect_available_file_ids(
|
||||
pass
|
||||
|
||||
if project_id:
|
||||
project_files = get_user_files_from_project(
|
||||
user_files = get_user_files_from_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
for uf in project_files:
|
||||
for uf in user_files:
|
||||
user_file_ids.add(uf.id)
|
||||
|
||||
return _AvailableFiles(
|
||||
@@ -192,9 +194,67 @@ def _convert_loaded_files_to_chat_files(
|
||||
return chat_files
|
||||
|
||||
|
||||
def _extract_project_file_texts_and_images(
|
||||
def resolve_context_user_files(
|
||||
persona: Persona,
|
||||
project_id: int | None,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> list[UserFile]:
|
||||
"""Apply the precedence rule to decide which user files to load.
|
||||
|
||||
A custom persona fully supersedes the project. When a chat uses a
|
||||
custom persona, the project is purely organisational — its files are
|
||||
never loaded and never made searchable.
|
||||
|
||||
Custom persona → persona's own user_files (may be empty).
|
||||
Default persona inside a project → project files.
|
||||
Otherwise → empty list.
|
||||
"""
|
||||
if persona.id != DEFAULT_PERSONA_ID:
|
||||
return list(persona.user_files) if persona.user_files else []
|
||||
if project_id:
|
||||
return get_user_files_from_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
def _empty_extracted_context_files() -> ExtractedContextFiles:
|
||||
return ExtractedContextFiles(
|
||||
file_texts=[],
|
||||
image_files=[],
|
||||
use_as_search_filter=False,
|
||||
total_token_count=0,
|
||||
file_metadata=[],
|
||||
uncapped_token_count=None,
|
||||
)
|
||||
|
||||
|
||||
def _extract_text_from_in_memory_file(f: InMemoryChatFile) -> str | None:
|
||||
"""Extract text content from an InMemoryChatFile.
|
||||
|
||||
PLAIN_TEXT: the content is pre-extracted UTF-8 plaintext stored during
|
||||
ingestion — decode directly.
|
||||
DOC / CSV / other text types: the content is the original file bytes —
|
||||
use extract_file_text which handles encoding detection and format parsing.
|
||||
"""
|
||||
try:
|
||||
if f.file_type == ChatFileType.PLAIN_TEXT:
|
||||
return f.content.decode("utf-8", errors="ignore").replace("\x00", "")
|
||||
return extract_file_text(
|
||||
file=io.BytesIO(f.content),
|
||||
file_name=f.filename or "",
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(f"Failed to extract text from file {f.file_id}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def extract_context_files(
|
||||
user_files: list[UserFile],
|
||||
llm_max_context_window: int,
|
||||
reserved_token_count: int,
|
||||
db_session: Session,
|
||||
@@ -203,8 +263,12 @@ def _extract_project_file_texts_and_images(
|
||||
# 60% of the LLM's max context window. The other benefit is that for projects with
|
||||
# more files, this makes it so that we don't throw away the history too quickly every time.
|
||||
max_llm_context_percentage: float = 0.6,
|
||||
) -> ExtractedProjectFiles:
|
||||
"""Extract text content from project files if they fit within the context window.
|
||||
) -> ExtractedContextFiles:
|
||||
"""Load user files into context if they fit; otherwise flag for search.
|
||||
|
||||
The caller is responsible for deciding *which* user files to pass in
|
||||
(project files, persona files, etc.). This function only cares about
|
||||
the all-or-nothing fit check and the actual content loading.
|
||||
|
||||
Args:
|
||||
project_id: The project ID to load files from
|
||||
@@ -213,160 +277,95 @@ def _extract_project_file_texts_and_images(
|
||||
reserved_token_count: Number of tokens to reserve for other content
|
||||
db_session: Database session
|
||||
max_llm_context_percentage: Maximum percentage of the LLM context window to use.
|
||||
|
||||
Returns:
|
||||
ExtractedProjectFiles containing:
|
||||
- List of text content strings from project files (text files only)
|
||||
- List of image files from project (ChatLoadedFile objects)
|
||||
- Project id if the the project should be provided as a filter in search or None if not.
|
||||
ExtractedContextFiles containing:
|
||||
- List of text content strings from context files (text files only)
|
||||
- List of image files from context (ChatLoadedFile objects)
|
||||
- Total token count of all extracted files
|
||||
- File metadata for context files
|
||||
- Uncapped token count of all extracted files
|
||||
- File metadata for files that don't fit in context and vector DB is disabled
|
||||
"""
|
||||
# TODO I believe this is not handling all file types correctly.
|
||||
project_as_filter = False
|
||||
if not project_id:
|
||||
return ExtractedProjectFiles(
|
||||
project_file_texts=[],
|
||||
project_image_files=[],
|
||||
project_as_filter=False,
|
||||
total_token_count=0,
|
||||
project_file_metadata=[],
|
||||
project_uncapped_token_count=None,
|
||||
)
|
||||
# TODO(yuhong): I believe this is not handling all file types correctly.
|
||||
|
||||
if not user_files:
|
||||
return _empty_extracted_context_files()
|
||||
|
||||
aggregate_tokens = sum(uf.token_count or 0 for uf in user_files)
|
||||
max_actual_tokens = (
|
||||
llm_max_context_window - reserved_token_count
|
||||
) * max_llm_context_percentage
|
||||
|
||||
# Calculate total token count for all user files in the project
|
||||
project_tokens = get_project_token_count(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
if aggregate_tokens >= max_actual_tokens:
|
||||
tool_metadata = []
|
||||
use_as_search_filter = not DISABLE_VECTOR_DB
|
||||
if DISABLE_VECTOR_DB:
|
||||
tool_metadata = _build_file_tool_metadata_for_user_files(user_files)
|
||||
return ExtractedContextFiles(
|
||||
file_texts=[],
|
||||
image_files=[],
|
||||
use_as_search_filter=use_as_search_filter,
|
||||
total_token_count=0,
|
||||
file_metadata=[],
|
||||
uncapped_token_count=aggregate_tokens,
|
||||
file_metadata_for_tool=tool_metadata,
|
||||
)
|
||||
|
||||
# Files fit — load them into context
|
||||
user_file_map = {str(uf.id): uf for uf in user_files}
|
||||
in_memory_files = load_in_memory_chat_files(
|
||||
user_file_ids=[uf.id for uf in user_files],
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
project_file_texts: list[str] = []
|
||||
project_image_files: list[ChatLoadedFile] = []
|
||||
project_file_metadata: list[ProjectFileMetadata] = []
|
||||
file_texts: list[str] = []
|
||||
image_files: list[ChatLoadedFile] = []
|
||||
file_metadata: list[ContextFileMetadata] = []
|
||||
total_token_count = 0
|
||||
if project_tokens < max_actual_tokens:
|
||||
# Load project files into memory using cached plaintext when available
|
||||
project_user_files = get_user_files_from_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
if project_user_files:
|
||||
# Create a mapping from file_id to UserFile for token count lookup
|
||||
user_file_map = {str(file.id): file for file in project_user_files}
|
||||
|
||||
project_file_ids = [file.id for file in project_user_files]
|
||||
in_memory_project_files = load_in_memory_chat_files(
|
||||
user_file_ids=project_file_ids,
|
||||
db_session=db_session,
|
||||
for f in in_memory_files:
|
||||
uf = user_file_map.get(str(f.file_id))
|
||||
if f.file_type.is_text_file():
|
||||
text_content = _extract_text_from_in_memory_file(f)
|
||||
if not text_content:
|
||||
continue
|
||||
file_texts.append(text_content)
|
||||
file_metadata.append(
|
||||
ContextFileMetadata(
|
||||
file_id=str(f.file_id),
|
||||
filename=f.filename or f"file_{f.file_id}",
|
||||
file_content=text_content,
|
||||
)
|
||||
)
|
||||
if uf and uf.token_count:
|
||||
total_token_count += uf.token_count
|
||||
elif f.file_type == ChatFileType.IMAGE:
|
||||
token_count = uf.token_count if uf and uf.token_count else 0
|
||||
total_token_count += token_count
|
||||
image_files.append(
|
||||
ChatLoadedFile(
|
||||
file_id=f.file_id,
|
||||
content=f.content,
|
||||
file_type=f.file_type,
|
||||
filename=f.filename,
|
||||
content_text=None,
|
||||
token_count=token_count,
|
||||
)
|
||||
)
|
||||
|
||||
# Extract text content from loaded files
|
||||
for file in in_memory_project_files:
|
||||
if file.file_type.is_text_file():
|
||||
try:
|
||||
text_content = file.content.decode("utf-8", errors="ignore")
|
||||
# Strip null bytes
|
||||
text_content = text_content.replace("\x00", "")
|
||||
if text_content:
|
||||
project_file_texts.append(text_content)
|
||||
# Add metadata for citation support
|
||||
project_file_metadata.append(
|
||||
ProjectFileMetadata(
|
||||
file_id=str(file.file_id),
|
||||
filename=file.filename or f"file_{file.file_id}",
|
||||
file_content=text_content,
|
||||
)
|
||||
)
|
||||
# Add token count for text file
|
||||
user_file = user_file_map.get(str(file.file_id))
|
||||
if user_file and user_file.token_count:
|
||||
total_token_count += user_file.token_count
|
||||
except Exception:
|
||||
# Skip files that can't be decoded
|
||||
pass
|
||||
elif file.file_type == ChatFileType.IMAGE:
|
||||
# Convert InMemoryChatFile to ChatLoadedFile
|
||||
user_file = user_file_map.get(str(file.file_id))
|
||||
token_count = (
|
||||
user_file.token_count
|
||||
if user_file and user_file.token_count
|
||||
else 0
|
||||
)
|
||||
total_token_count += token_count
|
||||
chat_loaded_file = ChatLoadedFile(
|
||||
file_id=file.file_id,
|
||||
content=file.content,
|
||||
file_type=file.file_type,
|
||||
filename=file.filename,
|
||||
content_text=None, # Images don't have text content
|
||||
token_count=token_count,
|
||||
)
|
||||
project_image_files.append(chat_loaded_file)
|
||||
else:
|
||||
if DISABLE_VECTOR_DB:
|
||||
# Without a vector DB we can't use project-as-filter search.
|
||||
# Instead, build lightweight metadata so the LLM can call the
|
||||
# FileReaderTool to inspect individual files on demand.
|
||||
file_metadata_for_tool = _build_file_tool_metadata_for_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
return ExtractedProjectFiles(
|
||||
project_file_texts=[],
|
||||
project_image_files=[],
|
||||
project_as_filter=False,
|
||||
total_token_count=0,
|
||||
project_file_metadata=[],
|
||||
project_uncapped_token_count=project_tokens,
|
||||
file_metadata_for_tool=file_metadata_for_tool,
|
||||
)
|
||||
project_as_filter = True
|
||||
|
||||
return ExtractedProjectFiles(
|
||||
project_file_texts=project_file_texts,
|
||||
project_image_files=project_image_files,
|
||||
project_as_filter=project_as_filter,
|
||||
return ExtractedContextFiles(
|
||||
file_texts=file_texts,
|
||||
image_files=image_files,
|
||||
use_as_search_filter=False,
|
||||
total_token_count=total_token_count,
|
||||
project_file_metadata=project_file_metadata,
|
||||
project_uncapped_token_count=project_tokens,
|
||||
file_metadata=file_metadata,
|
||||
uncapped_token_count=aggregate_tokens,
|
||||
)
|
||||
|
||||
|
||||
APPROX_CHARS_PER_TOKEN = 4
|
||||
|
||||
|
||||
def _build_file_tool_metadata_for_project(
|
||||
project_id: int,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> list[FileToolMetadata]:
|
||||
"""Build lightweight FileToolMetadata for every file in a project.
|
||||
|
||||
Used when files are too large to fit in context and the vector DB is
|
||||
disabled, so the LLM needs to know which files it can read via the
|
||||
FileReaderTool.
|
||||
"""
|
||||
project_user_files = get_user_files_from_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
return [
|
||||
FileToolMetadata(
|
||||
file_id=str(uf.id),
|
||||
filename=uf.name,
|
||||
approx_char_count=(uf.token_count or 0) * APPROX_CHARS_PER_TOKEN,
|
||||
)
|
||||
for uf in project_user_files
|
||||
]
|
||||
|
||||
|
||||
def _build_file_tool_metadata_for_user_files(
|
||||
user_files: list[UserFile],
|
||||
) -> list[FileToolMetadata]:
|
||||
@@ -381,55 +380,46 @@ def _build_file_tool_metadata_for_user_files(
|
||||
]
|
||||
|
||||
|
||||
def _get_project_search_availability(
|
||||
def determine_search_params(
|
||||
persona_id: int,
|
||||
project_id: int | None,
|
||||
persona_id: int | None,
|
||||
loaded_project_files: bool,
|
||||
project_has_files: bool,
|
||||
forced_tool_id: int | None,
|
||||
search_tool_id: int | None,
|
||||
) -> ProjectSearchConfig:
|
||||
"""Determine search tool availability based on project context.
|
||||
extracted_context_files: ExtractedContextFiles,
|
||||
) -> SearchParams:
|
||||
"""Decide which search filter IDs and search-tool usage apply for a chat turn.
|
||||
|
||||
Search is disabled when ALL of the following are true:
|
||||
- User is in a project
|
||||
- Using the default persona (not a custom agent)
|
||||
- Project files are already loaded in context
|
||||
A custom persona fully supersedes the project — project files are never
|
||||
searchable and the search tool config is entirely controlled by the
|
||||
persona. The project_id filter is only set for the default persona.
|
||||
|
||||
When search is disabled and the user tried to force the search tool,
|
||||
that forcing is also disabled.
|
||||
|
||||
Returns AUTO (follow persona config) in all other cases.
|
||||
For the default persona inside a project:
|
||||
- Files overflow → ENABLED (vector DB scopes to these files)
|
||||
- Files fit → DISABLED (content already in prompt)
|
||||
- No files at all → DISABLED (nothing to search)
|
||||
"""
|
||||
# Not in a project, this should have no impact on search tool availability
|
||||
if not project_id:
|
||||
return ProjectSearchConfig(
|
||||
search_usage=SearchToolUsage.AUTO, disable_forced_tool=False
|
||||
)
|
||||
is_custom_persona = persona_id != DEFAULT_PERSONA_ID
|
||||
|
||||
# Custom persona in project - let persona config decide
|
||||
# Even if there are no files in the project, it's still guided by the persona config.
|
||||
if persona_id != DEFAULT_PERSONA_ID:
|
||||
return ProjectSearchConfig(
|
||||
search_usage=SearchToolUsage.AUTO, disable_forced_tool=False
|
||||
)
|
||||
search_project_id: int | None = None
|
||||
search_persona_id: int | None = None
|
||||
if extracted_context_files.use_as_search_filter:
|
||||
if is_custom_persona:
|
||||
search_persona_id = persona_id
|
||||
else:
|
||||
search_project_id = project_id
|
||||
|
||||
# If in a project with the default persona and the files have been already loaded into the context or
|
||||
# there are no files in the project, disable search as there is nothing to search for.
|
||||
if loaded_project_files or not project_has_files:
|
||||
user_forced_search = (
|
||||
forced_tool_id is not None
|
||||
and search_tool_id is not None
|
||||
and forced_tool_id == search_tool_id
|
||||
)
|
||||
return ProjectSearchConfig(
|
||||
search_usage=SearchToolUsage.DISABLED,
|
||||
disable_forced_tool=user_forced_search,
|
||||
)
|
||||
search_usage = SearchToolUsage.AUTO
|
||||
if not is_custom_persona and project_id:
|
||||
has_context_files = bool(extracted_context_files.uncapped_token_count)
|
||||
files_loaded_in_context = bool(extracted_context_files.file_texts)
|
||||
|
||||
# Default persona in a project with files, but also the files have not been loaded into the context already.
|
||||
return ProjectSearchConfig(
|
||||
search_usage=SearchToolUsage.ENABLED, disable_forced_tool=False
|
||||
if extracted_context_files.use_as_search_filter:
|
||||
search_usage = SearchToolUsage.ENABLED
|
||||
elif files_loaded_in_context or not has_context_files:
|
||||
search_usage = SearchToolUsage.DISABLED
|
||||
|
||||
return SearchParams(
|
||||
search_project_id=search_project_id,
|
||||
search_persona_id=search_persona_id,
|
||||
search_usage=search_usage,
|
||||
)
|
||||
|
||||
|
||||
@@ -661,26 +651,37 @@ def handle_stream_message_objects(
|
||||
user_memory_context=prompt_memory_context,
|
||||
)
|
||||
|
||||
# Process projects, if all of the files fit in the context, it doesn't need to use RAG
|
||||
extracted_project_files = _extract_project_file_texts_and_images(
|
||||
# Determine which user files to use. A custom persona fully
|
||||
# supersedes the project — project files are never loaded or
|
||||
# searchable when a custom persona is in play. Only the default
|
||||
# persona inside a project uses the project's files.
|
||||
context_user_files = resolve_context_user_files(
|
||||
persona=persona,
|
||||
project_id=chat_session.project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
extracted_context_files = extract_context_files(
|
||||
user_files=context_user_files,
|
||||
llm_max_context_window=llm.config.max_input_tokens,
|
||||
reserved_token_count=reserved_token_count,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# When the vector DB is disabled, persona-attached user_files have no
|
||||
# search pipeline path. Inject them as file_metadata_for_tool so the
|
||||
# LLM can read them via the FileReaderTool.
|
||||
if DISABLE_VECTOR_DB and persona.user_files:
|
||||
persona_file_metadata = _build_file_tool_metadata_for_user_files(
|
||||
persona.user_files
|
||||
)
|
||||
# Merge persona file metadata into the extracted project files
|
||||
extracted_project_files.file_metadata_for_tool.extend(persona_file_metadata)
|
||||
search_params = determine_search_params(
|
||||
persona_id=persona.id,
|
||||
project_id=chat_session.project_id,
|
||||
extracted_context_files=extracted_context_files,
|
||||
)
|
||||
|
||||
# Also grant access to persona-attached user files for FileReaderTool
|
||||
if persona.user_files:
|
||||
existing = set(available_files.user_file_ids)
|
||||
for uf in persona.user_files:
|
||||
if uf.id not in existing:
|
||||
available_files.user_file_ids.append(uf.id)
|
||||
|
||||
# Build a mapping of tool_id to tool_name for history reconstruction
|
||||
all_tools = get_tools(db_session)
|
||||
tool_id_to_name_map = {tool.id: tool.name for tool in all_tools}
|
||||
|
||||
@@ -689,30 +690,17 @@ def handle_stream_message_objects(
|
||||
None,
|
||||
)
|
||||
|
||||
# Determine if search should be disabled for this project context
|
||||
forced_tool_id = new_msg_req.forced_tool_id
|
||||
project_search_config = _get_project_search_availability(
|
||||
project_id=chat_session.project_id,
|
||||
persona_id=persona.id,
|
||||
loaded_project_files=bool(extracted_project_files.project_file_texts),
|
||||
project_has_files=bool(
|
||||
extracted_project_files.project_uncapped_token_count
|
||||
),
|
||||
forced_tool_id=new_msg_req.forced_tool_id,
|
||||
search_tool_id=search_tool_id,
|
||||
)
|
||||
if project_search_config.disable_forced_tool:
|
||||
if (
|
||||
search_params.search_usage == SearchToolUsage.DISABLED
|
||||
and forced_tool_id is not None
|
||||
and search_tool_id is not None
|
||||
and forced_tool_id == search_tool_id
|
||||
):
|
||||
forced_tool_id = None
|
||||
|
||||
emitter = get_default_emitter()
|
||||
|
||||
# Also grant access to persona-attached user files
|
||||
if persona.user_files:
|
||||
existing = set(available_files.user_file_ids)
|
||||
for uf in persona.user_files:
|
||||
if uf.id not in existing:
|
||||
available_files.user_file_ids.append(uf.id)
|
||||
|
||||
# Construct tools based on the persona configurations
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
@@ -722,11 +710,8 @@ def handle_stream_message_objects(
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=new_msg_req.internal_search_filters,
|
||||
project_id=(
|
||||
chat_session.project_id
|
||||
if extracted_project_files.project_as_filter
|
||||
else None
|
||||
),
|
||||
project_id=search_params.search_project_id,
|
||||
persona_id=search_params.search_persona_id,
|
||||
bypass_acl=bypass_acl,
|
||||
slack_context=slack_context,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
@@ -744,7 +729,7 @@ def handle_stream_message_objects(
|
||||
chat_file_ids=available_files.chat_file_ids,
|
||||
),
|
||||
allowed_tool_ids=new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=project_search_config.search_usage,
|
||||
search_usage_forcing_setting=search_params.search_usage,
|
||||
)
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
@@ -783,7 +768,7 @@ def handle_stream_message_objects(
|
||||
chat_history_result = convert_chat_history(
|
||||
chat_history=chat_history,
|
||||
files=files,
|
||||
project_image_files=extracted_project_files.project_image_files,
|
||||
context_image_files=extracted_context_files.image_files,
|
||||
additional_context=additional_context,
|
||||
token_counter=token_counter,
|
||||
tool_id_to_name_map=tool_id_to_name_map,
|
||||
@@ -879,46 +864,54 @@ def handle_stream_message_objects(
|
||||
# (user has already responded to a clarification question)
|
||||
skip_clarification = is_last_assistant_message_clarification(chat_history)
|
||||
|
||||
# NOTE: we _could_ pass in a zero argument function since emitter and state_container
|
||||
# are just passed in immediately anyways, but the abstraction is cleaner this way.
|
||||
yield from run_chat_loop_with_state_containers(
|
||||
run_deep_research_llm_loop,
|
||||
lambda emitter, state_container: run_deep_research_llm_loop(
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
skip_clarification=skip_clarification,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
),
|
||||
llm_loop_completion_callback,
|
||||
is_connected=check_is_connected,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
skip_clarification=skip_clarification,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
else:
|
||||
yield from run_chat_loop_with_state_containers(
|
||||
run_llm_loop,
|
||||
lambda emitter, state_container: run_llm_loop(
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
context_files=extracted_context_files,
|
||||
persona=persona,
|
||||
user_memory_context=user_memory_context,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
forced_tool_id=forced_tool_id,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
chat_files=chat_files_for_tools,
|
||||
include_citations=new_msg_req.include_citations,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
),
|
||||
llm_loop_completion_callback,
|
||||
is_connected=check_is_connected, # Not passed through to run_llm_loop
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
project_files=extracted_project_files,
|
||||
persona=persona,
|
||||
user_memory_context=user_memory_context,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
forced_tool_id=forced_tool_id,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
chat_files=chat_files_for_tools,
|
||||
include_citations=new_msg_req.include_citations,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
|
||||
@@ -294,6 +294,12 @@ ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX = (
|
||||
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
and os.environ.get("ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX", "").lower() == "true"
|
||||
)
|
||||
# Whether we should check for and create an index if necessary every time we
|
||||
# instantiate an OpenSearchDocumentIndex on multitenant cloud. Defaults to True.
|
||||
VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT = (
|
||||
os.environ.get("VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT", "true").lower()
|
||||
== "true"
|
||||
)
|
||||
|
||||
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
|
||||
# NOTE: this is used if and only if the vespa config server is accessible via a
|
||||
|
||||
@@ -32,6 +32,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
BASE_URL = "https://api.gong.io"
|
||||
MAX_CALL_DETAILS_ATTEMPTS = 6
|
||||
CALL_DETAILS_DELAY = 30 # in seconds
|
||||
# Gong API limit is 3 calls/sec — stay safely under it
|
||||
MIN_REQUEST_INTERVAL = 0.5 # seconds between requests
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -45,9 +47,13 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
self.continue_on_fail = continue_on_fail
|
||||
self.auth_token_basic: str | None = None
|
||||
self.hide_user_info = hide_user_info
|
||||
self._last_request_time: float = 0.0
|
||||
|
||||
# urllib3 Retry already respects the Retry-After header by default
|
||||
# (respect_retry_after_header=True), so on 429 it will sleep for the
|
||||
# duration Gong specifies before retrying.
|
||||
retry_strategy = Retry(
|
||||
total=5,
|
||||
total=10,
|
||||
backoff_factor=2,
|
||||
status_forcelist=[429, 500, 502, 503, 504],
|
||||
)
|
||||
@@ -61,8 +67,24 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
url = f"{GongConnector.BASE_URL}{endpoint}"
|
||||
return url
|
||||
|
||||
def _throttled_request(
|
||||
self, method: str, url: str, **kwargs: Any
|
||||
) -> requests.Response:
|
||||
"""Rate-limited request wrapper. Enforces MIN_REQUEST_INTERVAL between
|
||||
calls to stay under Gong's 3 calls/sec limit and avoid triggering 429s."""
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._last_request_time
|
||||
if elapsed < self.MIN_REQUEST_INTERVAL:
|
||||
time.sleep(self.MIN_REQUEST_INTERVAL - elapsed)
|
||||
|
||||
response = self._session.request(method, url, **kwargs)
|
||||
self._last_request_time = time.monotonic()
|
||||
return response
|
||||
|
||||
def _get_workspace_id_map(self) -> dict[str, str]:
|
||||
response = self._session.get(GongConnector.make_url("/v2/workspaces"))
|
||||
response = self._throttled_request(
|
||||
"GET", GongConnector.make_url("/v2/workspaces")
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
workspaces_details = response.json().get("workspaces")
|
||||
@@ -106,8 +128,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
del body["filter"]["workspaceId"]
|
||||
|
||||
while True:
|
||||
response = self._session.post(
|
||||
GongConnector.make_url("/v2/calls/transcript"), json=body
|
||||
response = self._throttled_request(
|
||||
"POST", GongConnector.make_url("/v2/calls/transcript"), json=body
|
||||
)
|
||||
# If no calls in the range, just break out
|
||||
if response.status_code == 404:
|
||||
@@ -142,8 +164,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
"contentSelector": {"exposedFields": {"parties": True}},
|
||||
}
|
||||
|
||||
response = self._session.post(
|
||||
GongConnector.make_url("/v2/calls/extensive"), json=body
|
||||
response = self._throttled_request(
|
||||
"POST", GongConnector.make_url("/v2/calls/extensive"), json=body
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -194,7 +216,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
# There's a likely race condition in the API where a transcript will have a
|
||||
# call id but the call to v2/calls/extensive will not return all of the id's
|
||||
# retry with exponential backoff has been observed to mitigate this
|
||||
# in ~2 minutes
|
||||
# in ~2 minutes. After max attempts, proceed with whatever we have —
|
||||
# the per-call loop below will skip missing IDs gracefully.
|
||||
current_attempt = 0
|
||||
while True:
|
||||
current_attempt += 1
|
||||
@@ -213,11 +236,14 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
f"missing_call_ids={missing_call_ids}"
|
||||
)
|
||||
if current_attempt >= self.MAX_CALL_DETAILS_ATTEMPTS:
|
||||
raise RuntimeError(
|
||||
f"Attempt count exceeded for _get_call_details_by_ids: "
|
||||
f"missing_call_ids={missing_call_ids} "
|
||||
f"max_attempts={self.MAX_CALL_DETAILS_ATTEMPTS}"
|
||||
logger.error(
|
||||
f"Giving up on missing call id's after "
|
||||
f"{self.MAX_CALL_DETAILS_ATTEMPTS} attempts: "
|
||||
f"missing_call_ids={missing_call_ids} — "
|
||||
f"proceeding with {len(call_details_map)} of "
|
||||
f"{len(transcript_call_ids)} calls"
|
||||
)
|
||||
break
|
||||
|
||||
wait_seconds = self.CALL_DETAILS_DELAY * pow(2, current_attempt - 1)
|
||||
logger.warning(
|
||||
|
||||
@@ -23,7 +23,6 @@ from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.serialization import pkcs12
|
||||
from office365.graph_client import GraphClient # type: ignore[import-untyped]
|
||||
from office365.intune.organizations.organization import Organization # type: ignore[import-untyped]
|
||||
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
|
||||
from office365.onedrive.sites.site import Site # type: ignore[import-untyped]
|
||||
from office365.onedrive.sites.sites_with_root import SitesWithRoot # type: ignore[import-untyped]
|
||||
@@ -872,6 +871,56 @@ class SharepointConnector(
|
||||
"Site URLs must be full Sharepoint URLs (e.g. https://your-tenant.sharepoint.com/sites/your-site or https://your-tenant.sharepoint.com/teams/your-team)"
|
||||
)
|
||||
|
||||
def _extract_tenant_domain_from_sites(self) -> str | None:
|
||||
"""Extract the tenant domain from configured site URLs.
|
||||
|
||||
Site URLs look like https://{tenant}.sharepoint.com/sites/... so the
|
||||
tenant domain is the first label of the hostname.
|
||||
"""
|
||||
for site_url in self.sites:
|
||||
try:
|
||||
hostname = urlsplit(site_url.strip()).hostname
|
||||
except ValueError:
|
||||
continue
|
||||
if not hostname:
|
||||
continue
|
||||
tenant = hostname.split(".")[0]
|
||||
if tenant:
|
||||
return tenant
|
||||
logger.warning(f"No tenant domain found from {len(self.sites)} sites")
|
||||
return None
|
||||
|
||||
def _resolve_tenant_domain_from_root_site(self) -> str:
|
||||
"""Resolve tenant domain via GET /v1.0/sites/root which only requires
|
||||
Sites.Read.All (a permission the connector already needs)."""
|
||||
root_site = self.graph_client.sites.root.get().execute_query()
|
||||
hostname = root_site.site_collection.hostname
|
||||
if not hostname:
|
||||
raise ConnectorValidationError(
|
||||
"Could not determine tenant domain from root site"
|
||||
)
|
||||
tenant_domain = hostname.split(".")[0]
|
||||
logger.info(
|
||||
"Resolved tenant domain '%s' from root site hostname '%s'",
|
||||
tenant_domain,
|
||||
hostname,
|
||||
)
|
||||
return tenant_domain
|
||||
|
||||
def _resolve_tenant_domain(self) -> str:
|
||||
"""Determine the tenant domain, preferring site URLs over a Graph API
|
||||
call to avoid needing extra permissions."""
|
||||
from_sites = self._extract_tenant_domain_from_sites()
|
||||
if from_sites:
|
||||
logger.info(
|
||||
"Resolved tenant domain '%s' from site URLs",
|
||||
from_sites,
|
||||
)
|
||||
return from_sites
|
||||
|
||||
logger.info("No site URLs available; resolving tenant domain from root site")
|
||||
return self._resolve_tenant_domain_from_root_site()
|
||||
|
||||
@property
|
||||
def graph_client(self) -> GraphClient:
|
||||
if self._graph_client is None:
|
||||
@@ -1589,6 +1638,11 @@ class SharepointConnector(
|
||||
sp_private_key = credentials.get("sp_private_key")
|
||||
sp_certificate_password = credentials.get("sp_certificate_password")
|
||||
|
||||
if not sp_client_id:
|
||||
raise ConnectorValidationError("Client ID is required")
|
||||
if not sp_directory_id:
|
||||
raise ConnectorValidationError("Directory (tenant) ID is required")
|
||||
|
||||
authority_url = f"{self.authority_host}/{sp_directory_id}"
|
||||
|
||||
if auth_method == SharepointAuthMethod.CERTIFICATE.value:
|
||||
@@ -1641,21 +1695,7 @@ class SharepointConnector(
|
||||
_acquire_token_for_graph, environment=self._azure_environment
|
||||
)
|
||||
if auth_method == SharepointAuthMethod.CERTIFICATE.value:
|
||||
org = self.graph_client.organization.get().execute_query()
|
||||
if not org or len(org) == 0:
|
||||
raise ConnectorValidationError("No organization found")
|
||||
|
||||
tenant_info: Organization = org[
|
||||
0
|
||||
] # Access first item directly from collection
|
||||
if not tenant_info.verified_domains:
|
||||
raise ConnectorValidationError("No verified domains found for tenant")
|
||||
|
||||
sp_tenant_domain = tenant_info.verified_domains[0].name
|
||||
if not sp_tenant_domain:
|
||||
raise ConnectorValidationError("No verified domains found for tenant")
|
||||
# remove the .onmicrosoft.com part
|
||||
self.sp_tenant_domain = sp_tenant_domain.split(".")[0]
|
||||
self.sp_tenant_domain = self._resolve_tenant_domain()
|
||||
return None
|
||||
|
||||
def _get_drive_names_for_site(self, site_url: str) -> list[str]:
|
||||
|
||||
@@ -72,6 +72,7 @@ class BaseFilters(BaseModel):
|
||||
class UserFileFilters(BaseModel):
|
||||
user_file_ids: list[UUID] | None = None
|
||||
project_id: int | None = None
|
||||
persona_id: int | None = None
|
||||
|
||||
|
||||
class AssistantKnowledgeFilters(BaseModel):
|
||||
|
||||
@@ -40,6 +40,7 @@ def _build_index_filters(
|
||||
user_provided_filters: BaseFilters | None,
|
||||
user: User, # Used for ACLs, anonymous users only see public docs
|
||||
project_id: int | None,
|
||||
persona_id: int | None,
|
||||
user_file_ids: list[UUID] | None,
|
||||
persona_document_sets: list[str] | None,
|
||||
persona_time_cutoff: datetime | None,
|
||||
@@ -118,6 +119,7 @@ def _build_index_filters(
|
||||
final_filters = IndexFilters(
|
||||
user_file_ids=user_file_ids,
|
||||
project_id=project_id,
|
||||
persona_id=persona_id,
|
||||
source_type=source_filter,
|
||||
document_set=document_set_filter,
|
||||
time_cutoff=time_filter,
|
||||
@@ -265,6 +267,8 @@ def search_pipeline(
|
||||
llm: LLM | None = None,
|
||||
# If a project ID is provided, it will be exclusively scoped to that project
|
||||
project_id: int | None = None,
|
||||
# If a persona_id is provided, search scopes to files attached to this persona
|
||||
persona_id: int | None = None,
|
||||
# Pre-fetched data — when provided, avoids DB queries (no session needed)
|
||||
acl_filters: list[str] | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
@@ -299,6 +303,7 @@ def search_pipeline(
|
||||
user_provided_filters=chunk_search_request.user_selected_filters,
|
||||
user=user,
|
||||
project_id=project_id,
|
||||
persona_id=persona_id,
|
||||
user_file_ids=user_uploaded_persona_files,
|
||||
persona_document_sets=persona_document_sets,
|
||||
persona_time_cutoff=persona_time_cutoff,
|
||||
|
||||
@@ -21,8 +21,8 @@ from onyx.configs.app_configs import POSTGRES_POOL_RECYCLE
|
||||
from onyx.configs.app_configs import POSTGRES_PORT
|
||||
from onyx.configs.app_configs import POSTGRES_USE_NULL_POOL
|
||||
from onyx.configs.app_configs import POSTGRES_USER
|
||||
from onyx.db.engine.iam_auth import create_ssl_context_if_iam
|
||||
from onyx.db.engine.iam_auth import get_iam_auth_token
|
||||
from onyx.db.engine.iam_auth import ssl_context
|
||||
from onyx.db.engine.sql_engine import ASYNC_DB_API
|
||||
from onyx.db.engine.sql_engine import build_connection_string
|
||||
from onyx.db.engine.sql_engine import is_valid_schema_name
|
||||
@@ -66,7 +66,7 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
if app_name:
|
||||
connect_args["server_settings"] = {"application_name": app_name}
|
||||
|
||||
connect_args["ssl"] = ssl_context
|
||||
connect_args["ssl"] = create_ssl_context_if_iam()
|
||||
|
||||
engine_kwargs = {
|
||||
"connect_args": connect_args,
|
||||
@@ -97,7 +97,7 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
user = POSTGRES_USER
|
||||
token = get_iam_auth_token(host, port, user, AWS_REGION_NAME)
|
||||
cparams["password"] = token
|
||||
cparams["ssl"] = ssl_context
|
||||
cparams["ssl"] = create_ssl_context_if_iam()
|
||||
|
||||
return _ASYNC_ENGINE
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import functools
|
||||
import os
|
||||
import ssl
|
||||
from typing import Any
|
||||
@@ -48,11 +49,9 @@ def provide_iam_token(
|
||||
configure_psycopg2_iam_auth(cparams, host, port, user, region)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def create_ssl_context_if_iam() -> ssl.SSLContext | None:
|
||||
"""Create an SSL context if IAM authentication is enabled, else return None."""
|
||||
if USE_IAM_AUTH:
|
||||
return ssl.create_default_context(cafile=SSL_CERT_FILE)
|
||||
return None
|
||||
|
||||
|
||||
ssl_context = create_ssl_context_if_iam()
|
||||
|
||||
@@ -213,12 +213,8 @@ def upsert_llm_provider(
|
||||
llm_provider_upsert_request: LLMProviderUpsertRequest,
|
||||
db_session: Session,
|
||||
) -> LLMProviderView:
|
||||
existing_llm_provider = (
|
||||
fetch_existing_llm_provider_by_id(
|
||||
id=llm_provider_upsert_request.id, db_session=db_session
|
||||
)
|
||||
if llm_provider_upsert_request.id
|
||||
else None
|
||||
existing_llm_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
)
|
||||
|
||||
if not existing_llm_provider:
|
||||
@@ -242,6 +238,11 @@ def upsert_llm_provider(
|
||||
existing_llm_provider.api_base = api_base
|
||||
existing_llm_provider.api_version = llm_provider_upsert_request.api_version
|
||||
existing_llm_provider.custom_config = custom_config
|
||||
# TODO: Remove default model name on api change
|
||||
# Needed due to /provider/{id}/default endpoint not disclosing the default model name
|
||||
existing_llm_provider.default_model_name = (
|
||||
llm_provider_upsert_request.default_model_name
|
||||
)
|
||||
existing_llm_provider.is_public = llm_provider_upsert_request.is_public
|
||||
existing_llm_provider.is_auto_mode = llm_provider_upsert_request.is_auto_mode
|
||||
existing_llm_provider.deployment_name = llm_provider_upsert_request.deployment_name
|
||||
@@ -250,10 +251,6 @@ def upsert_llm_provider(
|
||||
# If its not already in the db, we need to generate an ID by flushing
|
||||
db_session.flush()
|
||||
|
||||
models_to_exist = {
|
||||
mc.name for mc in llm_provider_upsert_request.model_configurations
|
||||
}
|
||||
|
||||
# Build a lookup of existing model configurations by name (single iteration)
|
||||
existing_by_name = {
|
||||
mc.name: mc for mc in existing_llm_provider.model_configurations
|
||||
@@ -309,6 +306,15 @@ def upsert_llm_provider(
|
||||
display_name=model_config.display_name,
|
||||
)
|
||||
|
||||
default_model = fetch_default_model(db_session, LLMModelFlowType.CHAT)
|
||||
if default_model and default_model.llm_provider_id == existing_llm_provider.id:
|
||||
_update_default_model(
|
||||
db_session=db_session,
|
||||
provider_id=existing_llm_provider.id,
|
||||
model=existing_llm_provider.default_model_name,
|
||||
flow_type=LLMModelFlowType.CHAT,
|
||||
)
|
||||
|
||||
# Make sure the relationship table stays up to date
|
||||
update_group_llm_provider_relationships__no_commit(
|
||||
llm_provider_id=existing_llm_provider.id,
|
||||
@@ -482,22 +488,6 @@ def fetch_existing_llm_provider(
|
||||
return provider_model
|
||||
|
||||
|
||||
def fetch_existing_llm_provider_by_id(
|
||||
id: int, db_session: Session
|
||||
) -> LLMProviderModel | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel)
|
||||
.where(LLMProviderModel.id == id)
|
||||
.options(
|
||||
selectinload(LLMProviderModel.model_configurations),
|
||||
selectinload(LLMProviderModel.groups),
|
||||
selectinload(LLMProviderModel.personas),
|
||||
)
|
||||
)
|
||||
|
||||
return provider_model
|
||||
|
||||
|
||||
def fetch_embedding_provider(
|
||||
db_session: Session, provider_type: EmbeddingProvider
|
||||
) -> CloudEmbeddingProviderModel | None:
|
||||
@@ -614,13 +604,22 @@ def remove_llm_provider__no_commit(db_session: Session, provider_id: int) -> Non
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def update_default_provider(
|
||||
provider_id: int, model_name: str, db_session: Session
|
||||
) -> None:
|
||||
def update_default_provider(provider_id: int, db_session: Session) -> None:
|
||||
# Attempt to get the default_model_name from the provider first
|
||||
# TODO: Remove default_model_name check
|
||||
provider = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.id == provider_id,
|
||||
)
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f"LLM Provider with id={provider_id} does not exist")
|
||||
|
||||
_update_default_model(
|
||||
db_session,
|
||||
provider_id,
|
||||
model_name,
|
||||
provider.default_model_name, # type: ignore[arg-type]
|
||||
LLMModelFlowType.CHAT,
|
||||
)
|
||||
|
||||
@@ -806,6 +805,12 @@ def sync_auto_mode_models(
|
||||
)
|
||||
changes += 1
|
||||
|
||||
# In Auto mode, default model is always set from GitHub config
|
||||
default_model = llm_recommendations.get_default_model(provider.provider)
|
||||
if default_model and provider.default_model_name != default_model.name:
|
||||
provider.default_model_name = default_model.name
|
||||
changes += 1
|
||||
|
||||
db_session.commit()
|
||||
return changes
|
||||
|
||||
@@ -861,6 +866,7 @@ def insert_new_model_configuration__no_commit(
|
||||
is_visible=is_visible,
|
||||
max_input_tokens=max_input_tokens,
|
||||
display_name=display_name,
|
||||
supports_image_input=LLMModelFlowType.VISION in supported_flows,
|
||||
)
|
||||
.on_conflict_do_nothing()
|
||||
.returning(ModelConfiguration.id)
|
||||
@@ -895,6 +901,7 @@ def update_model_configuration__no_commit(
|
||||
is_visible=is_visible,
|
||||
max_input_tokens=max_input_tokens,
|
||||
display_name=display_name,
|
||||
supports_image_input=LLMModelFlowType.VISION in supported_flows,
|
||||
)
|
||||
.where(ModelConfiguration.id == model_configuration_id)
|
||||
.returning(ModelConfiguration)
|
||||
|
||||
@@ -2823,8 +2823,17 @@ class LLMProvider(Base):
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
|
||||
# Deprecated: use LLMModelFlow with CHAT flow type instead
|
||||
default_model_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
# Deprecated: use LLMModelFlow.is_default with CHAT flow type instead
|
||||
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
# Deprecated: use LLMModelFlow.is_default with VISION flow type instead
|
||||
is_default_vision_provider: Mapped[bool | None] = mapped_column(Boolean)
|
||||
# Deprecated: use LLMModelFlow with VISION flow type instead
|
||||
default_vision_model: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
# EE only
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
# Auto mode: models, visibility, and defaults are managed by GitHub config
|
||||
@@ -2874,6 +2883,9 @@ class ModelConfiguration(Base):
|
||||
# - The end-user is configuring a model and chooses not to set a max-input-tokens limit.
|
||||
max_input_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Deprecated: use LLMModelFlow with VISION flow type instead
|
||||
supports_image_input: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
|
||||
# Human-readable display name for the model.
|
||||
# For dynamic providers (OpenRouter, Bedrock, Ollama), this comes from the source API.
|
||||
# For static providers (OpenAI, Anthropic), this may be null and will fall back to LiteLLM.
|
||||
|
||||
@@ -256,9 +256,6 @@ def create_update_persona(
|
||||
try:
|
||||
# Default persona validation
|
||||
if create_persona_request.is_default_persona:
|
||||
if not create_persona_request.is_public:
|
||||
raise ValueError("Cannot make a default persona non public")
|
||||
|
||||
# Curators can edit default personas, but not make them
|
||||
if user.role == UserRole.CURATOR or user.role == UserRole.GLOBAL_CURATOR:
|
||||
pass
|
||||
@@ -335,6 +332,7 @@ def update_persona_shared(
|
||||
db_session: Session,
|
||||
group_ids: list[int] | None = None,
|
||||
is_public: bool | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
) -> None:
|
||||
"""Simplified version of `create_update_persona` which only touches the
|
||||
accessibility rather than any of the logic (e.g. prompt, connected data sources,
|
||||
@@ -344,9 +342,7 @@ def update_persona_shared(
|
||||
)
|
||||
|
||||
if user and user.role != UserRole.ADMIN and persona.user_id != user.id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have permission to modify this persona"
|
||||
)
|
||||
raise PermissionError("You don't have permission to modify this persona")
|
||||
|
||||
versioned_update_persona_access = fetch_versioned_implementation(
|
||||
"onyx.db.persona", "update_persona_access"
|
||||
@@ -360,6 +356,15 @@ def update_persona_shared(
|
||||
group_ids=group_ids,
|
||||
)
|
||||
|
||||
if label_ids is not None:
|
||||
labels = (
|
||||
db_session.query(PersonaLabel).filter(PersonaLabel.id.in_(label_ids)).all()
|
||||
)
|
||||
if len(labels) != len(label_ids):
|
||||
raise ValueError("Some label IDs were not found in the database")
|
||||
persona.labels.clear()
|
||||
persona.labels = labels
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -965,6 +970,8 @@ def upsert_persona(
|
||||
labels = (
|
||||
db_session.query(PersonaLabel).filter(PersonaLabel.id.in_(label_ids)).all()
|
||||
)
|
||||
if len(labels) != len(label_ids):
|
||||
raise ValueError("Some label IDs were not found in the database")
|
||||
|
||||
# Fetch and attach hierarchy_nodes by IDs
|
||||
hierarchy_nodes = None
|
||||
@@ -1161,9 +1168,6 @@ def update_persona_is_default(
|
||||
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
|
||||
)
|
||||
|
||||
if not persona.is_public:
|
||||
persona.is_public = True
|
||||
|
||||
persona.is_default_persona = is_default
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import Project__UserFile
|
||||
from onyx.db.models import UserFile
|
||||
|
||||
|
||||
@@ -57,12 +58,19 @@ def fetch_user_project_ids_for_user_files(
|
||||
db_session: Session,
|
||||
) -> dict[str, list[int]]:
|
||||
"""Fetch user project ids for specified user files"""
|
||||
stmt = select(UserFile).where(UserFile.id.in_(user_file_ids))
|
||||
results = db_session.execute(stmt).scalars().all()
|
||||
return {
|
||||
str(user_file.id): [project.id for project in user_file.projects]
|
||||
for user_file in results
|
||||
user_file_uuid_ids = [UUID(user_file_id) for user_file_id in user_file_ids]
|
||||
stmt = select(Project__UserFile.user_file_id, Project__UserFile.project_id).where(
|
||||
Project__UserFile.user_file_id.in_(user_file_uuid_ids)
|
||||
)
|
||||
rows = db_session.execute(stmt).all()
|
||||
|
||||
user_file_id_to_project_ids: dict[str, list[int]] = {
|
||||
user_file_id: [] for user_file_id in user_file_ids
|
||||
}
|
||||
for user_file_id, project_id in rows:
|
||||
user_file_id_to_project_ids[str(user_file_id)].append(project_id)
|
||||
|
||||
return user_file_id_to_project_ids
|
||||
|
||||
|
||||
def fetch_persona_ids_for_user_files(
|
||||
|
||||
@@ -139,7 +139,7 @@ def generate_final_report(
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=history,
|
||||
reminder_message=reminder_message,
|
||||
project_files=None,
|
||||
context_files=None,
|
||||
available_tokens=llm.config.max_input_tokens,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
@@ -257,7 +257,7 @@ def run_deep_research_llm_loop(
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=None,
|
||||
context_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
@@ -321,7 +321,7 @@ def run_deep_research_llm_loop(
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history + [reminder_message],
|
||||
reminder_message=None,
|
||||
project_files=None,
|
||||
context_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT + 1,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
@@ -485,7 +485,7 @@ def run_deep_research_llm_loop(
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=first_cycle_reminder_message,
|
||||
project_files=None,
|
||||
context_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
|
||||
@@ -11,6 +11,7 @@ from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchOldDocumentIndex,
|
||||
)
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
@@ -49,8 +50,11 @@ def get_default_document_index(
|
||||
|
||||
opensearch_retrieval_enabled = get_opensearch_retrieval_state(db_session)
|
||||
if opensearch_retrieval_enabled:
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
return OpenSearchOldDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
embedding_dim=indexing_setting.final_embedding_dim,
|
||||
embedding_precision=indexing_setting.embedding_precision,
|
||||
secondary_index_name=secondary_index_name,
|
||||
large_chunks_enabled=search_settings.large_chunks_enabled,
|
||||
secondary_large_chunks_enabled=secondary_large_chunks_enabled,
|
||||
@@ -118,8 +122,11 @@ def get_all_document_indices(
|
||||
)
|
||||
opensearch_document_index: OpenSearchOldDocumentIndex | None = None
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
opensearch_document_index = OpenSearchOldDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
embedding_dim=indexing_setting.final_embedding_dim,
|
||||
embedding_precision=indexing_setting.embedding_precision,
|
||||
secondary_index_name=None,
|
||||
large_chunks_enabled=False,
|
||||
secondary_large_chunks_enabled=None,
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import logging
|
||||
import time
|
||||
from contextlib import AbstractContextManager
|
||||
from contextlib import nullcontext
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TypeVar
|
||||
@@ -83,22 +85,26 @@ def get_new_body_without_vectors(body: dict[str, Any]) -> dict[str, Any]:
|
||||
return new_body
|
||||
|
||||
|
||||
class OpenSearchClient:
|
||||
"""Client for interacting with OpenSearch.
|
||||
class OpenSearchClient(AbstractContextManager):
|
||||
"""Client for interacting with OpenSearch for cluster-level operations.
|
||||
|
||||
OpenSearch's Python module has pretty bad typing support so this client
|
||||
attempts to protect the rest of the codebase from this. As a consequence,
|
||||
most methods here return the minimum data needed for the rest of Onyx, and
|
||||
tend to rely on Exceptions to handle errors.
|
||||
|
||||
TODO(andrei): This class currently assumes the structure of the database
|
||||
schema when it returns a DocumentChunk. Make the class, or at least the
|
||||
search method, templated on the structure the caller can expect.
|
||||
Args:
|
||||
host: The host of the OpenSearch cluster.
|
||||
port: The port of the OpenSearch cluster.
|
||||
auth: The authentication credentials for the OpenSearch cluster. A tuple
|
||||
of (username, password).
|
||||
use_ssl: Whether to use SSL for the OpenSearch cluster. Defaults to
|
||||
True.
|
||||
verify_certs: Whether to verify the SSL certificates for the OpenSearch
|
||||
cluster. Defaults to False.
|
||||
ssl_show_warn: Whether to show warnings for SSL certificates. Defaults
|
||||
to False.
|
||||
timeout: The timeout for the OpenSearch cluster. Defaults to
|
||||
DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
host: str = OPENSEARCH_HOST,
|
||||
port: int = OPENSEARCH_REST_API_PORT,
|
||||
auth: tuple[str, str] = (OPENSEARCH_ADMIN_USERNAME, OPENSEARCH_ADMIN_PASSWORD),
|
||||
@@ -107,9 +113,8 @@ class OpenSearchClient:
|
||||
ssl_show_warn: bool = False,
|
||||
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
|
||||
):
|
||||
self._index_name = index_name
|
||||
logger.debug(
|
||||
f"Creating OpenSearch client for index {index_name} with host {host} and port {port} and timeout {timeout} seconds."
|
||||
f"Creating OpenSearch client with host {host}, port {port} and timeout {timeout} seconds."
|
||||
)
|
||||
self._client = OpenSearch(
|
||||
hosts=[{"host": host, "port": port}],
|
||||
@@ -125,6 +130,142 @@ class OpenSearchClient:
|
||||
# your request body that is less than this value.
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
def __exit__(self, *_: Any) -> None:
|
||||
self.close()
|
||||
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
self.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def create_search_pipeline(
|
||||
self,
|
||||
pipeline_id: str,
|
||||
pipeline_body: dict[str, Any],
|
||||
) -> None:
|
||||
"""Creates a search pipeline.
|
||||
|
||||
See the OpenSearch documentation for more information on the search
|
||||
pipeline body.
|
||||
https://docs.opensearch.org/latest/search-plugins/search-pipelines/index/
|
||||
|
||||
Args:
|
||||
pipeline_id: The ID of the search pipeline to create.
|
||||
pipeline_body: The body of the search pipeline to create.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error creating the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
|
||||
if not result.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to create search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def delete_search_pipeline(self, pipeline_id: str) -> None:
|
||||
"""Deletes a search pipeline.
|
||||
|
||||
Args:
|
||||
pipeline_id: The ID of the search pipeline to delete.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error deleting the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.delete(id=pipeline_id)
|
||||
if not result.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to delete search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def put_cluster_settings(self, settings: dict[str, Any]) -> bool:
|
||||
"""Puts cluster settings.
|
||||
|
||||
Args:
|
||||
settings: The settings to put.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error putting the cluster settings.
|
||||
|
||||
Returns:
|
||||
True if the settings were put successfully, False otherwise.
|
||||
"""
|
||||
response = self._client.cluster.put_settings(body=settings)
|
||||
if response.get("acknowledged", False):
|
||||
logger.info("Successfully put cluster settings.")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to put cluster settings: {response}.")
|
||||
return False
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def ping(self) -> bool:
|
||||
"""Pings the OpenSearch cluster.
|
||||
|
||||
Returns:
|
||||
True if OpenSearch could be reached, False if it could not.
|
||||
"""
|
||||
return self._client.ping()
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def close(self) -> None:
|
||||
"""Closes the client.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error closing the client.
|
||||
"""
|
||||
self._client.close()
|
||||
|
||||
|
||||
class OpenSearchIndexClient(OpenSearchClient):
|
||||
"""Client for interacting with OpenSearch for index-level operations.
|
||||
|
||||
OpenSearch's Python module has pretty bad typing support so this client
|
||||
attempts to protect the rest of the codebase from this. As a consequence,
|
||||
most methods here return the minimum data needed for the rest of Onyx, and
|
||||
tend to rely on Exceptions to handle errors.
|
||||
|
||||
TODO(andrei): This class currently assumes the structure of the database
|
||||
schema when it returns a DocumentChunk. Make the class, or at least the
|
||||
search method, templated on the structure the caller can expect.
|
||||
|
||||
Args:
|
||||
index_name: The name of the index to interact with.
|
||||
host: The host of the OpenSearch cluster.
|
||||
port: The port of the OpenSearch cluster.
|
||||
auth: The authentication credentials for the OpenSearch cluster. A tuple
|
||||
of (username, password).
|
||||
use_ssl: Whether to use SSL for the OpenSearch cluster. Defaults to
|
||||
True.
|
||||
verify_certs: Whether to verify the SSL certificates for the OpenSearch
|
||||
cluster. Defaults to False.
|
||||
ssl_show_warn: Whether to show warnings for SSL certificates. Defaults
|
||||
to False.
|
||||
timeout: The timeout for the OpenSearch cluster. Defaults to
|
||||
DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
host: str = OPENSEARCH_HOST,
|
||||
port: int = OPENSEARCH_REST_API_PORT,
|
||||
auth: tuple[str, str] = (OPENSEARCH_ADMIN_USERNAME, OPENSEARCH_ADMIN_PASSWORD),
|
||||
use_ssl: bool = True,
|
||||
verify_certs: bool = False,
|
||||
ssl_show_warn: bool = False,
|
||||
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
|
||||
):
|
||||
super().__init__(
|
||||
host=host,
|
||||
port=port,
|
||||
auth=auth,
|
||||
use_ssl=use_ssl,
|
||||
verify_certs=verify_certs,
|
||||
ssl_show_warn=ssl_show_warn,
|
||||
timeout=timeout,
|
||||
)
|
||||
self._index_name = index_name
|
||||
logger.debug(
|
||||
f"OpenSearch client created successfully for index {self._index_name}."
|
||||
)
|
||||
@@ -192,6 +333,38 @@ class OpenSearchClient:
|
||||
"""
|
||||
return self._client.indices.exists(index=self._index_name)
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def put_mapping(self, mappings: dict[str, Any]) -> None:
|
||||
"""Updates the index mapping in an idempotent manner.
|
||||
|
||||
- Existing fields with the same definition: No-op (succeeds silently).
|
||||
- New fields: Added to the index.
|
||||
- Existing fields with different types: Raises exception (requires
|
||||
reindex).
|
||||
|
||||
See the OpenSearch documentation for more information:
|
||||
https://docs.opensearch.org/latest/api-reference/index-apis/put-mapping/
|
||||
|
||||
Args:
|
||||
mappings: The complete mapping definition to apply. This will be
|
||||
merged with existing mappings in the index.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error updating the mappings, such as
|
||||
attempting to change the type of an existing field.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Putting mappings for index {self._index_name} with mappings {mappings}."
|
||||
)
|
||||
response = self._client.indices.put_mapping(
|
||||
index=self._index_name, body=mappings
|
||||
)
|
||||
if not response.get("acknowledged", False):
|
||||
raise RuntimeError(
|
||||
f"Failed to put the mapping update for index {self._index_name}."
|
||||
)
|
||||
logger.debug(f"Successfully put mappings for index {self._index_name}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def validate_index(self, expected_mappings: dict[str, Any]) -> bool:
|
||||
"""Validates the index.
|
||||
@@ -610,43 +783,6 @@ class OpenSearchClient:
|
||||
)
|
||||
return DocumentChunk.model_validate(document_chunk_source)
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def create_search_pipeline(
|
||||
self,
|
||||
pipeline_id: str,
|
||||
pipeline_body: dict[str, Any],
|
||||
) -> None:
|
||||
"""Creates a search pipeline.
|
||||
|
||||
See the OpenSearch documentation for more information on the search
|
||||
pipeline body.
|
||||
https://docs.opensearch.org/latest/search-plugins/search-pipelines/index/
|
||||
|
||||
Args:
|
||||
pipeline_id: The ID of the search pipeline to create.
|
||||
pipeline_body: The body of the search pipeline to create.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error creating the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
|
||||
if not result.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to create search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def delete_search_pipeline(self, pipeline_id: str) -> None:
|
||||
"""Deletes a search pipeline.
|
||||
|
||||
Args:
|
||||
pipeline_id: The ID of the search pipeline to delete.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error deleting the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.delete(id=pipeline_id)
|
||||
if not result.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to delete search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def search(
|
||||
self, body: dict[str, Any], search_pipeline_id: str | None
|
||||
@@ -807,48 +943,6 @@ class OpenSearchClient:
|
||||
"""
|
||||
self._client.indices.refresh(index=self._index_name)
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def put_cluster_settings(self, settings: dict[str, Any]) -> bool:
|
||||
"""Puts cluster settings.
|
||||
|
||||
Args:
|
||||
settings: The settings to put.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error putting the cluster settings.
|
||||
|
||||
Returns:
|
||||
True if the settings were put successfully, False otherwise.
|
||||
"""
|
||||
response = self._client.cluster.put_settings(body=settings)
|
||||
if response.get("acknowledged", False):
|
||||
logger.info("Successfully put cluster settings.")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to put cluster settings: {response}.")
|
||||
return False
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def ping(self) -> bool:
|
||||
"""Pings the OpenSearch cluster.
|
||||
|
||||
Returns:
|
||||
True if OpenSearch could be reached, False if it could not.
|
||||
"""
|
||||
return self._client.ping()
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def close(self) -> None:
|
||||
"""Closes the client.
|
||||
|
||||
TODO(andrei): Can we have some way to auto close when the client no
|
||||
longer has any references?
|
||||
|
||||
Raises:
|
||||
Exception: There was an error closing the client.
|
||||
"""
|
||||
self._client.close()
|
||||
|
||||
def _get_hits_and_profile_from_search_result(
|
||||
self, result: dict[str, Any]
|
||||
) -> tuple[list[Any], int | None, bool | None, dict[str, Any], dict[str, Any]]:
|
||||
@@ -945,14 +1039,7 @@ def wait_for_opensearch_with_timeout(
|
||||
Returns:
|
||||
True if OpenSearch is ready, False otherwise.
|
||||
"""
|
||||
made_client = False
|
||||
try:
|
||||
if client is None:
|
||||
# NOTE: index_name does not matter because we are only using this object
|
||||
# to ping.
|
||||
# TODO(andrei): Make this better.
|
||||
client = OpenSearchClient(index_name="")
|
||||
made_client = True
|
||||
with nullcontext(client) if client else OpenSearchClient() as client:
|
||||
time_start = time.monotonic()
|
||||
while True:
|
||||
if client.ping():
|
||||
@@ -969,7 +1056,3 @@ def wait_for_opensearch_with_timeout(
|
||||
f"[OpenSearch] Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={wait_limit_s:.1f}"
|
||||
)
|
||||
time.sleep(wait_interval_s)
|
||||
finally:
|
||||
if made_client:
|
||||
assert client is not None
|
||||
client.close()
|
||||
|
||||
@@ -7,6 +7,7 @@ from opensearchpy import NotFoundError
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH
|
||||
from onyx.configs.app_configs import VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT
|
||||
from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
from onyx.configs.constants import PUBLIC_DOC_PAT
|
||||
@@ -40,6 +41,7 @@ from onyx.document_index.interfaces_new import IndexingMetadata
|
||||
from onyx.document_index.interfaces_new import MetadataUpdateRequest
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.client import OpenSearchIndexClient
|
||||
from onyx.document_index.opensearch.client import SearchHit
|
||||
from onyx.document_index.opensearch.cluster_settings import OPENSEARCH_CLUSTER_SETTINGS
|
||||
from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
|
||||
@@ -93,6 +95,25 @@ def generate_opensearch_filtered_access_control_list(
|
||||
return list(access_control_list)
|
||||
|
||||
|
||||
def set_cluster_state(client: OpenSearchClient) -> None:
|
||||
if not client.put_cluster_settings(settings=OPENSEARCH_CLUSTER_SETTINGS):
|
||||
logger.error(
|
||||
"Failed to put cluster settings. If the settings have never been set before, "
|
||||
"this may cause unexpected index creation when indexing documents into an "
|
||||
"index that does not exist, or may cause expected logs to not appear. If this "
|
||||
"is not the first time running Onyx against this instance of OpenSearch, these "
|
||||
"settings have likely already been set. Not taking any further action..."
|
||||
)
|
||||
client.create_search_pipeline(
|
||||
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
|
||||
pipeline_body=MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
|
||||
)
|
||||
client.create_search_pipeline(
|
||||
pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
|
||||
pipeline_body=ZSCORE_NORMALIZATION_PIPELINE_CONFIG,
|
||||
)
|
||||
|
||||
|
||||
def _convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
chunk: DocumentChunk,
|
||||
score: float | None,
|
||||
@@ -248,6 +269,8 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
embedding_dim: int,
|
||||
embedding_precision: EmbeddingPrecision,
|
||||
secondary_index_name: str | None,
|
||||
large_chunks_enabled: bool, # noqa: ARG002
|
||||
secondary_large_chunks_enabled: bool | None, # noqa: ARG002
|
||||
@@ -258,10 +281,6 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
index_name=index_name,
|
||||
secondary_index_name=secondary_index_name,
|
||||
)
|
||||
if multitenant:
|
||||
raise ValueError(
|
||||
"Bug: OpenSearch is not yet ready for multitenant environments but something tried to use it."
|
||||
)
|
||||
if multitenant != MULTI_TENANT:
|
||||
raise ValueError(
|
||||
"Bug: Multitenant mismatch when initializing an OpenSearchDocumentIndex. "
|
||||
@@ -269,8 +288,10 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
)
|
||||
tenant_id = get_current_tenant_id()
|
||||
self._real_index = OpenSearchDocumentIndex(
|
||||
index_name=index_name,
|
||||
tenant_state=TenantState(tenant_id=tenant_id, multitenant=multitenant),
|
||||
index_name=index_name,
|
||||
embedding_dim=embedding_dim,
|
||||
embedding_precision=embedding_precision,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -279,9 +300,8 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
embedding_dims: list[int],
|
||||
embedding_precisions: list[EmbeddingPrecision],
|
||||
) -> None:
|
||||
# TODO(andrei): Implement.
|
||||
raise NotImplementedError(
|
||||
"Multitenant index registration is not yet implemented for OpenSearch."
|
||||
"Bug: Multitenant index registration is not supported for OpenSearch."
|
||||
)
|
||||
|
||||
def ensure_indices_exist(
|
||||
@@ -471,19 +491,37 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
for an OpenSearch search engine instance. It handles the complete lifecycle
|
||||
of document chunks within a specific OpenSearch index/schema.
|
||||
|
||||
Although not yet used in this way in the codebase, each kind of embedding
|
||||
used should correspond to a different instance of this class, and therefore
|
||||
a different index in OpenSearch.
|
||||
Each kind of embedding used should correspond to a different instance of
|
||||
this class, and therefore a different index in OpenSearch.
|
||||
|
||||
If in a multitenant environment and
|
||||
VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT, will verify and create the index
|
||||
if necessary on initialization. This is because there is no logic which runs
|
||||
on cluster restart which scans through all search settings over all tenants
|
||||
and creates the relevant indices.
|
||||
|
||||
Args:
|
||||
tenant_state: The tenant state of the caller.
|
||||
index_name: The name of the index to interact with.
|
||||
embedding_dim: The dimensionality of the embeddings used for the index.
|
||||
embedding_precision: The precision of the embeddings used for the index.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
tenant_state: TenantState,
|
||||
index_name: str,
|
||||
embedding_dim: int,
|
||||
embedding_precision: EmbeddingPrecision,
|
||||
) -> None:
|
||||
self._index_name: str = index_name
|
||||
self._tenant_state: TenantState = tenant_state
|
||||
self._os_client = OpenSearchClient(index_name=self._index_name)
|
||||
self._client = OpenSearchIndexClient(index_name=self._index_name)
|
||||
|
||||
if self._tenant_state.multitenant and VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT:
|
||||
self.verify_and_create_index_if_necessary(
|
||||
embedding_dim=embedding_dim, embedding_precision=embedding_precision
|
||||
)
|
||||
|
||||
def verify_and_create_index_if_necessary(
|
||||
self,
|
||||
@@ -492,10 +530,15 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
) -> None:
|
||||
"""Verifies and creates the index if necessary.
|
||||
|
||||
Also puts the desired cluster settings.
|
||||
Also puts the desired cluster settings if not in a multitenant
|
||||
environment.
|
||||
|
||||
Also puts the desired search pipeline state, creating the pipelines if
|
||||
they do not exist and updating them otherwise.
|
||||
Also puts the desired search pipeline state if not in a multitenant
|
||||
environment, creating the pipelines if they do not exist and updating
|
||||
them otherwise.
|
||||
|
||||
In a multitenant environment, the above steps happen explicitly on
|
||||
setup.
|
||||
|
||||
Args:
|
||||
embedding_dim: Vector dimensionality for the vector similarity part
|
||||
@@ -508,47 +551,38 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
search pipelines.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Verifying and creating index {self._index_name} if necessary, "
|
||||
f"with embedding dimension {embedding_dim}."
|
||||
f"[OpenSearchDocumentIndex] Verifying and creating index {self._index_name} if "
|
||||
f"necessary, with embedding dimension {embedding_dim}."
|
||||
)
|
||||
|
||||
if not self._tenant_state.multitenant:
|
||||
set_cluster_state(self._client)
|
||||
|
||||
expected_mappings = DocumentSchema.get_document_schema(
|
||||
embedding_dim, self._tenant_state.multitenant
|
||||
)
|
||||
if not self._os_client.put_cluster_settings(
|
||||
settings=OPENSEARCH_CLUSTER_SETTINGS
|
||||
):
|
||||
logger.error(
|
||||
f"Failed to put cluster settings for index {self._index_name}. If the settings have never been set before this "
|
||||
"may cause unexpected index creation when indexing documents into an index that does not exist, or may cause "
|
||||
"expected logs to not appear. If this is not the first time running Onyx against this instance of OpenSearch, "
|
||||
"these settings have likely already been set. Not taking any further action..."
|
||||
)
|
||||
if not self._os_client.index_exists():
|
||||
|
||||
if not self._client.index_exists():
|
||||
if USING_AWS_MANAGED_OPENSEARCH:
|
||||
index_settings = (
|
||||
DocumentSchema.get_index_settings_for_aws_managed_opensearch()
|
||||
)
|
||||
else:
|
||||
index_settings = DocumentSchema.get_index_settings()
|
||||
self._os_client.create_index(
|
||||
self._client.create_index(
|
||||
mappings=expected_mappings,
|
||||
settings=index_settings,
|
||||
)
|
||||
if not self._os_client.validate_index(
|
||||
expected_mappings=expected_mappings,
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"The index {self._index_name} is not valid. The expected mappings do not match the actual mappings."
|
||||
)
|
||||
|
||||
self._os_client.create_search_pipeline(
|
||||
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
|
||||
pipeline_body=MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
|
||||
)
|
||||
self._os_client.create_search_pipeline(
|
||||
pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
|
||||
pipeline_body=ZSCORE_NORMALIZATION_PIPELINE_CONFIG,
|
||||
)
|
||||
else:
|
||||
# Ensure schema is up to date by applying the current mappings.
|
||||
try:
|
||||
self._client.put_mapping(expected_mappings)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to update mappings for index {self._index_name}. This likely means a "
|
||||
f"field type was changed which requires reindexing. Error: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
def index(
|
||||
self,
|
||||
@@ -620,7 +654,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
)
|
||||
# Now index. This will raise if a chunk of the same ID exists, which
|
||||
# we do not expect because we should have deleted all chunks.
|
||||
self._os_client.bulk_index_documents(
|
||||
self._client.bulk_index_documents(
|
||||
documents=chunk_batch,
|
||||
tenant_state=self._tenant_state,
|
||||
)
|
||||
@@ -660,7 +694,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
tenant_state=self._tenant_state,
|
||||
)
|
||||
|
||||
return self._os_client.delete_by_query(query_body)
|
||||
return self._client.delete_by_query(query_body)
|
||||
|
||||
def update(
|
||||
self,
|
||||
@@ -760,7 +794,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
document_id=doc_id,
|
||||
chunk_index=chunk_index,
|
||||
)
|
||||
self._os_client.update_document(
|
||||
self._client.update_document(
|
||||
document_chunk_id=document_chunk_id,
|
||||
properties_to_update=properties_to_update,
|
||||
)
|
||||
@@ -799,7 +833,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
min_chunk_index=chunk_request.min_chunk_ind,
|
||||
max_chunk_index=chunk_request.max_chunk_ind,
|
||||
)
|
||||
search_hits = self._os_client.search(
|
||||
search_hits = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
)
|
||||
@@ -849,7 +883,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
# NOTE: Using z-score normalization here because it's better for hybrid search from a theoretical standpoint.
|
||||
# Empirically on a small dataset of up to 10K docs, it's not very different. Likely more impactful at scale.
|
||||
# https://opensearch.org/blog/introducing-the-z-score-normalization-technique-for-hybrid-search/
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._os_client.search(
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
|
||||
)
|
||||
@@ -881,7 +915,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
index_filters=filters,
|
||||
num_to_retrieve=num_to_retrieve,
|
||||
)
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._os_client.search(
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
)
|
||||
@@ -909,6 +943,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
# Do not raise if the document already exists, just update. This is
|
||||
# because the document may already have been indexed during the
|
||||
# OpenSearch transition period.
|
||||
self._os_client.bulk_index_documents(
|
||||
self._client.bulk_index_documents(
|
||||
documents=chunks, tenant_state=self._tenant_state, update_if_exists=True
|
||||
)
|
||||
|
||||
@@ -28,6 +28,7 @@ from onyx.document_index.opensearch.schema import HIDDEN_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import LAST_UPDATED_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import MAX_CHUNK_SIZE_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import METADATA_LIST_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import PERSONAS_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import PUBLIC_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import set_or_convert_timezone_to_utc
|
||||
from onyx.document_index.opensearch.schema import SOURCE_TYPE_FIELD_NAME
|
||||
@@ -144,6 +145,7 @@ class DocumentQuery:
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=min_chunk_index,
|
||||
max_chunk_index=max_chunk_index,
|
||||
@@ -202,6 +204,7 @@ class DocumentQuery:
|
||||
document_sets=[],
|
||||
user_file_ids=[],
|
||||
project_id=None,
|
||||
persona_id=None,
|
||||
time_cutoff=None,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -267,6 +270,7 @@ class DocumentQuery:
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -334,6 +338,7 @@ class DocumentQuery:
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -496,6 +501,7 @@ class DocumentQuery:
|
||||
document_sets: list[str],
|
||||
user_file_ids: list[UUID],
|
||||
project_id: int | None,
|
||||
persona_id: int | None,
|
||||
time_cutoff: datetime | None,
|
||||
min_chunk_index: int | None,
|
||||
max_chunk_index: int | None,
|
||||
@@ -530,6 +536,8 @@ class DocumentQuery:
|
||||
retrieved.
|
||||
project_id: If not None, only documents with this project ID in user
|
||||
projects will be retrieved.
|
||||
persona_id: If not None, only documents whose personas array
|
||||
contains this persona ID will be retrieved.
|
||||
time_cutoff: Time cutoff for the documents to retrieve. If not None,
|
||||
Documents which were last updated before this date will not be
|
||||
returned. For documents which do not have a value for their last
|
||||
@@ -627,6 +635,9 @@ class DocumentQuery:
|
||||
)
|
||||
return user_project_filter
|
||||
|
||||
def _get_persona_filter(persona_id: int) -> dict[str, Any]:
|
||||
return {"term": {PERSONAS_FIELD_NAME: {"value": persona_id}}}
|
||||
|
||||
def _get_time_cutoff_filter(time_cutoff: datetime) -> dict[str, Any]:
|
||||
# Convert to UTC if not already so the cutoff is comparable to the
|
||||
# document data.
|
||||
@@ -780,6 +791,9 @@ class DocumentQuery:
|
||||
# document's user projects list.
|
||||
filter_clauses.append(_get_user_project_filter(project_id))
|
||||
|
||||
if persona_id is not None:
|
||||
filter_clauses.append(_get_persona_filter(persona_id))
|
||||
|
||||
if time_cutoff is not None:
|
||||
# If a time cutoff is provided, the caller will only retrieve
|
||||
# documents where the document was last updated at or after the time
|
||||
|
||||
@@ -12,6 +12,7 @@ from onyx.document_index.vespa_constants import DOCUMENT_ID
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_SETS
|
||||
from onyx.document_index.vespa_constants import HIDDEN
|
||||
from onyx.document_index.vespa_constants import METADATA_LIST
|
||||
from onyx.document_index.vespa_constants import PERSONAS
|
||||
from onyx.document_index.vespa_constants import SOURCE_TYPE
|
||||
from onyx.document_index.vespa_constants import TENANT_ID
|
||||
from onyx.document_index.vespa_constants import USER_PROJECT
|
||||
@@ -149,6 +150,18 @@ def build_vespa_filters(
|
||||
# Vespa YQL 'contains' expects a string literal; quote the integer
|
||||
return f'({USER_PROJECT} contains "{pid}") and '
|
||||
|
||||
def _build_persona_filter(
|
||||
persona_id: int | None,
|
||||
) -> str:
|
||||
if persona_id is None:
|
||||
return ""
|
||||
try:
|
||||
pid = int(persona_id)
|
||||
except Exception:
|
||||
logger.warning(f"Invalid persona ID: {persona_id}")
|
||||
return ""
|
||||
return f'({PERSONAS} contains "{pid}") and '
|
||||
|
||||
# Start building the filter string
|
||||
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
|
||||
|
||||
@@ -192,6 +205,9 @@ def build_vespa_filters(
|
||||
# User project filter (array<int> attribute membership)
|
||||
filter_str += _build_user_project_filter(filters.project_id)
|
||||
|
||||
# Persona filter (array<int> attribute membership)
|
||||
filter_str += _build_persona_filter(filters.persona_id)
|
||||
|
||||
# Time filter
|
||||
filter_str += _build_time_filter(filters.time_cutoff)
|
||||
|
||||
|
||||
@@ -592,11 +592,8 @@ def build_slack_response_blocks(
|
||||
)
|
||||
|
||||
citations_blocks = []
|
||||
document_blocks = []
|
||||
if answer.citation_info:
|
||||
citations_blocks = _build_citations_blocks(answer)
|
||||
else:
|
||||
document_blocks = _priority_ordered_documents_blocks(answer)
|
||||
|
||||
citations_divider = [DividerBlock()] if citations_blocks else []
|
||||
buttons_divider = [DividerBlock()] if web_follow_up_block or follow_up_block else []
|
||||
@@ -608,7 +605,6 @@ def build_slack_response_blocks(
|
||||
+ ai_feedback_block
|
||||
+ citations_divider
|
||||
+ citations_blocks
|
||||
+ document_blocks
|
||||
+ buttons_divider
|
||||
+ web_follow_up_block
|
||||
+ follow_up_block
|
||||
|
||||
@@ -1,10 +1,59 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from mistune import create_markdown
|
||||
from mistune import HTMLRenderer
|
||||
|
||||
_CITATION_LINK_PATTERN = re.compile(r"\[\[\d+\]\]\(")
|
||||
# Tags that should be replaced with a newline (line-break and block-level elements)
|
||||
_HTML_NEWLINE_TAG_PATTERN = re.compile(
|
||||
r"<br\s*/?>|</(?:p|div|li|h[1-6]|tr|blockquote|section|article)>",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Strips HTML tags but excludes autolinks like <https://...> and <mailto:...>
|
||||
_HTML_TAG_PATTERN = re.compile(
|
||||
r"<(?!https?://|mailto:)/?[a-zA-Z][^>]*>",
|
||||
)
|
||||
|
||||
# Matches fenced code blocks (``` ... ```) so we can skip sanitization inside them
|
||||
_FENCED_CODE_BLOCK_PATTERN = re.compile(r"```[\s\S]*?```")
|
||||
|
||||
# Matches the start of any markdown link: [text]( or [[n]](
|
||||
# The inner group handles nested brackets for citation links like [[1]](.
|
||||
_MARKDOWN_LINK_PATTERN = re.compile(r"\[(?:[^\[\]]|\[[^\]]*\])*\]\(")
|
||||
|
||||
# Matches Slack-style links <url|text> that LLMs sometimes output directly.
|
||||
# Mistune doesn't recognise this syntax, so text() would escape the angle
|
||||
# brackets and Slack would render them as literal text instead of links.
|
||||
_SLACK_LINK_PATTERN = re.compile(r"<(https?://[^|>]+)\|([^>]+)>")
|
||||
|
||||
|
||||
def _sanitize_html(text: str) -> str:
|
||||
"""Strip HTML tags from a text fragment.
|
||||
|
||||
Block-level closing tags and <br> are converted to newlines.
|
||||
All other HTML tags are removed. Autolinks (<https://...>) are preserved.
|
||||
"""
|
||||
text = _HTML_NEWLINE_TAG_PATTERN.sub("\n", text)
|
||||
text = _HTML_TAG_PATTERN.sub("", text)
|
||||
return text
|
||||
|
||||
|
||||
def _transform_outside_code_blocks(
|
||||
message: str, transform: Callable[[str], str]
|
||||
) -> str:
|
||||
"""Apply *transform* only to text outside fenced code blocks."""
|
||||
parts = _FENCED_CODE_BLOCK_PATTERN.split(message)
|
||||
code_blocks = _FENCED_CODE_BLOCK_PATTERN.findall(message)
|
||||
|
||||
result: list[str] = []
|
||||
for i, part in enumerate(parts):
|
||||
result.append(transform(part))
|
||||
if i < len(code_blocks):
|
||||
result.append(code_blocks[i])
|
||||
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def _extract_link_destination(message: str, start_idx: int) -> tuple[str, int | None]:
|
||||
@@ -29,15 +78,21 @@ def _extract_link_destination(message: str, start_idx: int) -> tuple[str, int |
|
||||
return message[start_idx:], None
|
||||
|
||||
|
||||
def _normalize_citation_link_destinations(message: str) -> str:
|
||||
"""Wrap citation URLs in angle brackets so markdown parsers handle parentheses safely."""
|
||||
if "[[" not in message:
|
||||
def _normalize_link_destinations(message: str) -> str:
|
||||
"""Wrap markdown link URLs in angle brackets so the parser handles special chars safely.
|
||||
|
||||
Markdown link syntax [text](url) breaks when the URL contains unescaped
|
||||
parentheses, spaces, or other special characters. Wrapping the URL in angle
|
||||
brackets — [text](<url>) — tells the parser to treat everything inside as
|
||||
a literal URL. This applies to all links, not just citations.
|
||||
"""
|
||||
if "](" not in message:
|
||||
return message
|
||||
|
||||
normalized_parts: list[str] = []
|
||||
cursor = 0
|
||||
|
||||
while match := _CITATION_LINK_PATTERN.search(message, cursor):
|
||||
while match := _MARKDOWN_LINK_PATTERN.search(message, cursor):
|
||||
normalized_parts.append(message[cursor : match.end()])
|
||||
destination_start = match.end()
|
||||
destination, end_idx = _extract_link_destination(message, destination_start)
|
||||
@@ -57,18 +112,38 @@ def _normalize_citation_link_destinations(message: str) -> str:
|
||||
return "".join(normalized_parts)
|
||||
|
||||
|
||||
def _convert_slack_links_to_markdown(message: str) -> str:
|
||||
"""Convert Slack-style <url|text> links to standard markdown [text](url).
|
||||
|
||||
LLMs sometimes emit Slack mrkdwn link syntax directly. Mistune doesn't
|
||||
recognise it, so the angle brackets would be escaped by text() and Slack
|
||||
would render the link as literal text instead of a clickable link.
|
||||
"""
|
||||
return _transform_outside_code_blocks(
|
||||
message, lambda text: _SLACK_LINK_PATTERN.sub(r"[\2](\1)", text)
|
||||
)
|
||||
|
||||
|
||||
def format_slack_message(message: str | None) -> str:
|
||||
if message is None:
|
||||
return ""
|
||||
message = _transform_outside_code_blocks(message, _sanitize_html)
|
||||
message = _convert_slack_links_to_markdown(message)
|
||||
normalized_message = _normalize_link_destinations(message)
|
||||
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough"])
|
||||
normalized_message = _normalize_citation_link_destinations(message)
|
||||
result = md(normalized_message)
|
||||
# With HTMLRenderer, result is always str (not AST list)
|
||||
assert isinstance(result, str)
|
||||
return result
|
||||
return result.rstrip("\n")
|
||||
|
||||
|
||||
class SlackRenderer(HTMLRenderer):
|
||||
"""Renders markdown as Slack mrkdwn format instead of HTML.
|
||||
|
||||
Overrides all HTMLRenderer methods that produce HTML tags to ensure
|
||||
no raw HTML ever appears in Slack messages.
|
||||
"""
|
||||
|
||||
SPECIALS: dict[str, str] = {"&": "&", "<": "<", ">": ">"}
|
||||
|
||||
def escape_special(self, text: str) -> str:
|
||||
@@ -77,7 +152,7 @@ class SlackRenderer(HTMLRenderer):
|
||||
return text
|
||||
|
||||
def heading(self, text: str, level: int, **attrs: Any) -> str: # noqa: ARG002
|
||||
return f"*{text}*\n"
|
||||
return f"*{text}*\n\n"
|
||||
|
||||
def emphasis(self, text: str) -> str:
|
||||
return f"_{text}_"
|
||||
@@ -96,7 +171,7 @@ class SlackRenderer(HTMLRenderer):
|
||||
count += 1
|
||||
prefix = f"{count}. " if ordered else "• "
|
||||
lines[i] = f"{prefix}{line[4:]}"
|
||||
return "\n".join(lines)
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
def list_item(self, text: str) -> str:
|
||||
return f"li: {text}\n"
|
||||
@@ -118,7 +193,30 @@ class SlackRenderer(HTMLRenderer):
|
||||
return f"`{text}`"
|
||||
|
||||
def block_code(self, code: str, info: str | None = None) -> str: # noqa: ARG002
|
||||
return f"```\n{code}\n```\n"
|
||||
return f"```\n{code.rstrip(chr(10))}\n```\n\n"
|
||||
|
||||
def linebreak(self) -> str:
|
||||
return "\n"
|
||||
|
||||
def thematic_break(self) -> str:
|
||||
return "---\n\n"
|
||||
|
||||
def block_quote(self, text: str) -> str:
|
||||
lines = text.strip().split("\n")
|
||||
quoted = "\n".join(f">{line}" for line in lines)
|
||||
return quoted + "\n\n"
|
||||
|
||||
def block_html(self, html: str) -> str:
|
||||
return _sanitize_html(html) + "\n\n"
|
||||
|
||||
def block_error(self, text: str) -> str:
|
||||
return f"```\n{text}\n```\n\n"
|
||||
|
||||
def text(self, text: str) -> str:
|
||||
# Only escape the three entities Slack recognizes: & < >
|
||||
# HTMLRenderer.text() also escapes " to " which Slack renders
|
||||
# as literal " text since Slack doesn't recognize that entity.
|
||||
return self.escape_special(text)
|
||||
|
||||
def paragraph(self, text: str) -> str:
|
||||
return f"{text}\n"
|
||||
return f"{text}\n\n"
|
||||
|
||||
@@ -405,6 +405,7 @@ class PersonaShareRequest(BaseModel):
|
||||
user_ids: list[UUID] | None = None
|
||||
group_ids: list[int] | None = None
|
||||
is_public: bool | None = None
|
||||
label_ids: list[int] | None = None
|
||||
|
||||
|
||||
# We notify each user when a user is shared with them
|
||||
@@ -415,14 +416,22 @@ def share_persona(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
update_persona_shared(
|
||||
persona_id=persona_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
user_ids=persona_share_request.user_ids,
|
||||
group_ids=persona_share_request.group_ids,
|
||||
is_public=persona_share_request.is_public,
|
||||
)
|
||||
try:
|
||||
update_persona_shared(
|
||||
persona_id=persona_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
user_ids=persona_share_request.user_ids,
|
||||
group_ids=persona_share_request.group_ids,
|
||||
is_public=persona_share_request.is_public,
|
||||
label_ids=persona_share_request.label_ids,
|
||||
)
|
||||
except PermissionError as e:
|
||||
logger.exception("Failed to share persona")
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to share persona")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@basic_router.delete("/{persona_id}", tags=PUBLIC_API_TAGS)
|
||||
|
||||
@@ -97,6 +97,7 @@ def _build_llm_provider_request(
|
||||
), # Only this from source
|
||||
api_base=api_base, # From request
|
||||
api_version=api_version, # From request
|
||||
default_model_name=model_name,
|
||||
deployment_name=deployment_name, # From request
|
||||
is_public=True,
|
||||
groups=[],
|
||||
@@ -135,6 +136,7 @@ def _build_llm_provider_request(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
default_model_name=model_name,
|
||||
deployment_name=deployment_name,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
@@ -166,6 +168,7 @@ def _create_image_gen_llm_provider__no_commit(
|
||||
api_key=provider_request.api_key,
|
||||
api_base=provider_request.api_base,
|
||||
api_version=provider_request.api_version,
|
||||
default_model_name=provider_request.default_model_name,
|
||||
deployment_name=provider_request.deployment_name,
|
||||
is_public=provider_request.is_public,
|
||||
custom_config=provider_request.custom_config,
|
||||
|
||||
@@ -22,10 +22,7 @@ from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.llm import can_user_access_llm_provider
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.db.llm import fetch_default_vision_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import fetch_existing_llm_provider_by_id
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_existing_models
|
||||
from onyx.db.llm import fetch_persona_with_groups
|
||||
@@ -55,12 +52,11 @@ from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
)
|
||||
from onyx.server.manage.llm.models import BedrockFinalModelResponse
|
||||
from onyx.server.manage.llm.models import BedrockModelsRequest
|
||||
from onyx.server.manage.llm.models import DefaultModel
|
||||
from onyx.server.manage.llm.models import LLMCost
|
||||
from onyx.server.manage.llm.models import LLMProviderDescriptor
|
||||
from onyx.server.manage.llm.models import LLMProviderResponse
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.manage.llm.models import OllamaFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OllamaModelDetails
|
||||
from onyx.server.manage.llm.models import OllamaModelsRequest
|
||||
@@ -237,12 +233,12 @@ def test_llm_configuration(
|
||||
|
||||
test_api_key = test_llm_request.api_key
|
||||
test_custom_config = test_llm_request.custom_config
|
||||
if test_llm_request.id:
|
||||
if test_llm_request.name:
|
||||
# NOTE: we are querying by name. we probably should be querying by an invariant id, but
|
||||
# as it turns out the name is not editable in the UI and other code also keys off name,
|
||||
# so we won't rock the boat just yet.
|
||||
existing_provider = fetch_existing_llm_provider_by_id(
|
||||
id=test_llm_request.id, db_session=db_session
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=test_llm_request.name, db_session=db_session
|
||||
)
|
||||
if existing_provider:
|
||||
test_custom_config = _restore_masked_custom_config_values(
|
||||
@@ -272,7 +268,7 @@ def test_llm_configuration(
|
||||
|
||||
llm = get_llm(
|
||||
provider=test_llm_request.provider,
|
||||
model=test_llm_request.model,
|
||||
model=test_llm_request.default_model_name,
|
||||
api_key=test_api_key,
|
||||
api_base=test_llm_request.api_base,
|
||||
api_version=test_llm_request.api_version,
|
||||
@@ -307,7 +303,7 @@ def list_llm_providers(
|
||||
include_image_gen: bool = Query(False),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LLMProviderResponse[LLMProviderView]:
|
||||
) -> list[LLMProviderView]:
|
||||
start_time = datetime.now(timezone.utc)
|
||||
logger.debug("Starting to fetch LLM providers")
|
||||
|
||||
@@ -332,25 +328,7 @@ def list_llm_providers(
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
logger.debug(f"Completed fetching LLM providers in {duration:.2f} seconds")
|
||||
|
||||
default_model = None
|
||||
if model_config := fetch_default_llm_model(db_session):
|
||||
default_model = DefaultModel(
|
||||
provider_id=model_config.llm_provider.id,
|
||||
model_name=model_config.name,
|
||||
)
|
||||
|
||||
default_vision_model = None
|
||||
if model_config := fetch_default_vision_model(db_session):
|
||||
default_vision_model = DefaultModel(
|
||||
provider_id=model_config.llm_provider.id,
|
||||
model_name=model_config.name,
|
||||
)
|
||||
|
||||
return LLMProviderResponse[LLMProviderView].from_models(
|
||||
providers=llm_provider_list,
|
||||
default_text=default_model,
|
||||
default_vision=default_vision_model,
|
||||
)
|
||||
return llm_provider_list
|
||||
|
||||
|
||||
@admin_router.put("/provider")
|
||||
@@ -363,29 +341,21 @@ def put_llm_provider(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LLMProviderView:
|
||||
# NOTE: Name updating functionality currently not supported. There are many places that still
|
||||
# rely on immutable names, so this will be a larger change
|
||||
|
||||
# validate request (e.g. if we're intending to create but the name already exists we should throw an error)
|
||||
# NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache
|
||||
# the result
|
||||
existing_provider = None
|
||||
if llm_provider_upsert_request.id:
|
||||
existing_provider = fetch_existing_llm_provider_by_id(
|
||||
id=llm_provider_upsert_request.id, db_session=db_session
|
||||
)
|
||||
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
)
|
||||
if existing_provider and is_creation:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"LLM Provider with name {llm_provider_upsert_request.name} and \
|
||||
id={llm_provider_upsert_request.id} already exists",
|
||||
detail=f"LLM Provider with name {llm_provider_upsert_request.name} already exists",
|
||||
)
|
||||
elif not existing_provider and not is_creation:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"LLM Provider with name {llm_provider_upsert_request.name} and \
|
||||
id={llm_provider_upsert_request.id} does not exist",
|
||||
detail=f"LLM Provider with name {llm_provider_upsert_request.name} does not exist",
|
||||
)
|
||||
|
||||
# SSRF Protection: Validate api_base and custom_config match stored values
|
||||
@@ -423,6 +393,22 @@ def put_llm_provider(
|
||||
deduplicated_personas.append(persona_id)
|
||||
llm_provider_upsert_request.personas = deduplicated_personas
|
||||
|
||||
default_model_found = False
|
||||
|
||||
for model_configuration in llm_provider_upsert_request.model_configurations:
|
||||
if model_configuration.name == llm_provider_upsert_request.default_model_name:
|
||||
model_configuration.is_visible = True
|
||||
default_model_found = True
|
||||
|
||||
# TODO: Remove this logic on api change
|
||||
# Believed to be a dead pathway but we want to be safe for now
|
||||
if not default_model_found:
|
||||
llm_provider_upsert_request.model_configurations.append(
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=llm_provider_upsert_request.default_model_name, is_visible=True
|
||||
)
|
||||
)
|
||||
|
||||
# the llm api key is sanitized when returned to clients, so the only time we
|
||||
# should get a real key is when it is explicitly changed
|
||||
if existing_provider and not llm_provider_upsert_request.api_key_changed:
|
||||
@@ -452,8 +438,8 @@ def put_llm_provider(
|
||||
config = fetch_llm_recommendations_from_github()
|
||||
if config and llm_provider_upsert_request.provider in config.providers:
|
||||
# Refetch the provider to get the updated model
|
||||
updated_provider = fetch_existing_llm_provider_by_id(
|
||||
id=result.id, db_session=db_session
|
||||
updated_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
)
|
||||
if updated_provider:
|
||||
sync_auto_mode_models(
|
||||
@@ -483,29 +469,28 @@ def delete_llm_provider(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.post("/default")
|
||||
@admin_router.post("/provider/{provider_id}/default")
|
||||
def set_provider_as_default(
|
||||
default_model_request: DefaultModel,
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
update_default_provider(
|
||||
provider_id=default_model_request.provider_id,
|
||||
model_name=default_model_request.model_name,
|
||||
db_session=db_session,
|
||||
)
|
||||
update_default_provider(provider_id=provider_id, db_session=db_session)
|
||||
|
||||
|
||||
@admin_router.post("/default-vision")
|
||||
@admin_router.post("/provider/{provider_id}/default-vision")
|
||||
def set_provider_as_default_vision(
|
||||
default_model_request: DefaultModel,
|
||||
provider_id: int,
|
||||
vision_model: str | None = Query(
|
||||
None, description="The default vision model to use"
|
||||
),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
if vision_model is None:
|
||||
raise HTTPException(status_code=404, detail="Vision model not provided")
|
||||
update_default_vision_provider(
|
||||
provider_id=default_model_request.provider_id,
|
||||
vision_model=default_model_request.model_name,
|
||||
db_session=db_session,
|
||||
provider_id=provider_id, vision_model=vision_model, db_session=db_session
|
||||
)
|
||||
|
||||
|
||||
@@ -531,7 +516,7 @@ def get_auto_config(
|
||||
def get_vision_capable_providers(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LLMProviderResponse[VisionProviderResponse]:
|
||||
) -> list[VisionProviderResponse]:
|
||||
"""Return a list of LLM providers and their models that support image input"""
|
||||
vision_models = fetch_existing_models(
|
||||
db_session=db_session, flow_types=[LLMModelFlowType.VISION]
|
||||
@@ -560,18 +545,7 @@ def get_vision_capable_providers(
|
||||
]
|
||||
|
||||
logger.debug(f"Found {len(vision_provider_response)} vision-capable providers")
|
||||
|
||||
default_vision_model = None
|
||||
if model_config := fetch_default_vision_model(db_session):
|
||||
default_vision_model = DefaultModel(
|
||||
provider_id=model_config.llm_provider.id,
|
||||
model_name=model_config.name,
|
||||
)
|
||||
|
||||
return LLMProviderResponse[VisionProviderResponse].from_models(
|
||||
providers=vision_provider_response,
|
||||
default_vision=default_vision_model,
|
||||
)
|
||||
return vision_provider_response
|
||||
|
||||
|
||||
"""Endpoints for all"""
|
||||
@@ -581,7 +555,7 @@ def get_vision_capable_providers(
|
||||
def list_llm_provider_basics(
|
||||
user: User = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LLMProviderResponse[LLMProviderDescriptor]:
|
||||
) -> list[LLMProviderDescriptor]:
|
||||
"""Get LLM providers accessible to the current user.
|
||||
|
||||
Returns:
|
||||
@@ -618,25 +592,7 @@ def list_llm_provider_basics(
|
||||
f"Completed fetching {len(accessible_providers)} user-accessible providers in {duration:.2f} seconds"
|
||||
)
|
||||
|
||||
default_model = None
|
||||
if model_config := fetch_default_llm_model(db_session):
|
||||
default_model = DefaultModel(
|
||||
provider_id=model_config.llm_provider.id,
|
||||
model_name=model_config.name,
|
||||
)
|
||||
|
||||
default_vision_model = None
|
||||
if model_config := fetch_default_vision_model(db_session):
|
||||
default_vision_model = DefaultModel(
|
||||
provider_id=model_config.llm_provider.id,
|
||||
model_name=model_config.name,
|
||||
)
|
||||
|
||||
return LLMProviderResponse[LLMProviderDescriptor].from_models(
|
||||
providers=accessible_providers,
|
||||
default_text=default_model,
|
||||
default_vision=default_vision_model,
|
||||
)
|
||||
return accessible_providers
|
||||
|
||||
|
||||
def get_valid_model_names_for_persona(
|
||||
@@ -679,7 +635,7 @@ def list_llm_providers_for_persona(
|
||||
persona_id: int,
|
||||
user: User = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LLMProviderResponse[LLMProviderDescriptor]:
|
||||
) -> list[LLMProviderDescriptor]:
|
||||
"""Get LLM providers for a specific persona.
|
||||
|
||||
Returns providers that the user can access when using this persona:
|
||||
@@ -726,63 +682,7 @@ def list_llm_providers_for_persona(
|
||||
f"Completed fetching {len(llm_provider_list)} LLM providers for persona {persona_id} in {duration:.2f} seconds"
|
||||
)
|
||||
|
||||
# Get the default model and vision model for the persona
|
||||
# NOTE: This should be ported over to use id as it is blocking on name mutability
|
||||
persona_default_provider = persona.llm_model_provider_override
|
||||
persona_default_model = persona.llm_model_version_override
|
||||
|
||||
default_text_model = fetch_default_llm_model(db_session)
|
||||
default_vision_model = fetch_default_vision_model(db_session)
|
||||
|
||||
# Build default_text and default_vision using persona overrides when available,
|
||||
# falling back to the global defaults.
|
||||
default_text: DefaultModel | None = (
|
||||
DefaultModel(
|
||||
provider_id=default_text_model.llm_provider.id,
|
||||
model_name=default_text_model.name,
|
||||
)
|
||||
if default_text_model
|
||||
else None
|
||||
)
|
||||
default_vision: DefaultModel | None = (
|
||||
DefaultModel(
|
||||
provider_id=default_vision_model.llm_provider.id,
|
||||
model_name=default_vision_model.name,
|
||||
)
|
||||
if default_vision_model
|
||||
else None
|
||||
)
|
||||
|
||||
if persona_default_provider:
|
||||
provider = fetch_existing_llm_provider(persona_default_provider, db_session)
|
||||
if provider:
|
||||
if persona_default_model:
|
||||
# Persona specifies both provider and model — use them directly
|
||||
default_text = DefaultModel(
|
||||
provider_id=provider.id,
|
||||
model_name=persona_default_model,
|
||||
)
|
||||
else:
|
||||
# Persona specifies only the provider — pick a visible (public) model,
|
||||
# falling back to any model on this provider
|
||||
visible_model = next(
|
||||
(mc for mc in provider.model_configurations if mc.is_visible),
|
||||
None,
|
||||
)
|
||||
fallback_model = visible_model or next(
|
||||
iter(provider.model_configurations), None
|
||||
)
|
||||
if fallback_model:
|
||||
default_text = DefaultModel(
|
||||
provider_id=provider.id,
|
||||
model_name=fallback_model.name,
|
||||
)
|
||||
|
||||
return LLMProviderResponse[LLMProviderDescriptor].from_models(
|
||||
providers=llm_provider_list,
|
||||
default_text=default_text,
|
||||
default_vision=default_vision,
|
||||
)
|
||||
return llm_provider_list
|
||||
|
||||
|
||||
@admin_router.get("/provider-contextual-cost")
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
@@ -23,8 +21,6 @@ if TYPE_CHECKING:
|
||||
ModelConfiguration as ModelConfigurationModel,
|
||||
)
|
||||
|
||||
T = TypeVar("T", bound="LLMProviderDescriptor | LLMProviderView")
|
||||
|
||||
|
||||
# TODO: Clear this up on api refactor
|
||||
# There is still logic that requires sending each providers default model name
|
||||
@@ -56,18 +52,19 @@ def get_default_vision_model_name(llm_provider_model: "LLMProviderModel") -> str
|
||||
|
||||
class TestLLMRequest(BaseModel):
|
||||
# provider level
|
||||
id: int | None = None
|
||||
name: str | None = None
|
||||
provider: str
|
||||
model: str
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
custom_config: dict[str, str] | None = None
|
||||
|
||||
# model level
|
||||
default_model_name: str
|
||||
deployment_name: str | None = None
|
||||
|
||||
model_configurations: list["ModelConfigurationUpsertRequest"]
|
||||
|
||||
# if try and use the existing API/custom config key
|
||||
api_key_changed: bool
|
||||
custom_config_changed: bool
|
||||
@@ -83,10 +80,13 @@ class LLMProviderDescriptor(BaseModel):
|
||||
"""A descriptor for an LLM provider that can be safely viewed by
|
||||
non-admin users. Used when giving a list of available LLMs."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
provider: str
|
||||
provider_display_name: str # Human-friendly name like "Claude (Anthropic)"
|
||||
default_model_name: str
|
||||
is_default_provider: bool | None
|
||||
is_default_vision_provider: bool | None
|
||||
default_vision_model: str | None
|
||||
model_configurations: list["ModelConfigurationView"]
|
||||
|
||||
@classmethod
|
||||
@@ -99,12 +99,24 @@ class LLMProviderDescriptor(BaseModel):
|
||||
)
|
||||
|
||||
provider = llm_provider_model.provider
|
||||
default_model_name = get_default_llm_model_name(llm_provider_model)
|
||||
default_vision_model = get_default_vision_model_name(llm_provider_model)
|
||||
|
||||
is_default_provider = bool(default_model_name)
|
||||
is_default_vision_provider = default_vision_model is not None
|
||||
|
||||
default_model_name = (
|
||||
default_model_name or llm_provider_model.default_model_name or ""
|
||||
)
|
||||
|
||||
return cls(
|
||||
id=llm_provider_model.id,
|
||||
name=llm_provider_model.name,
|
||||
provider=provider,
|
||||
provider_display_name=get_provider_display_name(provider),
|
||||
default_model_name=default_model_name,
|
||||
is_default_provider=is_default_provider,
|
||||
is_default_vision_provider=is_default_vision_provider,
|
||||
default_vision_model=default_vision_model,
|
||||
model_configurations=filter_model_configurations(
|
||||
llm_provider_model.model_configurations, provider
|
||||
),
|
||||
@@ -118,17 +130,18 @@ class LLMProvider(BaseModel):
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
custom_config: dict[str, str] | None = None
|
||||
default_model_name: str
|
||||
is_public: bool = True
|
||||
is_auto_mode: bool = False
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
personas: list[int] = Field(default_factory=list)
|
||||
deployment_name: str | None = None
|
||||
default_vision_model: str | None = None
|
||||
|
||||
|
||||
class LLMProviderUpsertRequest(LLMProvider):
|
||||
# should only be used for a "custom" provider
|
||||
# for default providers, the built-in model names are used
|
||||
id: int | None = None
|
||||
api_key_changed: bool = False
|
||||
custom_config_changed: bool = False
|
||||
model_configurations: list["ModelConfigurationUpsertRequest"] = []
|
||||
@@ -144,6 +157,8 @@ class LLMProviderView(LLMProvider):
|
||||
"""Stripped down representation of LLMProvider for display / limited access info only"""
|
||||
|
||||
id: int
|
||||
is_default_provider: bool | None = None
|
||||
is_default_vision_provider: bool | None = None
|
||||
model_configurations: list["ModelConfigurationView"]
|
||||
|
||||
@classmethod
|
||||
@@ -165,6 +180,16 @@ class LLMProviderView(LLMProvider):
|
||||
|
||||
provider = llm_provider_model.provider
|
||||
|
||||
default_model_name = get_default_llm_model_name(llm_provider_model)
|
||||
default_vision_model = get_default_vision_model_name(llm_provider_model)
|
||||
|
||||
is_default_provider = bool(default_model_name)
|
||||
is_default_vision_provider = default_vision_model is not None
|
||||
|
||||
default_model_name = (
|
||||
default_model_name or llm_provider_model.default_model_name or ""
|
||||
)
|
||||
|
||||
return cls(
|
||||
id=llm_provider_model.id,
|
||||
name=llm_provider_model.name,
|
||||
@@ -177,6 +202,10 @@ class LLMProviderView(LLMProvider):
|
||||
api_base=llm_provider_model.api_base,
|
||||
api_version=llm_provider_model.api_version,
|
||||
custom_config=llm_provider_model.custom_config,
|
||||
default_model_name=default_model_name,
|
||||
is_default_provider=is_default_provider,
|
||||
is_default_vision_provider=is_default_vision_provider,
|
||||
default_vision_model=default_vision_model,
|
||||
is_public=llm_provider_model.is_public,
|
||||
is_auto_mode=llm_provider_model.is_auto_mode,
|
||||
groups=groups,
|
||||
@@ -203,8 +232,7 @@ class ModelConfigurationUpsertRequest(BaseModel):
|
||||
name=model_configuration_model.name,
|
||||
is_visible=model_configuration_model.is_visible,
|
||||
max_input_tokens=model_configuration_model.max_input_tokens,
|
||||
supports_image_input=LLMModelFlowType.VISION
|
||||
in model_configuration_model.llm_model_flow_types,
|
||||
supports_image_input=model_configuration_model.supports_image_input,
|
||||
display_name=model_configuration_model.display_name,
|
||||
)
|
||||
|
||||
@@ -397,27 +425,3 @@ class OpenRouterFinalModelResponse(BaseModel):
|
||||
int | None
|
||||
) # From OpenRouter API context_length (may be missing for some models)
|
||||
supports_image_input: bool
|
||||
|
||||
|
||||
class DefaultModel(BaseModel):
|
||||
provider_id: int
|
||||
model_name: str
|
||||
|
||||
|
||||
class LLMProviderResponse(BaseModel, Generic[T]):
|
||||
providers: list[T]
|
||||
default_text: DefaultModel | None = None
|
||||
default_vision: DefaultModel | None = None
|
||||
|
||||
@classmethod
|
||||
def from_models(
|
||||
cls,
|
||||
providers: list[T],
|
||||
default_text: DefaultModel | None = None,
|
||||
default_vision: DefaultModel | None = None,
|
||||
) -> "LLMProviderResponse[T]":
|
||||
return cls(
|
||||
providers=providers,
|
||||
default_text=default_text,
|
||||
default_vision=default_vision,
|
||||
)
|
||||
|
||||
27
backend/onyx/server/metrics/per_tenant.py
Normal file
27
backend/onyx/server/metrics/per_tenant.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Per-tenant request counter metric.
|
||||
|
||||
Increments a counter on every request, labelled by tenant, so Grafana can
|
||||
answer "which tenant is generating the most traffic?"
|
||||
"""
|
||||
|
||||
from prometheus_client import Counter
|
||||
from prometheus_fastapi_instrumentator.metrics import Info
|
||||
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
_requests_by_tenant = Counter(
|
||||
"onyx_api_requests_by_tenant_total",
|
||||
"Total API requests by tenant",
|
||||
["tenant_id", "method", "handler", "status"],
|
||||
)
|
||||
|
||||
|
||||
def per_tenant_request_callback(info: Info) -> None:
|
||||
"""Increment per-tenant request counter for every request."""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or "unknown"
|
||||
_requests_by_tenant.labels(
|
||||
tenant_id=tenant_id,
|
||||
method=info.method,
|
||||
handler=info.modified_handler,
|
||||
status=info.modified_status,
|
||||
).inc()
|
||||
@@ -32,6 +32,7 @@ from sqlalchemy.pool import QueuePool
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_ENDPOINT_CONTEXTVAR
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -72,7 +73,7 @@ _checkout_timeout_total = Counter(
|
||||
_connections_held = Gauge(
|
||||
"onyx_db_connections_held_by_endpoint",
|
||||
"Number of DB connections currently held, by endpoint and engine",
|
||||
["handler", "engine"],
|
||||
["handler", "engine", "tenant_id"],
|
||||
)
|
||||
|
||||
_hold_seconds = Histogram(
|
||||
@@ -163,10 +164,14 @@ def _register_pool_events(engine: Engine, label: str) -> None:
|
||||
conn_proxy: PoolProxiedConnection, # noqa: ARG001
|
||||
) -> None:
|
||||
handler = CURRENT_ENDPOINT_CONTEXTVAR.get() or "unknown"
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or "unknown"
|
||||
conn_record.info["_metrics_endpoint"] = handler
|
||||
conn_record.info["_metrics_tenant_id"] = tenant_id
|
||||
conn_record.info["_metrics_checkout_time"] = time.monotonic()
|
||||
_checkout_total.labels(engine=label).inc()
|
||||
_connections_held.labels(handler=handler, engine=label).inc()
|
||||
_connections_held.labels(
|
||||
handler=handler, engine=label, tenant_id=tenant_id
|
||||
).inc()
|
||||
|
||||
@event.listens_for(engine, "checkin")
|
||||
def on_checkin(
|
||||
@@ -174,9 +179,12 @@ def _register_pool_events(engine: Engine, label: str) -> None:
|
||||
conn_record: ConnectionPoolEntry,
|
||||
) -> None:
|
||||
handler = conn_record.info.pop("_metrics_endpoint", "unknown")
|
||||
tenant_id = conn_record.info.pop("_metrics_tenant_id", "unknown")
|
||||
start = conn_record.info.pop("_metrics_checkout_time", None)
|
||||
_checkin_total.labels(engine=label).inc()
|
||||
_connections_held.labels(handler=handler, engine=label).dec()
|
||||
_connections_held.labels(
|
||||
handler=handler, engine=label, tenant_id=tenant_id
|
||||
).dec()
|
||||
if start is not None:
|
||||
_hold_seconds.labels(handler=handler, engine=label).observe(
|
||||
time.monotonic() - start
|
||||
@@ -199,9 +207,12 @@ def _register_pool_events(engine: Engine, label: str) -> None:
|
||||
# Defensively clean up the held-connections gauge in case checkin
|
||||
# doesn't fire after invalidation (e.g. hard pool shutdown).
|
||||
handler = conn_record.info.pop("_metrics_endpoint", None)
|
||||
tenant_id = conn_record.info.pop("_metrics_tenant_id", "unknown")
|
||||
start = conn_record.info.pop("_metrics_checkout_time", None)
|
||||
if handler:
|
||||
_connections_held.labels(handler=handler, engine=label).dec()
|
||||
_connections_held.labels(
|
||||
handler=handler, engine=label, tenant_id=tenant_id
|
||||
).dec()
|
||||
if start is not None:
|
||||
_hold_seconds.labels(handler=handler or "unknown", engine=label).observe(
|
||||
time.monotonic() - start
|
||||
|
||||
@@ -11,9 +11,11 @@ SQLAlchemy connection pool metrics are registered separately via
|
||||
"""
|
||||
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from prometheus_fastapi_instrumentator.metrics import default as default_metrics
|
||||
from sqlalchemy.exc import TimeoutError as SATimeoutError
|
||||
from starlette.applications import Starlette
|
||||
|
||||
from onyx.server.metrics.per_tenant import per_tenant_request_callback
|
||||
from onyx.server.metrics.postgres_connection_pool import pool_timeout_handler
|
||||
from onyx.server.metrics.slow_requests import slow_request_callback
|
||||
|
||||
@@ -59,6 +61,15 @@ def setup_prometheus_metrics(app: Starlette) -> None:
|
||||
excluded_handlers=_EXCLUDED_HANDLERS,
|
||||
)
|
||||
|
||||
# Explicitly create the default metrics (http_requests_total,
|
||||
# http_request_duration_seconds, etc.) and add them first. The library
|
||||
# skips creating defaults when ANY custom instrumentations are registered
|
||||
# via .add(), so we must include them ourselves.
|
||||
default_callback = default_metrics(latency_lowr_buckets=_LATENCY_BUCKETS)
|
||||
if default_callback:
|
||||
instrumentator.add(default_callback)
|
||||
|
||||
instrumentator.add(slow_request_callback)
|
||||
instrumentator.add(per_tenant_request_callback)
|
||||
|
||||
instrumentator.instrument(app, latency_lowr_buckets=_LATENCY_BUCKETS).expose(app)
|
||||
|
||||
@@ -4,6 +4,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_NUM_ATTEMPTS_ON_STARTUP
|
||||
@@ -32,6 +33,9 @@ from onyx.db.search_settings import update_current_search_settings
|
||||
from onyx.db.swap_index import check_and_perform_index_swap
|
||||
from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
|
||||
from onyx.document_index.opensearch.opensearch_document_index import set_cluster_state
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
@@ -245,11 +249,7 @@ def setup_postgres(db_session: Session) -> None:
|
||||
create_initial_default_connector(db_session)
|
||||
associate_default_cc_pair(db_session)
|
||||
|
||||
if (
|
||||
GEN_AI_API_KEY
|
||||
and fetch_default_llm_model(db_session) is None
|
||||
and not INTEGRATION_TESTS_MODE
|
||||
):
|
||||
if GEN_AI_API_KEY and fetch_default_llm_model(db_session) is None:
|
||||
# Only for dev flows
|
||||
logger.notice("Setting up default OpenAI LLM for dev.")
|
||||
|
||||
@@ -261,6 +261,7 @@ def setup_postgres(db_session: Session) -> None:
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
custom_config=None,
|
||||
default_model_name=llm_model,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
model_configurations=[
|
||||
@@ -272,9 +273,7 @@ def setup_postgres(db_session: Session) -> None:
|
||||
new_llm_provider = upsert_llm_provider(
|
||||
llm_provider_upsert_request=model_req, db_session=db_session
|
||||
)
|
||||
update_default_provider(
|
||||
provider_id=new_llm_provider.id, model_name=llm_model, db_session=db_session
|
||||
)
|
||||
update_default_provider(provider_id=new_llm_provider.id, db_session=db_session)
|
||||
|
||||
|
||||
def update_default_multipass_indexing(db_session: Session) -> None:
|
||||
@@ -316,7 +315,14 @@ def setup_multitenant_onyx() -> None:
|
||||
logger.notice("DISABLE_VECTOR_DB is set — skipping multitenant Vespa setup.")
|
||||
return
|
||||
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
opensearch_client = OpenSearchClient()
|
||||
if not wait_for_opensearch_with_timeout(client=opensearch_client):
|
||||
raise RuntimeError("Failed to connect to OpenSearch.")
|
||||
set_cluster_state(opensearch_client)
|
||||
|
||||
# For Managed Vespa, the schema is sent over via the Vespa Console manually.
|
||||
# NOTE: Pretty sure this code is never hit in any production environment.
|
||||
if not MANAGED_VESPA:
|
||||
setup_vespa_multitenant(SUPPORTED_EMBEDDING_MODELS)
|
||||
|
||||
|
||||
@@ -120,7 +120,7 @@ def generate_intermediate_report(
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=history,
|
||||
reminder_message=reminder_message,
|
||||
project_files=None,
|
||||
context_files=None,
|
||||
available_tokens=llm.config.max_input_tokens,
|
||||
)
|
||||
|
||||
@@ -325,7 +325,7 @@ def run_research_agent_call(
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=msg_history,
|
||||
reminder_message=reminder_message,
|
||||
project_files=None,
|
||||
context_files=None,
|
||||
available_tokens=llm.config.max_input_tokens,
|
||||
)
|
||||
|
||||
|
||||
@@ -54,6 +54,7 @@ logger = setup_logger()
|
||||
class SearchToolConfig(BaseModel):
|
||||
user_selected_filters: BaseFilters | None = None
|
||||
project_id: int | None = None
|
||||
persona_id: int | None = None
|
||||
bypass_acl: bool = False
|
||||
additional_context: str | None = None
|
||||
slack_context: SlackContext | None = None
|
||||
@@ -180,6 +181,7 @@ def construct_tools(
|
||||
document_index=document_index,
|
||||
user_selected_filters=search_tool_config.user_selected_filters,
|
||||
project_id=search_tool_config.project_id,
|
||||
persona_id=search_tool_config.persona_id,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
slack_context=search_tool_config.slack_context,
|
||||
enable_slack_search=search_tool_config.enable_slack_search,
|
||||
@@ -427,6 +429,7 @@ def construct_tools(
|
||||
document_index=document_index,
|
||||
user_selected_filters=search_tool_config.user_selected_filters,
|
||||
project_id=search_tool_config.project_id,
|
||||
persona_id=search_tool_config.persona_id,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
slack_context=search_tool_config.slack_context,
|
||||
enable_slack_search=search_tool_config.enable_slack_search,
|
||||
|
||||
@@ -247,6 +247,8 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
user_selected_filters: BaseFilters | None,
|
||||
# If the chat is part of a project
|
||||
project_id: int | None,
|
||||
# If set, search scopes to files attached to this persona
|
||||
persona_id: int | None = None,
|
||||
bypass_acl: bool = False,
|
||||
# Slack context for federated Slack search (tokens fetched internally)
|
||||
slack_context: SlackContext | None = None,
|
||||
@@ -261,6 +263,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
self.document_index = document_index
|
||||
self.user_selected_filters = user_selected_filters
|
||||
self.project_id = project_id
|
||||
self.persona_id = persona_id
|
||||
self.bypass_acl = bypass_acl
|
||||
self.slack_context = slack_context
|
||||
self.enable_slack_search = enable_slack_search
|
||||
@@ -456,6 +459,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
limit=num_hits,
|
||||
),
|
||||
project_id=self.project_id,
|
||||
persona_id=self.persona_id,
|
||||
document_index=self.document_index,
|
||||
user=self.user,
|
||||
persona=self.persona,
|
||||
|
||||
@@ -257,7 +257,7 @@ exceptiongroup==1.3.0
|
||||
# via
|
||||
# braintrust
|
||||
# fastmcp
|
||||
fastapi==0.128.0
|
||||
fastapi==0.133.1
|
||||
# via
|
||||
# fastapi-limiter
|
||||
# fastapi-users
|
||||
@@ -809,7 +809,7 @@ pypandoc-binary==1.16.2
|
||||
# via onyx
|
||||
pyparsing==3.2.5
|
||||
# via httplib2
|
||||
pypdf==6.6.2
|
||||
pypdf==6.7.3
|
||||
# via
|
||||
# onyx
|
||||
# unstructured-client
|
||||
@@ -1155,6 +1155,7 @@ typing-inspect==0.9.0
|
||||
# via dataclasses-json
|
||||
typing-inspection==0.4.2
|
||||
# via
|
||||
# fastapi
|
||||
# mcp
|
||||
# pydantic
|
||||
# pydantic-settings
|
||||
@@ -1216,7 +1217,7 @@ websockets==15.0.1
|
||||
# via
|
||||
# fastmcp
|
||||
# google-genai
|
||||
werkzeug==3.1.5
|
||||
werkzeug==3.1.6
|
||||
# via sendgrid
|
||||
wrapt==1.17.3
|
||||
# via
|
||||
|
||||
@@ -125,7 +125,7 @@ executing==2.2.1
|
||||
# via stack-data
|
||||
faker==40.1.2
|
||||
# via onyx
|
||||
fastapi==0.128.0
|
||||
fastapi==0.133.1
|
||||
# via
|
||||
# onyx
|
||||
# onyx-devtools
|
||||
@@ -317,7 +317,7 @@ oauthlib==3.2.2
|
||||
# via
|
||||
# kubernetes
|
||||
# requests-oauthlib
|
||||
onyx-devtools==0.6.1
|
||||
onyx-devtools==0.6.2
|
||||
# via onyx
|
||||
openai==2.14.0
|
||||
# via
|
||||
@@ -619,6 +619,7 @@ typing-extensions==4.15.0
|
||||
# typing-inspection
|
||||
typing-inspection==0.4.2
|
||||
# via
|
||||
# fastapi
|
||||
# mcp
|
||||
# pydantic
|
||||
# pydantic-settings
|
||||
|
||||
@@ -90,7 +90,7 @@ docstring-parser==0.17.0
|
||||
# via google-cloud-aiplatform
|
||||
durationpy==0.10
|
||||
# via kubernetes
|
||||
fastapi==0.128.0
|
||||
fastapi==0.133.1
|
||||
# via onyx
|
||||
fastavro==1.12.1
|
||||
# via cohere
|
||||
@@ -398,6 +398,7 @@ typing-extensions==4.15.0
|
||||
# typing-inspection
|
||||
typing-inspection==0.4.2
|
||||
# via
|
||||
# fastapi
|
||||
# mcp
|
||||
# pydantic
|
||||
# pydantic-settings
|
||||
|
||||
@@ -108,7 +108,7 @@ durationpy==0.10
|
||||
# via kubernetes
|
||||
einops==0.8.1
|
||||
# via onyx
|
||||
fastapi==0.128.0
|
||||
fastapi==0.133.1
|
||||
# via
|
||||
# onyx
|
||||
# sentry-sdk
|
||||
@@ -525,6 +525,7 @@ typing-extensions==4.15.0
|
||||
# typing-inspection
|
||||
typing-inspection==0.4.2
|
||||
# via
|
||||
# fastapi
|
||||
# mcp
|
||||
# pydantic
|
||||
# pydantic-settings
|
||||
|
||||
@@ -12,6 +12,7 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.sharepoint.connector import SharepointAuthMethod
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
from onyx.db.enums import HierarchyNodeType
|
||||
from tests.daily.connectors.utils import load_all_from_connector
|
||||
@@ -521,3 +522,46 @@ def test_sharepoint_connector_hierarchy_nodes(
|
||||
f"Document {doc.semantic_identifier} should have "
|
||||
"parent_hierarchy_raw_node_id set"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sharepoint_cert_credentials() -> dict[str, str]:
|
||||
return {
|
||||
"authentication_method": SharepointAuthMethod.CERTIFICATE.value,
|
||||
"sp_client_id": os.environ["PERM_SYNC_SHAREPOINT_CLIENT_ID"],
|
||||
"sp_private_key": os.environ["PERM_SYNC_SHAREPOINT_PRIVATE_KEY"],
|
||||
"sp_certificate_password": os.environ[
|
||||
"PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD"
|
||||
],
|
||||
"sp_directory_id": os.environ["PERM_SYNC_SHAREPOINT_DIRECTORY_ID"],
|
||||
}
|
||||
|
||||
|
||||
def test_resolve_tenant_domain_from_site_urls(
|
||||
sharepoint_cert_credentials: dict[str, str],
|
||||
) -> None:
|
||||
"""Verify that certificate auth resolves the tenant domain from site URLs
|
||||
without calling the /organization endpoint."""
|
||||
site_url = os.environ["SHAREPOINT_SITE"]
|
||||
connector = SharepointConnector(sites=[site_url])
|
||||
connector.load_credentials(sharepoint_cert_credentials)
|
||||
|
||||
assert connector.sp_tenant_domain is not None
|
||||
assert len(connector.sp_tenant_domain) > 0
|
||||
# The tenant domain should match the first label of the site URL hostname
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
expected = urlsplit(site_url).hostname.split(".")[0] # type: ignore
|
||||
assert connector.sp_tenant_domain == expected
|
||||
|
||||
|
||||
def test_resolve_tenant_domain_from_root_site(
|
||||
sharepoint_cert_credentials: dict[str, str],
|
||||
) -> None:
|
||||
"""Verify that certificate auth resolves the tenant domain via the root
|
||||
site endpoint when no site URLs are configured."""
|
||||
connector = SharepointConnector(sites=[])
|
||||
connector.load_credentials(sharepoint_cert_credentials)
|
||||
|
||||
assert connector.sp_tenant_domain is not None
|
||||
assert len(connector.sp_tenant_domain) > 0
|
||||
|
||||
@@ -17,7 +17,7 @@ def test_bedrock_llm_configuration(client: TestClient) -> None:
|
||||
# Prepare the test request payload
|
||||
test_request: dict[str, Any] = {
|
||||
"provider": LlmProviderNames.BEDROCK,
|
||||
"model": _DEFAULT_BEDROCK_MODEL,
|
||||
"default_model_name": _DEFAULT_BEDROCK_MODEL,
|
||||
"api_key": None,
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
@@ -26,6 +26,7 @@ def test_bedrock_llm_configuration(client: TestClient) -> None:
|
||||
"AWS_ACCESS_KEY_ID": os.environ.get("AWS_ACCESS_KEY_ID"),
|
||||
"AWS_SECRET_ACCESS_KEY": os.environ.get("AWS_SECRET_ACCESS_KEY"),
|
||||
},
|
||||
"model_configurations": [{"name": _DEFAULT_BEDROCK_MODEL, "is_visible": True}],
|
||||
"api_key_changed": True,
|
||||
"custom_config_changed": True,
|
||||
}
|
||||
@@ -43,7 +44,7 @@ def test_bedrock_llm_configuration_invalid_key(client: TestClient) -> None:
|
||||
# Prepare the test request payload with invalid credentials
|
||||
test_request: dict[str, Any] = {
|
||||
"provider": LlmProviderNames.BEDROCK,
|
||||
"model": _DEFAULT_BEDROCK_MODEL,
|
||||
"default_model_name": _DEFAULT_BEDROCK_MODEL,
|
||||
"api_key": None,
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
@@ -52,6 +53,7 @@ def test_bedrock_llm_configuration_invalid_key(client: TestClient) -> None:
|
||||
"AWS_ACCESS_KEY_ID": "invalid_access_key_id",
|
||||
"AWS_SECRET_ACCESS_KEY": "invalid_secret_access_key",
|
||||
},
|
||||
"model_configurations": [{"name": _DEFAULT_BEDROCK_MODEL, "is_visible": True}],
|
||||
"api_key_changed": True,
|
||||
"custom_config_changed": True,
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ def ensure_default_llm_provider(db_session: Session) -> None:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=os.environ.get("OPENAI_API_KEY", "test"),
|
||||
is_public=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini",
|
||||
@@ -40,7 +41,7 @@ def ensure_default_llm_provider(db_session: Session) -> None:
|
||||
llm_provider_upsert_request=llm_provider_request,
|
||||
db_session=db_session,
|
||||
)
|
||||
update_default_provider(provider.id, "gpt-4o-mini", db_session)
|
||||
update_default_provider(provider.id, db_session)
|
||||
except Exception as exc: # pragma: no cover - only hits on duplicate setup issues
|
||||
# Rollback to clear the pending transaction state
|
||||
db_session.rollback()
|
||||
|
||||
@@ -47,6 +47,7 @@ def test_answer_with_only_anthropic_provider(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.ANTHROPIC,
|
||||
api_key=anthropic_api_key,
|
||||
default_model_name=anthropic_model,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
model_configurations=[
|
||||
@@ -58,7 +59,7 @@ def test_answer_with_only_anthropic_provider(
|
||||
)
|
||||
|
||||
try:
|
||||
update_default_provider(anthropic_provider.id, anthropic_model, db_session)
|
||||
update_default_provider(anthropic_provider.id, db_session)
|
||||
|
||||
test_user = create_test_user(db_session, email_prefix="anthropic_only")
|
||||
chat_session = create_chat_session(
|
||||
|
||||
@@ -0,0 +1,544 @@
|
||||
"""
|
||||
External dependency unit tests for persona file sync.
|
||||
|
||||
Validates that:
|
||||
|
||||
1. The check_for_user_file_project_sync beat task picks up UserFiles with
|
||||
needs_persona_sync=True (not just needs_project_sync).
|
||||
|
||||
2. The process_single_user_file_project_sync worker task reads persona
|
||||
associations from the DB, passes persona_ids to the document index via
|
||||
VespaDocumentUserFields, and clears needs_persona_sync afterwards.
|
||||
|
||||
3. upsert_persona correctly marks affected UserFiles with
|
||||
needs_persona_sync=True when file associations change.
|
||||
|
||||
Uses real Redis and PostgreSQL. Document index (Vespa) calls are mocked
|
||||
since we only need to verify the arguments passed to update_single.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import PropertyMock
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
check_for_user_file_project_sync,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
process_single_user_file_project_sync,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
user_file_project_sync_lock_key,
|
||||
)
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserFile
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _create_completed_user_file(
|
||||
db_session: Session,
|
||||
user: User,
|
||||
needs_persona_sync: bool = False,
|
||||
needs_project_sync: bool = False,
|
||||
) -> UserFile:
|
||||
"""Insert a UserFile in COMPLETED status."""
|
||||
uf = UserFile(
|
||||
id=uuid4(),
|
||||
user_id=user.id,
|
||||
file_id=f"test_file_{uuid4().hex[:8]}",
|
||||
name=f"test_{uuid4().hex[:8]}.txt",
|
||||
file_type="text/plain",
|
||||
status=UserFileStatus.COMPLETED,
|
||||
needs_persona_sync=needs_persona_sync,
|
||||
needs_project_sync=needs_project_sync,
|
||||
chunk_count=5,
|
||||
)
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
db_session.refresh(uf)
|
||||
return uf
|
||||
|
||||
|
||||
def _create_test_persona(
|
||||
db_session: Session,
|
||||
user: User,
|
||||
user_files: list[UserFile] | None = None,
|
||||
) -> Persona:
|
||||
"""Create a minimal Persona via direct model insert."""
|
||||
persona = Persona(
|
||||
name=f"Test Persona {uuid4().hex[:8]}",
|
||||
description="Test persona",
|
||||
num_chunks=10.0,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
system_prompt="You are a test assistant",
|
||||
task_prompt="Answer the question",
|
||||
tools=[],
|
||||
document_sets=[],
|
||||
users=[user],
|
||||
groups=[],
|
||||
is_visible=True,
|
||||
is_public=True,
|
||||
display_priority=None,
|
||||
starter_messages=None,
|
||||
deleted=False,
|
||||
user_files=user_files or [],
|
||||
user_id=user.id,
|
||||
)
|
||||
db_session.add(persona)
|
||||
db_session.commit()
|
||||
db_session.refresh(persona)
|
||||
return persona
|
||||
|
||||
|
||||
def _link_file_to_persona(
|
||||
db_session: Session, persona: Persona, user_file: UserFile
|
||||
) -> None:
|
||||
"""Create the join table row between a persona and a user file."""
|
||||
link = Persona__UserFile(persona_id=persona.id, user_file_id=user_file.id)
|
||||
db_session.add(link)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
_PATCH_QUEUE_DEPTH = (
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks"
|
||||
".get_user_file_project_sync_queue_depth"
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, None]:
|
||||
"""Patch the ``app`` property on a bound Celery task."""
|
||||
task_instance = task.run.__self__
|
||||
with (
|
||||
patch.object(
|
||||
type(task_instance),
|
||||
"app",
|
||||
new_callable=PropertyMock,
|
||||
return_value=mock_app,
|
||||
),
|
||||
patch(_PATCH_QUEUE_DEPTH, return_value=0),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: check_for_user_file_project_sync picks up persona sync
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckSweepIncludesPersonaSync:
|
||||
"""The beat task must pick up files needing persona sync, not just project sync."""
|
||||
|
||||
def test_persona_sync_flag_enqueues_task(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file with needs_persona_sync=True (and COMPLETED) gets enqueued."""
|
||||
user = create_test_user(db_session, "persona_sweep")
|
||||
uf = _create_completed_user_file(db_session, user, needs_persona_sync=True)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
with _patch_task_app(check_for_user_file_project_sync, mock_app):
|
||||
check_for_user_file_project_sync.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
enqueued_ids = {
|
||||
call.kwargs["kwargs"]["user_file_id"]
|
||||
for call in mock_app.send_task.call_args_list
|
||||
}
|
||||
assert str(uf.id) in enqueued_ids
|
||||
|
||||
def test_neither_flag_does_not_enqueue(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file with both flags False is not enqueued."""
|
||||
user = create_test_user(db_session, "no_sync")
|
||||
uf = _create_completed_user_file(db_session, user)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
with _patch_task_app(check_for_user_file_project_sync, mock_app):
|
||||
check_for_user_file_project_sync.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
enqueued_ids = {
|
||||
call.kwargs["kwargs"]["user_file_id"]
|
||||
for call in mock_app.send_task.call_args_list
|
||||
}
|
||||
assert str(uf.id) not in enqueued_ids
|
||||
|
||||
def test_both_flags_enqueues_once(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file with BOTH flags True is enqueued exactly once."""
|
||||
user = create_test_user(db_session, "both_flags")
|
||||
uf = _create_completed_user_file(
|
||||
db_session, user, needs_persona_sync=True, needs_project_sync=True
|
||||
)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
with _patch_task_app(check_for_user_file_project_sync, mock_app):
|
||||
check_for_user_file_project_sync.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
matching_calls = [
|
||||
call
|
||||
for call in mock_app.send_task.call_args_list
|
||||
if call.kwargs["kwargs"]["user_file_id"] == str(uf.id)
|
||||
]
|
||||
assert len(matching_calls) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: process_single_user_file_project_sync passes persona_ids to index
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PATCH_GET_SETTINGS = (
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.get_active_search_settings"
|
||||
)
|
||||
_PATCH_GET_INDICES = (
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.get_all_document_indices"
|
||||
)
|
||||
_PATCH_HTTPX_INIT = (
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.httpx_init_vespa_pool"
|
||||
)
|
||||
_PATCH_DISABLE_VDB = (
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.DISABLE_VECTOR_DB"
|
||||
)
|
||||
|
||||
|
||||
class TestSyncTaskWritesPersonaIds:
|
||||
"""The sync task reads persona associations and sends them to the index."""
|
||||
|
||||
def test_passes_persona_ids_to_update_single(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""After linking a file to a persona, sync sends the persona ID."""
|
||||
user = create_test_user(db_session, "sync_persona")
|
||||
uf = _create_completed_user_file(db_session, user, needs_persona_sync=True)
|
||||
persona = _create_test_persona(db_session, user)
|
||||
_link_file_to_persona(db_session, persona, uf)
|
||||
|
||||
mock_doc_index = MagicMock()
|
||||
mock_search_settings = MagicMock()
|
||||
mock_search_settings.primary = MagicMock()
|
||||
mock_search_settings.secondary = None
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
lock_key = user_file_project_sync_lock_key(str(uf.id))
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
with (
|
||||
patch(_PATCH_DISABLE_VDB, False),
|
||||
patch(_PATCH_HTTPX_INIT),
|
||||
patch(_PATCH_GET_SETTINGS, return_value=mock_search_settings),
|
||||
patch(_PATCH_GET_INDICES, return_value=[mock_doc_index]),
|
||||
):
|
||||
process_single_user_file_project_sync.run(
|
||||
user_file_id=str(uf.id), tenant_id=TEST_TENANT_ID
|
||||
)
|
||||
|
||||
mock_doc_index.update_single.assert_called_once()
|
||||
call_args = mock_doc_index.update_single.call_args
|
||||
user_fields: VespaDocumentUserFields = call_args.kwargs["user_fields"]
|
||||
assert user_fields.personas is not None
|
||||
assert persona.id in user_fields.personas
|
||||
assert call_args.args[0] == str(uf.id)
|
||||
|
||||
def test_clears_persona_sync_flag(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""After a successful sync the needs_persona_sync flag is cleared."""
|
||||
user = create_test_user(db_session, "sync_clear")
|
||||
uf = _create_completed_user_file(db_session, user, needs_persona_sync=True)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
lock_key = user_file_project_sync_lock_key(str(uf.id))
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
with patch(_PATCH_DISABLE_VDB, True):
|
||||
process_single_user_file_project_sync.run(
|
||||
user_file_id=str(uf.id), tenant_id=TEST_TENANT_ID
|
||||
)
|
||||
|
||||
db_session.refresh(uf)
|
||||
assert uf.needs_persona_sync is False
|
||||
|
||||
def test_passes_both_project_and_persona_ids(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file linked to both a project and a persona gets both IDs."""
|
||||
from onyx.db.models import Project__UserFile
|
||||
from onyx.db.models import UserProject
|
||||
|
||||
user = create_test_user(db_session, "sync_both")
|
||||
uf = _create_completed_user_file(
|
||||
db_session, user, needs_persona_sync=True, needs_project_sync=True
|
||||
)
|
||||
persona = _create_test_persona(db_session, user)
|
||||
_link_file_to_persona(db_session, persona, uf)
|
||||
|
||||
project = UserProject(user_id=user.id, name="test-project", instructions="")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
db_session.refresh(project)
|
||||
|
||||
link = Project__UserFile(project_id=project.id, user_file_id=uf.id)
|
||||
db_session.add(link)
|
||||
db_session.commit()
|
||||
|
||||
mock_doc_index = MagicMock()
|
||||
mock_search_settings = MagicMock()
|
||||
mock_search_settings.primary = MagicMock()
|
||||
mock_search_settings.secondary = None
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
lock_key = user_file_project_sync_lock_key(str(uf.id))
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
with (
|
||||
patch(_PATCH_DISABLE_VDB, False),
|
||||
patch(_PATCH_HTTPX_INIT),
|
||||
patch(_PATCH_GET_SETTINGS, return_value=mock_search_settings),
|
||||
patch(_PATCH_GET_INDICES, return_value=[mock_doc_index]),
|
||||
):
|
||||
process_single_user_file_project_sync.run(
|
||||
user_file_id=str(uf.id), tenant_id=TEST_TENANT_ID
|
||||
)
|
||||
|
||||
call_kwargs = mock_doc_index.update_single.call_args.kwargs
|
||||
user_fields: VespaDocumentUserFields = call_kwargs["user_fields"]
|
||||
assert user_fields.personas is not None
|
||||
assert user_fields.user_projects is not None
|
||||
assert persona.id in user_fields.personas
|
||||
assert project.id in user_fields.user_projects
|
||||
|
||||
# Both flags should be cleared
|
||||
db_session.refresh(uf)
|
||||
assert uf.needs_persona_sync is False
|
||||
assert uf.needs_project_sync is False
|
||||
|
||||
def test_deleted_persona_excluded_from_ids(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A soft-deleted persona should NOT appear in the persona_ids sent to Vespa."""
|
||||
user = create_test_user(db_session, "sync_deleted")
|
||||
uf = _create_completed_user_file(db_session, user, needs_persona_sync=True)
|
||||
persona = _create_test_persona(db_session, user)
|
||||
_link_file_to_persona(db_session, persona, uf)
|
||||
|
||||
persona.deleted = True
|
||||
db_session.commit()
|
||||
|
||||
mock_doc_index = MagicMock()
|
||||
mock_search_settings = MagicMock()
|
||||
mock_search_settings.primary = MagicMock()
|
||||
mock_search_settings.secondary = None
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
lock_key = user_file_project_sync_lock_key(str(uf.id))
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
with (
|
||||
patch(_PATCH_DISABLE_VDB, False),
|
||||
patch(_PATCH_HTTPX_INIT),
|
||||
patch(_PATCH_GET_SETTINGS, return_value=mock_search_settings),
|
||||
patch(_PATCH_GET_INDICES, return_value=[mock_doc_index]),
|
||||
):
|
||||
process_single_user_file_project_sync.run(
|
||||
user_file_id=str(uf.id), tenant_id=TEST_TENANT_ID
|
||||
)
|
||||
|
||||
call_kwargs = mock_doc_index.update_single.call_args.kwargs
|
||||
user_fields: VespaDocumentUserFields = call_kwargs["user_fields"]
|
||||
assert user_fields.personas is not None
|
||||
assert persona.id not in user_fields.personas
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: upsert_persona marks files for persona sync
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUpsertPersonaMarksSyncFlag:
|
||||
"""upsert_persona must set needs_persona_sync on affected UserFiles."""
|
||||
|
||||
def test_creating_persona_with_files_marks_sync(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "upsert_create")
|
||||
uf = _create_completed_user_file(db_session, user)
|
||||
assert uf.needs_persona_sync is False
|
||||
|
||||
upsert_persona(
|
||||
user=user,
|
||||
name=f"persona-{uuid4().hex[:8]}",
|
||||
description="test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
system_prompt="test",
|
||||
task_prompt="test",
|
||||
datetime_aware=None,
|
||||
is_public=True,
|
||||
db_session=db_session,
|
||||
user_file_ids=[uf.id],
|
||||
)
|
||||
|
||||
db_session.refresh(uf)
|
||||
assert uf.needs_persona_sync is True
|
||||
|
||||
def test_updating_persona_files_marks_both_old_and_new(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""When file associations change, both the removed and added files are flagged."""
|
||||
user = create_test_user(db_session, "upsert_update")
|
||||
uf_old = _create_completed_user_file(db_session, user)
|
||||
uf_new = _create_completed_user_file(db_session, user)
|
||||
|
||||
persona = upsert_persona(
|
||||
user=user,
|
||||
name=f"persona-{uuid4().hex[:8]}",
|
||||
description="test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
system_prompt="test",
|
||||
task_prompt="test",
|
||||
datetime_aware=None,
|
||||
is_public=True,
|
||||
db_session=db_session,
|
||||
user_file_ids=[uf_old.id],
|
||||
)
|
||||
|
||||
# Clear the flag from creation so we can observe the update
|
||||
uf_old.needs_persona_sync = False
|
||||
db_session.commit()
|
||||
|
||||
assert persona.num_chunks is not None
|
||||
# Now update the persona to swap files
|
||||
upsert_persona(
|
||||
user=user,
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=persona.num_chunks,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=persona.recency_bias,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
system_prompt=persona.system_prompt,
|
||||
task_prompt=persona.task_prompt,
|
||||
datetime_aware=None,
|
||||
is_public=persona.is_public,
|
||||
db_session=db_session,
|
||||
persona_id=persona.id,
|
||||
user_file_ids=[uf_new.id],
|
||||
)
|
||||
|
||||
db_session.refresh(uf_old)
|
||||
db_session.refresh(uf_new)
|
||||
assert uf_old.needs_persona_sync is True, "Removed file should be flagged"
|
||||
assert uf_new.needs_persona_sync is True, "Added file should be flagged"
|
||||
|
||||
def test_removing_all_files_marks_old_files(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""Removing all files from a persona flags the previously associated files."""
|
||||
user = create_test_user(db_session, "upsert_remove")
|
||||
uf = _create_completed_user_file(db_session, user)
|
||||
|
||||
persona = upsert_persona(
|
||||
user=user,
|
||||
name=f"persona-{uuid4().hex[:8]}",
|
||||
description="test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
system_prompt="test",
|
||||
task_prompt="test",
|
||||
datetime_aware=None,
|
||||
is_public=True,
|
||||
db_session=db_session,
|
||||
user_file_ids=[uf.id],
|
||||
)
|
||||
|
||||
uf.needs_persona_sync = False
|
||||
db_session.commit()
|
||||
|
||||
assert persona.num_chunks is not None
|
||||
upsert_persona(
|
||||
user=user,
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=persona.num_chunks,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=persona.recency_bias,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
system_prompt=persona.system_prompt,
|
||||
task_prompt=persona.task_prompt,
|
||||
datetime_aware=None,
|
||||
is_public=persona.is_public,
|
||||
db_session=db_session,
|
||||
persona_id=persona.id,
|
||||
user_file_ids=[],
|
||||
)
|
||||
|
||||
db_session.refresh(uf)
|
||||
assert uf.needs_persona_sync is True
|
||||
@@ -0,0 +1,318 @@
|
||||
"""
|
||||
External dependency unit tests for UserFileIndexingAdapter metadata writing.
|
||||
|
||||
Validates that build_metadata_aware_chunks produces DocMetadataAwareIndexChunk
|
||||
objects with both `user_project` and `personas` fields populated correctly
|
||||
based on actual DB associations.
|
||||
|
||||
Uses real PostgreSQL for UserFile/Persona/UserProject rows.
|
||||
Mocks the LLM tokenizer and file store since they are not relevant here.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserFile
|
||||
from onyx.db.models import Project__UserFile
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.models import UserProject
|
||||
from onyx.indexing.adapters.user_file_indexing_adapter import UserFileIndexingAdapter
|
||||
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
from onyx.indexing.models import IndexChunk
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _create_user_file(db_session: Session, user: User) -> UserFile:
|
||||
uf = UserFile(
|
||||
id=uuid4(),
|
||||
user_id=user.id,
|
||||
file_id=f"test_file_{uuid4().hex[:8]}",
|
||||
name=f"test_{uuid4().hex[:8]}.txt",
|
||||
file_type="text/plain",
|
||||
status=UserFileStatus.COMPLETED,
|
||||
chunk_count=1,
|
||||
)
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
db_session.refresh(uf)
|
||||
return uf
|
||||
|
||||
|
||||
def _create_persona(db_session: Session, user: User) -> Persona:
|
||||
persona = Persona(
|
||||
name=f"Test Persona {uuid4().hex[:8]}",
|
||||
description="Test persona",
|
||||
num_chunks=10.0,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
system_prompt="test",
|
||||
task_prompt="test",
|
||||
tools=[],
|
||||
document_sets=[],
|
||||
users=[user],
|
||||
groups=[],
|
||||
is_visible=True,
|
||||
is_public=True,
|
||||
display_priority=None,
|
||||
starter_messages=None,
|
||||
deleted=False,
|
||||
user_id=user.id,
|
||||
)
|
||||
db_session.add(persona)
|
||||
db_session.commit()
|
||||
db_session.refresh(persona)
|
||||
return persona
|
||||
|
||||
|
||||
def _create_project(db_session: Session, user: User) -> UserProject:
|
||||
project = UserProject(
|
||||
user_id=user.id,
|
||||
name=f"project-{uuid4().hex[:8]}",
|
||||
instructions="",
|
||||
)
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
db_session.refresh(project)
|
||||
return project
|
||||
|
||||
|
||||
def _make_index_chunk(user_file: UserFile) -> IndexChunk:
|
||||
"""Build a minimal IndexChunk whose source document ID matches the UserFile."""
|
||||
doc = Document(
|
||||
id=str(user_file.id),
|
||||
source=DocumentSource.USER_FILE,
|
||||
semantic_identifier=user_file.name,
|
||||
sections=[TextSection(text="test chunk content", link=None)],
|
||||
metadata={},
|
||||
)
|
||||
return IndexChunk(
|
||||
source_document=doc,
|
||||
chunk_id=0,
|
||||
blurb="test chunk",
|
||||
content="test chunk content",
|
||||
source_links={0: ""},
|
||||
image_file_id=None,
|
||||
section_continuation=False,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
embeddings=ChunkEmbedding(
|
||||
full_embedding=[0.0] * 768,
|
||||
mini_chunk_embeddings=[],
|
||||
),
|
||||
title_embedding=None,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAdapterWritesBothMetadataFields:
|
||||
"""build_metadata_aware_chunks must populate user_project AND personas."""
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
|
||||
side_effect=Exception("no LLM in test"),
|
||||
)
|
||||
def test_file_linked_to_persona_gets_persona_id(
|
||||
self,
|
||||
_mock_llm: MagicMock,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "adapter_persona")
|
||||
uf = _create_user_file(db_session, user)
|
||||
persona = _create_persona(db_session, user)
|
||||
|
||||
db_session.add(Persona__UserFile(persona_id=persona.id, user_file_id=uf.id))
|
||||
db_session.commit()
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=TEST_TENANT_ID, db_session=db_session
|
||||
)
|
||||
chunk = _make_index_chunk(uf)
|
||||
doc = chunk.source_document
|
||||
context = DocumentBatchPrepareContext(updatable_docs=[doc], id_to_boost_map={})
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert len(result.chunks) == 1
|
||||
aware_chunk = result.chunks[0]
|
||||
assert persona.id in aware_chunk.personas
|
||||
assert aware_chunk.user_project == []
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
|
||||
side_effect=Exception("no LLM in test"),
|
||||
)
|
||||
def test_file_linked_to_project_gets_project_id(
|
||||
self,
|
||||
_mock_llm: MagicMock,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "adapter_project")
|
||||
uf = _create_user_file(db_session, user)
|
||||
project = _create_project(db_session, user)
|
||||
|
||||
db_session.add(Project__UserFile(project_id=project.id, user_file_id=uf.id))
|
||||
db_session.commit()
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=TEST_TENANT_ID, db_session=db_session
|
||||
)
|
||||
chunk = _make_index_chunk(uf)
|
||||
context = DocumentBatchPrepareContext(
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert len(result.chunks) == 1
|
||||
aware_chunk = result.chunks[0]
|
||||
assert project.id in aware_chunk.user_project
|
||||
assert aware_chunk.personas == []
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
|
||||
side_effect=Exception("no LLM in test"),
|
||||
)
|
||||
def test_file_linked_to_both_gets_both_ids(
|
||||
self,
|
||||
_mock_llm: MagicMock,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "adapter_both")
|
||||
uf = _create_user_file(db_session, user)
|
||||
persona = _create_persona(db_session, user)
|
||||
project = _create_project(db_session, user)
|
||||
|
||||
db_session.add(Persona__UserFile(persona_id=persona.id, user_file_id=uf.id))
|
||||
db_session.add(Project__UserFile(project_id=project.id, user_file_id=uf.id))
|
||||
db_session.commit()
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=TEST_TENANT_ID, db_session=db_session
|
||||
)
|
||||
chunk = _make_index_chunk(uf)
|
||||
context = DocumentBatchPrepareContext(
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
context=context,
|
||||
)
|
||||
|
||||
aware_chunk = result.chunks[0]
|
||||
assert persona.id in aware_chunk.personas
|
||||
assert project.id in aware_chunk.user_project
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
|
||||
side_effect=Exception("no LLM in test"),
|
||||
)
|
||||
def test_file_with_no_associations_gets_empty_lists(
|
||||
self,
|
||||
_mock_llm: MagicMock,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "adapter_empty")
|
||||
uf = _create_user_file(db_session, user)
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=TEST_TENANT_ID, db_session=db_session
|
||||
)
|
||||
chunk = _make_index_chunk(uf)
|
||||
context = DocumentBatchPrepareContext(
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
context=context,
|
||||
)
|
||||
|
||||
aware_chunk = result.chunks[0]
|
||||
assert aware_chunk.personas == []
|
||||
assert aware_chunk.user_project == []
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
|
||||
side_effect=Exception("no LLM in test"),
|
||||
)
|
||||
def test_multiple_personas_all_appear(
|
||||
self,
|
||||
_mock_llm: MagicMock,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file linked to multiple personas should have all their IDs."""
|
||||
user = create_test_user(db_session, "adapter_multi")
|
||||
uf = _create_user_file(db_session, user)
|
||||
persona_a = _create_persona(db_session, user)
|
||||
persona_b = _create_persona(db_session, user)
|
||||
|
||||
db_session.add(Persona__UserFile(persona_id=persona_a.id, user_file_id=uf.id))
|
||||
db_session.add(Persona__UserFile(persona_id=persona_b.id, user_file_id=uf.id))
|
||||
db_session.commit()
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=TEST_TENANT_ID, db_session=db_session
|
||||
)
|
||||
chunk = _make_index_chunk(uf)
|
||||
context = DocumentBatchPrepareContext(
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
context=context,
|
||||
)
|
||||
|
||||
aware_chunk = result.chunks[0]
|
||||
assert set(aware_chunk.personas) == {persona_a.id, persona_b.id}
|
||||
@@ -29,7 +29,6 @@ from onyx.server.manage.llm.api import (
|
||||
test_llm_configuration as run_test_llm_configuration,
|
||||
)
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.manage.llm.models import TestLLMRequest as LLMTestRequest
|
||||
|
||||
@@ -45,14 +44,15 @@ def _create_test_provider(
|
||||
db_session: Session,
|
||||
name: str,
|
||||
api_key: str = "sk-test-key-00000000000000000000000000000000000",
|
||||
) -> LLMProviderView:
|
||||
) -> None:
|
||||
"""Helper to create a test LLM provider in the database."""
|
||||
return upsert_llm_provider(
|
||||
upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=api_key,
|
||||
api_key_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o-mini", is_visible=True)
|
||||
],
|
||||
@@ -107,7 +107,12 @@ class TestLLMConfigurationEndpoint:
|
||||
api_key="sk-new-test-key-0000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=False,
|
||||
model="gpt-4o-mini",
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -152,7 +157,12 @@ class TestLLMConfigurationEndpoint:
|
||||
api_key="sk-invalid-key-00000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=False,
|
||||
model="gpt-4o-mini",
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -184,9 +194,7 @@ class TestLLMConfigurationEndpoint:
|
||||
|
||||
try:
|
||||
# First, create the provider in the database
|
||||
provider = _create_test_provider(
|
||||
db_session, provider_name, api_key=original_api_key
|
||||
)
|
||||
_create_test_provider(db_session, provider_name, api_key=original_api_key)
|
||||
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
|
||||
@@ -194,13 +202,17 @@ class TestLLMConfigurationEndpoint:
|
||||
# Test with api_key_changed=False - should use stored key
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
id=provider.id,
|
||||
name=provider_name, # Existing provider
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None, # Not providing a new key
|
||||
api_key_changed=False, # Using existing key
|
||||
custom_config_changed=False,
|
||||
model="gpt-4o-mini",
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -247,7 +259,12 @@ class TestLLMConfigurationEndpoint:
|
||||
api_key=new_api_key, # Providing a new key
|
||||
api_key_changed=True, # Key is being changed
|
||||
custom_config_changed=False,
|
||||
model="gpt-4o-mini",
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -280,7 +297,7 @@ class TestLLMConfigurationEndpoint:
|
||||
|
||||
try:
|
||||
# First, create the provider in the database with custom_config
|
||||
provider = upsert_llm_provider(
|
||||
upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
@@ -288,6 +305,12 @@ class TestLLMConfigurationEndpoint:
|
||||
api_key_changed=True,
|
||||
custom_config=original_custom_config,
|
||||
custom_config_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
@@ -298,14 +321,18 @@ class TestLLMConfigurationEndpoint:
|
||||
# Test with custom_config_changed=False - should use stored config
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None,
|
||||
api_key_changed=False,
|
||||
custom_config=None, # Not providing new config
|
||||
custom_config_changed=False, # Using existing config
|
||||
model="gpt-4o-mini",
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -346,7 +373,12 @@ class TestLLMConfigurationEndpoint:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=False,
|
||||
model=model_name,
|
||||
default_model_name=model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -410,6 +442,7 @@ class TestDefaultProviderEndpoint:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=provider_1_api_key,
|
||||
api_key_changed=True,
|
||||
default_model_name=provider_1_initial_model,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True),
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
|
||||
@@ -419,7 +452,7 @@ class TestDefaultProviderEndpoint:
|
||||
)
|
||||
|
||||
# Set provider 1 as the default provider explicitly
|
||||
update_default_provider(provider_1.id, provider_1_initial_model, db_session)
|
||||
update_default_provider(provider_1.id, db_session)
|
||||
|
||||
# Step 2: Call run_test_default_provider - should use provider 1's default model
|
||||
with patch(
|
||||
@@ -439,6 +472,7 @@ class TestDefaultProviderEndpoint:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=provider_2_api_key,
|
||||
api_key_changed=True,
|
||||
default_model_name=provider_2_default_model,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
@@ -465,11 +499,11 @@ class TestDefaultProviderEndpoint:
|
||||
# Step 5: Update provider 1's default model
|
||||
upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
id=provider_1.id,
|
||||
name=provider_1_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=provider_1_api_key,
|
||||
api_key_changed=True,
|
||||
default_model_name=provider_1_updated_model, # Changed
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True),
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
|
||||
@@ -478,9 +512,6 @@ class TestDefaultProviderEndpoint:
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Set provider 1's default model to the updated model
|
||||
update_default_provider(provider_1.id, provider_1_updated_model, db_session)
|
||||
|
||||
# Step 6: Call run_test_default_provider - should use new model on provider 1
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
|
||||
@@ -493,7 +524,7 @@ class TestDefaultProviderEndpoint:
|
||||
captured_llms.clear()
|
||||
|
||||
# Step 7: Change the default provider to provider 2
|
||||
update_default_provider(provider_2.id, provider_2_default_model, db_session)
|
||||
update_default_provider(provider_2.id, db_session)
|
||||
|
||||
# Step 8: Call run_test_default_provider - should use provider 2
|
||||
with patch(
|
||||
@@ -565,6 +596,7 @@ class TestDefaultProviderEndpoint:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
@@ -573,7 +605,7 @@ class TestDefaultProviderEndpoint:
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
update_default_provider(provider.id, "gpt-4o-mini", db_session)
|
||||
update_default_provider(provider.id, db_session)
|
||||
|
||||
# Test should fail
|
||||
with patch(
|
||||
|
||||
@@ -20,7 +20,6 @@ from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import fetch_llm_provider_view
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import UserRole
|
||||
@@ -50,6 +49,7 @@ def _create_test_provider(
|
||||
api_key_changed=True,
|
||||
api_base=api_base,
|
||||
custom_config=custom_config,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o-mini", is_visible=True)
|
||||
],
|
||||
@@ -91,14 +91,14 @@ class TestLLMProviderChanges:
|
||||
the API key should be blocked.
|
||||
"""
|
||||
try:
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
_create_test_provider(db_session, provider_name)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base="https://attacker.example.com",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -125,16 +125,16 @@ class TestLLMProviderChanges:
|
||||
Changing api_base IS allowed when the API key is also being changed.
|
||||
"""
|
||||
try:
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
_create_test_provider(db_session, provider_name)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-new-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
api_base="https://custom-endpoint.example.com/v1",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -159,16 +159,14 @@ class TestLLMProviderChanges:
|
||||
original_api_base = "https://original.example.com/v1"
|
||||
|
||||
try:
|
||||
provider = _create_test_provider(
|
||||
db_session, provider_name, api_base=original_api_base
|
||||
)
|
||||
_create_test_provider(db_session, provider_name, api_base=original_api_base)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base=original_api_base,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -192,14 +190,14 @@ class TestLLMProviderChanges:
|
||||
changes. This allows model-only updates when provider has no custom base URL.
|
||||
"""
|
||||
try:
|
||||
view = _create_test_provider(db_session, provider_name, api_base=None)
|
||||
_create_test_provider(db_session, provider_name, api_base=None)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=view.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base="",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -225,16 +223,14 @@ class TestLLMProviderChanges:
|
||||
original_api_base = "https://original.example.com/v1"
|
||||
|
||||
try:
|
||||
provider = _create_test_provider(
|
||||
db_session, provider_name, api_base=original_api_base
|
||||
)
|
||||
_create_test_provider(db_session, provider_name, api_base=original_api_base)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base=None,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -263,14 +259,14 @@ class TestLLMProviderChanges:
|
||||
users have full control over their deployment.
|
||||
"""
|
||||
try:
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
_create_test_provider(db_session, provider_name)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base="https://custom.example.com/v1",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -301,6 +297,7 @@ class TestLLMProviderChanges:
|
||||
api_key="sk-new-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
api_base="https://custom.example.com/v1",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -325,7 +322,7 @@ class TestLLMProviderChanges:
|
||||
redirect LLM API requests).
|
||||
"""
|
||||
try:
|
||||
provider = _create_test_provider(
|
||||
_create_test_provider(
|
||||
db_session,
|
||||
provider_name,
|
||||
custom_config={"SOME_CONFIG": "original_value"},
|
||||
@@ -333,11 +330,11 @@ class TestLLMProviderChanges:
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
custom_config={"OPENAI_API_BASE": "https://attacker.example.com"},
|
||||
custom_config_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -365,15 +362,15 @@ class TestLLMProviderChanges:
|
||||
without changing the API key.
|
||||
"""
|
||||
try:
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
_create_test_provider(db_session, provider_name)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
custom_config={"OPENAI_API_BASE": "https://attacker.example.com"},
|
||||
custom_config_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -402,7 +399,7 @@ class TestLLMProviderChanges:
|
||||
new_config = {"AWS_REGION_NAME": "us-west-2"}
|
||||
|
||||
try:
|
||||
provider = _create_test_provider(
|
||||
_create_test_provider(
|
||||
db_session,
|
||||
provider_name,
|
||||
custom_config={"AWS_REGION_NAME": "us-east-1"},
|
||||
@@ -410,13 +407,13 @@ class TestLLMProviderChanges:
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-new-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=True,
|
||||
custom_config=new_config,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -441,17 +438,17 @@ class TestLLMProviderChanges:
|
||||
original_config = {"AWS_REGION_NAME": "us-east-1"}
|
||||
|
||||
try:
|
||||
provider = _create_test_provider(
|
||||
_create_test_provider(
|
||||
db_session, provider_name, custom_config=original_config
|
||||
)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
custom_config=original_config,
|
||||
custom_config_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -477,7 +474,7 @@ class TestLLMProviderChanges:
|
||||
new_config = {"AWS_REGION_NAME": "eu-west-1"}
|
||||
|
||||
try:
|
||||
provider = _create_test_provider(
|
||||
_create_test_provider(
|
||||
db_session,
|
||||
provider_name,
|
||||
custom_config={"AWS_REGION_NAME": "us-east-1"},
|
||||
@@ -485,10 +482,10 @@ class TestLLMProviderChanges:
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
custom_config=new_config,
|
||||
default_model_name="gpt-4o-mini",
|
||||
custom_config_changed=True,
|
||||
)
|
||||
|
||||
@@ -535,7 +532,12 @@ def test_upload_with_custom_config_then_change(
|
||||
LLMTestRequest(
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
model=default_model_name,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
api_key_changed=False,
|
||||
custom_config_changed=True,
|
||||
custom_config=custom_config,
|
||||
@@ -544,10 +546,11 @@ def test_upload_with_custom_config_then_change(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
provider = put_llm_provider(
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
default_model_name=default_model_name,
|
||||
custom_config=custom_config,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -566,10 +569,14 @@ def test_upload_with_custom_config_then_change(
|
||||
# Turn auto mode off
|
||||
run_llm_config_test(
|
||||
LLMTestRequest(
|
||||
id=provider.id,
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
model=default_model_name,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
api_key_changed=False,
|
||||
custom_config_changed=False,
|
||||
),
|
||||
@@ -579,9 +586,9 @@ def test_upload_with_custom_config_then_change(
|
||||
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
@@ -609,9 +616,7 @@ def test_upload_with_custom_config_then_change(
|
||||
)
|
||||
|
||||
# Check inside the database and check that custom_config is the same as the original
|
||||
provider = fetch_llm_provider_view(
|
||||
db_session=db_session, provider_name=name
|
||||
)
|
||||
provider = fetch_existing_llm_provider(name=name, db_session=db_session)
|
||||
if not provider:
|
||||
assert False, "Provider not found in the database"
|
||||
|
||||
@@ -637,10 +642,11 @@ def test_preserves_masked_sensitive_custom_config_on_provider_update(
|
||||
}
|
||||
|
||||
try:
|
||||
view = put_llm_provider(
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=provider,
|
||||
default_model_name=default_model_name,
|
||||
custom_config=original_custom_config,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -659,9 +665,9 @@ def test_preserves_masked_sensitive_custom_config_on_provider_update(
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
id=view.id,
|
||||
name=name,
|
||||
provider=provider,
|
||||
default_model_name=default_model_name,
|
||||
custom_config={
|
||||
"vertex_credentials": _mask_string(
|
||||
original_custom_config["vertex_credentials"]
|
||||
@@ -713,10 +719,11 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
|
||||
return ""
|
||||
|
||||
try:
|
||||
view = put_llm_provider(
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=provider,
|
||||
default_model_name=default_model_name,
|
||||
custom_config=original_custom_config,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -735,10 +742,14 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
|
||||
with patch("onyx.server.manage.llm.api.test_llm", side_effect=capture_test_llm):
|
||||
run_llm_config_test(
|
||||
LLMTestRequest(
|
||||
id=view.id,
|
||||
name=name,
|
||||
provider=provider,
|
||||
model=default_model_name,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
api_key_changed=False,
|
||||
custom_config_changed=True,
|
||||
custom_config={
|
||||
|
||||
@@ -18,7 +18,6 @@ from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_llm_provider_view
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import sync_auto_mode_models
|
||||
from onyx.db.llm import update_default_provider
|
||||
@@ -136,6 +135,7 @@ class TestAutoModeSyncFeature:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name=expected_default_model,
|
||||
model_configurations=[], # No model configs provided
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -163,8 +163,13 @@ class TestAutoModeSyncFeature:
|
||||
if mc.name in all_expected_models:
|
||||
assert mc.is_visible is True, f"Model '{mc.name}' should be visible"
|
||||
|
||||
# Verify the default model was set correctly
|
||||
assert (
|
||||
provider.default_model_name == expected_default_model
|
||||
), f"Default model should be '{expected_default_model}'"
|
||||
|
||||
# Step 4: Set the provider as default
|
||||
update_default_provider(provider.id, expected_default_model, db_session)
|
||||
update_default_provider(provider.id, db_session)
|
||||
|
||||
# Step 5: Fetch the default provider and verify
|
||||
default_model = fetch_default_llm_model(db_session)
|
||||
@@ -233,6 +238,7 @@ class TestAutoModeSyncFeature:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name="gpt-4o",
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -304,13 +310,14 @@ class TestAutoModeSyncFeature:
|
||||
|
||||
try:
|
||||
# Step 1: Upload provider WITHOUT auto mode, with initial models
|
||||
provider = put_llm_provider(
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=False, # Not in auto mode initially
|
||||
default_model_name="gpt-4",
|
||||
model_configurations=initial_models,
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -337,12 +344,12 @@ class TestAutoModeSyncFeature:
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None, # Not changing API key
|
||||
api_key_changed=False,
|
||||
is_auto_mode=True, # Now enabling auto mode
|
||||
default_model_name=auto_mode_default,
|
||||
model_configurations=[], # Auto mode will sync from config
|
||||
),
|
||||
is_creation=False, # This is an update
|
||||
@@ -353,8 +360,8 @@ class TestAutoModeSyncFeature:
|
||||
# Step 3: Verify model visibility after auto mode transition
|
||||
# Expire session cache to force fresh fetch after sync_auto_mode_models committed
|
||||
db_session.expire_all()
|
||||
provider = fetch_llm_provider_view(
|
||||
db_session=db_session, provider_name=provider_name
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
assert provider.is_auto_mode is True
|
||||
@@ -381,6 +388,9 @@ class TestAutoModeSyncFeature:
|
||||
model_visibility[model_name] is False
|
||||
), f"Model '{model_name}' not in auto config should NOT be visible"
|
||||
|
||||
# Verify the default model was updated
|
||||
assert provider.default_model_name == auto_mode_default
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
@@ -422,12 +432,8 @@ class TestAutoModeSyncFeature:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o",
|
||||
is_visible=True,
|
||||
)
|
||||
],
|
||||
default_model_name="gpt-4o",
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
@@ -529,6 +535,7 @@ class TestAutoModeSyncFeature:
|
||||
api_key=provider_1_api_key,
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name=provider_1_default_model,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -542,7 +549,7 @@ class TestAutoModeSyncFeature:
|
||||
name=provider_1_name, db_session=db_session
|
||||
)
|
||||
assert provider_1 is not None
|
||||
update_default_provider(provider_1.id, provider_1_default_model, db_session)
|
||||
update_default_provider(provider_1.id, db_session)
|
||||
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
@@ -556,6 +563,7 @@ class TestAutoModeSyncFeature:
|
||||
api_key=provider_2_api_key,
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name=provider_2_default_model,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -576,7 +584,7 @@ class TestAutoModeSyncFeature:
|
||||
name=provider_2_name, db_session=db_session
|
||||
)
|
||||
assert provider_2 is not None
|
||||
update_default_provider(provider_2.id, provider_2_default_model, db_session)
|
||||
update_default_provider(provider_2.id, db_session)
|
||||
|
||||
# Step 5: Verify provider 2 is now the default
|
||||
db_session.expire_all()
|
||||
|
||||
@@ -64,6 +64,7 @@ def _create_provider(
|
||||
name=name,
|
||||
provider=provider,
|
||||
api_key="sk-ant-api03-...",
|
||||
default_model_name="claude-3-5-sonnet-20240620",
|
||||
is_public=is_public,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -153,9 +154,7 @@ def test_user_sends_message_to_private_provider(
|
||||
)
|
||||
_create_provider(db_session, LlmProviderNames.GOOGLE, "private-provider", False)
|
||||
|
||||
update_default_provider(
|
||||
public_provider_id, "claude-3-5-sonnet-20240620", db_session
|
||||
)
|
||||
update_default_provider(public_provider_id, db_session)
|
||||
|
||||
try:
|
||||
# Create chat session
|
||||
|
||||
@@ -144,7 +144,8 @@ def use_mock_search_pipeline(
|
||||
auto_detect_filters: bool = False, # noqa: ARG001
|
||||
llm: LLM | None = None, # noqa: ARG001
|
||||
project_id: int | None = None, # noqa: ARG001
|
||||
# Pre-fetched data (used by SearchTool to avoid DB access in parallel)
|
||||
persona_id: int | None = None, # noqa: ARG001
|
||||
# Pre-fetched data (used by SearchTool to avoid DB access in parallel calls)
|
||||
acl_filters: list[str] | None = None, # noqa: ARG001
|
||||
embedding_model: EmbeddingModel | None = None, # noqa: ARG001
|
||||
prefetched_federated_retrieval_infos: ( # noqa: ARG001
|
||||
|
||||
@@ -38,6 +38,7 @@ def _get_search_filters(
|
||||
tags=[],
|
||||
document_sets=[],
|
||||
project_id=None,
|
||||
persona_id=None,
|
||||
time_cutoff=None,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""External dependency unit tests for OpenSearchClient.
|
||||
"""External dependency unit tests for OpenSearchIndexClient.
|
||||
|
||||
These tests assume OpenSearch is running and test all implemented methods
|
||||
using real schemas, pipelines, and search queries from the codebase.
|
||||
@@ -19,7 +19,7 @@ from onyx.access.utils import prefix_user_email
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.client import OpenSearchIndexClient
|
||||
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
|
||||
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
@@ -125,10 +125,10 @@ def opensearch_available() -> None:
|
||||
@pytest.fixture(scope="function")
|
||||
def test_client(
|
||||
opensearch_available: None, # noqa: ARG001
|
||||
) -> Generator[OpenSearchClient, None, None]:
|
||||
) -> Generator[OpenSearchIndexClient, None, None]:
|
||||
"""Creates an OpenSearch client for testing with automatic cleanup."""
|
||||
test_index_name = f"test_index_{uuid.uuid4().hex[:8]}"
|
||||
client = OpenSearchClient(index_name=test_index_name)
|
||||
client = OpenSearchIndexClient(index_name=test_index_name)
|
||||
|
||||
yield client # Test runs here.
|
||||
|
||||
@@ -142,7 +142,7 @@ def test_client(
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def search_pipeline(test_client: OpenSearchClient) -> Generator[None, None, None]:
|
||||
def search_pipeline(test_client: OpenSearchIndexClient) -> Generator[None, None, None]:
|
||||
"""Creates a search pipeline for testing with automatic cleanup."""
|
||||
test_client.create_search_pipeline(
|
||||
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
|
||||
@@ -158,9 +158,9 @@ def search_pipeline(test_client: OpenSearchClient) -> Generator[None, None, None
|
||||
|
||||
|
||||
class TestOpenSearchClient:
|
||||
"""Tests for OpenSearchClient."""
|
||||
"""Tests for OpenSearchIndexClient."""
|
||||
|
||||
def test_create_index(self, test_client: OpenSearchClient) -> None:
|
||||
def test_create_index(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests creating an index with a real schema."""
|
||||
# Precondition.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
@@ -176,7 +176,7 @@ class TestOpenSearchClient:
|
||||
# Verify index exists.
|
||||
assert test_client.validate_index(expected_mappings=mappings) is True
|
||||
|
||||
def test_delete_existing_index(self, test_client: OpenSearchClient) -> None:
|
||||
def test_delete_existing_index(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests deleting an existing index returns True."""
|
||||
# Precondition.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
@@ -193,7 +193,7 @@ class TestOpenSearchClient:
|
||||
assert result is True
|
||||
assert test_client.validate_index(expected_mappings=mappings) is False
|
||||
|
||||
def test_delete_nonexistent_index(self, test_client: OpenSearchClient) -> None:
|
||||
def test_delete_nonexistent_index(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests deleting a nonexistent index returns False."""
|
||||
# Under test.
|
||||
# Don't create index, just try to delete.
|
||||
@@ -202,7 +202,7 @@ class TestOpenSearchClient:
|
||||
# Postcondition.
|
||||
assert result is False
|
||||
|
||||
def test_index_exists(self, test_client: OpenSearchClient) -> None:
|
||||
def test_index_exists(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests checking if an index exists."""
|
||||
# Precondition.
|
||||
# Index should not exist before creation.
|
||||
@@ -219,7 +219,7 @@ class TestOpenSearchClient:
|
||||
# Index should exist after creation.
|
||||
assert test_client.index_exists() is True
|
||||
|
||||
def test_validate_index(self, test_client: OpenSearchClient) -> None:
|
||||
def test_validate_index(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests validating an index."""
|
||||
# Precondition.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
@@ -239,7 +239,120 @@ class TestOpenSearchClient:
|
||||
# Should return True after creation.
|
||||
assert test_client.validate_index(expected_mappings=mappings) is True
|
||||
|
||||
def test_create_duplicate_index(self, test_client: OpenSearchClient) -> None:
|
||||
def test_put_mapping_idempotent(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests put_mapping with same schema is idempotent."""
|
||||
# Precondition.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=True
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Under test.
|
||||
# Applying the same mappings again should succeed.
|
||||
test_client.put_mapping(mappings)
|
||||
|
||||
# Postcondition.
|
||||
# Index should still be valid.
|
||||
assert test_client.validate_index(expected_mappings=mappings)
|
||||
|
||||
def test_put_mapping_adds_new_field(
|
||||
self, test_client: OpenSearchIndexClient
|
||||
) -> None:
|
||||
"""Tests put_mapping successfully adds new fields to existing index."""
|
||||
# Precondition.
|
||||
# Create index with minimal schema (just required fields).
|
||||
initial_mappings = {
|
||||
"dynamic": "strict",
|
||||
"properties": {
|
||||
"document_id": {"type": "keyword"},
|
||||
"chunk_index": {"type": "integer"},
|
||||
"content": {"type": "text"},
|
||||
"content_vector": {
|
||||
"type": "knn_vector",
|
||||
"dimension": 128,
|
||||
"method": {
|
||||
"name": "hnsw",
|
||||
"space_type": "cosinesimil",
|
||||
"engine": "lucene",
|
||||
"parameters": {"ef_construction": 512, "m": 16},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=initial_mappings, settings=settings)
|
||||
|
||||
# Under test.
|
||||
# Add a new field using put_mapping.
|
||||
updated_mappings = {
|
||||
"properties": {
|
||||
"document_id": {"type": "keyword"},
|
||||
"chunk_index": {"type": "integer"},
|
||||
"content": {"type": "text"},
|
||||
"content_vector": {
|
||||
"type": "knn_vector",
|
||||
"dimension": 128,
|
||||
"method": {
|
||||
"name": "hnsw",
|
||||
"space_type": "cosinesimil",
|
||||
"engine": "lucene",
|
||||
"parameters": {"ef_construction": 512, "m": 16},
|
||||
},
|
||||
},
|
||||
# New field
|
||||
"new_test_field": {"type": "keyword"},
|
||||
},
|
||||
}
|
||||
# Should not raise.
|
||||
test_client.put_mapping(updated_mappings)
|
||||
|
||||
# Postcondition.
|
||||
# Validate the new schema includes the new field.
|
||||
assert test_client.validate_index(expected_mappings=updated_mappings)
|
||||
|
||||
def test_put_mapping_fails_on_type_change(
|
||||
self, test_client: OpenSearchIndexClient
|
||||
) -> None:
|
||||
"""Tests put_mapping fails when trying to change existing field type."""
|
||||
# Precondition.
|
||||
initial_mappings = {
|
||||
"dynamic": "strict",
|
||||
"properties": {
|
||||
"document_id": {"type": "keyword"},
|
||||
"test_field": {"type": "keyword"},
|
||||
},
|
||||
}
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=initial_mappings, settings=settings)
|
||||
|
||||
# Under test and postcondition.
|
||||
# Try to change test_field type from keyword to text.
|
||||
conflicting_mappings = {
|
||||
"properties": {
|
||||
"document_id": {"type": "keyword"},
|
||||
"test_field": {"type": "text"}, # Changed from keyword to text
|
||||
},
|
||||
}
|
||||
# Should raise because field type cannot be changed.
|
||||
with pytest.raises(Exception, match="mapper|illegal_argument_exception"):
|
||||
test_client.put_mapping(conflicting_mappings)
|
||||
|
||||
def test_put_mapping_on_nonexistent_index(
|
||||
self, test_client: OpenSearchIndexClient
|
||||
) -> None:
|
||||
"""Tests put_mapping on non-existent index raises an error."""
|
||||
# Precondition.
|
||||
# Index does not exist yet.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=True
|
||||
)
|
||||
|
||||
# Under test and postcondition.
|
||||
with pytest.raises(Exception, match="index_not_found_exception|404"):
|
||||
test_client.put_mapping(mappings)
|
||||
|
||||
def test_create_duplicate_index(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests creating an index twice raises an error."""
|
||||
# Precondition.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
@@ -254,14 +367,14 @@ class TestOpenSearchClient:
|
||||
with pytest.raises(Exception, match="already exists"):
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
def test_update_settings(self, test_client: OpenSearchClient) -> None:
|
||||
def test_update_settings(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests that update_settings raises NotImplementedError."""
|
||||
# Under test and postcondition.
|
||||
with pytest.raises(NotImplementedError):
|
||||
test_client.update_settings(settings={})
|
||||
|
||||
def test_create_and_delete_search_pipeline(
|
||||
self, test_client: OpenSearchClient
|
||||
self, test_client: OpenSearchIndexClient
|
||||
) -> None:
|
||||
"""Tests creating and deleting a search pipeline."""
|
||||
# Under test and postcondition.
|
||||
@@ -278,7 +391,7 @@ class TestOpenSearchClient:
|
||||
)
|
||||
|
||||
def test_index_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests indexing a document."""
|
||||
# Precondition.
|
||||
@@ -306,7 +419,7 @@ class TestOpenSearchClient:
|
||||
)
|
||||
|
||||
def test_bulk_index_documents(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests bulk indexing documents."""
|
||||
# Precondition.
|
||||
@@ -337,7 +450,7 @@ class TestOpenSearchClient:
|
||||
)
|
||||
|
||||
def test_index_duplicate_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests indexing a duplicate document raises an error."""
|
||||
# Precondition.
|
||||
@@ -365,7 +478,7 @@ class TestOpenSearchClient:
|
||||
test_client.index_document(document=doc, tenant_state=tenant_state)
|
||||
|
||||
def test_get_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests getting a document."""
|
||||
# Precondition.
|
||||
@@ -401,7 +514,7 @@ class TestOpenSearchClient:
|
||||
assert retrieved_doc == original_doc
|
||||
|
||||
def test_get_nonexistent_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests getting a nonexistent document raises an error."""
|
||||
# Precondition.
|
||||
@@ -419,7 +532,7 @@ class TestOpenSearchClient:
|
||||
)
|
||||
|
||||
def test_delete_existing_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests deleting an existing document returns True."""
|
||||
# Precondition.
|
||||
@@ -455,7 +568,7 @@ class TestOpenSearchClient:
|
||||
test_client.get_document(document_chunk_id=doc_chunk_id)
|
||||
|
||||
def test_delete_nonexistent_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests deleting a nonexistent document returns False."""
|
||||
# Precondition.
|
||||
@@ -476,7 +589,7 @@ class TestOpenSearchClient:
|
||||
assert result is False
|
||||
|
||||
def test_delete_by_query(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests deleting documents by query."""
|
||||
# Precondition.
|
||||
@@ -552,7 +665,7 @@ class TestOpenSearchClient:
|
||||
assert len(keep_ids) == 1
|
||||
|
||||
def test_update_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests updating a document's properties."""
|
||||
# Precondition.
|
||||
@@ -601,7 +714,7 @@ class TestOpenSearchClient:
|
||||
assert updated_doc.public == doc.public
|
||||
|
||||
def test_update_nonexistent_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests updating a nonexistent document raises an error."""
|
||||
# Precondition.
|
||||
@@ -623,7 +736,7 @@ class TestOpenSearchClient:
|
||||
|
||||
def test_hybrid_search_with_pipeline(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
test_client: OpenSearchIndexClient,
|
||||
search_pipeline: None, # noqa: ARG002
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
@@ -704,7 +817,7 @@ class TestOpenSearchClient:
|
||||
|
||||
def test_search_empty_index(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
test_client: OpenSearchIndexClient,
|
||||
search_pipeline: None, # noqa: ARG002
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
@@ -743,7 +856,7 @@ class TestOpenSearchClient:
|
||||
|
||||
def test_hybrid_search_with_pipeline_and_filters(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
test_client: OpenSearchIndexClient,
|
||||
search_pipeline: None, # noqa: ARG002
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
@@ -863,7 +976,7 @@ class TestOpenSearchClient:
|
||||
|
||||
def test_hybrid_search_with_pipeline_and_filters_returns_chunks_with_related_content_first(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
test_client: OpenSearchIndexClient,
|
||||
search_pipeline: None, # noqa: ARG002
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
@@ -993,7 +1106,7 @@ class TestOpenSearchClient:
|
||||
previous_score = current_score
|
||||
|
||||
def test_delete_by_query_multitenant_isolation(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""
|
||||
Tests delete_by_query respects tenant boundaries in multi-tenant mode.
|
||||
@@ -1087,7 +1200,7 @@ class TestOpenSearchClient:
|
||||
assert set(remaining_y_ids) == expected_y_ids
|
||||
|
||||
def test_delete_by_query_nonexistent_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""
|
||||
Tests delete_by_query for non-existent document returns 0 deleted.
|
||||
@@ -1116,7 +1229,7 @@ class TestOpenSearchClient:
|
||||
assert num_deleted == 0
|
||||
|
||||
def test_search_for_document_ids(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests search_for_document_ids method returns correct chunk IDs."""
|
||||
# Precondition.
|
||||
@@ -1181,7 +1294,7 @@ class TestOpenSearchClient:
|
||||
assert set(chunk_ids) == expected_ids
|
||||
|
||||
def test_search_with_no_document_access_can_retrieve_all_documents(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""
|
||||
Tests search with no document access can retrieve all documents, even
|
||||
@@ -1259,7 +1372,7 @@ class TestOpenSearchClient:
|
||||
|
||||
def test_time_cutoff_filter(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
test_client: OpenSearchIndexClient,
|
||||
search_pipeline: None, # noqa: ARG002
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
@@ -1352,7 +1465,7 @@ class TestOpenSearchClient:
|
||||
)
|
||||
|
||||
def test_random_search(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests the random search query works."""
|
||||
# Precondition.
|
||||
|
||||
@@ -37,6 +37,7 @@ from onyx.db.opensearch_migration import build_sanitized_to_original_doc_id_mapp
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.client import OpenSearchIndexClient
|
||||
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
|
||||
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
|
||||
from onyx.document_index.opensearch.schema import DocumentChunk
|
||||
@@ -74,7 +75,7 @@ CHUNK_COUNT = 5
|
||||
|
||||
|
||||
def _get_document_chunks_from_opensearch(
|
||||
opensearch_client: OpenSearchClient, document_id: str, current_tenant_id: str
|
||||
opensearch_client: OpenSearchIndexClient, document_id: str, current_tenant_id: str
|
||||
) -> list[DocumentChunk]:
|
||||
opensearch_client.refresh_index()
|
||||
filters = IndexFilters(access_control_list=None, tenant_id=current_tenant_id)
|
||||
@@ -95,7 +96,7 @@ def _get_document_chunks_from_opensearch(
|
||||
|
||||
|
||||
def _delete_document_chunks_from_opensearch(
|
||||
opensearch_client: OpenSearchClient, document_id: str, current_tenant_id: str
|
||||
opensearch_client: OpenSearchIndexClient, document_id: str, current_tenant_id: str
|
||||
) -> None:
|
||||
opensearch_client.refresh_index()
|
||||
query_body = DocumentQuery.delete_from_document_id_query(
|
||||
@@ -283,10 +284,10 @@ def vespa_document_index(
|
||||
def opensearch_client(
|
||||
db_session: Session,
|
||||
full_deployment_setup: None, # noqa: ARG001
|
||||
) -> Generator[OpenSearchClient, None, None]:
|
||||
) -> Generator[OpenSearchIndexClient, None, None]:
|
||||
"""Creates an OpenSearch client for the test tenant."""
|
||||
active = get_active_search_settings(db_session)
|
||||
yield OpenSearchClient(index_name=active.primary.index_name) # Test runs here.
|
||||
yield OpenSearchIndexClient(index_name=active.primary.index_name) # Test runs here.
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -330,7 +331,7 @@ def patch_get_vespa_chunks_page_size() -> Generator[int, None, None]:
|
||||
def test_documents(
|
||||
db_session: Session,
|
||||
vespa_document_index: VespaDocumentIndex,
|
||||
opensearch_client: OpenSearchClient,
|
||||
opensearch_client: OpenSearchIndexClient,
|
||||
patch_get_vespa_chunks_page_size: int,
|
||||
) -> Generator[list[Document], None, None]:
|
||||
"""
|
||||
@@ -411,7 +412,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
db_session: Session,
|
||||
test_documents: list[Document],
|
||||
vespa_document_index: VespaDocumentIndex,
|
||||
opensearch_client: OpenSearchClient,
|
||||
opensearch_client: OpenSearchIndexClient,
|
||||
test_embedding_dimension: int,
|
||||
clean_migration_tables: None, # noqa: ARG002
|
||||
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
|
||||
@@ -480,7 +481,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
db_session: Session,
|
||||
test_documents: list[Document],
|
||||
vespa_document_index: VespaDocumentIndex,
|
||||
opensearch_client: OpenSearchClient,
|
||||
opensearch_client: OpenSearchIndexClient,
|
||||
test_embedding_dimension: int,
|
||||
clean_migration_tables: None, # noqa: ARG002
|
||||
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
|
||||
@@ -618,7 +619,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
db_session: Session,
|
||||
test_documents: list[Document],
|
||||
vespa_document_index: VespaDocumentIndex,
|
||||
opensearch_client: OpenSearchClient,
|
||||
opensearch_client: OpenSearchIndexClient,
|
||||
test_embedding_dimension: int,
|
||||
clean_migration_tables: None, # noqa: ARG002
|
||||
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
|
||||
@@ -712,7 +713,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
db_session: Session,
|
||||
test_documents: list[Document],
|
||||
vespa_document_index: VespaDocumentIndex,
|
||||
opensearch_client: OpenSearchClient,
|
||||
opensearch_client: OpenSearchIndexClient,
|
||||
test_embedding_dimension: int,
|
||||
clean_migration_tables: None, # noqa: ARG002
|
||||
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
|
||||
|
||||
@@ -434,6 +434,7 @@ class TestSlackBotFederatedSearch:
|
||||
name=f"test-llm-provider-{uuid4().hex[:8]}",
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=api_key,
|
||||
default_model_name="gpt-4o",
|
||||
is_public=True,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -447,7 +448,7 @@ class TestSlackBotFederatedSearch:
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
update_default_provider(provider_view.id, "gpt-4o", db_session)
|
||||
update_default_provider(provider_view.id, db_session)
|
||||
|
||||
def _teardown_common_mocks(self, patches: list) -> None:
|
||||
"""Stop all patches"""
|
||||
|
||||
@@ -20,6 +20,7 @@ from onyx.auth.oauth_token_manager import OAuthTokenManager
|
||||
from onyx.db.models import OAuthConfig
|
||||
from onyx.db.oauth_config import create_oauth_config
|
||||
from onyx.db.oauth_config import upsert_user_oauth_token
|
||||
from onyx.utils.sensitive import SensitiveValue
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
|
||||
|
||||
@@ -491,3 +492,19 @@ class TestOAuthTokenManagerURLBuilding:
|
||||
# Should use & instead of ? since URL already has query params
|
||||
assert "foo=bar&" in url or "?foo=bar" in url
|
||||
assert "client_id=custom_client_id" in url
|
||||
|
||||
|
||||
class TestUnwrapSensitiveStr:
|
||||
"""Tests for _unwrap_sensitive_str static method"""
|
||||
|
||||
def test_unwrap_sensitive_str(self) -> None:
|
||||
"""Test that both SensitiveValue and plain str inputs are handled"""
|
||||
# SensitiveValue input
|
||||
sensitive = SensitiveValue[str](
|
||||
encrypted_bytes=b"test_client_id",
|
||||
decrypt_fn=lambda b: b.decode(),
|
||||
)
|
||||
assert OAuthTokenManager._unwrap_sensitive_str(sensitive) == "test_client_id"
|
||||
|
||||
# Plain str input
|
||||
assert OAuthTokenManager._unwrap_sensitive_str("plain_string") == "plain_string"
|
||||
|
||||
@@ -76,9 +76,12 @@ class ChatSessionManager:
|
||||
user_performing_action: DATestUser,
|
||||
persona_id: int = 0,
|
||||
description: str = "Test chat session",
|
||||
project_id: int | None = None,
|
||||
) -> DATestChatSession:
|
||||
chat_session_creation_req = ChatSessionCreationRequest(
|
||||
persona_id=persona_id, description=description
|
||||
persona_id=persona_id,
|
||||
description=description,
|
||||
project_id=project_id,
|
||||
)
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/create-chat-session",
|
||||
|
||||
@@ -4,12 +4,10 @@ from uuid import uuid4
|
||||
import requests
|
||||
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.server.manage.llm.models import DefaultModel
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -34,6 +32,7 @@ class LLMProviderManager:
|
||||
llm_provider = LLMProviderUpsertRequest(
|
||||
name=name or f"test-provider-{uuid4()}",
|
||||
provider=provider or LlmProviderNames.OPENAI,
|
||||
default_model_name=default_model_name or "gpt-4o-mini",
|
||||
api_key=api_key or os.environ["OPENAI_API_KEY"],
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
@@ -66,6 +65,7 @@ class LLMProviderManager:
|
||||
name=response_data["name"],
|
||||
provider=response_data["provider"],
|
||||
api_key=response_data["api_key"],
|
||||
default_model_name=response_data["default_model_name"],
|
||||
is_public=response_data["is_public"],
|
||||
is_auto_mode=response_data.get("is_auto_mode", False),
|
||||
groups=response_data["groups"],
|
||||
@@ -75,20 +75,9 @@ class LLMProviderManager:
|
||||
)
|
||||
|
||||
if set_as_default:
|
||||
if default_model_name is None:
|
||||
default_model_name = "gpt-4o-mini"
|
||||
|
||||
set_default_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/llm/default",
|
||||
json={
|
||||
"provider_id": response_data["id"],
|
||||
"model_name": default_model_name,
|
||||
},
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{llm_response.json()['id']}/default",
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
set_default_response.raise_for_status()
|
||||
|
||||
@@ -124,12 +113,7 @@ class LLMProviderManager:
|
||||
verify_deleted: bool = False,
|
||||
) -> None:
|
||||
all_llm_providers = LLMProviderManager.get_all(user_performing_action)
|
||||
default_model = LLMProviderManager.get_default_model(user_performing_action)
|
||||
for fetched_llm_provider in all_llm_providers:
|
||||
model_names = [
|
||||
model.name for model in fetched_llm_provider.model_configurations
|
||||
]
|
||||
|
||||
if llm_provider.id == fetched_llm_provider.id:
|
||||
if verify_deleted:
|
||||
raise ValueError(
|
||||
@@ -142,25 +126,11 @@ class LLMProviderManager:
|
||||
if (
|
||||
fetched_llm_groups == llm_provider_groups
|
||||
and llm_provider.provider == fetched_llm_provider.provider
|
||||
and default_model.model_name in model_names
|
||||
and llm_provider.default_model_name
|
||||
== fetched_llm_provider.default_model_name
|
||||
and llm_provider.is_public == fetched_llm_provider.is_public
|
||||
and set(fetched_llm_provider.personas) == set(llm_provider.personas)
|
||||
):
|
||||
return
|
||||
if not verify_deleted:
|
||||
raise ValueError(f"LLM Provider {llm_provider.id} not found")
|
||||
|
||||
@staticmethod
|
||||
def get_default_model(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DefaultModel:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/llm/default",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
return DefaultModel(**response.json())
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestScimToken
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
class ScimTokenManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
name: str,
|
||||
user_performing_action: DATestUser,
|
||||
) -> DATestScimToken:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/enterprise-settings/scim/token",
|
||||
json={"name": name},
|
||||
headers=user_performing_action.headers,
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return DATestScimToken(
|
||||
id=data["id"],
|
||||
name=data["name"],
|
||||
token_display=data["token_display"],
|
||||
is_active=data["is_active"],
|
||||
created_at=data["created_at"],
|
||||
last_used_at=data.get("last_used_at"),
|
||||
raw_token=data["raw_token"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_active(
|
||||
user_performing_action: DATestUser,
|
||||
) -> DATestScimToken | None:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/enterprise-settings/scim/token",
|
||||
headers=user_performing_action.headers,
|
||||
timeout=60,
|
||||
)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return DATestScimToken(
|
||||
id=data["id"],
|
||||
name=data["name"],
|
||||
token_display=data["token_display"],
|
||||
is_active=data["is_active"],
|
||||
created_at=data["created_at"],
|
||||
last_used_at=data.get("last_used_at"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_scim_headers(raw_token: str) -> dict[str, str]:
|
||||
return {
|
||||
**GENERAL_HEADERS,
|
||||
"Authorization": f"Bearer {raw_token}",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def scim_get(
|
||||
path: str,
|
||||
raw_token: str,
|
||||
) -> requests.Response:
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=ScimTokenManager.get_scim_headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def scim_get_no_auth(path: str) -> requests.Response:
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=GENERAL_HEADERS,
|
||||
timeout=60,
|
||||
)
|
||||
@@ -42,6 +42,18 @@ class DATestPAT(BaseModel):
|
||||
last_used_at: str | None = None
|
||||
|
||||
|
||||
class DATestScimToken(BaseModel):
|
||||
"""SCIM bearer token model for testing."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
raw_token: str | None = None # Only present on initial creation
|
||||
token_display: str
|
||||
is_active: bool
|
||||
created_at: str
|
||||
last_used_at: str | None = None
|
||||
|
||||
|
||||
class DATestAPIKey(BaseModel):
|
||||
api_key_id: int
|
||||
api_key_display: str
|
||||
@@ -116,6 +128,7 @@ class DATestLLMProvider(BaseModel):
|
||||
name: str
|
||||
provider: str
|
||||
api_key: str
|
||||
default_model_name: str
|
||||
is_public: bool
|
||||
is_auto_mode: bool = False
|
||||
groups: list[int]
|
||||
|
||||
@@ -72,7 +72,7 @@ def _get_provider_by_id(admin_user: DATestUser, provider_id: int) -> dict:
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
for provider in response.json()["providers"]:
|
||||
for provider in response.json():
|
||||
if provider["id"] == provider_id:
|
||||
return provider
|
||||
raise ValueError(f"Provider with id {provider_id} not found")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -9,8 +9,6 @@ from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.llm import can_user_access_llm_provider
|
||||
from onyx.db.llm import fetch_user_group_ids
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import LLMModelFlow
|
||||
from onyx.db.models import LLMProvider as LLMProviderModel
|
||||
from onyx.db.models import LLMProvider__Persona
|
||||
@@ -22,8 +20,6 @@ from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.persona import PersonaManager
|
||||
@@ -45,32 +41,24 @@ def _create_llm_provider(
|
||||
is_public: bool,
|
||||
is_default: bool,
|
||||
) -> LLMProviderModel:
|
||||
_provider = upsert_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None,
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
custom_config=None,
|
||||
is_public=is_public,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name,
|
||||
is_visible=True,
|
||||
)
|
||||
],
|
||||
),
|
||||
db_session=db_session,
|
||||
provider = LLMProviderModel(
|
||||
name=name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None,
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
custom_config=None,
|
||||
default_model_name=default_model_name,
|
||||
deployment_name=None,
|
||||
is_public=is_public,
|
||||
# Use None instead of False to avoid unique constraint violation
|
||||
# The is_default_provider column has unique=True, so only one True and one False allowed
|
||||
is_default_provider=is_default if is_default else None,
|
||||
is_default_vision_provider=False,
|
||||
default_vision_model=None,
|
||||
)
|
||||
|
||||
if is_default:
|
||||
update_default_provider(_provider.id, default_model_name, db_session)
|
||||
|
||||
provider = db_session.get(LLMProviderModel, _provider.id)
|
||||
if not provider:
|
||||
raise ValueError(f"Provider {name} not found")
|
||||
|
||||
db_session.add(provider)
|
||||
db_session.flush()
|
||||
return provider
|
||||
|
||||
|
||||
@@ -333,19 +321,13 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
|
||||
persona=persona,
|
||||
user=admin_model,
|
||||
)
|
||||
assert (
|
||||
allowed_llm.config.model_name
|
||||
== restricted_provider.model_configurations[0].name
|
||||
)
|
||||
assert allowed_llm.config.model_name == restricted_provider.default_model_name
|
||||
|
||||
fallback_llm = get_llm_for_persona(
|
||||
persona=persona,
|
||||
user=basic_model,
|
||||
)
|
||||
assert (
|
||||
fallback_llm.config.model_name
|
||||
== default_provider.model_configurations[0].name
|
||||
)
|
||||
assert fallback_llm.config.model_name == default_provider.default_model_name
|
||||
|
||||
|
||||
def test_list_llm_provider_basics_excludes_non_public_unrestricted(
|
||||
@@ -364,7 +346,6 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
|
||||
name="public-provider",
|
||||
is_public=True,
|
||||
set_as_default=True,
|
||||
default_model_name="gpt-4o",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
@@ -384,7 +365,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
|
||||
headers=basic_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
providers = response.json()["providers"]
|
||||
providers = response.json()
|
||||
provider_names = [p["name"] for p in providers]
|
||||
|
||||
# Public provider should be visible
|
||||
@@ -399,7 +380,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert admin_response.status_code == 200
|
||||
admin_providers = admin_response.json()["providers"]
|
||||
admin_providers = admin_response.json()
|
||||
admin_provider_names = [p["name"] for p in admin_providers]
|
||||
|
||||
assert public_provider.name in admin_provider_names
|
||||
@@ -415,7 +396,6 @@ def test_provider_delete_clears_persona_references(reset: None) -> None: # noqa
|
||||
name="default-provider",
|
||||
is_public=True,
|
||||
set_as_default=True,
|
||||
default_model_name="gpt-4o",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
|
||||
@@ -107,7 +107,7 @@ def test_authorized_persona_access_returns_filtered_providers(
|
||||
|
||||
# Should succeed
|
||||
assert response.status_code == 200
|
||||
providers = response.json()["providers"]
|
||||
providers = response.json()
|
||||
|
||||
# Should include the restricted provider since basic_user can access the persona
|
||||
provider_names = [p["name"] for p in providers]
|
||||
@@ -140,7 +140,7 @@ def test_persona_id_zero_applies_rbac(
|
||||
|
||||
# Should succeed (persona_id=0 refers to default persona, which is public)
|
||||
assert response.status_code == 200
|
||||
providers = response.json()["providers"]
|
||||
providers = response.json()
|
||||
|
||||
# Should NOT include the restricted provider since basic_user is not in group2
|
||||
provider_names = [p["name"] for p in providers]
|
||||
@@ -182,7 +182,7 @@ def test_admin_can_query_any_persona(
|
||||
|
||||
# Should succeed - admins can access any persona
|
||||
assert response.status_code == 200
|
||||
providers = response.json()["providers"]
|
||||
providers = response.json()
|
||||
|
||||
# Should include the restricted provider
|
||||
provider_names = [p["name"] for p in providers]
|
||||
@@ -223,7 +223,7 @@ def test_public_persona_accessible_to_all(
|
||||
|
||||
# Should succeed
|
||||
assert response.status_code == 200
|
||||
providers = response.json()["providers"]
|
||||
providers = response.json()
|
||||
|
||||
# Should return the public provider
|
||||
assert len(providers) > 0
|
||||
|
||||
@@ -23,6 +23,8 @@ _ENV_PROVIDER = "NIGHTLY_LLM_PROVIDER"
|
||||
_ENV_MODELS = "NIGHTLY_LLM_MODELS"
|
||||
_ENV_API_KEY = "NIGHTLY_LLM_API_KEY"
|
||||
_ENV_API_BASE = "NIGHTLY_LLM_API_BASE"
|
||||
_ENV_API_VERSION = "NIGHTLY_LLM_API_VERSION"
|
||||
_ENV_DEPLOYMENT_NAME = "NIGHTLY_LLM_DEPLOYMENT_NAME"
|
||||
_ENV_CUSTOM_CONFIG_JSON = "NIGHTLY_LLM_CUSTOM_CONFIG_JSON"
|
||||
_ENV_STRICT = "NIGHTLY_LLM_STRICT"
|
||||
|
||||
@@ -34,6 +36,8 @@ class NightlyProviderConfig(BaseModel):
|
||||
model_names: list[str]
|
||||
api_key: str | None
|
||||
api_base: str | None
|
||||
api_version: str | None
|
||||
deployment_name: str | None
|
||||
custom_config: dict[str, str] | None
|
||||
strict: bool
|
||||
|
||||
@@ -45,17 +49,29 @@ def _env_true(env_var: str, default: bool = False) -> bool:
|
||||
return value.strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _split_csv_env(env_var: str) -> list[str]:
|
||||
return [
|
||||
part.strip() for part in os.environ.get(env_var, "").split(",") if part.strip()
|
||||
]
|
||||
def _parse_models_env(env_var: str) -> list[str]:
|
||||
raw_value = os.environ.get(env_var, "").strip()
|
||||
if not raw_value:
|
||||
return []
|
||||
|
||||
try:
|
||||
parsed_json = json.loads(raw_value)
|
||||
except json.JSONDecodeError:
|
||||
parsed_json = None
|
||||
|
||||
if isinstance(parsed_json, list):
|
||||
return [str(model).strip() for model in parsed_json if str(model).strip()]
|
||||
|
||||
return [part.strip() for part in raw_value.split(",") if part.strip()]
|
||||
|
||||
|
||||
def _load_provider_config() -> NightlyProviderConfig:
|
||||
provider = os.environ.get(_ENV_PROVIDER, "").strip().lower()
|
||||
model_names = _split_csv_env(_ENV_MODELS)
|
||||
model_names = _parse_models_env(_ENV_MODELS)
|
||||
api_key = os.environ.get(_ENV_API_KEY) or None
|
||||
api_base = os.environ.get(_ENV_API_BASE) or None
|
||||
api_version = os.environ.get(_ENV_API_VERSION) or None
|
||||
deployment_name = os.environ.get(_ENV_DEPLOYMENT_NAME) or None
|
||||
strict = _env_true(_ENV_STRICT, default=False)
|
||||
|
||||
custom_config: dict[str, str] | None = None
|
||||
@@ -74,6 +90,8 @@ def _load_provider_config() -> NightlyProviderConfig:
|
||||
model_names=model_names,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
deployment_name=deployment_name,
|
||||
custom_config=custom_config,
|
||||
strict=strict,
|
||||
)
|
||||
@@ -95,10 +113,15 @@ def _validate_provider_config(config: NightlyProviderConfig) -> None:
|
||||
message=f"{_ENV_MODELS} must include at least one model",
|
||||
)
|
||||
|
||||
if config.provider != "ollama_chat" and not config.api_key:
|
||||
if config.provider != "ollama_chat" and not (
|
||||
config.api_key or config.custom_config
|
||||
):
|
||||
_skip_or_fail(
|
||||
strict=config.strict,
|
||||
message=(f"{_ENV_API_KEY} is required for provider '{config.provider}'"),
|
||||
message=(
|
||||
f"{_ENV_API_KEY} or {_ENV_CUSTOM_CONFIG_JSON} is required for "
|
||||
f"provider '{config.provider}'"
|
||||
),
|
||||
)
|
||||
|
||||
if config.provider == "ollama_chat" and not (
|
||||
@@ -109,6 +132,22 @@ def _validate_provider_config(config: NightlyProviderConfig) -> None:
|
||||
message=(f"{_ENV_API_BASE} is required for provider '{config.provider}'"),
|
||||
)
|
||||
|
||||
if config.provider == "azure":
|
||||
if not config.api_base:
|
||||
_skip_or_fail(
|
||||
strict=config.strict,
|
||||
message=(
|
||||
f"{_ENV_API_BASE} is required for provider '{config.provider}'"
|
||||
),
|
||||
)
|
||||
if not config.api_version:
|
||||
_skip_or_fail(
|
||||
strict=config.strict,
|
||||
message=(
|
||||
f"{_ENV_API_VERSION} is required for provider '{config.provider}'"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _assert_integration_mode_enabled() -> None:
|
||||
assert (
|
||||
@@ -147,6 +186,8 @@ def _create_provider_payload(
|
||||
model_name: str,
|
||||
api_key: str | None,
|
||||
api_base: str | None,
|
||||
api_version: str | None,
|
||||
deployment_name: str | None,
|
||||
custom_config: dict[str, str] | None,
|
||||
) -> dict:
|
||||
return {
|
||||
@@ -154,6 +195,8 @@ def _create_provider_payload(
|
||||
"provider": provider,
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
"api_version": api_version,
|
||||
"deployment_name": deployment_name,
|
||||
"custom_config": custom_config,
|
||||
"default_model_name": model_name,
|
||||
"is_public": True,
|
||||
@@ -255,6 +298,8 @@ def _create_and_test_provider_for_model(
|
||||
model_name=model_name,
|
||||
api_key=config.api_key,
|
||||
api_base=resolved_api_base,
|
||||
api_version=config.api_version,
|
||||
deployment_name=config.deployment_name,
|
||||
custom_config=config.custom_config,
|
||||
)
|
||||
|
||||
@@ -313,10 +358,21 @@ def test_nightly_provider_chat_workflow(admin_user: DATestUser) -> None:
|
||||
_seed_connector_for_search_tool(admin_user)
|
||||
search_tool_id = _get_internal_search_tool_id(admin_user)
|
||||
|
||||
failures: list[str] = []
|
||||
for model_name in config.model_names:
|
||||
_create_and_test_provider_for_model(
|
||||
admin_user=admin_user,
|
||||
config=config,
|
||||
model_name=model_name,
|
||||
search_tool_id=search_tool_id,
|
||||
)
|
||||
try:
|
||||
_create_and_test_provider_for_model(
|
||||
admin_user=admin_user,
|
||||
config=config,
|
||||
model_name=model_name,
|
||||
search_tool_id=search_tool_id,
|
||||
)
|
||||
except BaseException as exc:
|
||||
if isinstance(exc, (KeyboardInterrupt, SystemExit)):
|
||||
raise
|
||||
failures.append(
|
||||
f"provider={config.provider} model={model_name} error={type(exc).__name__}: {exc}"
|
||||
)
|
||||
|
||||
if failures:
|
||||
pytest.fail("Nightly provider chat failures:\n" + "\n".join(failures))
|
||||
|
||||
@@ -72,6 +72,9 @@ def test_cold_startup_default_assistant() -> None:
|
||||
assert (
|
||||
"read_file" in tool_names
|
||||
), "Default assistant should have FileReaderTool attached"
|
||||
assert (
|
||||
"python" in tool_names
|
||||
), "Default assistant should have PythonTool attached"
|
||||
|
||||
# Also verify by display names for clarity
|
||||
assert (
|
||||
@@ -86,8 +89,11 @@ def test_cold_startup_default_assistant() -> None:
|
||||
assert (
|
||||
"File Reader" in tool_display_names
|
||||
), "Default assistant should have File Reader tool"
|
||||
|
||||
# Should have exactly 5 tools
|
||||
assert (
|
||||
len(tool_associations) == 5
|
||||
), f"Default assistant should have exactly 5 tools attached, got {len(tool_associations)}"
|
||||
"Code Interpreter" in tool_display_names
|
||||
), "Default assistant should have Code Interpreter tool"
|
||||
|
||||
# Should have exactly 6 tools
|
||||
assert (
|
||||
len(tool_associations) == 6
|
||||
), f"Default assistant should have exactly 6 tools attached, got {len(tool_associations)}"
|
||||
|
||||
@@ -0,0 +1,318 @@
|
||||
"""
|
||||
Integration tests for the unified persona file context flow.
|
||||
|
||||
End-to-end tests that verify:
|
||||
1. Files can be uploaded and attached to a persona via API.
|
||||
2. The persona correctly reports its attached files.
|
||||
3. A chat session with a file-bearing persona processes without error.
|
||||
4. Precedence: custom persona files take priority over project files when
|
||||
the chat session is inside a project.
|
||||
|
||||
These tests run against a real Onyx deployment (all services running).
|
||||
File processing is asynchronous, so we poll the file status endpoint
|
||||
until files reach COMPLETED before chatting.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import MAX_DELAY
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.file import FileManager
|
||||
from tests.integration.common_utils.managers.persona import PersonaManager
|
||||
from tests.integration.common_utils.managers.project import ProjectManager
|
||||
from tests.integration.common_utils.test_file_utils import create_test_text_file
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
FILE_PROCESSING_POLL_INTERVAL = 2
|
||||
|
||||
|
||||
def _poll_file_statuses(
|
||||
user_file_ids: list[str],
|
||||
user: DATestUser,
|
||||
target_status: UserFileStatus = UserFileStatus.COMPLETED,
|
||||
timeout: int = MAX_DELAY,
|
||||
) -> None:
|
||||
"""Block until all files reach the target status or timeout expires."""
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/user/projects/file/statuses",
|
||||
json={"file_ids": user_file_ids},
|
||||
headers=user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
statuses = response.json()
|
||||
if all(f["status"] == target_status.value for f in statuses):
|
||||
return
|
||||
time.sleep(FILE_PROCESSING_POLL_INTERVAL)
|
||||
raise TimeoutError(
|
||||
f"Files {user_file_ids} did not reach {target_status.value} "
|
||||
f"within {timeout}s"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_persona_with_files_chat_no_error(
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Upload files, attach them to a persona, wait for processing,
|
||||
then send a chat message. Verify no error is returned."""
|
||||
|
||||
# Upload files (creates UserFile records)
|
||||
text_file = create_test_text_file(
|
||||
"The secret project codename is NIGHTINGALE. "
|
||||
"It was started in 2024 by the Advanced Research division."
|
||||
)
|
||||
file_descriptors, error = FileManager.upload_files(
|
||||
files=[("nightingale_brief.txt", text_file)],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert not error, f"File upload failed: {error}"
|
||||
assert len(file_descriptors) == 1
|
||||
|
||||
user_file_id = file_descriptors[0]["user_file_id"]
|
||||
assert user_file_id is not None
|
||||
|
||||
# Wait for file processing
|
||||
_poll_file_statuses([user_file_id], admin_user, timeout=120)
|
||||
|
||||
# Create persona with the file attached
|
||||
persona = PersonaManager.create(
|
||||
user_performing_action=admin_user,
|
||||
name="Nightingale Agent",
|
||||
description="Agent with secret file",
|
||||
system_prompt="You are a helpful assistant with access to uploaded files.",
|
||||
user_file_ids=[user_file_id],
|
||||
)
|
||||
|
||||
# Verify persona has the file
|
||||
persona_snapshots = PersonaManager.get_one(persona.id, admin_user)
|
||||
assert len(persona_snapshots) == 1
|
||||
assert user_file_id in persona_snapshots[0].user_file_ids
|
||||
|
||||
# Chat with the persona
|
||||
chat_session = ChatSessionManager.create(
|
||||
persona_id=persona.id,
|
||||
description="Test persona file context",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session.id,
|
||||
message="What is the secret project codename?",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert response.error is None, f"Chat should succeed, got error: {response.error}"
|
||||
assert len(response.full_message) > 0, "Response should not be empty"
|
||||
|
||||
|
||||
def test_persona_without_files_still_works(
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""A persona with no attached files should still chat normally."""
|
||||
persona = PersonaManager.create(
|
||||
user_performing_action=admin_user,
|
||||
name="Blank Agent",
|
||||
description="No files attached",
|
||||
system_prompt="You are a helpful assistant.",
|
||||
)
|
||||
|
||||
chat_session = ChatSessionManager.create(
|
||||
persona_id=persona.id,
|
||||
description="Test blank persona",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session.id,
|
||||
message="Hello, how are you?",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert response.error is None
|
||||
assert len(response.full_message) > 0
|
||||
|
||||
|
||||
def test_persona_files_override_project_files(
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""When a custom persona (with its own files) is used inside a project,
|
||||
the persona's files take precedence — the project's files are invisible.
|
||||
|
||||
We verify this by putting different content in project vs persona files
|
||||
and checking which content the model responds with."""
|
||||
|
||||
# Upload persona file
|
||||
persona_file = create_test_text_file("The persona's secret word is ALBATROSS.")
|
||||
persona_fds, err1 = FileManager.upload_files(
|
||||
files=[("persona_secret.txt", persona_file)],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert not err1
|
||||
persona_user_file_id = persona_fds[0]["user_file_id"]
|
||||
assert persona_user_file_id is not None
|
||||
# Create a project and upload project files
|
||||
project = ProjectManager.create(
|
||||
name="Precedence Test Project",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
project_files = [
|
||||
("project_secret.txt", b"The project's secret word is FLAMINGO."),
|
||||
]
|
||||
project_upload_result = ProjectManager.upload_files(
|
||||
project_id=project.id,
|
||||
files=project_files,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert len(project_upload_result.user_files) == 1
|
||||
project_user_file_id = str(project_upload_result.user_files[0].id)
|
||||
|
||||
# Wait for both persona and project file processing
|
||||
_poll_file_statuses([persona_user_file_id], admin_user, timeout=120)
|
||||
_poll_file_statuses([project_user_file_id], admin_user, timeout=120)
|
||||
|
||||
# Create persona with persona file
|
||||
persona = PersonaManager.create(
|
||||
user_performing_action=admin_user,
|
||||
name="Override Agent",
|
||||
description="Persona with its own files",
|
||||
system_prompt="You are a helpful assistant. Answer using the files.",
|
||||
user_file_ids=[persona_user_file_id],
|
||||
)
|
||||
|
||||
# Create chat session inside the project but using the custom persona
|
||||
chat_session = ChatSessionManager.create(
|
||||
persona_id=persona.id,
|
||||
project_id=project.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session.id,
|
||||
message="What is the secret word?",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert response.error is None, f"Chat should succeed, got error: {response.error}"
|
||||
# The persona's file should be what the model sees, not the project's
|
||||
message_lower = response.full_message.lower()
|
||||
assert "albatross" in message_lower, (
|
||||
"Response should reference the persona file's secret word (ALBATROSS), "
|
||||
f"but got: {response.full_message}"
|
||||
)
|
||||
|
||||
|
||||
def test_default_persona_in_project_uses_project_files(
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""When the default persona (id=0) is used inside a project,
|
||||
the project's files should be used for context."""
|
||||
project = ProjectManager.create(
|
||||
name="Default Persona Project",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
project_files = [
|
||||
("project_info.txt", b"The project mascot is a PANGOLIN."),
|
||||
]
|
||||
upload_result = ProjectManager.upload_files(
|
||||
project_id=project.id,
|
||||
files=project_files,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert len(upload_result.user_files) == 1
|
||||
|
||||
# Wait for project file processing
|
||||
project_file_id = str(upload_result.user_files[0].id)
|
||||
_poll_file_statuses([project_file_id], admin_user, timeout=120)
|
||||
|
||||
# Create chat session inside project using default persona (id=0)
|
||||
chat_session = ChatSessionManager.create(
|
||||
persona_id=0,
|
||||
project_id=project.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session.id,
|
||||
message="What is the project mascot?",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert response.error is None
|
||||
assert "pangolin" in response.full_message.lower(), (
|
||||
"Response should reference the project file content (PANGOLIN), "
|
||||
f"but got: {response.full_message}"
|
||||
)
|
||||
|
||||
|
||||
def test_custom_persona_no_files_in_project_ignores_project(
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""A custom persona with NO files, used inside a project with files,
|
||||
should NOT see the project's files. The project is purely organizational.
|
||||
|
||||
We verify by asking about content only in the project file and checking
|
||||
the model does NOT reference it."""
|
||||
|
||||
project = ProjectManager.create(
|
||||
name="Ignored Project",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
project_upload_result = ProjectManager.upload_files(
|
||||
project_id=project.id,
|
||||
files=[("project_only.txt", b"The project secret is CAPYBARA.")],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert len(project_upload_result.user_files) == 1
|
||||
project_user_file_id = str(project_upload_result.user_files[0].id)
|
||||
|
||||
# Wait for project file processing
|
||||
_poll_file_statuses([project_user_file_id], admin_user, timeout=120)
|
||||
|
||||
# Custom persona with no files
|
||||
persona = PersonaManager.create(
|
||||
user_performing_action=admin_user,
|
||||
name="No Files Agent",
|
||||
description="No files, project is irrelevant",
|
||||
system_prompt=(
|
||||
"You are a helpful assistant. If you do not have information "
|
||||
"to answer a question, say 'I do not have that information.'"
|
||||
),
|
||||
)
|
||||
|
||||
chat_session = ChatSessionManager.create(
|
||||
persona_id=persona.id,
|
||||
project_id=project.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session.id,
|
||||
message="What is the project secret?",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert response.error is None
|
||||
assert len(response.full_message) > 0
|
||||
assert "capybara" not in response.full_message.lower(), (
|
||||
"Response should NOT reference the project file content (CAPYBARA) "
|
||||
"because the custom persona has no files and should not inherit "
|
||||
f"project files, but got: {response.full_message}"
|
||||
)
|
||||
166
backend/tests/integration/tests/scim/test_scim_tokens.py
Normal file
166
backend/tests/integration/tests/scim/test_scim_tokens.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""Integration tests for SCIM token management.
|
||||
|
||||
Covers the admin token API and SCIM bearer-token authentication:
|
||||
1. Token lifecycle: create, retrieve metadata, use for SCIM requests
|
||||
2. Token rotation: creating a new token revokes previous tokens
|
||||
3. Revoked tokens are rejected by SCIM endpoints
|
||||
4. Non-admin users cannot manage SCIM tokens
|
||||
5. SCIM requests without a token are rejected
|
||||
6. Service discovery endpoints work without authentication
|
||||
7. last_used_at is updated after a SCIM request
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.scim_token import ScimTokenManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_scim_token_lifecycle(admin_user: DATestUser) -> None:
|
||||
"""Create token → retrieve metadata → use for SCIM request."""
|
||||
token = ScimTokenManager.create(
|
||||
name="Test SCIM Token",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert token.raw_token is not None
|
||||
assert token.raw_token.startswith("onyx_scim_")
|
||||
assert token.is_active is True
|
||||
assert "****" in token.token_display
|
||||
|
||||
# GET returns the same metadata but raw_token is None because the
|
||||
# server only reveals the raw token once at creation time (it stores
|
||||
# only the SHA-256 hash).
|
||||
active = ScimTokenManager.get_active(user_performing_action=admin_user)
|
||||
assert active == token.model_copy(update={"raw_token": None})
|
||||
|
||||
# Token works for SCIM requests
|
||||
response = ScimTokenManager.scim_get("/Users", token.raw_token)
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert "Resources" in body
|
||||
assert body["totalResults"] >= 0
|
||||
|
||||
|
||||
def test_scim_token_rotation_revokes_previous(admin_user: DATestUser) -> None:
|
||||
"""Creating a new token automatically revokes the previous one."""
|
||||
first = ScimTokenManager.create(
|
||||
name="First Token",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert first.raw_token is not None
|
||||
|
||||
response = ScimTokenManager.scim_get("/Users", first.raw_token)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Create second token — should revoke first
|
||||
second = ScimTokenManager.create(
|
||||
name="Second Token",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert second.raw_token is not None
|
||||
|
||||
# Active token should now be the second one
|
||||
active = ScimTokenManager.get_active(user_performing_action=admin_user)
|
||||
assert active == second.model_copy(update={"raw_token": None})
|
||||
|
||||
# First token rejected, second works
|
||||
assert ScimTokenManager.scim_get("/Users", first.raw_token).status_code == 401
|
||||
assert ScimTokenManager.scim_get("/Users", second.raw_token).status_code == 200
|
||||
|
||||
|
||||
def test_scim_request_without_token_rejected(
|
||||
admin_user: DATestUser, # noqa: ARG001
|
||||
) -> None:
|
||||
"""SCIM endpoints reject requests with no Authorization header."""
|
||||
assert ScimTokenManager.scim_get_no_auth("/Users").status_code == 401
|
||||
|
||||
|
||||
def test_scim_request_with_bad_token_rejected(
|
||||
admin_user: DATestUser, # noqa: ARG001
|
||||
) -> None:
|
||||
"""SCIM endpoints reject requests with an invalid token."""
|
||||
assert (
|
||||
ScimTokenManager.scim_get("/Users", "onyx_scim_bogus_token_value").status_code
|
||||
== 401
|
||||
)
|
||||
|
||||
|
||||
def test_non_admin_cannot_create_token(
|
||||
admin_user: DATestUser, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Non-admin users get 403 when trying to create a SCIM token."""
|
||||
basic_user = UserManager.create(name="scim_basic_user")
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/enterprise-settings/scim/token",
|
||||
json={"name": "Should Fail"},
|
||||
headers=basic_user.headers,
|
||||
timeout=60,
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_non_admin_cannot_get_token(
|
||||
admin_user: DATestUser, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Non-admin users get 403 when trying to retrieve SCIM token metadata."""
|
||||
basic_user = UserManager.create(name="scim_basic_user2")
|
||||
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/enterprise-settings/scim/token",
|
||||
headers=basic_user.headers,
|
||||
timeout=60,
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_no_active_token_returns_404(new_admin_user: DATestUser) -> None:
|
||||
"""GET active token returns 404 when no token exists."""
|
||||
# new_admin_user depends on the reset fixture, ensuring a clean DB
|
||||
# with no active SCIM tokens.
|
||||
active = ScimTokenManager.get_active(user_performing_action=new_admin_user)
|
||||
assert active is None
|
||||
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/enterprise-settings/scim/token",
|
||||
headers=new_admin_user.headers,
|
||||
timeout=60,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_service_discovery_no_auth_required(
|
||||
admin_user: DATestUser, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Service discovery endpoints work without any authentication."""
|
||||
for path in ["/ServiceProviderConfig", "/ResourceTypes", "/Schemas"]:
|
||||
response = ScimTokenManager.scim_get_no_auth(path)
|
||||
assert response.status_code == 200, f"{path} returned {response.status_code}"
|
||||
|
||||
|
||||
def test_last_used_at_updated_after_scim_request(
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""last_used_at timestamp is updated after using the token."""
|
||||
token = ScimTokenManager.create(
|
||||
name="Last Used Token",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert token.raw_token is not None
|
||||
|
||||
active = ScimTokenManager.get_active(user_performing_action=admin_user)
|
||||
assert active is not None
|
||||
assert active.last_used_at is None
|
||||
|
||||
# Make a SCIM request, then verify last_used_at is set
|
||||
assert ScimTokenManager.scim_get("/Users", token.raw_token).status_code == 200
|
||||
time.sleep(0.5)
|
||||
|
||||
active_after = ScimTokenManager.get_active(user_performing_action=admin_user)
|
||||
assert active_after is not None
|
||||
assert active_after.last_used_at is not None
|
||||
@@ -3,6 +3,8 @@ from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from ee.onyx.external_permissions.sharepoint.permission_utils import (
|
||||
_enumerate_ad_groups_paginated,
|
||||
)
|
||||
@@ -15,6 +17,9 @@ from ee.onyx.external_permissions.sharepoint.permission_utils import (
|
||||
from ee.onyx.external_permissions.sharepoint.permission_utils import (
|
||||
AD_GROUP_ENUMERATION_THRESHOLD,
|
||||
)
|
||||
from ee.onyx.external_permissions.sharepoint.permission_utils import (
|
||||
get_external_access_from_sharepoint,
|
||||
)
|
||||
from ee.onyx.external_permissions.sharepoint.permission_utils import (
|
||||
get_sharepoint_external_groups,
|
||||
)
|
||||
@@ -266,3 +271,65 @@ def test_enumerate_all_without_token_skips(
|
||||
|
||||
assert results == []
|
||||
mock_enum.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_external_access_from_sharepoint – site page URL handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"site_base_url, web_url, expected_relative_url",
|
||||
[
|
||||
(
|
||||
"https://tenant.sharepoint.com/sites/Evan%27sSite",
|
||||
"https://tenant.sharepoint.com/sites/Evan%27sSite/SitePages/Home.aspx",
|
||||
"/sites/Evan%27sSite/SitePages/Home.aspx",
|
||||
),
|
||||
(
|
||||
"https://tenant.sharepoint.com/sites/NormalSite",
|
||||
"https://tenant.sharepoint.com/sites/NormalSite/SitePages/Page.aspx",
|
||||
"/sites/NormalSite/SitePages/Page.aspx",
|
||||
),
|
||||
(
|
||||
"https://tenant.sharepoint.com/sites/Site%20With%20Spaces",
|
||||
"https://tenant.sharepoint.com/sites/Site%20With%20Spaces/SitePages/Doc.aspx",
|
||||
"/sites/Site%20With%20Spaces/SitePages/Doc.aspx",
|
||||
),
|
||||
],
|
||||
ids=["apostrophe-encoded", "no-special-chars", "space-encoded"],
|
||||
)
|
||||
@patch(f"{MODULE}._get_groups_and_members_recursively")
|
||||
@patch(f"{MODULE}.sleep_and_retry")
|
||||
def test_site_page_url_not_duplicated(
|
||||
mock_sleep: MagicMock, # noqa: ARG001
|
||||
mock_recursive: MagicMock,
|
||||
site_base_url: str,
|
||||
web_url: str,
|
||||
expected_relative_url: str,
|
||||
) -> None:
|
||||
"""Regression: the server-relative URL passed to
|
||||
get_file_by_server_relative_url must preserve percent-encoding so the
|
||||
Office365 library's SPResPath.create_relative() recognises the site prefix
|
||||
and doesn't duplicate it."""
|
||||
mock_recursive.return_value = GroupsResult(
|
||||
groups_to_emails={},
|
||||
found_public_group=False,
|
||||
)
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.base_url = site_base_url
|
||||
|
||||
site_page = {"webUrl": web_url}
|
||||
|
||||
get_external_access_from_sharepoint(
|
||||
client_context=ctx,
|
||||
graph_client=MagicMock(),
|
||||
drive_name=None,
|
||||
drive_item=None,
|
||||
site_page=site_page,
|
||||
)
|
||||
|
||||
ctx.web.get_file_by_server_relative_url.assert_called_once_with(
|
||||
expected_relative_url
|
||||
)
|
||||
|
||||
426
backend/tests/unit/onyx/chat/test_context_files.py
Normal file
426
backend/tests/unit/onyx/chat/test_context_files.py
Normal file
@@ -0,0 +1,426 @@
|
||||
"""Tests for the unified context file extraction logic (Phase 5).
|
||||
|
||||
Covers:
|
||||
- resolve_context_user_files: precedence rule (custom persona supersedes project)
|
||||
- extract_context_files: all-or-nothing context window fit check
|
||||
- Search filter / search_usage determination in the caller
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.chat.models import ExtractedContextFiles
|
||||
from onyx.chat.process_message import determine_search_params
|
||||
from onyx.chat.process_message import extract_context_files
|
||||
from onyx.chat.process_message import resolve_context_user_files
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.tools.models import SearchToolUsage
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_user_file(
|
||||
token_count: int = 100,
|
||||
name: str = "file.txt",
|
||||
file_id: str | None = None,
|
||||
) -> UserFile:
|
||||
file_uuid = UUID(file_id) if file_id else uuid4()
|
||||
return UserFile(
|
||||
id=file_uuid,
|
||||
file_id=str(file_uuid),
|
||||
name=name,
|
||||
token_count=token_count,
|
||||
)
|
||||
|
||||
|
||||
def _make_persona(
|
||||
persona_id: int,
|
||||
user_files: list | None = None,
|
||||
) -> MagicMock:
|
||||
persona = MagicMock()
|
||||
persona.id = persona_id
|
||||
persona.user_files = user_files or []
|
||||
return persona
|
||||
|
||||
|
||||
def _make_in_memory_file(
|
||||
file_id: str,
|
||||
content: str = "hello world",
|
||||
file_type: ChatFileType = ChatFileType.PLAIN_TEXT,
|
||||
filename: str = "file.txt",
|
||||
) -> InMemoryChatFile:
|
||||
return InMemoryChatFile(
|
||||
file_id=file_id,
|
||||
content=content.encode("utf-8"),
|
||||
file_type=file_type,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# resolve_context_user_files
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestResolveContextUserFiles:
|
||||
"""Precedence rule: custom persona fully supersedes project."""
|
||||
|
||||
def test_custom_persona_with_files_returns_persona_files(self) -> None:
|
||||
persona_files = [_make_user_file(), _make_user_file()]
|
||||
persona = _make_persona(persona_id=42, user_files=persona_files)
|
||||
db_session = MagicMock()
|
||||
|
||||
result = resolve_context_user_files(
|
||||
persona=persona, project_id=99, user_id=uuid4(), db_session=db_session
|
||||
)
|
||||
|
||||
assert result == persona_files
|
||||
|
||||
def test_custom_persona_without_files_returns_empty(self) -> None:
|
||||
"""Custom persona with no files should NOT fall through to project."""
|
||||
persona = _make_persona(persona_id=42, user_files=[])
|
||||
db_session = MagicMock()
|
||||
|
||||
result = resolve_context_user_files(
|
||||
persona=persona, project_id=99, user_id=uuid4(), db_session=db_session
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_custom_persona_none_files_returns_empty(self) -> None:
|
||||
"""Custom persona with user_files=None should NOT fall through."""
|
||||
persona = _make_persona(persona_id=42, user_files=None)
|
||||
db_session = MagicMock()
|
||||
|
||||
result = resolve_context_user_files(
|
||||
persona=persona, project_id=99, user_id=uuid4(), db_session=db_session
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
@patch("onyx.chat.process_message.get_user_files_from_project")
|
||||
def test_default_persona_in_project_returns_project_files(
|
||||
self, mock_get_files: MagicMock
|
||||
) -> None:
|
||||
project_files = [_make_user_file(), _make_user_file()]
|
||||
mock_get_files.return_value = project_files
|
||||
persona = _make_persona(persona_id=DEFAULT_PERSONA_ID)
|
||||
user_id = uuid4()
|
||||
db_session = MagicMock()
|
||||
|
||||
result = resolve_context_user_files(
|
||||
persona=persona, project_id=99, user_id=user_id, db_session=db_session
|
||||
)
|
||||
|
||||
assert result == project_files
|
||||
mock_get_files.assert_called_once_with(
|
||||
project_id=99, user_id=user_id, db_session=db_session
|
||||
)
|
||||
|
||||
def test_default_persona_no_project_returns_empty(self) -> None:
|
||||
persona = _make_persona(persona_id=DEFAULT_PERSONA_ID)
|
||||
db_session = MagicMock()
|
||||
|
||||
result = resolve_context_user_files(
|
||||
persona=persona, project_id=None, user_id=uuid4(), db_session=db_session
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
@patch("onyx.chat.process_message.get_user_files_from_project")
|
||||
def test_custom_persona_without_files_ignores_project(
|
||||
self, mock_get_files: MagicMock
|
||||
) -> None:
|
||||
"""Even with a project_id, custom persona means project is invisible."""
|
||||
persona = _make_persona(persona_id=7, user_files=[])
|
||||
db_session = MagicMock()
|
||||
|
||||
result = resolve_context_user_files(
|
||||
persona=persona, project_id=99, user_id=uuid4(), db_session=db_session
|
||||
)
|
||||
|
||||
assert result == []
|
||||
mock_get_files.assert_not_called()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# extract_context_files
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestExtractContextFiles:
|
||||
"""All-or-nothing context window fit check."""
|
||||
|
||||
def test_empty_user_files_returns_empty(self) -> None:
|
||||
db_session = MagicMock()
|
||||
result = extract_context_files(
|
||||
user_files=[],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=db_session,
|
||||
)
|
||||
assert result.file_texts == []
|
||||
assert result.image_files == []
|
||||
assert result.use_as_search_filter is False
|
||||
assert result.uncapped_token_count is None
|
||||
|
||||
@patch("onyx.chat.process_message.load_in_memory_chat_files")
|
||||
def test_files_fit_in_context_are_loaded(self, mock_load: MagicMock) -> None:
|
||||
file_id = str(uuid4())
|
||||
uf = _make_user_file(token_count=100, file_id=file_id)
|
||||
mock_load.return_value = [
|
||||
_make_in_memory_file(file_id=file_id, content="file content")
|
||||
]
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.file_texts == ["file content"]
|
||||
assert result.use_as_search_filter is False
|
||||
assert result.total_token_count == 100
|
||||
assert len(result.file_metadata) == 1
|
||||
assert result.file_metadata[0].file_id == file_id
|
||||
|
||||
def test_files_overflow_context_not_loaded(self) -> None:
|
||||
"""When aggregate tokens exceed 60% of available window, nothing is loaded."""
|
||||
uf = _make_user_file(token_count=7000)
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.file_texts == []
|
||||
assert result.image_files == []
|
||||
assert result.use_as_search_filter is True
|
||||
assert result.uncapped_token_count == 7000
|
||||
assert result.total_token_count == 0
|
||||
|
||||
def test_overflow_boundary_exact(self) -> None:
|
||||
"""Token count exactly at the 60% boundary should trigger overflow."""
|
||||
# Available = (10000 - 0) * 0.6 = 6000. Tokens = 6000 → >= threshold.
|
||||
uf = _make_user_file(token_count=6000)
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.use_as_search_filter is True
|
||||
|
||||
@patch("onyx.chat.process_message.load_in_memory_chat_files")
|
||||
def test_just_under_boundary_loads(self, mock_load: MagicMock) -> None:
|
||||
"""Token count just under the 60% boundary should load files."""
|
||||
file_id = str(uuid4())
|
||||
uf = _make_user_file(token_count=5999, file_id=file_id)
|
||||
mock_load.return_value = [_make_in_memory_file(file_id=file_id, content="data")]
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.use_as_search_filter is False
|
||||
assert result.file_texts == ["data"]
|
||||
|
||||
@patch("onyx.chat.process_message.load_in_memory_chat_files")
|
||||
def test_multiple_files_aggregate_check(self, mock_load: MagicMock) -> None:
|
||||
"""Multiple small files that individually fit but collectively overflow."""
|
||||
files = [_make_user_file(token_count=2500) for _ in range(3)]
|
||||
# 3 * 2500 = 7500 > 6000 threshold
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=files,
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.use_as_search_filter is True
|
||||
assert result.file_texts == []
|
||||
mock_load.assert_not_called()
|
||||
|
||||
@patch("onyx.chat.process_message.load_in_memory_chat_files")
|
||||
def test_reserved_tokens_reduce_available_space(self, mock_load: MagicMock) -> None:
|
||||
"""Reserved tokens shrink the available window."""
|
||||
file_id = str(uuid4())
|
||||
uf = _make_user_file(token_count=3000, file_id=file_id)
|
||||
# Available = (10000 - 5000) * 0.6 = 3000. Tokens = 3000 → overflow.
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=5000,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.use_as_search_filter is True
|
||||
mock_load.assert_not_called()
|
||||
|
||||
@patch("onyx.chat.process_message.load_in_memory_chat_files")
|
||||
def test_image_files_are_extracted(self, mock_load: MagicMock) -> None:
|
||||
file_id = str(uuid4())
|
||||
uf = _make_user_file(token_count=50, file_id=file_id)
|
||||
mock_load.return_value = [
|
||||
InMemoryChatFile(
|
||||
file_id=file_id,
|
||||
content=b"\x89PNG",
|
||||
file_type=ChatFileType.IMAGE,
|
||||
filename="photo.png",
|
||||
)
|
||||
]
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert len(result.image_files) == 1
|
||||
assert result.image_files[0].file_id == file_id
|
||||
assert result.file_texts == []
|
||||
assert result.total_token_count == 50
|
||||
|
||||
@patch("onyx.chat.process_message.DISABLE_VECTOR_DB", True)
|
||||
def test_overflow_with_vector_db_disabled_provides_tool_metadata(self) -> None:
|
||||
"""When vector DB is disabled, overflow produces FileToolMetadata."""
|
||||
uf = _make_user_file(token_count=7000, name="bigfile.txt")
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.use_as_search_filter is False
|
||||
assert len(result.file_metadata_for_tool) == 1
|
||||
assert result.file_metadata_for_tool[0].filename == "bigfile.txt"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Search filter + search_usage determination
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestSearchFilterDetermination:
|
||||
"""Verify that determine_search_params correctly resolves
|
||||
search_project_id, search_persona_id, and search_usage based on
|
||||
the extraction result and the precedence rule.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_context(
|
||||
use_as_search_filter: bool = False,
|
||||
file_texts: list[str] | None = None,
|
||||
uncapped_token_count: int | None = None,
|
||||
) -> ExtractedContextFiles:
|
||||
return ExtractedContextFiles(
|
||||
file_texts=file_texts or [],
|
||||
image_files=[],
|
||||
use_as_search_filter=use_as_search_filter,
|
||||
total_token_count=0,
|
||||
file_metadata=[],
|
||||
uncapped_token_count=uncapped_token_count,
|
||||
)
|
||||
|
||||
def test_custom_persona_files_fit_no_filter(self) -> None:
|
||||
"""Custom persona, files fit → no search filter, AUTO."""
|
||||
result = determine_search_params(
|
||||
persona_id=42,
|
||||
project_id=99,
|
||||
extracted_context_files=self._make_context(
|
||||
file_texts=["content"],
|
||||
uncapped_token_count=100,
|
||||
),
|
||||
)
|
||||
assert result.search_project_id is None
|
||||
assert result.search_persona_id is None
|
||||
assert result.search_usage == SearchToolUsage.AUTO
|
||||
|
||||
def test_custom_persona_files_overflow_persona_filter(self) -> None:
|
||||
"""Custom persona, files overflow → persona_id filter, AUTO."""
|
||||
result = determine_search_params(
|
||||
persona_id=42,
|
||||
project_id=99,
|
||||
extracted_context_files=self._make_context(use_as_search_filter=True),
|
||||
)
|
||||
assert result.search_persona_id == 42
|
||||
assert result.search_project_id is None
|
||||
assert result.search_usage == SearchToolUsage.AUTO
|
||||
|
||||
def test_custom_persona_no_files_no_project_leak(self) -> None:
|
||||
"""Custom persona (no files) in project → nothing leaks from project."""
|
||||
result = determine_search_params(
|
||||
persona_id=42,
|
||||
project_id=99,
|
||||
extracted_context_files=self._make_context(),
|
||||
)
|
||||
assert result.search_project_id is None
|
||||
assert result.search_persona_id is None
|
||||
assert result.search_usage == SearchToolUsage.AUTO
|
||||
|
||||
def test_default_persona_project_files_fit_disables_search(self) -> None:
|
||||
"""Default persona, project files fit → DISABLED."""
|
||||
result = determine_search_params(
|
||||
persona_id=DEFAULT_PERSONA_ID,
|
||||
project_id=99,
|
||||
extracted_context_files=self._make_context(
|
||||
file_texts=["content"],
|
||||
uncapped_token_count=100,
|
||||
),
|
||||
)
|
||||
assert result.search_project_id is None
|
||||
assert result.search_usage == SearchToolUsage.DISABLED
|
||||
|
||||
def test_default_persona_project_files_overflow_enables_search(self) -> None:
|
||||
"""Default persona, project files overflow → ENABLED + project_id filter."""
|
||||
result = determine_search_params(
|
||||
persona_id=DEFAULT_PERSONA_ID,
|
||||
project_id=99,
|
||||
extracted_context_files=self._make_context(
|
||||
use_as_search_filter=True,
|
||||
uncapped_token_count=7000,
|
||||
),
|
||||
)
|
||||
assert result.search_project_id == 99
|
||||
assert result.search_persona_id is None
|
||||
assert result.search_usage == SearchToolUsage.ENABLED
|
||||
|
||||
def test_default_persona_no_project_auto(self) -> None:
|
||||
"""Default persona, no project → AUTO."""
|
||||
result = determine_search_params(
|
||||
persona_id=DEFAULT_PERSONA_ID,
|
||||
project_id=None,
|
||||
extracted_context_files=self._make_context(),
|
||||
)
|
||||
assert result.search_project_id is None
|
||||
assert result.search_usage == SearchToolUsage.AUTO
|
||||
|
||||
def test_default_persona_project_no_files_disables_search(self) -> None:
|
||||
"""Default persona in project with no files → DISABLED."""
|
||||
result = determine_search_params(
|
||||
persona_id=DEFAULT_PERSONA_ID,
|
||||
project_id=99,
|
||||
extracted_context_files=self._make_context(),
|
||||
)
|
||||
assert result.search_usage == SearchToolUsage.DISABLED
|
||||
@@ -7,10 +7,10 @@ from onyx.chat.llm_loop import _try_fallback_tool_extraction
|
||||
from onyx.chat.llm_loop import construct_message_history
|
||||
from onyx.chat.models import ChatLoadedFile
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
from onyx.chat.models import ContextFileMetadata
|
||||
from onyx.chat.models import ExtractedContextFiles
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.models import ProjectFileMetadata
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.file_store.models import ChatFileType
|
||||
@@ -74,20 +74,20 @@ def create_tool_response(
|
||||
)
|
||||
|
||||
|
||||
def create_project_files(
|
||||
def create_context_files(
|
||||
num_files: int = 0, num_images: int = 0, tokens_per_file: int = 100
|
||||
) -> ExtractedProjectFiles:
|
||||
"""Helper to create ExtractedProjectFiles for testing."""
|
||||
project_file_texts = [f"Project file {i} content" for i in range(num_files)]
|
||||
project_file_metadata = [
|
||||
ProjectFileMetadata(
|
||||
) -> ExtractedContextFiles:
|
||||
"""Helper to create ExtractedContextFiles for testing."""
|
||||
file_texts = [f"Project file {i} content" for i in range(num_files)]
|
||||
file_metadata = [
|
||||
ContextFileMetadata(
|
||||
file_id=f"file_{i}",
|
||||
filename=f"file_{i}.txt",
|
||||
file_content=f"Project file {i} content",
|
||||
)
|
||||
for i in range(num_files)
|
||||
]
|
||||
project_image_files = [
|
||||
image_files = [
|
||||
ChatLoadedFile(
|
||||
file_id=f"image_{i}",
|
||||
content=b"",
|
||||
@@ -98,13 +98,13 @@ def create_project_files(
|
||||
)
|
||||
for i in range(num_images)
|
||||
]
|
||||
return ExtractedProjectFiles(
|
||||
project_file_texts=project_file_texts,
|
||||
project_image_files=project_image_files,
|
||||
project_as_filter=False,
|
||||
return ExtractedContextFiles(
|
||||
file_texts=file_texts,
|
||||
image_files=image_files,
|
||||
use_as_search_filter=False,
|
||||
total_token_count=num_files * tokens_per_file,
|
||||
project_file_metadata=project_file_metadata,
|
||||
project_uncapped_token_count=num_files * tokens_per_file,
|
||||
file_metadata=file_metadata,
|
||||
uncapped_token_count=num_files * tokens_per_file,
|
||||
)
|
||||
|
||||
|
||||
@@ -121,14 +121,14 @@ class TestConstructMessageHistory:
|
||||
user_msg2 = create_message("How are you?", MessageType.USER, 5)
|
||||
|
||||
simple_chat_history = [user_msg1, assistant_msg1, user_msg2]
|
||||
project_files = create_project_files()
|
||||
context_files = create_context_files()
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -148,14 +148,14 @@ class TestConstructMessageHistory:
|
||||
custom_agent = create_message("Custom instructions", MessageType.USER, 10)
|
||||
|
||||
simple_chat_history = [user_msg1, assistant_msg1, user_msg2]
|
||||
project_files = create_project_files()
|
||||
context_files = create_context_files()
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=custom_agent,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -167,25 +167,25 @@ class TestConstructMessageHistory:
|
||||
assert result[3] == custom_agent # Before last user message
|
||||
assert result[4] == user_msg2
|
||||
|
||||
def test_with_project_files(self) -> None:
|
||||
def test_with_context_files(self) -> None:
|
||||
"""Test that project files are inserted before the last user message."""
|
||||
system_prompt = create_message("System", MessageType.SYSTEM, 10)
|
||||
user_msg1 = create_message("First message", MessageType.USER, 5)
|
||||
user_msg2 = create_message("Second message", MessageType.USER, 5)
|
||||
|
||||
simple_chat_history = [user_msg1, user_msg2]
|
||||
project_files = create_project_files(num_files=2, tokens_per_file=50)
|
||||
context_files = create_context_files(num_files=2, tokens_per_file=50)
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
# Should have: system, user1, project_files_message, user2
|
||||
# Should have: system, user1, context_files_message, user2
|
||||
assert len(result) == 4
|
||||
assert result[0] == system_prompt
|
||||
assert result[1] == user_msg1
|
||||
@@ -202,14 +202,14 @@ class TestConstructMessageHistory:
|
||||
reminder = create_message("Remember to cite sources", MessageType.USER, 10)
|
||||
|
||||
simple_chat_history = [user_msg]
|
||||
project_files = create_project_files()
|
||||
context_files = create_context_files()
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=reminder,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -235,14 +235,14 @@ class TestConstructMessageHistory:
|
||||
assistant_with_tool,
|
||||
tool_response,
|
||||
]
|
||||
project_files = create_project_files()
|
||||
context_files = create_context_files()
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -264,18 +264,18 @@ class TestConstructMessageHistory:
|
||||
custom_agent = create_message("Custom", MessageType.USER, 10)
|
||||
|
||||
simple_chat_history = [user_msg1, user_msg2, assistant_with_tool]
|
||||
project_files = create_project_files(num_files=1, tokens_per_file=50)
|
||||
context_files = create_context_files(num_files=1, tokens_per_file=50)
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=custom_agent,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
# Should have: system, user1, custom_agent, project_files, user2, assistant_with_tool
|
||||
# Should have: system, user1, custom_agent, context_files, user2, assistant_with_tool
|
||||
assert len(result) == 6
|
||||
assert result[0] == system_prompt
|
||||
assert result[1] == user_msg1
|
||||
@@ -292,14 +292,14 @@ class TestConstructMessageHistory:
|
||||
user_msg2 = create_message("Second", MessageType.USER, 5)
|
||||
|
||||
simple_chat_history = [user_msg1, user_msg2]
|
||||
project_files = create_project_files(num_files=0, num_images=2)
|
||||
context_files = create_context_files(num_files=0, num_images=2)
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -332,14 +332,14 @@ class TestConstructMessageHistory:
|
||||
)
|
||||
|
||||
simple_chat_history = [user_msg]
|
||||
project_files = create_project_files(num_files=0, num_images=1)
|
||||
context_files = create_context_files(num_files=0, num_images=1)
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -366,7 +366,7 @@ class TestConstructMessageHistory:
|
||||
assistant_msg2,
|
||||
user_msg3,
|
||||
]
|
||||
project_files = create_project_files()
|
||||
context_files = create_context_files()
|
||||
|
||||
# Budget only allows last 3 messages + system (10 + 20 + 20 + 20 = 70 tokens)
|
||||
result = construct_message_history(
|
||||
@@ -374,7 +374,7 @@ class TestConstructMessageHistory:
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=80,
|
||||
)
|
||||
|
||||
@@ -395,7 +395,7 @@ class TestConstructMessageHistory:
|
||||
tool_response = create_tool_response("tc_1", "tool_response", 20)
|
||||
|
||||
simple_chat_history = [user_msg1, user_msg2, assistant_with_tool, tool_response]
|
||||
project_files = create_project_files()
|
||||
context_files = create_context_files()
|
||||
|
||||
# Budget only allows last user message and messages after + system
|
||||
# (10 + 20 + 20 + 20 = 70 tokens)
|
||||
@@ -404,7 +404,7 @@ class TestConstructMessageHistory:
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=80,
|
||||
)
|
||||
|
||||
@@ -432,7 +432,7 @@ class TestConstructMessageHistory:
|
||||
assistant_msg1,
|
||||
user_msg2,
|
||||
]
|
||||
project_files = create_project_files()
|
||||
context_files = create_context_files()
|
||||
|
||||
# Remaining history budget is 10 tokens (30 total - 10 system - 10 last user):
|
||||
# keeps [tool_response, assistant_msg1] from history_before_last_user,
|
||||
@@ -442,7 +442,7 @@ class TestConstructMessageHistory:
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=30,
|
||||
)
|
||||
|
||||
@@ -461,7 +461,7 @@ class TestConstructMessageHistory:
|
||||
user_msg2 = create_message("Latest question", MessageType.USER, 10)
|
||||
|
||||
simple_chat_history = [user_msg1, assistant_with_tool, tool_response, user_msg2]
|
||||
project_files = create_project_files()
|
||||
context_files = create_context_files()
|
||||
|
||||
# Remaining history budget is 25 tokens (45 total - 10 system - 10 last user):
|
||||
# keeps both assistant_with_tool and tool_response in history_before_last_user.
|
||||
@@ -470,7 +470,7 @@ class TestConstructMessageHistory:
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=45,
|
||||
)
|
||||
|
||||
@@ -487,18 +487,18 @@ class TestConstructMessageHistory:
|
||||
reminder = create_message("Reminder", MessageType.USER, 10)
|
||||
|
||||
simple_chat_history: list[ChatMessageSimple] = []
|
||||
project_files = create_project_files(num_files=1, tokens_per_file=50)
|
||||
context_files = create_context_files(num_files=1, tokens_per_file=50)
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=custom_agent,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=reminder,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
# Should have: system, custom_agent, project_files, reminder
|
||||
# Should have: system, custom_agent, context_files, reminder
|
||||
assert len(result) == 4
|
||||
assert result[0] == system_prompt
|
||||
assert result[1] == custom_agent
|
||||
@@ -512,7 +512,7 @@ class TestConstructMessageHistory:
|
||||
assistant_with_tool = create_assistant_with_tool_call("tc_1", "tool", 5)
|
||||
|
||||
simple_chat_history = [assistant_msg, assistant_with_tool]
|
||||
project_files = create_project_files()
|
||||
context_files = create_context_files()
|
||||
|
||||
with pytest.raises(ValueError, match="No user message found"):
|
||||
construct_message_history(
|
||||
@@ -520,7 +520,7 @@ class TestConstructMessageHistory:
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -531,7 +531,7 @@ class TestConstructMessageHistory:
|
||||
custom_agent = create_message("Custom", MessageType.USER, 50)
|
||||
|
||||
simple_chat_history = [user_msg]
|
||||
project_files = create_project_files(num_files=1, tokens_per_file=100)
|
||||
context_files = create_context_files(num_files=1, tokens_per_file=100)
|
||||
|
||||
# Total required: 50 (system) + 50 (custom) + 100 (project) + 50 (user) = 250
|
||||
# But only 200 available
|
||||
@@ -541,7 +541,7 @@ class TestConstructMessageHistory:
|
||||
custom_agent_prompt=custom_agent,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=200,
|
||||
)
|
||||
|
||||
@@ -553,7 +553,7 @@ class TestConstructMessageHistory:
|
||||
assistant_with_tool = create_assistant_with_tool_call("tc_1", "tool", 30)
|
||||
|
||||
simple_chat_history = [user_msg1, user_msg2, assistant_with_tool]
|
||||
project_files = create_project_files()
|
||||
context_files = create_context_files()
|
||||
|
||||
# Budget: 50 tokens
|
||||
# Required: 10 (system) + 30 (user2) + 30 (assistant_with_tool) = 70 tokens
|
||||
@@ -566,7 +566,7 @@ class TestConstructMessageHistory:
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=50,
|
||||
)
|
||||
|
||||
@@ -592,20 +592,20 @@ class TestConstructMessageHistory:
|
||||
assistant_with_tool,
|
||||
tool_response,
|
||||
]
|
||||
project_files = create_project_files(num_files=2, tokens_per_file=20)
|
||||
context_files = create_context_files(num_files=2, tokens_per_file=20)
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=custom_agent,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=reminder,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
# Expected order:
|
||||
# system, user1, assistant1, user2, assistant2,
|
||||
# custom_agent, project_files, user3, assistant_with_tool, tool_response, reminder
|
||||
# custom_agent, context_files, user3, assistant_with_tool, tool_response, reminder
|
||||
assert len(result) == 11
|
||||
assert result[0] == system_prompt
|
||||
assert result[1] == user_msg1
|
||||
@@ -622,20 +622,20 @@ class TestConstructMessageHistory:
|
||||
assert result[9] == tool_response # After last user
|
||||
assert result[10] == reminder # At the very end
|
||||
|
||||
def test_project_files_json_format(self) -> None:
|
||||
def test_context_files_json_format(self) -> None:
|
||||
"""Test that project files are formatted correctly as JSON."""
|
||||
system_prompt = create_message("System", MessageType.SYSTEM, 10)
|
||||
user_msg = create_message("Hello", MessageType.USER, 5)
|
||||
|
||||
simple_chat_history = [user_msg]
|
||||
project_files = create_project_files(num_files=2, tokens_per_file=50)
|
||||
context_files = create_context_files(num_files=2, tokens_per_file=50)
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -692,7 +692,7 @@ class TestForgottenFileMetadata:
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=create_project_files(),
|
||||
context_files=create_context_files(),
|
||||
available_tokens=available_tokens,
|
||||
token_counter=_simple_token_counter,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
|
||||
@@ -98,6 +98,11 @@ class TestScimDALUserMappings:
|
||||
"external_id": "ext-1",
|
||||
"user_id": user_id,
|
||||
"scim_username": None,
|
||||
"department": None,
|
||||
"manager": None,
|
||||
"given_name": None,
|
||||
"family_name": None,
|
||||
"scim_emails_json": None,
|
||||
}
|
||||
|
||||
def test_delete_user_mapping(
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from onyx.onyxbot.slack.formatting import _normalize_citation_link_destinations
|
||||
from onyx.onyxbot.slack.formatting import _convert_slack_links_to_markdown
|
||||
from onyx.onyxbot.slack.formatting import _normalize_link_destinations
|
||||
from onyx.onyxbot.slack.formatting import _sanitize_html
|
||||
from onyx.onyxbot.slack.formatting import _transform_outside_code_blocks
|
||||
from onyx.onyxbot.slack.formatting import format_slack_message
|
||||
from onyx.onyxbot.slack.utils import remove_slack_text_interactions
|
||||
from onyx.utils.text_processing import decode_escapes
|
||||
@@ -9,7 +12,7 @@ def test_normalize_citation_link_wraps_url_with_parentheses() -> None:
|
||||
"See [[1]](https://example.com/Access%20ID%20Card(s)%20Guide.pdf) for details."
|
||||
)
|
||||
|
||||
normalized = _normalize_citation_link_destinations(message)
|
||||
normalized = _normalize_link_destinations(message)
|
||||
|
||||
assert (
|
||||
"See [[1]](<https://example.com/Access%20ID%20Card(s)%20Guide.pdf>) for details."
|
||||
@@ -20,7 +23,7 @@ def test_normalize_citation_link_wraps_url_with_parentheses() -> None:
|
||||
def test_normalize_citation_link_keeps_existing_angle_brackets() -> None:
|
||||
message = "[[1]](<https://example.com/Access%20ID%20Card(s)%20Guide.pdf>)"
|
||||
|
||||
normalized = _normalize_citation_link_destinations(message)
|
||||
normalized = _normalize_link_destinations(message)
|
||||
|
||||
assert message == normalized
|
||||
|
||||
@@ -31,7 +34,7 @@ def test_normalize_citation_link_handles_multiple_links() -> None:
|
||||
"[[2]](https://example.com/Plan(s)%20Overview.pdf)"
|
||||
)
|
||||
|
||||
normalized = _normalize_citation_link_destinations(message)
|
||||
normalized = _normalize_link_destinations(message)
|
||||
|
||||
assert "[[1]](<https://example.com/(USA)%20Guide.pdf>)" in normalized
|
||||
assert "[[2]](<https://example.com/Plan(s)%20Overview.pdf>)" in normalized
|
||||
@@ -50,3 +53,54 @@ def test_format_slack_message_keeps_parenthesized_citation_links_intact() -> Non
|
||||
in rendered
|
||||
)
|
||||
assert "|[1]>%20Access%20ID%20Card" not in rendered
|
||||
|
||||
|
||||
def test_slack_style_links_converted_to_clickable_links() -> None:
|
||||
message = "Visit <https://example.com/page|Example Page> for details."
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "<https://example.com/page|Example Page>" in formatted
|
||||
assert "<" not in formatted
|
||||
|
||||
|
||||
def test_slack_style_links_preserved_inside_code_blocks() -> None:
|
||||
message = "```\n<https://example.com|click>\n```"
|
||||
|
||||
converted = _convert_slack_links_to_markdown(message)
|
||||
|
||||
assert "<https://example.com|click>" in converted
|
||||
|
||||
|
||||
def test_html_tags_stripped_outside_code_blocks() -> None:
|
||||
message = "Hello<br/>world ```<div>code</div>``` after"
|
||||
|
||||
sanitized = _transform_outside_code_blocks(message, _sanitize_html)
|
||||
|
||||
assert "<br" not in sanitized
|
||||
assert "<div>code</div>" in sanitized
|
||||
|
||||
|
||||
def test_format_slack_message_block_spacing() -> None:
|
||||
message = "Paragraph one.\n\nParagraph two."
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "Paragraph one.\n\nParagraph two." == formatted
|
||||
|
||||
|
||||
def test_format_slack_message_code_block_no_trailing_blank_line() -> None:
|
||||
message = "```python\nprint('hi')\n```"
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert formatted.endswith("print('hi')\n```")
|
||||
|
||||
|
||||
def test_format_slack_message_ampersand_not_double_escaped() -> None:
|
||||
message = 'She said "hello" & goodbye.'
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "&" in formatted
|
||||
assert """ not in formatted
|
||||
|
||||
983
backend/tests/unit/onyx/server/scim/test_entra.py
Normal file
983
backend/tests/unit/onyx/server/scim/test_entra.py
Normal file
@@ -0,0 +1,983 @@
|
||||
"""Comprehensive Entra ID (Azure AD) SCIM compatibility tests.
|
||||
|
||||
Covers the full Entra provisioning lifecycle: service discovery, user CRUD
|
||||
with enterprise extension schema, group CRUD with excludedAttributes, and
|
||||
all Entra-specific behavioral quirks (PascalCase ops, enterprise URN in
|
||||
PATCH value dicts).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import Response
|
||||
|
||||
from ee.onyx.server.scim.api import create_user
|
||||
from ee.onyx.server.scim.api import delete_user
|
||||
from ee.onyx.server.scim.api import get_group
|
||||
from ee.onyx.server.scim.api import get_resource_types
|
||||
from ee.onyx.server.scim.api import get_schemas
|
||||
from ee.onyx.server.scim.api import get_service_provider_config
|
||||
from ee.onyx.server.scim.api import get_user
|
||||
from ee.onyx.server.scim.api import list_groups
|
||||
from ee.onyx.server.scim.api import list_users
|
||||
from ee.onyx.server.scim.api import patch_group
|
||||
from ee.onyx.server.scim.api import patch_user
|
||||
from ee.onyx.server.scim.api import replace_user
|
||||
from ee.onyx.server.scim.api import ScimJSONResponse
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimEnterpriseExtension
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimManagerRef
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimPatchOperation
|
||||
from ee.onyx.server.scim.models import ScimPatchOperationType
|
||||
from ee.onyx.server.scim.models import ScimPatchRequest
|
||||
from ee.onyx.server.scim.models import ScimPatchResourceValue
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
from ee.onyx.server.scim.providers.entra import EntraProvider
|
||||
from tests.unit.onyx.server.scim.conftest import make_db_group
|
||||
from tests.unit.onyx.server.scim.conftest import make_db_user
|
||||
from tests.unit.onyx.server.scim.conftest import make_scim_user
|
||||
from tests.unit.onyx.server.scim.conftest import make_user_mapping
|
||||
from tests.unit.onyx.server.scim.conftest import parse_scim_group
|
||||
from tests.unit.onyx.server.scim.conftest import parse_scim_list
|
||||
from tests.unit.onyx.server.scim.conftest import parse_scim_user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def entra_provider() -> ScimProvider:
|
||||
"""An EntraProvider instance for Entra-specific endpoint tests."""
|
||||
return EntraProvider()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Service Discovery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEntraServiceDiscovery:
|
||||
"""Entra expects enterprise extension in discovery endpoints."""
|
||||
|
||||
def test_service_provider_config_advertises_patch(self) -> None:
|
||||
config = get_service_provider_config()
|
||||
assert config.patch.supported is True
|
||||
|
||||
def test_resource_types_include_enterprise_extension(self) -> None:
|
||||
result = get_resource_types()
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
assert "Resources" in parsed
|
||||
user_type = next(rt for rt in parsed["Resources"] if rt["id"] == "User")
|
||||
extension_schemas = [ext["schema"] for ext in user_type["schemaExtensions"]]
|
||||
assert SCIM_ENTERPRISE_USER_SCHEMA in extension_schemas
|
||||
|
||||
def test_schemas_include_enterprise_user(self) -> None:
|
||||
result = get_schemas()
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
schema_ids = [s["id"] for s in parsed["Resources"]]
|
||||
assert SCIM_ENTERPRISE_USER_SCHEMA in schema_ids
|
||||
|
||||
def test_enterprise_schema_has_expected_attributes(self) -> None:
|
||||
result = get_schemas()
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
enterprise = next(
|
||||
s for s in parsed["Resources"] if s["id"] == SCIM_ENTERPRISE_USER_SCHEMA
|
||||
)
|
||||
attr_names = {a["name"] for a in enterprise["attributes"]}
|
||||
assert "department" in attr_names
|
||||
assert "manager" in attr_names
|
||||
|
||||
def test_service_discovery_content_type(self) -> None:
|
||||
"""SCIM responses must use application/scim+json content type."""
|
||||
result = get_resource_types()
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
assert result.media_type == "application/scim+json"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User Lifecycle (Entra-specific)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEntraUserLifecycle:
|
||||
"""Test user CRUD through Entra's lens: enterprise schemas, PascalCase ops."""
|
||||
|
||||
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
|
||||
def test_create_user_includes_enterprise_schema(
|
||||
self,
|
||||
mock_seats: MagicMock, # noqa: ARG002
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_user_by_email.return_value = None
|
||||
resource = make_scim_user(userName="alice@contoso.com")
|
||||
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result, status=201)
|
||||
assert SCIM_ENTERPRISE_USER_SCHEMA in resource.schemas
|
||||
assert SCIM_USER_SCHEMA in resource.schemas
|
||||
|
||||
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
|
||||
def test_create_user_with_enterprise_extension(
|
||||
self,
|
||||
mock_seats: MagicMock, # noqa: ARG002
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Enterprise extension department/manager should round-trip on create."""
|
||||
mock_dal.get_user_by_email.return_value = None
|
||||
resource = make_scim_user(
|
||||
userName="alice@contoso.com",
|
||||
enterprise_extension=ScimEnterpriseExtension(
|
||||
department="Engineering",
|
||||
manager=ScimManagerRef(value="mgr-uuid-123"),
|
||||
),
|
||||
)
|
||||
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result, status=201)
|
||||
assert resource.enterprise_extension is not None
|
||||
assert resource.enterprise_extension.department == "Engineering"
|
||||
assert resource.enterprise_extension.manager is not None
|
||||
assert resource.enterprise_extension.manager.value == "mgr-uuid-123"
|
||||
|
||||
# Verify DAL received the enterprise fields
|
||||
mock_dal.create_user_mapping.assert_called_once()
|
||||
call_kwargs = mock_dal.create_user_mapping.call_args[1]
|
||||
assert call_kwargs["fields"] == ScimMappingFields(
|
||||
department="Engineering",
|
||||
manager="mgr-uuid-123",
|
||||
given_name="Test",
|
||||
family_name="User",
|
||||
)
|
||||
|
||||
def test_get_user_includes_enterprise_schema(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user(email="alice@contoso.com")
|
||||
mock_dal.get_user.return_value = user
|
||||
|
||||
result = get_user(
|
||||
user_id=str(user.id),
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result)
|
||||
assert SCIM_ENTERPRISE_USER_SCHEMA in resource.schemas
|
||||
|
||||
def test_get_user_returns_enterprise_extension_data(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""GET should return stored enterprise extension data."""
|
||||
user = make_db_user(email="alice@contoso.com")
|
||||
mock_dal.get_user.return_value = user
|
||||
mapping = make_user_mapping(user_id=user.id)
|
||||
mapping.department = "Sales"
|
||||
mapping.manager = "mgr-456"
|
||||
mock_dal.get_user_mapping_by_user_id.return_value = mapping
|
||||
|
||||
result = get_user(
|
||||
user_id=str(user.id),
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result)
|
||||
assert resource.enterprise_extension is not None
|
||||
assert resource.enterprise_extension.department == "Sales"
|
||||
assert resource.enterprise_extension.manager is not None
|
||||
assert resource.enterprise_extension.manager.value == "mgr-456"
|
||||
|
||||
def test_list_users_includes_enterprise_schema(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user(email="alice@contoso.com")
|
||||
mapping = make_user_mapping(external_id="entra-ext-1", user_id=user.id)
|
||||
mock_dal.list_users.return_value = ([(user, mapping)], 1)
|
||||
|
||||
result = list_users(
|
||||
filter=None,
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parsed = parse_scim_list(result)
|
||||
resource = parsed.Resources[0]
|
||||
assert isinstance(resource, ScimUserResource)
|
||||
assert SCIM_ENTERPRISE_USER_SCHEMA in resource.schemas
|
||||
|
||||
def test_patch_user_deactivate_with_pascal_case_replace(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Entra sends ``"Replace"`` (PascalCase) instead of ``"replace"``."""
|
||||
user = make_db_user(is_active=True)
|
||||
mock_dal.get_user.return_value = user
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op="Replace", # type: ignore[arg-type]
|
||||
path="active",
|
||||
value=False,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_user(
|
||||
user_id=str(user.id),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result)
|
||||
# Mock doesn't propagate the change, so verify via the DAL call
|
||||
mock_dal.update_user.assert_called_once()
|
||||
call_kwargs = mock_dal.update_user.call_args
|
||||
assert call_kwargs[1]["is_active"] is False
|
||||
|
||||
def test_patch_user_add_external_id_with_pascal_case(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Entra sends ``"Add"`` (PascalCase) instead of ``"add"``."""
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op="Add", # type: ignore[arg-type]
|
||||
path="externalId",
|
||||
value="entra-ext-999",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_user(
|
||||
user_id=str(user.id),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result)
|
||||
# Verify the patched externalId was synced to the DAL
|
||||
mock_dal.sync_user_external_id.assert_called_once()
|
||||
call_args = mock_dal.sync_user_external_id.call_args
|
||||
assert call_args[0][1] == "entra-ext-999"
|
||||
|
||||
def test_patch_user_enterprise_extension_in_value_dict(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Entra sends enterprise extension URN as key in path-less PATCH value
|
||||
dicts — enterprise data should be stored, not ignored."""
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
|
||||
value = ScimPatchResourceValue(active=False)
|
||||
assert value.__pydantic_extra__ is not None
|
||||
value.__pydantic_extra__[
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
|
||||
] = {"department": "Engineering"}
|
||||
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op=ScimPatchOperationType.REPLACE,
|
||||
path=None,
|
||||
value=value,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_user(
|
||||
user_id=str(user.id),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result)
|
||||
# Verify active=False was applied
|
||||
mock_dal.update_user.assert_called_once()
|
||||
call_kwargs = mock_dal.update_user.call_args
|
||||
assert call_kwargs[1]["is_active"] is False
|
||||
# Verify enterprise data was passed to DAL
|
||||
mock_dal.sync_user_external_id.assert_called_once()
|
||||
sync_kwargs = mock_dal.sync_user_external_id.call_args[1]
|
||||
assert sync_kwargs["fields"] == ScimMappingFields(
|
||||
department="Engineering",
|
||||
given_name="Test",
|
||||
family_name="User",
|
||||
scim_emails_json='[{"value": "test@example.com", "type": "work", "primary": true}]',
|
||||
)
|
||||
|
||||
def test_patch_user_remove_external_id(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""PATCH remove op should clear the target field."""
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
mapping = make_user_mapping(user_id=user.id)
|
||||
mapping.external_id = "ext-to-remove"
|
||||
mock_dal.get_user_mapping_by_user_id.return_value = mapping
|
||||
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op=ScimPatchOperationType.REMOVE,
|
||||
path="externalId",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_user(
|
||||
user_id=str(user.id),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result)
|
||||
# externalId should be cleared (None)
|
||||
mock_dal.sync_user_external_id.assert_called_once()
|
||||
call_args = mock_dal.sync_user_external_id.call_args
|
||||
assert call_args[0][1] is None
|
||||
|
||||
def test_patch_user_emails_primary_eq_true_value(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""PATCH with path emails[primary eq true].value should update
|
||||
the primary email entry, not userName."""
|
||||
user = make_db_user(email="old@contoso.com")
|
||||
mock_dal.get_user.return_value = user
|
||||
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op=ScimPatchOperationType.REPLACE,
|
||||
path="emails[primary eq true].value",
|
||||
value="new@contoso.com",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_user(
|
||||
user_id=str(user.id),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result)
|
||||
# userName should remain unchanged — emails and userName are separate
|
||||
assert resource.userName == "old@contoso.com"
|
||||
# Primary email should be updated
|
||||
primary_emails = [e for e in resource.emails if e.primary]
|
||||
assert len(primary_emails) == 1
|
||||
assert primary_emails[0].value == "new@contoso.com"
|
||||
|
||||
def test_patch_user_enterprise_urn_department_path(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""PATCH with dotted enterprise URN path should store department."""
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op=ScimPatchOperationType.REPLACE,
|
||||
path="urn:ietf:params:scim:schemas:extension:enterprise:2.0:User:department",
|
||||
value="Marketing",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_user(
|
||||
user_id=str(user.id),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result)
|
||||
mock_dal.sync_user_external_id.assert_called_once()
|
||||
sync_kwargs = mock_dal.sync_user_external_id.call_args[1]
|
||||
assert sync_kwargs["fields"] == ScimMappingFields(
|
||||
department="Marketing",
|
||||
given_name="Test",
|
||||
family_name="User",
|
||||
scim_emails_json='[{"value": "test@example.com", "type": "work", "primary": true}]',
|
||||
)
|
||||
|
||||
def test_replace_user_includes_enterprise_schema(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user(email="old@contoso.com")
|
||||
mock_dal.get_user.return_value = user
|
||||
resource = make_scim_user(
|
||||
userName="new@contoso.com",
|
||||
name=ScimName(givenName="New", familyName="Name"),
|
||||
)
|
||||
|
||||
result = replace_user(
|
||||
user_id=str(user.id),
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result)
|
||||
assert SCIM_ENTERPRISE_USER_SCHEMA in resource.schemas
|
||||
|
||||
def test_replace_user_with_enterprise_extension(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""PUT with enterprise extension should store the fields."""
|
||||
user = make_db_user(email="alice@contoso.com")
|
||||
mock_dal.get_user.return_value = user
|
||||
resource = make_scim_user(
|
||||
userName="alice@contoso.com",
|
||||
enterprise_extension=ScimEnterpriseExtension(
|
||||
department="HR",
|
||||
manager=ScimManagerRef(value="boss-id"),
|
||||
),
|
||||
)
|
||||
|
||||
result = replace_user(
|
||||
user_id=str(user.id),
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result)
|
||||
mock_dal.sync_user_external_id.assert_called_once()
|
||||
sync_kwargs = mock_dal.sync_user_external_id.call_args[1]
|
||||
assert sync_kwargs["fields"] == ScimMappingFields(
|
||||
department="HR",
|
||||
manager="boss-id",
|
||||
given_name="Test",
|
||||
family_name="User",
|
||||
)
|
||||
|
||||
def test_delete_user_returns_204(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
mock_dal.get_user_mapping_by_user_id.return_value = MagicMock(id=1)
|
||||
|
||||
result = delete_user(
|
||||
user_id=str(user.id),
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert result.status_code == 204
|
||||
|
||||
def test_double_delete_returns_404(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
"""Second DELETE should return 404 — the SCIM mapping is gone."""
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
# No mapping — user was already deleted from SCIM's perspective
|
||||
mock_dal.get_user_mapping_by_user_id.return_value = None
|
||||
|
||||
result = delete_user(
|
||||
user_id=str(user.id),
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
assert result.status_code == 404
|
||||
|
||||
def test_name_formatted_preserved_on_create(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""When name.formatted is provided, it should be used as personal_name."""
|
||||
mock_dal.get_user_by_email.return_value = None
|
||||
resource = make_scim_user(
|
||||
userName="alice@contoso.com",
|
||||
name=ScimName(
|
||||
givenName="Alice",
|
||||
familyName="Smith",
|
||||
formatted="Dr. Alice Smith",
|
||||
),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"ee.onyx.server.scim.api._check_seat_availability", return_value=None
|
||||
):
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result, status=201)
|
||||
# The User constructor should have received the formatted name
|
||||
mock_dal.add_user.assert_called_once()
|
||||
created_user = mock_dal.add_user.call_args[0][0]
|
||||
assert created_user.personal_name == "Dr. Alice Smith"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Group Lifecycle (Entra-specific)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEntraGroupLifecycle:
|
||||
"""Test group CRUD with Entra-specific behaviors."""
|
||||
|
||||
def test_get_group_standard_response(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=10, name="Contoso Engineering")
|
||||
mock_dal.get_group.return_value = group
|
||||
uid = uuid4()
|
||||
mock_dal.get_group_members.return_value = [(uid, "alice@contoso.com")]
|
||||
|
||||
result = get_group(
|
||||
group_id="10",
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_group(result)
|
||||
assert resource.displayName == "Contoso Engineering"
|
||||
assert len(resource.members) == 1
|
||||
|
||||
def test_list_groups_with_excluded_attributes_members(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Entra sends ?excludedAttributes=members on group list queries."""
|
||||
group = make_db_group(id=10, name="Engineering")
|
||||
uid = uuid4()
|
||||
mock_dal.list_groups.return_value = ([(group, "ext-g-1")], 1)
|
||||
mock_dal.get_group_members.return_value = [(uid, "alice@contoso.com")]
|
||||
|
||||
result = list_groups(
|
||||
filter=None,
|
||||
excludedAttributes="members",
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
assert parsed["totalResults"] == 1
|
||||
resource = parsed["Resources"][0]
|
||||
assert "members" not in resource
|
||||
assert resource["displayName"] == "Engineering"
|
||||
|
||||
def test_get_group_with_excluded_attributes_members(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Entra sends ?excludedAttributes=members on single group GET."""
|
||||
group = make_db_group(id=10, name="Engineering")
|
||||
uid = uuid4()
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_dal.get_group_members.return_value = [(uid, "alice@contoso.com")]
|
||||
|
||||
result = get_group(
|
||||
group_id="10",
|
||||
excludedAttributes="members",
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
assert "members" not in parsed
|
||||
assert parsed["displayName"] == "Engineering"
|
||||
|
||||
@patch("ee.onyx.server.scim.api.apply_group_patch")
|
||||
def test_patch_group_add_members_with_pascal_case(
|
||||
self,
|
||||
mock_apply: MagicMock,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Entra sends ``"Add"`` (PascalCase) for group member additions."""
|
||||
group = make_db_group(id=10)
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_dal.get_group_members.return_value = []
|
||||
mock_dal.validate_member_ids.return_value = []
|
||||
|
||||
uid = str(uuid4())
|
||||
patched = ScimGroupResource(
|
||||
id="10",
|
||||
displayName="Engineering",
|
||||
members=[ScimGroupMember(value=uid)],
|
||||
)
|
||||
mock_apply.return_value = (patched, [uid], [])
|
||||
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op="Add", # type: ignore[arg-type]
|
||||
path="members",
|
||||
value=[ScimGroupMember(value=uid)],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_group(
|
||||
group_id="10",
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_group(result)
|
||||
mock_dal.upsert_group_members.assert_called_once()
|
||||
|
||||
@patch("ee.onyx.server.scim.api.apply_group_patch")
|
||||
def test_patch_group_remove_member_with_pascal_case(
|
||||
self,
|
||||
mock_apply: MagicMock,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Entra sends ``"Remove"`` (PascalCase) for group member removals."""
|
||||
group = make_db_group(id=10)
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_dal.get_group_members.return_value = []
|
||||
|
||||
uid = str(uuid4())
|
||||
patched = ScimGroupResource(id="10", displayName="Engineering", members=[])
|
||||
mock_apply.return_value = (patched, [], [uid])
|
||||
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op="Remove", # type: ignore[arg-type]
|
||||
path=f'members[value eq "{uid}"]',
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_group(
|
||||
group_id="10",
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_group(result)
|
||||
mock_dal.remove_group_members.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# excludedAttributes (RFC 7644 §3.4.2.5)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExcludedAttributes:
|
||||
"""Test excludedAttributes query parameter on GET endpoints."""
|
||||
|
||||
def test_list_groups_excludes_members(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=1, name="Team")
|
||||
uid = uuid4()
|
||||
mock_dal.list_groups.return_value = ([(group, None)], 1)
|
||||
mock_dal.get_group_members.return_value = [(uid, "user@example.com")]
|
||||
|
||||
result = list_groups(
|
||||
filter=None,
|
||||
excludedAttributes="members",
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
resource = parsed["Resources"][0]
|
||||
assert "members" not in resource
|
||||
assert "displayName" in resource
|
||||
|
||||
def test_get_group_excludes_members(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=1, name="Team")
|
||||
uid = uuid4()
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_dal.get_group_members.return_value = [(uid, "user@example.com")]
|
||||
|
||||
result = get_group(
|
||||
group_id="1",
|
||||
excludedAttributes="members",
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
assert "members" not in parsed
|
||||
assert "displayName" in parsed
|
||||
|
||||
def test_list_users_excludes_groups(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user()
|
||||
mapping = make_user_mapping(user_id=user.id)
|
||||
mock_dal.list_users.return_value = ([(user, mapping)], 1)
|
||||
mock_dal.get_users_groups_batch.return_value = {user.id: [(1, "Engineering")]}
|
||||
|
||||
result = list_users(
|
||||
filter=None,
|
||||
excludedAttributes="groups",
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
resource = parsed["Resources"][0]
|
||||
assert "groups" not in resource
|
||||
assert "userName" in resource
|
||||
|
||||
def test_get_user_excludes_groups(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
mock_dal.get_user_groups.return_value = [(1, "Engineering")]
|
||||
|
||||
result = get_user(
|
||||
user_id=str(user.id),
|
||||
excludedAttributes="groups",
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
assert "groups" not in parsed
|
||||
assert "userName" in parsed
|
||||
|
||||
def test_multiple_excluded_attributes(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=1, name="Team")
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_dal.get_group_members.return_value = []
|
||||
|
||||
result = get_group(
|
||||
group_id="1",
|
||||
excludedAttributes="members,externalId",
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
assert "members" not in parsed
|
||||
assert "externalId" not in parsed
|
||||
assert "displayName" in parsed
|
||||
|
||||
def test_no_excluded_attributes_returns_full_response(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=1, name="Team")
|
||||
uid = uuid4()
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_dal.get_group_members.return_value = [(uid, "user@example.com")]
|
||||
|
||||
result = get_group(
|
||||
group_id="1",
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_group(result)
|
||||
assert len(resource.members) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Entra Connection Probe
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEntraConnectionProbe:
|
||||
"""Entra sends a probe request during initial SCIM setup."""
|
||||
|
||||
def test_filter_for_nonexistent_user_returns_empty_list(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Entra probes with: GET /Users?filter=userName eq "non-existent"&count=1"""
|
||||
mock_dal.list_users.return_value = ([], 0)
|
||||
|
||||
result = list_users(
|
||||
filter='userName eq "non-existent@contoso.com"',
|
||||
startIndex=1,
|
||||
count=1,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parsed = parse_scim_list(result)
|
||||
assert parsed.totalResults == 0
|
||||
assert parsed.Resources == []
|
||||
@@ -13,9 +13,11 @@ from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.patch import apply_group_patch
|
||||
from ee.onyx.server.scim.patch import apply_user_patch
|
||||
from ee.onyx.server.scim.patch import ScimPatchError
|
||||
from ee.onyx.server.scim.providers.entra import EntraProvider
|
||||
from ee.onyx.server.scim.providers.okta import OktaProvider
|
||||
|
||||
_OKTA_IGNORED = OktaProvider().ignored_patch_paths
|
||||
_ENTRA_IGNORED = EntraProvider().ignored_patch_paths
|
||||
|
||||
|
||||
def _make_user(**kwargs: object) -> ScimUserResource:
|
||||
@@ -57,36 +59,36 @@ class TestApplyUserPatch:
|
||||
|
||||
def test_deactivate_user(self) -> None:
|
||||
user = _make_user()
|
||||
result = apply_user_patch([_replace_op("active", False)], user)
|
||||
result, _ = apply_user_patch([_replace_op("active", False)], user)
|
||||
assert result.active is False
|
||||
assert result.userName == "test@example.com"
|
||||
|
||||
def test_activate_user(self) -> None:
|
||||
user = _make_user(active=False)
|
||||
result = apply_user_patch([_replace_op("active", True)], user)
|
||||
result, _ = apply_user_patch([_replace_op("active", True)], user)
|
||||
assert result.active is True
|
||||
|
||||
def test_replace_given_name(self) -> None:
|
||||
user = _make_user()
|
||||
result = apply_user_patch([_replace_op("name.givenName", "NewFirst")], user)
|
||||
result, _ = apply_user_patch([_replace_op("name.givenName", "NewFirst")], user)
|
||||
assert result.name is not None
|
||||
assert result.name.givenName == "NewFirst"
|
||||
assert result.name.familyName == "User"
|
||||
|
||||
def test_replace_family_name(self) -> None:
|
||||
user = _make_user()
|
||||
result = apply_user_patch([_replace_op("name.familyName", "NewLast")], user)
|
||||
result, _ = apply_user_patch([_replace_op("name.familyName", "NewLast")], user)
|
||||
assert result.name is not None
|
||||
assert result.name.familyName == "NewLast"
|
||||
|
||||
def test_replace_username(self) -> None:
|
||||
user = _make_user()
|
||||
result = apply_user_patch([_replace_op("userName", "new@example.com")], user)
|
||||
result, _ = apply_user_patch([_replace_op("userName", "new@example.com")], user)
|
||||
assert result.userName == "new@example.com"
|
||||
|
||||
def test_replace_without_path_uses_dict(self) -> None:
|
||||
user = _make_user()
|
||||
result = apply_user_patch(
|
||||
result, _ = apply_user_patch(
|
||||
[
|
||||
_replace_op(
|
||||
None,
|
||||
@@ -100,7 +102,7 @@ class TestApplyUserPatch:
|
||||
|
||||
def test_multiple_operations(self) -> None:
|
||||
user = _make_user()
|
||||
result = apply_user_patch(
|
||||
result, _ = apply_user_patch(
|
||||
[
|
||||
_replace_op("active", False),
|
||||
_replace_op("name.givenName", "Updated"),
|
||||
@@ -113,7 +115,7 @@ class TestApplyUserPatch:
|
||||
|
||||
def test_case_insensitive_path(self) -> None:
|
||||
user = _make_user()
|
||||
result = apply_user_patch([_replace_op("Active", False)], user)
|
||||
result, _ = apply_user_patch([_replace_op("Active", False)], user)
|
||||
assert result.active is False
|
||||
|
||||
def test_original_not_mutated(self) -> None:
|
||||
@@ -129,7 +131,7 @@ class TestApplyUserPatch:
|
||||
def test_remove_op_clears_field(self) -> None:
|
||||
"""Remove op should clear the target field (not raise)."""
|
||||
user = _make_user(externalId="ext-123")
|
||||
result = apply_user_patch([_remove_op("externalId")], user)
|
||||
result, _ = apply_user_patch([_remove_op("externalId")], user)
|
||||
assert result.externalId is None
|
||||
|
||||
def test_remove_unsupported_path_raises(self) -> None:
|
||||
@@ -141,7 +143,7 @@ class TestApplyUserPatch:
|
||||
def test_replace_without_path_ignores_id(self) -> None:
|
||||
"""Okta sends 'id' alongside actual changes — it should be silently ignored."""
|
||||
user = _make_user()
|
||||
result = apply_user_patch(
|
||||
result, _ = apply_user_patch(
|
||||
[_replace_op(None, ScimPatchResourceValue(active=False, id="some-uuid"))],
|
||||
user,
|
||||
ignored_paths=_OKTA_IGNORED,
|
||||
@@ -151,7 +153,7 @@ class TestApplyUserPatch:
|
||||
def test_replace_without_path_ignores_schemas(self) -> None:
|
||||
"""The 'schemas' key in a value dict should be silently ignored."""
|
||||
user = _make_user()
|
||||
result = apply_user_patch(
|
||||
result, _ = apply_user_patch(
|
||||
[
|
||||
_replace_op(
|
||||
None,
|
||||
@@ -169,7 +171,7 @@ class TestApplyUserPatch:
|
||||
def test_okta_deactivation_payload(self) -> None:
|
||||
"""Exact Okta deactivation payload: path-less replace with id + active."""
|
||||
user = _make_user()
|
||||
result = apply_user_patch(
|
||||
result, _ = apply_user_patch(
|
||||
[
|
||||
_replace_op(
|
||||
None,
|
||||
@@ -184,7 +186,7 @@ class TestApplyUserPatch:
|
||||
|
||||
def test_replace_displayname(self) -> None:
|
||||
user = _make_user()
|
||||
result = apply_user_patch(
|
||||
result, _ = apply_user_patch(
|
||||
[_replace_op("displayName", "New Display Name")], user
|
||||
)
|
||||
assert result.displayName == "New Display Name"
|
||||
@@ -195,7 +197,7 @@ class TestApplyUserPatch:
|
||||
"""Okta sends id/schemas/meta alongside actual changes — complex types
|
||||
(lists, nested dicts) must not cause Pydantic validation errors."""
|
||||
user = _make_user()
|
||||
result = apply_user_patch(
|
||||
result, _ = apply_user_patch(
|
||||
[
|
||||
_replace_op(
|
||||
None,
|
||||
@@ -215,29 +217,65 @@ class TestApplyUserPatch:
|
||||
|
||||
def test_add_operation_works_like_replace(self) -> None:
|
||||
user = _make_user()
|
||||
result = apply_user_patch([_add_op("externalId", "ext-456")], user)
|
||||
result, _ = apply_user_patch([_add_op("externalId", "ext-456")], user)
|
||||
assert result.externalId == "ext-456"
|
||||
|
||||
def test_entra_capitalized_replace_op(self) -> None:
|
||||
"""Entra ID sends ``"Replace"`` instead of ``"replace"``."""
|
||||
user = _make_user()
|
||||
op = ScimPatchOperation(op="Replace", path="active", value=False) # type: ignore[arg-type]
|
||||
result = apply_user_patch([op], user)
|
||||
result, _ = apply_user_patch([op], user)
|
||||
assert result.active is False
|
||||
|
||||
def test_entra_capitalized_add_op(self) -> None:
|
||||
"""Entra ID sends ``"Add"`` instead of ``"add"``."""
|
||||
user = _make_user()
|
||||
op = ScimPatchOperation(op="Add", path="externalId", value="ext-999") # type: ignore[arg-type]
|
||||
result = apply_user_patch([op], user)
|
||||
result, _ = apply_user_patch([op], user)
|
||||
assert result.externalId == "ext-999"
|
||||
|
||||
def test_entra_enterprise_extension_handled(self) -> None:
|
||||
"""Entra sends the enterprise extension URN as a key in path-less
|
||||
PATCH value dicts — enterprise data should be captured in ent_data."""
|
||||
user = _make_user()
|
||||
value = ScimPatchResourceValue(active=False)
|
||||
# Simulate Entra including the enterprise extension URN as extra data
|
||||
assert value.__pydantic_extra__ is not None
|
||||
value.__pydantic_extra__[
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
|
||||
] = {"department": "Engineering"}
|
||||
result, ent_data = apply_user_patch(
|
||||
[_replace_op(None, value)],
|
||||
user,
|
||||
ignored_paths=_ENTRA_IGNORED,
|
||||
)
|
||||
assert result.active is False
|
||||
assert result.userName == "test@example.com"
|
||||
assert ent_data["department"] == "Engineering"
|
||||
|
||||
def test_okta_handles_enterprise_extension_urn(self) -> None:
|
||||
"""Enterprise extension URN paths are handled universally, even
|
||||
for Okta — the data is captured in the enterprise data dict."""
|
||||
user = _make_user()
|
||||
value = ScimPatchResourceValue(active=False)
|
||||
assert value.__pydantic_extra__ is not None
|
||||
value.__pydantic_extra__[
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
|
||||
] = {"department": "Engineering"}
|
||||
result, ent_data = apply_user_patch(
|
||||
[_replace_op(None, value)],
|
||||
user,
|
||||
ignored_paths=_OKTA_IGNORED,
|
||||
)
|
||||
assert result.active is False
|
||||
assert ent_data["department"] == "Engineering"
|
||||
|
||||
def test_emails_primary_eq_true_value(self) -> None:
|
||||
"""emails[primary eq true].value should update the primary email entry."""
|
||||
user = _make_user(
|
||||
emails=[ScimEmail(value="old@example.com", type="work", primary=True)]
|
||||
)
|
||||
result = apply_user_patch(
|
||||
result, _ = apply_user_patch(
|
||||
[_replace_op("emails[primary eq true].value", "new@example.com")], user
|
||||
)
|
||||
# userName should remain unchanged — emails and userName are separate
|
||||
@@ -246,6 +284,34 @@ class TestApplyUserPatch:
|
||||
assert result.emails[0].value == "new@example.com"
|
||||
assert result.emails[0].primary is True
|
||||
|
||||
def test_enterprise_urn_department_path(self) -> None:
|
||||
"""Dotted enterprise URN path should set department in ent_data."""
|
||||
user = _make_user()
|
||||
_, ent_data = apply_user_patch(
|
||||
[
|
||||
_replace_op(
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User:department",
|
||||
"Marketing",
|
||||
)
|
||||
],
|
||||
user,
|
||||
)
|
||||
assert ent_data["department"] == "Marketing"
|
||||
|
||||
def test_enterprise_urn_manager_path(self) -> None:
|
||||
"""Dotted enterprise URN path for manager should set manager."""
|
||||
user = _make_user()
|
||||
_, ent_data = apply_user_patch(
|
||||
[
|
||||
_replace_op(
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User:manager",
|
||||
ScimPatchResourceValue.model_validate({"value": "boss-id"}),
|
||||
)
|
||||
],
|
||||
user,
|
||||
)
|
||||
assert ent_data["manager"] == "boss-id"
|
||||
|
||||
|
||||
class TestApplyGroupPatch:
|
||||
"""Tests for SCIM group PATCH operations."""
|
||||
|
||||
@@ -2,6 +2,7 @@ from unittest.mock import MagicMock
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimEmail
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
@@ -12,6 +13,8 @@ from ee.onyx.server.scim.models import ScimUserGroupRef
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS
|
||||
from ee.onyx.server.scim.providers.base import get_default_provider
|
||||
from ee.onyx.server.scim.providers.entra import _ENTRA_IGNORED_PATCH_PATHS
|
||||
from ee.onyx.server.scim.providers.entra import EntraProvider
|
||||
from ee.onyx.server.scim.providers.okta import OktaProvider
|
||||
|
||||
|
||||
@@ -167,6 +170,42 @@ class TestOktaProvider:
|
||||
assert result.members == []
|
||||
|
||||
|
||||
class TestEntraProvider:
|
||||
def test_name(self) -> None:
|
||||
assert EntraProvider().name == "entra"
|
||||
|
||||
def test_ignored_patch_paths(self) -> None:
|
||||
paths = EntraProvider().ignored_patch_paths
|
||||
assert paths == _ENTRA_IGNORED_PATCH_PATHS
|
||||
# Enterprise extension URN is now handled (not ignored)
|
||||
assert paths >= COMMON_IGNORED_PATCH_PATHS
|
||||
|
||||
def test_build_user_resource_includes_enterprise_schema(self) -> None:
|
||||
provider = EntraProvider()
|
||||
user = _make_mock_user()
|
||||
result = provider.build_user_resource(user, "ext-entra-1")
|
||||
|
||||
assert result.schemas == [SCIM_USER_SCHEMA, SCIM_ENTERPRISE_USER_SCHEMA]
|
||||
|
||||
def test_build_user_resource_basic(self) -> None:
|
||||
provider = EntraProvider()
|
||||
user = _make_mock_user()
|
||||
result = provider.build_user_resource(user, "ext-entra-1")
|
||||
|
||||
assert result == ScimUserResource(
|
||||
schemas=[SCIM_USER_SCHEMA, SCIM_ENTERPRISE_USER_SCHEMA],
|
||||
id=str(user.id),
|
||||
externalId="ext-entra-1",
|
||||
userName="test@example.com",
|
||||
name=ScimName(givenName="Test", familyName="User", formatted="Test User"),
|
||||
displayName="Test User",
|
||||
emails=[ScimEmail(value="test@example.com", type="work", primary=True)],
|
||||
active=True,
|
||||
groups=[],
|
||||
meta=ScimMeta(resourceType="User"),
|
||||
)
|
||||
|
||||
|
||||
class TestGetDefaultProvider:
|
||||
def test_returns_okta(self) -> None:
|
||||
provider = get_default_provider()
|
||||
|
||||
@@ -16,6 +16,7 @@ from ee.onyx.server.scim.api import get_user
|
||||
from ee.onyx.server.scim.api import list_users
|
||||
from ee.onyx.server.scim.api import patch_user
|
||||
from ee.onyx.server.scim.api import replace_user
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimPatchOperation
|
||||
from ee.onyx.server.scim.models import ScimPatchOperationType
|
||||
@@ -418,6 +419,10 @@ class TestReplaceUser:
|
||||
user.id,
|
||||
None,
|
||||
scim_username="test@example.com",
|
||||
fields=ScimMappingFields(
|
||||
given_name="Test",
|
||||
family_name="User",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -106,6 +106,9 @@ def test_checkout_event_stores_endpoint_and_increments_gauge() -> None:
|
||||
patch(
|
||||
"onyx.server.metrics.postgres_connection_pool.CURRENT_ENDPOINT_CONTEXTVAR"
|
||||
) as mock_ctx,
|
||||
patch(
|
||||
"onyx.server.metrics.postgres_connection_pool.CURRENT_TENANT_ID_CONTEXTVAR"
|
||||
) as mock_tenant_ctx,
|
||||
patch(
|
||||
"onyx.server.metrics.postgres_connection_pool._connections_held"
|
||||
) as mock_gauge,
|
||||
@@ -114,12 +117,14 @@ def test_checkout_event_stores_endpoint_and_increments_gauge() -> None:
|
||||
mock_labels = MagicMock()
|
||||
mock_gauge.labels.return_value = mock_labels
|
||||
mock_ctx.get.return_value = "/api/chat/send-message"
|
||||
mock_tenant_ctx.get.return_value = "tenant_xyz"
|
||||
listeners["checkout"](None, conn_record, None)
|
||||
|
||||
assert conn_record.info["_metrics_endpoint"] == "/api/chat/send-message"
|
||||
assert conn_record.info["_metrics_tenant_id"] == "tenant_xyz"
|
||||
assert "_metrics_checkout_time" in conn_record.info
|
||||
mock_gauge.labels.assert_called_with(
|
||||
handler="/api/chat/send-message", engine="sync"
|
||||
handler="/api/chat/send-message", engine="sync", tenant_id="tenant_xyz"
|
||||
)
|
||||
mock_labels.inc.assert_called_once()
|
||||
|
||||
@@ -144,6 +149,7 @@ def test_checkin_event_observes_hold_duration() -> None:
|
||||
|
||||
conn_record = _make_conn_record()
|
||||
conn_record.info["_metrics_endpoint"] = "/api/search"
|
||||
conn_record.info["_metrics_tenant_id"] = "tenant_abc"
|
||||
conn_record.info["_metrics_checkout_time"] = time.monotonic() - 0.5
|
||||
|
||||
with (
|
||||
@@ -162,7 +168,9 @@ def test_checkin_event_observes_hold_duration() -> None:
|
||||
|
||||
listeners["checkin"](None, conn_record)
|
||||
|
||||
mock_gauge.labels.assert_called_with(handler="/api/search", engine="sync")
|
||||
mock_gauge.labels.assert_called_with(
|
||||
handler="/api/search", engine="sync", tenant_id="tenant_abc"
|
||||
)
|
||||
mock_labels.dec.assert_called_once()
|
||||
mock_hist.labels.assert_called_with(handler="/api/search", engine="sync")
|
||||
mock_hist_labels.observe.assert_called_once()
|
||||
@@ -172,11 +180,12 @@ def test_checkin_event_observes_hold_duration() -> None:
|
||||
|
||||
# conn_record.info should be cleaned up
|
||||
assert "_metrics_endpoint" not in conn_record.info
|
||||
assert "_metrics_tenant_id" not in conn_record.info
|
||||
assert "_metrics_checkout_time" not in conn_record.info
|
||||
|
||||
|
||||
def test_checkin_with_missing_endpoint_uses_unknown() -> None:
|
||||
"""Verify checkin gracefully handles missing endpoint info."""
|
||||
"""Verify checkin gracefully handles missing endpoint and tenant info."""
|
||||
engine = MagicMock()
|
||||
engine.pool = MagicMock()
|
||||
listeners: dict[str, Any] = {}
|
||||
@@ -207,7 +216,9 @@ def test_checkin_with_missing_endpoint_uses_unknown() -> None:
|
||||
|
||||
listeners["checkin"](None, conn_record)
|
||||
|
||||
mock_gauge.labels.assert_called_with(handler="unknown", engine="sync")
|
||||
mock_gauge.labels.assert_called_with(
|
||||
handler="unknown", engine="sync", tenant_id="unknown"
|
||||
)
|
||||
|
||||
|
||||
# --- setup_postgres_connection_pool_metrics tests ---
|
||||
|
||||
@@ -10,6 +10,7 @@ from fastapi.testclient import TestClient
|
||||
from prometheus_client import CollectorRegistry
|
||||
from prometheus_client import Gauge
|
||||
|
||||
from onyx.server.metrics.per_tenant import per_tenant_request_callback
|
||||
from onyx.server.metrics.prometheus_setup import setup_prometheus_metrics
|
||||
from onyx.server.metrics.slow_requests import slow_request_callback
|
||||
|
||||
@@ -81,7 +82,7 @@ def test_setup_attaches_instrumentator_to_app() -> None:
|
||||
inprogress_labels=True,
|
||||
excluded_handlers=["/health", "/metrics", "/openapi.json"],
|
||||
)
|
||||
mock_instance.add.assert_called_once()
|
||||
assert mock_instance.add.call_count == 3
|
||||
mock_instance.instrument.assert_called_once_with(
|
||||
app,
|
||||
latency_lowr_buckets=(
|
||||
@@ -100,6 +101,56 @@ def test_setup_attaches_instrumentator_to_app() -> None:
|
||||
mock_instance.expose.assert_called_once_with(app)
|
||||
|
||||
|
||||
def test_per_tenant_callback_increments_with_tenant_id() -> None:
|
||||
"""Verify per-tenant callback reads tenant from contextvar and increments."""
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.metrics.per_tenant.CURRENT_TENANT_ID_CONTEXTVAR"
|
||||
) as mock_ctx,
|
||||
patch("onyx.server.metrics.per_tenant._requests_by_tenant") as mock_counter,
|
||||
):
|
||||
mock_labels = MagicMock()
|
||||
mock_counter.labels.return_value = mock_labels
|
||||
mock_ctx.get.return_value = "tenant_abc"
|
||||
|
||||
info = _make_info(
|
||||
duration=0.1, method="POST", handler="/api/chat", status="200"
|
||||
)
|
||||
per_tenant_request_callback(info)
|
||||
|
||||
mock_counter.labels.assert_called_once_with(
|
||||
tenant_id="tenant_abc",
|
||||
method="POST",
|
||||
handler="/api/chat",
|
||||
status="200",
|
||||
)
|
||||
mock_labels.inc.assert_called_once()
|
||||
|
||||
|
||||
def test_per_tenant_callback_falls_back_to_unknown() -> None:
|
||||
"""Verify per-tenant callback uses 'unknown' when contextvar is None."""
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.metrics.per_tenant.CURRENT_TENANT_ID_CONTEXTVAR"
|
||||
) as mock_ctx,
|
||||
patch("onyx.server.metrics.per_tenant._requests_by_tenant") as mock_counter,
|
||||
):
|
||||
mock_labels = MagicMock()
|
||||
mock_counter.labels.return_value = mock_labels
|
||||
mock_ctx.get.return_value = None
|
||||
|
||||
info = _make_info(duration=0.1)
|
||||
per_tenant_request_callback(info)
|
||||
|
||||
mock_counter.labels.assert_called_once_with(
|
||||
tenant_id="unknown",
|
||||
method="GET",
|
||||
handler="/api/test",
|
||||
status="200",
|
||||
)
|
||||
mock_labels.inc.assert_called_once()
|
||||
|
||||
|
||||
def test_inprogress_gauge_increments_during_request() -> None:
|
||||
"""Verify the in-progress gauge goes up while a request is in flight."""
|
||||
registry = CollectorRegistry()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user