Compare commits

..

15 Commits

Author SHA1 Message Date
Dane Urban
80cf389774 . 2026-02-23 16:30:30 -08:00
Danelegend
e775aaacb7 chore: preview modal (#8665) 2026-02-23 16:29:13 -08:00
Justin Tahara
e5b08b3d92 fix(search): Improve Speed (#8430) 2026-02-23 16:29:13 -08:00
Jamison Lahman
7c91304ba2 chore(playwright): warn user if setup takes longer than usual (#8690) 2026-02-23 16:29:13 -08:00
roshan
68a292b500 fix(ui): Clean up NRF settings button styling (#8678)
Co-authored-by: Claude <noreply@anthropic.com>
2026-02-23 16:29:13 -08:00
Justin Tahara
e553b80030 fix(db): Multitenant Schema migration update (#8679) 2026-02-23 16:29:13 -08:00
Justin Tahara
f3949f8e09 chore(ods): Automated Cherry-pick backport (#8642) 2026-02-23 16:29:13 -08:00
Nikolas Garza
c7c064e296 feat(scim): Okta compatibility + provider abstraction (#8568) 2026-02-23 16:29:13 -08:00
Wenxi
68b91a8862 fix: domain rules for signup on cloud (#8671) 2026-02-23 16:29:13 -08:00
roshan
c23e5a196d fix: Handle unauthenticated state gracefully on NRF page (#8491)
Co-authored-by: Claude <noreply@anthropic.com>
2026-02-23 16:29:13 -08:00
Raunak Bhagat
093223c6c4 refactor: migrate Web Search page to SettingsLayouts + Content (#8662) 2026-02-23 16:29:13 -08:00
Danelegend
89517111d4 feat: Add code interpreter server db model (#8669) 2026-02-23 16:29:13 -08:00
Wenxi
883d4b4ceb chore: set trial api usage to 0 and show ui (#8664) 2026-02-23 16:29:13 -08:00
Dane Urban
f3672b6819 CSV rendering 2026-02-22 18:33:39 -08:00
Dane Urban
921f5d9e96 preview modal 2026-02-22 17:42:30 -08:00
303 changed files with 3264 additions and 9366 deletions

View File

@@ -1,73 +0,0 @@
name: "Build Backend Image"
description: "Builds and pushes the backend Docker image with cache reuse"
inputs:
runs-on-ecr-cache:
description: "ECR cache registry from runs-on/action"
required: true
ref-name:
description: "Git ref name used for cache suffix fallback"
required: true
pr-number:
description: "Optional PR number for cache suffix"
required: false
default: ""
github-sha:
description: "Commit SHA used for cache keys"
required: true
run-id:
description: "GitHub run ID used in output image tag"
required: true
docker-username:
description: "Docker Hub username"
required: true
docker-token:
description: "Docker Hub token"
required: true
docker-no-cache:
description: "Set to 'true' to disable docker build cache"
required: false
default: "false"
runs:
using: "composite"
steps:
- name: Format branch name for cache
id: format-branch
shell: bash
env:
PR_NUMBER: ${{ inputs.pr-number }}
REF_NAME: ${{ inputs.ref-name }}
run: |
if [ -n "${PR_NUMBER}" ]; then
CACHE_SUFFIX="${PR_NUMBER}"
else
# shellcheck disable=SC2001
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
fi
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ inputs.docker-username }}
password: ${{ inputs.docker-token }}
- name: Build and push Backend Docker image
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile
push: true
tags: ${{ inputs.runs-on-ecr-cache }}:nightly-llm-it-backend-${{ inputs.run-id }}
cache-from: |
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ inputs.github-sha }}
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }}
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache
type=registry,ref=onyxdotapp/onyx-backend:latest
cache-to: |
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ inputs.github-sha }},mode=max
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache,mode=max
no-cache: ${{ inputs.docker-no-cache == 'true' }}

View File

@@ -1,75 +0,0 @@
name: "Build Integration Image"
description: "Builds and pushes the integration test image with docker bake"
inputs:
runs-on-ecr-cache:
description: "ECR cache registry from runs-on/action"
required: true
ref-name:
description: "Git ref name used for cache suffix fallback"
required: true
pr-number:
description: "Optional PR number for cache suffix"
required: false
default: ""
github-sha:
description: "Commit SHA used for cache keys"
required: true
run-id:
description: "GitHub run ID used in output image tag"
required: true
docker-username:
description: "Docker Hub username"
required: true
docker-token:
description: "Docker Hub token"
required: true
runs:
using: "composite"
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ inputs.docker-username }}
password: ${{ inputs.docker-token }}
- name: Format branch name for cache
id: format-branch
shell: bash
env:
PR_NUMBER: ${{ inputs.pr-number }}
REF_NAME: ${{ inputs.ref-name }}
run: |
if [ -n "${PR_NUMBER}" ]; then
CACHE_SUFFIX="${PR_NUMBER}"
else
# shellcheck disable=SC2001
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
fi
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
- name: Build and push integration test image with Docker Bake
shell: bash
env:
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
TAG: nightly-llm-it-${{ inputs.run-id }}
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
HEAD_SHA: ${{ inputs.github-sha }}
run: |
docker buildx bake --push \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA} \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX} \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
--set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA},mode=max \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX},mode=max \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA} \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX} \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA},mode=max \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX},mode=max \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \
integration

View File

@@ -1,68 +0,0 @@
name: "Build Model Server Image"
description: "Builds and pushes the model server Docker image with cache reuse"
inputs:
runs-on-ecr-cache:
description: "ECR cache registry from runs-on/action"
required: true
ref-name:
description: "Git ref name used for cache suffix fallback"
required: true
pr-number:
description: "Optional PR number for cache suffix"
required: false
default: ""
github-sha:
description: "Commit SHA used for cache keys"
required: true
run-id:
description: "GitHub run ID used in output image tag"
required: true
docker-username:
description: "Docker Hub username"
required: true
docker-token:
description: "Docker Hub token"
required: true
runs:
using: "composite"
steps:
- name: Format branch name for cache
id: format-branch
shell: bash
env:
PR_NUMBER: ${{ inputs.pr-number }}
REF_NAME: ${{ inputs.ref-name }}
run: |
if [ -n "${PR_NUMBER}" ]; then
CACHE_SUFFIX="${PR_NUMBER}"
else
# shellcheck disable=SC2001
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
fi
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ inputs.docker-username }}
password: ${{ inputs.docker-token }}
- name: Build and push Model Server Docker image
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile.model_server
push: true
tags: ${{ inputs.runs-on-ecr-cache }}:nightly-llm-it-model-server-${{ inputs.run-id }}
cache-from: |
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ inputs.github-sha }}
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }}
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache
type=registry,ref=onyxdotapp/onyx-model-server:latest
cache-to: |
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ inputs.github-sha }},mode=max
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache,mode=max

View File

@@ -1,120 +0,0 @@
name: "Run Nightly Provider Chat Test"
description: "Starts required compose services and runs nightly provider integration test"
inputs:
provider:
description: "Provider slug for NIGHTLY_LLM_PROVIDER"
required: true
models:
description: "Comma-separated model list for NIGHTLY_LLM_MODELS"
required: true
provider-api-key:
description: "API key for NIGHTLY_LLM_API_KEY"
required: true
strict:
description: "String true/false for NIGHTLY_LLM_STRICT"
required: true
api-base:
description: "Optional NIGHTLY_LLM_API_BASE"
required: false
default: ""
custom-config-json:
description: "Optional NIGHTLY_LLM_CUSTOM_CONFIG_JSON"
required: false
default: ""
runs-on-ecr-cache:
description: "ECR cache registry from runs-on/action"
required: true
run-id:
description: "GitHub run ID used in image tags"
required: true
docker-username:
description: "Docker Hub username"
required: true
docker-token:
description: "Docker Hub token"
required: true
runs:
using: "composite"
steps:
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ inputs.docker-username }}
password: ${{ inputs.docker-token }}
- name: Create .env file for Docker Compose
shell: bash
env:
ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
RUN_ID: ${{ inputs.run-id }}
run: |
cat <<EOF2 > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
LICENSE_ENFORCEMENT_ENABLED=false
AUTH_TYPE=basic
POSTGRES_POOL_PRE_PING=true
POSTGRES_USE_NULL_POOL=true
REQUIRE_EMAIL_VERIFICATION=false
DISABLE_TELEMETRY=true
INTEGRATION_TESTS_MODE=true
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
ONYX_BACKEND_IMAGE=${ECR_CACHE}:nightly-llm-it-backend-${RUN_ID}
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:nightly-llm-it-model-server-${RUN_ID}
EOF2
- name: Start Docker containers
shell: bash
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d --wait \
relational_db \
index \
cache \
minio \
api_server \
inference_model_server
- name: Run nightly provider integration test
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
env:
MODELS: ${{ inputs.models }}
NIGHTLY_LLM_PROVIDER: ${{ inputs.provider }}
NIGHTLY_LLM_API_KEY: ${{ inputs.provider-api-key }}
NIGHTLY_LLM_API_BASE: ${{ inputs.api-base }}
NIGHTLY_LLM_CUSTOM_CONFIG_JSON: ${{ inputs.custom-config-json }}
NIGHTLY_LLM_STRICT: ${{ inputs.strict }}
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
RUN_ID: ${{ inputs.run-id }}
with:
timeout_minutes: 20
max_attempts: 2
retry_wait_seconds: 10
command: |
if [ -z "${MODELS}" ]; then
echo "Input 'models' must be non-empty for provider '${NIGHTLY_LLM_PROVIDER}'."
exit 1
fi
docker run --rm --network onyx_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e DB_READONLY_USER=db_readonly_user \
-e DB_READONLY_PASSWORD=password \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e TEST_WEB_HOSTNAME=test-runner \
-e NIGHTLY_LLM_PROVIDER="${NIGHTLY_LLM_PROVIDER}" \
-e NIGHTLY_LLM_MODELS="${MODELS}" \
-e NIGHTLY_LLM_API_KEY="${NIGHTLY_LLM_API_KEY}" \
-e NIGHTLY_LLM_API_BASE="${NIGHTLY_LLM_API_BASE}" \
-e NIGHTLY_LLM_CUSTOM_CONFIG_JSON="${NIGHTLY_LLM_CUSTOM_CONFIG_JSON}" \
-e NIGHTLY_LLM_STRICT="${NIGHTLY_LLM_STRICT}" \
${RUNS_ON_ECR_CACHE}:nightly-llm-it-${RUN_ID} \
/app/tests/integration/tests/llm_workflows/test_nightly_provider_chat_workflow.py

View File

@@ -1,44 +0,0 @@
name: Nightly LLM Provider Chat Tests (OpenAI)
concurrency:
group: Nightly-LLM-Provider-Chat-OpenAI-${{ github.workflow }}-${{ github.ref_name }}
cancel-in-progress: true
on:
schedule:
# Runs daily at 10:30 UTC (2:30 AM PST / 3:30 AM PDT)
- cron: "30 10 * * *"
workflow_dispatch:
permissions:
contents: read
jobs:
openai-provider-chat-test:
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
with:
provider: openai
models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
strict: true
secrets:
provider_api_key: ${{ secrets.OPENAI_API_KEY }}
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
notify-slack-on-failure:
needs: [openai-provider-chat-test]
if: failure() && github.event_name == 'schedule'
runs-on: ubuntu-slim
timeout-minutes: 5
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Send Slack notification
uses: ./.github/actions/slack-notify
with:
webhook-url: ${{ secrets.SLACK_WEBHOOK }}
failed-jobs: openai-provider-chat-test
title: "🚨 Scheduled OpenAI Provider Chat Tests failed!"
ref-name: ${{ github.ref_name }}

View File

@@ -11,11 +11,6 @@ permissions:
jobs:
cherry-pick-to-latest-release:
outputs:
should_cherrypick: ${{ steps.gate.outputs.should_cherrypick }}
pr_number: ${{ steps.gate.outputs.pr_number }}
cherry_pick_reason: ${{ steps.run_cherry_pick.outputs.reason }}
cherry_pick_details: ${{ steps.run_cherry_pick.outputs.details }}
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
@@ -41,13 +36,9 @@ jobs:
exit 0
fi
# Read the PR once so we can gate behavior and infer preferred actor.
pr_json="$(gh api "repos/${GITHUB_REPOSITORY}/pulls/${pr_number}")"
pr_body="$(printf '%s' "$pr_json" | jq -r '.body // ""')"
merged_by="$(printf '%s' "$pr_json" | jq -r '.merged_by.login // ""')"
# Read the PR body and check whether the helper checkbox is checked.
pr_body="$(gh api "repos/${GITHUB_REPOSITORY}/pulls/${pr_number}" --jq '.body // ""')"
echo "pr_number=$pr_number" >> "$GITHUB_OUTPUT"
echo "merged_by=$merged_by" >> "$GITHUB_OUTPUT"
if echo "$pr_body" | grep -qiE "\\[x\\][[:space:]]*(\\[[^]]+\\][[:space:]]*)?Please cherry-pick this PR to the latest release version"; then
echo "should_cherrypick=true" >> "$GITHUB_OUTPUT"
@@ -80,84 +71,9 @@ jobs:
git config user.email "github-actions[bot]@users.noreply.github.com"
- name: Create cherry-pick PR to latest release
id: run_cherry_pick
if: steps.gate.outputs.should_cherrypick == 'true'
continue-on-error: true
env:
GH_TOKEN: ${{ github.token }}
GITHUB_TOKEN: ${{ github.token }}
CHERRY_PICK_ASSIGNEE: ${{ steps.gate.outputs.merged_by }}
run: |
set -o pipefail
output_file="$(mktemp)"
uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify 2>&1 | tee "$output_file"
exit_code="${PIPESTATUS[0]}"
if [ "${exit_code}" -eq 0 ]; then
echo "status=success" >> "$GITHUB_OUTPUT"
exit 0
fi
echo "status=failure" >> "$GITHUB_OUTPUT"
reason="command-failed"
if grep -qiE "merge conflict during cherry-pick|CONFLICT|could not apply|cherry-pick in progress with staged changes" "$output_file"; then
reason="merge-conflict"
fi
echo "reason=${reason}" >> "$GITHUB_OUTPUT"
{
echo "details<<EOF"
tail -n 40 "$output_file"
echo "EOF"
} >> "$GITHUB_OUTPUT"
- name: Mark workflow as failed if cherry-pick failed
if: steps.gate.outputs.should_cherrypick == 'true' && steps.run_cherry_pick.outputs.status == 'failure'
env:
CHERRY_PICK_REASON: ${{ steps.run_cherry_pick.outputs.reason }}
run: |
echo "::error::Automated cherry-pick failed (${CHERRY_PICK_REASON})."
exit 1
notify-slack-on-cherry-pick-failure:
needs:
- cherry-pick-to-latest-release
if: always() && needs.cherry-pick-to-latest-release.outputs.should_cherrypick == 'true' && needs.cherry-pick-to-latest-release.result != 'success'
runs-on: ubuntu-slim
timeout-minutes: 10
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Build cherry-pick failure summary
id: failure-summary
env:
SOURCE_PR_NUMBER: ${{ needs.cherry-pick-to-latest-release.outputs.pr_number }}
CHERRY_PICK_REASON: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_reason }}
CHERRY_PICK_DETAILS: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_details }}
run: |
source_pr_url="https://github.com/${GITHUB_REPOSITORY}/pull/${SOURCE_PR_NUMBER}"
reason_text="cherry-pick command failed"
if [ "${CHERRY_PICK_REASON}" = "merge-conflict" ]; then
reason_text="merge conflict during cherry-pick"
fi
details_excerpt="$(printf '%s' "${CHERRY_PICK_DETAILS}" | tail -n 8 | tr '\n' ' ' | sed "s/[[:space:]]\\+/ /g" | sed "s/\"/'/g" | cut -c1-350)"
failed_jobs="• cherry-pick-to-latest-release\\n• source PR: ${source_pr_url}\\n• reason: ${reason_text}"
if [ -n "${details_excerpt}" ]; then
failed_jobs="${failed_jobs}\\n• excerpt: ${details_excerpt}"
fi
echo "jobs=${failed_jobs}" >> "$GITHUB_OUTPUT"
- name: Notify #cherry-pick-prs about cherry-pick failure
uses: ./.github/actions/slack-notify
with:
webhook-url: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }}
failed-jobs: ${{ steps.failure-summary.outputs.jobs }}
title: "🚨 Automated Cherry-Pick Failed"
ref-name: ${{ github.ref_name }}
uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify

View File

@@ -116,6 +116,7 @@ jobs:
run: |
cat <<EOF > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore,opensearch-enabled
CODE_INTERPRETER_BETA_ENABLED=true
DISABLE_TELEMETRY=true
OPENSEARCH_FOR_ONYX_ENABLED=true
EOF

View File

@@ -20,7 +20,6 @@ env:
# Test Environment Variables
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
SLACK_BOT_TOKEN_TEST_SPACE: ${{ secrets.SLACK_BOT_TOKEN_TEST_SPACE }}
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
@@ -424,7 +423,6 @@ jobs:
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e EXA_API_KEY=${EXA_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e SLACK_BOT_TOKEN_TEST_SPACE=${SLACK_BOT_TOKEN_TEST_SPACE} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
@@ -445,7 +443,6 @@ jobs:
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
-e ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${{ matrix.edition == 'ee' && 'true' || 'false' }} \
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
/app/tests/integration/${{ matrix.test-dir.path }}
@@ -704,7 +701,6 @@ jobs:
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e EXA_API_KEY=${EXA_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e SLACK_BOT_TOKEN_TEST_SPACE=${SLACK_BOT_TOKEN_TEST_SPACE} \
-e TEST_WEB_HOSTNAME=test-runner \
-e AUTH_TYPE=cloud \
-e MULTI_TENANT=true \

View File

@@ -1,206 +0,0 @@
name: Reusable Nightly LLM Provider Chat Tests
on:
workflow_call:
inputs:
provider:
description: "Provider slug passed to NIGHTLY_LLM_PROVIDER (e.g. openai, anthropic)"
required: true
type: string
models:
description: "Comma-separated model list passed to NIGHTLY_LLM_MODELS"
required: true
type: string
strict:
description: "Pass-through value for NIGHTLY_LLM_STRICT"
required: false
default: true
type: boolean
api_base:
description: "Optional NIGHTLY_LLM_API_BASE override"
required: false
default: ""
type: string
custom_config_json:
description: "Optional NIGHTLY_LLM_CUSTOM_CONFIG_JSON override"
required: false
default: ""
type: string
secrets:
provider_api_key:
description: "Provider API key passed to NIGHTLY_LLM_API_KEY"
required: true
DOCKER_USERNAME:
required: true
DOCKER_TOKEN:
required: true
permissions:
contents: read
env:
NIGHTLY_LLM_PROVIDER: ${{ inputs.provider }}
NIGHTLY_LLM_MODELS: ${{ inputs.models }}
NIGHTLY_LLM_API_KEY: ${{ secrets.provider_api_key }}
NIGHTLY_LLM_API_BASE: ${{ inputs.api_base }}
NIGHTLY_LLM_CUSTOM_CONFIG_JSON: ${{ inputs.custom_config_json }}
NIGHTLY_LLM_STRICT: ${{ inputs.strict && 'true' || 'false' }}
jobs:
validate-inputs:
# NOTE: Keep this cheap and fail before image builds if required inputs are missing.
runs-on: ubuntu-slim
timeout-minutes: 5
steps:
- name: Validate required nightly provider inputs
run: |
if [ -z "${NIGHTLY_LLM_MODELS}" ]; then
echo "Input 'models' must be non-empty for provider '${NIGHTLY_LLM_PROVIDER}'."
exit 1
fi
build-backend-image:
needs: [validate-inputs]
runs-on:
[
runs-on,
runner=1cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-backend-image",
"extras=ecr-cache",
]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Build backend image
uses: ./.github/actions/build-backend-image
with:
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
ref-name: ${{ github.ref_name }}
pr-number: ${{ github.event.pull_request.number }}
github-sha: ${{ github.sha }}
run-id: ${{ github.run_id }}
docker-username: ${{ secrets.DOCKER_USERNAME }}
docker-token: ${{ secrets.DOCKER_TOKEN }}
docker-no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' && 'true' || 'false' }}
build-model-server-image:
needs: [validate-inputs]
runs-on:
[
runs-on,
runner=1cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-model-server-image",
"extras=ecr-cache",
]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Build model server image
uses: ./.github/actions/build-model-server-image
with:
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
ref-name: ${{ github.ref_name }}
pr-number: ${{ github.event.pull_request.number }}
github-sha: ${{ github.sha }}
run-id: ${{ github.run_id }}
docker-username: ${{ secrets.DOCKER_USERNAME }}
docker-token: ${{ secrets.DOCKER_TOKEN }}
build-integration-image:
needs: [validate-inputs]
runs-on:
[
runs-on,
runner=2cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-integration-image",
"extras=ecr-cache",
]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Build integration image
uses: ./.github/actions/build-integration-image
with:
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
ref-name: ${{ github.ref_name }}
pr-number: ${{ github.event.pull_request.number }}
github-sha: ${{ github.sha }}
run-id: ${{ github.run_id }}
docker-username: ${{ secrets.DOCKER_USERNAME }}
docker-token: ${{ secrets.DOCKER_TOKEN }}
provider-chat-test:
needs:
[build-backend-image, build-model-server-image, build-integration-image]
runs-on:
- runs-on
- runner=4cpu-linux-arm64
- "run-id=${{ github.run_id }}-nightly-${{ inputs.provider }}-provider-chat-test"
- extras=ecr-cache
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Run nightly provider chat test
uses: ./.github/actions/run-nightly-provider-chat-test
with:
provider: ${{ env.NIGHTLY_LLM_PROVIDER }}
models: ${{ env.NIGHTLY_LLM_MODELS }}
provider-api-key: ${{ secrets.provider_api_key }}
strict: ${{ env.NIGHTLY_LLM_STRICT }}
api-base: ${{ env.NIGHTLY_LLM_API_BASE }}
custom-config-json: ${{ env.NIGHTLY_LLM_CUSTOM_CONFIG_JSON }}
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
run-id: ${{ github.run_id }}
docker-username: ${{ secrets.DOCKER_USERNAME }}
docker-token: ${{ secrets.DOCKER_TOKEN }}
- name: Dump API server logs
if: always()
run: |
cd deployment/docker_compose
docker compose logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
- name: Dump all-container logs
if: always()
run: |
cd deployment/docker_compose
docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: docker-all-logs-nightly-${{ inputs.provider }}-llm-provider
path: |
${{ github.workspace }}/api_server.log
${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
if: always()
run: |
cd deployment/docker_compose
docker compose down -v

View File

@@ -548,7 +548,7 @@ class in the utils over directly calling the APIs with a library like `requests`
calling the utilities directly (e.g. do NOT create admin users with
`admin_user = UserManager.create(name="admin_user")`, instead use the `admin_user` fixture).
A great example of this type of test is `backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py`.
A great example of this type of test is `backend/tests/integration/dev_apis/test_simple_chat_api.py`.
To run them:
@@ -616,9 +616,3 @@ This is a minimal list - feel free to include more. Do NOT write code as part of
Keep it high level. You can reference certain files or functions though.
Before writing your plan, make sure to do research. Explore the relevant sections in the codebase.
## Best Practices
In addition to the other content in this file, best practices for contributing
to the codebase can be found at `contributing_guides/best_practices.md`.
Understand its contents and follow them.

View File

@@ -1,58 +0,0 @@
"""LLMProvider deprecated fields are nullable
Revision ID: 001984c88745
Revises: 7616121f6e97
Create Date: 2026-02-01 22:24:34.171100
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "001984c88745"
down_revision = "7616121f6e97"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Make default_model_name nullable (was NOT NULL)
op.alter_column(
"llm_provider",
"default_model_name",
existing_type=sa.String(),
nullable=True,
)
# Remove server_default from is_default_vision_provider (was server_default=false())
op.alter_column(
"llm_provider",
"is_default_vision_provider",
existing_type=sa.Boolean(),
server_default=None,
)
# is_default_provider and default_vision_model are already nullable with no server_default
def downgrade() -> None:
# Restore default_model_name to NOT NULL (set empty string for any NULLs first)
op.execute(
"UPDATE llm_provider SET default_model_name = '' WHERE default_model_name IS NULL"
)
op.alter_column(
"llm_provider",
"default_model_name",
existing_type=sa.String(),
nullable=False,
)
# Restore server_default for is_default_vision_provider
op.alter_column(
"llm_provider",
"is_default_vision_provider",
existing_type=sa.Boolean(),
server_default=sa.false(),
)

View File

@@ -1,29 +0,0 @@
"""code interpreter seed
Revision ID: 07b98176f1de
Revises: 7cb492013621
Create Date: 2026-02-23 15:55:07.606784
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "07b98176f1de"
down_revision = "7cb492013621"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Seed the single instance of code_interpreter_server
# NOTE: There should only exist at most and at minimum 1 code_interpreter_server row
op.execute(
sa.text("INSERT INTO code_interpreter_server (server_enabled) VALUES (true)")
)
def downgrade() -> None:
op.execute(sa.text("DELETE FROM code_interpreter_server"))

View File

@@ -1,48 +0,0 @@
"""add enterprise and name fields to scim_user_mapping
Revision ID: 7616121f6e97
Revises: 07b98176f1de
Create Date: 2026-02-23 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "7616121f6e97"
down_revision = "07b98176f1de"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"scim_user_mapping",
sa.Column("department", sa.String(), nullable=True),
)
op.add_column(
"scim_user_mapping",
sa.Column("manager", sa.String(), nullable=True),
)
op.add_column(
"scim_user_mapping",
sa.Column("given_name", sa.String(), nullable=True),
)
op.add_column(
"scim_user_mapping",
sa.Column("family_name", sa.String(), nullable=True),
)
op.add_column(
"scim_user_mapping",
sa.Column("scim_emails_json", sa.Text(), nullable=True),
)
def downgrade() -> None:
op.drop_column("scim_user_mapping", "scim_emails_json")
op.drop_column("scim_user_mapping", "family_name")
op.drop_column("scim_user_mapping", "given_name")
op.drop_column("scim_user_mapping", "manager")
op.drop_column("scim_user_mapping", "department")

View File

@@ -1,33 +0,0 @@
"""add needs_persona_sync to user_file
Revision ID: 8ffcc2bcfc11
Revises: 7616121f6e97
Create Date: 2026-02-23 10:48:48.343826
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "8ffcc2bcfc11"
down_revision = "7616121f6e97"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user_file",
sa.Column(
"needs_persona_sync",
sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
),
)
def downgrade() -> None:
op.drop_column("user_file", "needs_persona_sync")

View File

@@ -26,13 +26,14 @@ from sqlalchemy.orm import Session
from ee.onyx.db.scim import ScimDAL
from ee.onyx.server.scim.auth import verify_scim_token
from ee.onyx.server.scim.filtering import parse_scim_filter
from ee.onyx.server.scim.models import SCIM_LIST_RESPONSE_SCHEMA
from ee.onyx.server.scim.models import ScimError
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimListResponse
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimPatchRequest
from ee.onyx.server.scim.models import ScimResourceType
from ee.onyx.server.scim.models import ScimSchemaDefinition
from ee.onyx.server.scim.models import ScimServiceProviderConfig
from ee.onyx.server.scim.models import ScimUserResource
from ee.onyx.server.scim.patch import apply_group_patch
@@ -50,24 +51,12 @@ from onyx.db.models import ScimToken
from onyx.db.models import User
from onyx.db.models import UserGroup
from onyx.db.models import UserRole
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
class ScimJSONResponse(JSONResponse):
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
media_type = "application/scim+json"
# NOTE: All URL paths in this router (/ServiceProviderConfig, /ResourceTypes,
# /Schemas, /Users, /Groups) are mandated by the SCIM spec (RFC 7643/7644).
# IdPs like Okta and Azure AD hardcode these exact paths, so they cannot be
# changed to kebab-case.
scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
_pw_helper = PasswordHelper()
@@ -97,39 +86,15 @@ def get_service_provider_config() -> ScimServiceProviderConfig:
@scim_router.get("/ResourceTypes")
def get_resource_types() -> ScimJSONResponse:
"""List available SCIM resource types (RFC 7643 §6).
Wrapped in a ListResponse envelope (RFC 7644 §3.4.2) because IdPs
like Entra ID expect a JSON object, not a bare array.
"""
resources = [USER_RESOURCE_TYPE, GROUP_RESOURCE_TYPE]
return ScimJSONResponse(
content={
"schemas": [SCIM_LIST_RESPONSE_SCHEMA],
"totalResults": len(resources),
"Resources": [
r.model_dump(exclude_none=True, by_alias=True) for r in resources
],
}
)
def get_resource_types() -> list[ScimResourceType]:
"""List available SCIM resource types (RFC 7643 §6)."""
return [USER_RESOURCE_TYPE, GROUP_RESOURCE_TYPE]
@scim_router.get("/Schemas")
def get_schemas() -> ScimJSONResponse:
"""Return SCIM schema definitions (RFC 7643 §7).
Wrapped in a ListResponse envelope (RFC 7644 §3.4.2) because IdPs
like Entra ID expect a JSON object, not a bare array.
"""
schemas = [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF]
return ScimJSONResponse(
content={
"schemas": [SCIM_LIST_RESPONSE_SCHEMA],
"totalResults": len(schemas),
"Resources": [s.model_dump(exclude_none=True) for s in schemas],
}
)
def get_schemas() -> list[ScimSchemaDefinition]:
"""Return SCIM schema definitions (RFC 7643 §7)."""
return [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF]
# ---------------------------------------------------------------------------
@@ -137,45 +102,15 @@ def get_schemas() -> ScimJSONResponse:
# ---------------------------------------------------------------------------
def _scim_error_response(status: int, detail: str) -> ScimJSONResponse:
def _scim_error_response(status: int, detail: str) -> JSONResponse:
"""Build a SCIM-compliant error response (RFC 7644 §3.12)."""
logger.warning("SCIM error response: status=%s detail=%s", status, detail)
body = ScimError(status=str(status), detail=detail)
return ScimJSONResponse(
return JSONResponse(
status_code=status,
content=body.model_dump(exclude_none=True),
)
def _parse_excluded_attributes(raw: str | None) -> set[str]:
"""Parse the ``excludedAttributes`` query parameter (RFC 7644 §3.4.2.5).
Returns a set of lowercased attribute names to omit from responses.
"""
if not raw:
return set()
return {attr.strip().lower() for attr in raw.split(",") if attr.strip()}
def _apply_exclusions(
resource: ScimUserResource | ScimGroupResource,
excluded: set[str],
) -> dict:
"""Serialize a SCIM resource, omitting attributes the IdP excluded.
RFC 7644 §3.4.2.5 lets the IdP pass ``?excludedAttributes=groups,emails``
to reduce response payload size. We strip those fields after serialization
so the rest of the pipeline doesn't need to know about them.
"""
data = resource.model_dump(exclude_none=True, by_alias=True)
for attr in excluded:
# Match case-insensitively against the camelCase field names
keys_to_remove = [k for k in data if k.lower() == attr]
for k in keys_to_remove:
del data[k]
return data
def _check_seat_availability(dal: ScimDAL) -> str | None:
"""Return an error message if seat limit is reached, else None."""
check_fn = fetch_ee_implementation_or_noop(
@@ -189,7 +124,7 @@ def _check_seat_availability(dal: ScimDAL) -> str | None:
return None
def _fetch_user_or_404(user_id: str, dal: ScimDAL) -> User | ScimJSONResponse:
def _fetch_user_or_404(user_id: str, dal: ScimDAL) -> User | JSONResponse:
"""Parse *user_id* as UUID, look up the user, or return a 404 error."""
try:
uid = UUID(user_id)
@@ -209,56 +144,10 @@ def _scim_name_to_str(name: ScimName | None) -> str | None:
"""
if not name:
return None
# If the client explicitly provides ``formatted``, prefer it — the client
# knows what display string it wants. Otherwise build from components.
if name.formatted:
return name.formatted
# Build from givenName/familyName first — IdPs like Okta may send a stale
# ``formatted`` value while updating the individual name components.
parts = " ".join(part for part in [name.givenName, name.familyName] if part)
return parts or None
def _scim_resource_response(
resource: ScimUserResource | ScimGroupResource | ScimListResponse,
status_code: int = 200,
) -> ScimJSONResponse:
"""Serialize a SCIM resource as ``application/scim+json``."""
content = resource.model_dump(exclude_none=True, by_alias=True)
return ScimJSONResponse(
status_code=status_code,
content=content,
)
def _build_list_response(
resources: list[ScimUserResource | ScimGroupResource],
total: int,
start_index: int,
count: int,
excluded: set[str] | None = None,
) -> ScimListResponse | ScimJSONResponse:
"""Build a SCIM list response, optionally applying attribute exclusions.
RFC 7644 §3.4.2.5 — IdPs may request certain attributes be omitted via
the ``excludedAttributes`` query parameter.
"""
if excluded:
envelope = ScimListResponse(
totalResults=total,
startIndex=start_index,
itemsPerPage=count,
)
data = envelope.model_dump(exclude_none=True)
data["Resources"] = [_apply_exclusions(r, excluded) for r in resources]
return ScimJSONResponse(content=data)
return _scim_resource_response(
ScimListResponse(
totalResults=total,
startIndex=start_index,
itemsPerPage=count,
Resources=resources,
)
)
return parts or name.formatted
# ---------------------------------------------------------------------------
@@ -269,13 +158,12 @@ def _build_list_response(
@scim_router.get("/Users", response_model=None)
def list_users(
filter: str | None = Query(None),
excludedAttributes: str | None = None,
startIndex: int = Query(1, ge=1),
count: int = Query(100, ge=0, le=500),
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimListResponse | ScimJSONResponse:
) -> ScimListResponse | JSONResponse:
"""List users with optional SCIM filter and pagination."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
@@ -301,48 +189,38 @@ def list_users(
for user, mapping in users_with_mappings
]
return _build_list_response(
resources,
total,
startIndex,
count,
excluded=_parse_excluded_attributes(excludedAttributes),
return ScimListResponse(
totalResults=total,
startIndex=startIndex,
itemsPerPage=count,
Resources=resources,
)
@scim_router.get("/Users/{user_id}", response_model=None)
def get_user(
user_id: str,
excludedAttributes: str | None = None,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | ScimJSONResponse:
) -> ScimUserResource | JSONResponse:
"""Get a single user by ID."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, ScimJSONResponse):
if isinstance(result, JSONResponse):
return result
user = result
mapping = dal.get_user_mapping_by_user_id(user.id)
resource = provider.build_user_resource(
return provider.build_user_resource(
user,
mapping.external_id if mapping else None,
groups=dal.get_user_groups(user.id),
scim_username=mapping.scim_username if mapping else None,
)
# RFC 7644 §3.4.2.5 — IdP may request certain attributes be omitted
excluded = _parse_excluded_attributes(excludedAttributes)
if excluded:
return ScimJSONResponse(content=_apply_exclusions(resource, excluded))
return _scim_resource_response(resource)
@scim_router.post("/Users", status_code=201, response_model=None)
def create_user(
@@ -350,7 +228,7 @@ def create_user(
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | ScimJSONResponse:
) -> ScimUserResource | JSONResponse:
"""Create a new user from a SCIM provisioning request."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
@@ -398,10 +276,7 @@ def create_user(
dal.commit()
return _scim_resource_response(
provider.build_user_resource(user, external_id, scim_username=scim_username),
status_code=201,
)
return provider.build_user_resource(user, external_id, scim_username=scim_username)
@scim_router.put("/Users/{user_id}", response_model=None)
@@ -411,13 +286,13 @@ def replace_user(
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | ScimJSONResponse:
) -> ScimUserResource | JSONResponse:
"""Replace a user entirely (RFC 7644 §3.5.1)."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, ScimJSONResponse):
if isinstance(result, JSONResponse):
return result
user = result
@@ -442,13 +317,11 @@ def replace_user(
dal.commit()
return _scim_resource_response(
provider.build_user_resource(
user,
new_external_id,
groups=dal.get_user_groups(user.id),
scim_username=scim_username,
)
return provider.build_user_resource(
user,
new_external_id,
groups=dal.get_user_groups(user.id),
scim_username=scim_username,
)
@@ -459,7 +332,7 @@ def patch_user(
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | ScimJSONResponse:
) -> ScimUserResource | JSONResponse:
"""Partially update a user (RFC 7644 §3.5.2).
This is the primary endpoint for user deprovisioning — Okta sends
@@ -469,7 +342,7 @@ def patch_user(
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, ScimJSONResponse):
if isinstance(result, JSONResponse):
return result
user = result
@@ -526,13 +399,11 @@ def patch_user(
dal.commit()
return _scim_resource_response(
provider.build_user_resource(
user,
patched.externalId,
groups=dal.get_user_groups(user.id),
scim_username=new_scim_username,
)
return provider.build_user_resource(
user,
patched.externalId,
groups=dal.get_user_groups(user.id),
scim_username=new_scim_username,
)
@@ -541,29 +412,25 @@ def delete_user(
user_id: str,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> Response | ScimJSONResponse:
) -> Response | JSONResponse:
"""Delete a user (RFC 7644 §3.6).
Deactivates the user and removes the SCIM mapping. Note that Okta
typically uses PATCH active=false instead of DELETE.
A second DELETE returns 404 per RFC 7644 §3.6.
"""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, ScimJSONResponse):
if isinstance(result, JSONResponse):
return result
user = result
# If no SCIM mapping exists, the user was already deleted from
# SCIM's perspective — return 404 per RFC 7644 §3.6.
mapping = dal.get_user_mapping_by_user_id(user.id)
if not mapping:
return _scim_error_response(404, f"User {user_id} not found")
dal.deactivate_user(user)
dal.delete_user_mapping(mapping.id)
mapping = dal.get_user_mapping_by_user_id(user.id)
if mapping:
dal.delete_user_mapping(mapping.id)
dal.commit()
@@ -575,7 +442,7 @@ def delete_user(
# ---------------------------------------------------------------------------
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | ScimJSONResponse:
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | JSONResponse:
"""Parse *group_id* as int, look up the group, or return a 404 error."""
try:
gid = int(group_id)
@@ -630,13 +497,12 @@ def _validate_and_parse_members(
@scim_router.get("/Groups", response_model=None)
def list_groups(
filter: str | None = Query(None),
excludedAttributes: str | None = None,
startIndex: int = Query(1, ge=1),
count: int = Query(100, ge=0, le=500),
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimListResponse | ScimJSONResponse:
) -> ScimListResponse | JSONResponse:
"""List groups with optional SCIM filter and pagination."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
@@ -656,46 +522,37 @@ def list_groups(
for group, ext_id in groups_with_ext_ids
]
return _build_list_response(
resources,
total,
startIndex,
count,
excluded=_parse_excluded_attributes(excludedAttributes),
return ScimListResponse(
totalResults=total,
startIndex=startIndex,
itemsPerPage=count,
Resources=resources,
)
@scim_router.get("/Groups/{group_id}", response_model=None)
def get_group(
group_id: str,
excludedAttributes: str | None = None,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | ScimJSONResponse:
) -> ScimGroupResource | JSONResponse:
"""Get a single group by ID."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, ScimJSONResponse):
if isinstance(result, JSONResponse):
return result
group = result
mapping = dal.get_group_mapping_by_group_id(group.id)
members = dal.get_group_members(group.id)
resource = provider.build_group_resource(
return provider.build_group_resource(
group, members, mapping.external_id if mapping else None
)
# RFC 7644 §3.4.2.5 — IdP may request certain attributes be omitted
excluded = _parse_excluded_attributes(excludedAttributes)
if excluded:
return ScimJSONResponse(content=_apply_exclusions(resource, excluded))
return _scim_resource_response(resource)
@scim_router.post("/Groups", status_code=201, response_model=None)
def create_group(
@@ -703,7 +560,7 @@ def create_group(
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | ScimJSONResponse:
) -> ScimGroupResource | JSONResponse:
"""Create a new group from a SCIM provisioning request."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
@@ -739,10 +596,7 @@ def create_group(
dal.commit()
members = dal.get_group_members(db_group.id)
return _scim_resource_response(
provider.build_group_resource(db_group, members, external_id),
status_code=201,
)
return provider.build_group_resource(db_group, members, external_id)
@scim_router.put("/Groups/{group_id}", response_model=None)
@@ -752,13 +606,13 @@ def replace_group(
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | ScimJSONResponse:
) -> ScimGroupResource | JSONResponse:
"""Replace a group entirely (RFC 7644 §3.5.1)."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, ScimJSONResponse):
if isinstance(result, JSONResponse):
return result
group = result
@@ -773,9 +627,7 @@ def replace_group(
dal.commit()
members = dal.get_group_members(group.id)
return _scim_resource_response(
provider.build_group_resource(group, members, group_resource.externalId)
)
return provider.build_group_resource(group, members, group_resource.externalId)
@scim_router.patch("/Groups/{group_id}", response_model=None)
@@ -785,7 +637,7 @@ def patch_group(
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | ScimJSONResponse:
) -> ScimGroupResource | JSONResponse:
"""Partially update a group (RFC 7644 §3.5.2).
Handles member add/remove operations from Okta and Azure AD.
@@ -794,7 +646,7 @@ def patch_group(
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, ScimJSONResponse):
if isinstance(result, JSONResponse):
return result
group = result
@@ -833,9 +685,7 @@ def patch_group(
dal.commit()
members = dal.get_group_members(group.id)
return _scim_resource_response(
provider.build_group_resource(group, members, patched.externalId)
)
return provider.build_group_resource(group, members, patched.externalId)
@scim_router.delete("/Groups/{group_id}", status_code=204, response_model=None)
@@ -843,13 +693,13 @@ def delete_group(
group_id: str,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> Response | ScimJSONResponse:
) -> Response | JSONResponse:
"""Delete a group (RFC 7644 §3.6)."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, ScimJSONResponse):
if isinstance(result, JSONResponse):
return result
group = result

View File

@@ -13,7 +13,6 @@ from enum import Enum
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from pydantic import field_validator
# ---------------------------------------------------------------------------
@@ -166,19 +165,6 @@ class ScimPatchOperation(BaseModel):
path: str | None = None
value: ScimPatchValue = None
@field_validator("op", mode="before")
@classmethod
def normalize_operation(cls, v: object) -> object:
"""Normalize op to lowercase for case-insensitive matching.
Some IdPs (e.g. Entra ID) send capitalized ops like ``"Replace"``
instead of ``"replace"``. This is safe for all providers since the
enum values are lowercase. If a future provider requires other
pre-processing quirks, move patch deserialization into the provider
subclass instead of adding more special cases here.
"""
return v.lower() if isinstance(v, str) else v
class ScimPatchRequest(BaseModel):
"""PATCH request body (RFC 7644 §3.5.2).

View File

@@ -15,8 +15,6 @@ responsible for persisting changes.
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Any
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
@@ -26,50 +24,6 @@ from ee.onyx.server.scim.models import ScimPatchResourceValue
from ee.onyx.server.scim.models import ScimPatchValue
from ee.onyx.server.scim.models import ScimUserResource
# Pattern for email filter paths, e.g.:
# emails[primary eq true].value (Okta)
# emails[type eq "work"].value (Azure AD / Entra ID)
_EMAIL_FILTER_RE = re.compile(
r"^emails\[.+\]\.value$",
re.IGNORECASE,
)
# Pattern for member removal path: members[value eq "user-id"]
_MEMBER_FILTER_RE = re.compile(
r'^members\[value\s+eq\s+"([^"]+)"\]$',
re.IGNORECASE,
)
# ---------------------------------------------------------------------------
# Dispatch tables for user PATCH paths
#
# Maps lowercased SCIM path → (camelCase key, target dict name).
# "data" writes to the top-level resource dict, "name" writes to the
# name sub-object dict. This replaces the elif chains for simple fields.
# ---------------------------------------------------------------------------
_USER_REPLACE_PATHS: dict[str, tuple[str, str]] = {
"active": ("active", "data"),
"username": ("userName", "data"),
"externalid": ("externalId", "data"),
"name.givenname": ("givenName", "name"),
"name.familyname": ("familyName", "name"),
"name.formatted": ("formatted", "name"),
}
_USER_REMOVE_PATHS: dict[str, tuple[str, str]] = {
"externalid": ("externalId", "data"),
"name.givenname": ("givenName", "name"),
"name.familyname": ("familyName", "name"),
"name.formatted": ("formatted", "name"),
"displayname": ("displayName", "data"),
}
_GROUP_REPLACE_PATHS: dict[str, tuple[str, str]] = {
"displayname": ("displayName", "data"),
"externalid": ("externalId", "data"),
}
class ScimPatchError(Exception):
"""Raised when a PATCH operation cannot be applied."""
@@ -80,17 +34,11 @@ class ScimPatchError(Exception):
super().__init__(detail)
@dataclass
class _UserPatchCtx:
"""Bundles the mutable state for user PATCH operations."""
data: dict[str, Any]
name_data: dict[str, Any]
# ---------------------------------------------------------------------------
# User PATCH
# ---------------------------------------------------------------------------
# Pattern for member removal path: members[value eq "user-id"]
_MEMBER_FILTER_RE = re.compile(
r'^members\[value\s+eq\s+"([^"]+)"\]$',
re.IGNORECASE,
)
def apply_user_patch(
@@ -112,126 +60,72 @@ def apply_user_patch(
ScimPatchError: If an operation targets an unsupported path.
"""
data = current.model_dump()
ctx = _UserPatchCtx(data=data, name_data=data.get("name") or {})
name_data = data.get("name") or {}
for op in operations:
if op.op in (ScimPatchOperationType.REPLACE, ScimPatchOperationType.ADD):
_apply_user_replace(op, ctx, ignored_paths)
elif op.op == ScimPatchOperationType.REMOVE:
_apply_user_remove(op, ctx, ignored_paths)
if op.op == ScimPatchOperationType.REPLACE:
_apply_user_replace(op, data, name_data, ignored_paths)
elif op.op == ScimPatchOperationType.ADD:
_apply_user_replace(op, data, name_data, ignored_paths)
else:
raise ScimPatchError(
f"Unsupported operation '{op.op.value}' on User resource"
)
ctx.data["name"] = ctx.name_data
return ScimUserResource.model_validate(ctx.data)
data["name"] = name_data
return ScimUserResource.model_validate(data)
def _apply_user_replace(
op: ScimPatchOperation,
ctx: _UserPatchCtx,
data: dict,
name_data: dict,
ignored_paths: frozenset[str],
) -> None:
"""Apply a replace/add operation to user data."""
path = (op.path or "").lower()
if not path:
# No path — value is a resource dict of top-level attributes to set.
# No path — value is a resource dict of top-level attributes to set
if isinstance(op.value, ScimPatchResourceValue):
for key, val in op.value.model_dump(exclude_unset=True).items():
_set_user_field(key.lower(), val, ctx, ignored_paths, strict=False)
_set_user_field(key.lower(), val, data, name_data, ignored_paths)
else:
raise ScimPatchError("Replace without path requires a dict value")
return
_set_user_field(path, op.value, ctx, ignored_paths)
def _apply_user_remove(
op: ScimPatchOperation,
ctx: _UserPatchCtx,
ignored_paths: frozenset[str],
) -> None:
"""Apply a remove operation to user data — clears the target field."""
path = (op.path or "").lower()
if not path:
raise ScimPatchError("Remove operation requires a path")
if path in ignored_paths:
return
entry = _USER_REMOVE_PATHS.get(path)
if entry:
key, target = entry
target_dict = ctx.data if target == "data" else ctx.name_data
target_dict[key] = None
return
raise ScimPatchError(f"Unsupported remove path '{path}' for User PATCH")
_set_user_field(path, op.value, data, name_data, ignored_paths)
def _set_user_field(
path: str,
value: ScimPatchValue,
ctx: _UserPatchCtx,
data: dict,
name_data: dict,
ignored_paths: frozenset[str],
*,
strict: bool = True,
) -> None:
"""Set a single field on user data by SCIM path.
Args:
strict: When ``False`` (path-less replace), unknown attributes are
silently skipped. When ``True`` (explicit path), they raise.
"""
"""Set a single field on user data by SCIM path."""
if path in ignored_paths:
return
# Simple field writes handled by the dispatch table
entry = _USER_REPLACE_PATHS.get(path)
if entry:
key, target = entry
target_dict = ctx.data if target == "data" else ctx.name_data
target_dict[key] = value
return
# displayName sets both the top-level field and the name.formatted sub-field
if path == "displayname":
ctx.data["displayName"] = value
ctx.name_data["formatted"] = value
elif path == "name":
if isinstance(value, dict):
for k, v in value.items():
ctx.name_data[k] = v
elif path == "emails":
if isinstance(value, list):
ctx.data["emails"] = value
elif _EMAIL_FILTER_RE.match(path):
_update_primary_email(ctx.data, value)
elif not strict:
return
elif path == "active":
data["active"] = value
elif path == "username":
data["userName"] = value
elif path == "externalid":
data["externalId"] = value
elif path == "name.givenname":
name_data["givenName"] = value
elif path == "name.familyname":
name_data["familyName"] = value
elif path == "name.formatted":
name_data["formatted"] = value
elif path == "displayname":
data["displayName"] = value
name_data["formatted"] = value
else:
raise ScimPatchError(f"Unsupported path '{path}' for User PATCH")
def _update_primary_email(data: dict[str, Any], value: ScimPatchValue) -> None:
"""Update the primary email entry via an email filter path."""
emails: list[dict] = data.get("emails") or []
for email_entry in emails:
if email_entry.get("primary"):
email_entry["value"] = value
break
else:
emails.append({"value": value, "type": "work", "primary": True})
data["emails"] = emails
# ---------------------------------------------------------------------------
# Group PATCH
# ---------------------------------------------------------------------------
def apply_group_patch(
operations: list[ScimPatchOperation],
current: ScimGroupResource,
@@ -341,14 +235,12 @@ def _set_group_field(
"""Set a single field on group data by SCIM path."""
if path in ignored_paths:
return
entry = _GROUP_REPLACE_PATHS.get(path)
if entry:
key, _ = entry
data[key] = value
return
raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH")
elif path == "displayname":
data["displayName"] = value
elif path == "externalid":
data["externalId"] = value
else:
raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH")
def _apply_group_add(

View File

@@ -16,14 +16,6 @@ from ee.onyx.server.scim.models import ScimUserResource
from onyx.db.models import User
from onyx.db.models import UserGroup
COMMON_IGNORED_PATCH_PATHS: frozenset[str] = frozenset(
{
"id",
"schemas",
"meta",
}
)
class ScimProvider(ABC):
"""Base class for provider-specific SCIM behavior.

View File

@@ -2,7 +2,6 @@
from __future__ import annotations
from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS
from ee.onyx.server.scim.providers.base import ScimProvider
@@ -23,4 +22,4 @@ class OktaProvider(ScimProvider):
@property
def ignored_patch_paths(self) -> frozenset[str]:
return COMMON_IGNORED_PATCH_PATHS
return frozenset({"id", "schemas", "meta"})

View File

@@ -123,21 +123,9 @@ def _seed_llms(
upsert_llm_provider(llm_upsert_request, db_session)
for llm_upsert_request in llm_upsert_requests
]
if len(seeded_providers[0].model_configurations) > 0:
default_model = next(
(
mc
for mc in seeded_providers[0].model_configurations
if mc.is_visible
),
seeded_providers[0].model_configurations[0],
).name
update_default_provider(
provider_id=seeded_providers[0].id,
model_name=default_model,
db_session=db_session,
)
update_default_provider(
provider_id=seeded_providers[0].id, db_session=db_session
)
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:

View File

@@ -302,12 +302,12 @@ def configure_default_api_keys(db_session: Session) -> None:
has_set_default_provider = False
def _upsert(request: LLMProviderUpsertRequest, default_model: str) -> None:
def _upsert(request: LLMProviderUpsertRequest) -> None:
nonlocal has_set_default_provider
try:
provider = upsert_llm_provider(request, db_session)
if not has_set_default_provider:
update_default_provider(provider.id, default_model, db_session)
update_default_provider(provider.id, db_session)
has_set_default_provider = True
except Exception as e:
logger.error(f"Failed to configure {request.provider} provider: {e}")
@@ -325,13 +325,14 @@ def configure_default_api_keys(db_session: Session) -> None:
name="OpenAI",
provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY,
default_model_name=default_model_name,
model_configurations=_build_model_configuration_upsert_requests(
OPENAI_PROVIDER_NAME, recommendations
),
api_key_changed=True,
is_auto_mode=True,
)
_upsert(openai_provider, default_model_name)
_upsert(openai_provider)
# Create default image generation config using the OpenAI API key
try:
@@ -360,13 +361,14 @@ def configure_default_api_keys(db_session: Session) -> None:
name="Anthropic",
provider=ANTHROPIC_PROVIDER_NAME,
api_key=ANTHROPIC_DEFAULT_API_KEY,
default_model_name=default_model_name,
model_configurations=_build_model_configuration_upsert_requests(
ANTHROPIC_PROVIDER_NAME, recommendations
),
api_key_changed=True,
is_auto_mode=True,
)
_upsert(anthropic_provider, default_model_name)
_upsert(anthropic_provider)
else:
logger.info(
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
@@ -391,13 +393,14 @@ def configure_default_api_keys(db_session: Session) -> None:
name="Google Vertex AI",
provider=VERTEXAI_PROVIDER_NAME,
custom_config=custom_config,
default_model_name=default_model_name,
model_configurations=_build_model_configuration_upsert_requests(
VERTEXAI_PROVIDER_NAME, recommendations
),
api_key_changed=True,
is_auto_mode=True,
)
_upsert(vertexai_provider, default_model_name)
_upsert(vertexai_provider)
else:
logger.info(
"VERTEXAI_DEFAULT_CREDENTIALS not set, skipping Vertex AI provider configuration"
@@ -429,11 +432,12 @@ def configure_default_api_keys(db_session: Session) -> None:
name="OpenRouter",
provider=OPENROUTER_PROVIDER_NAME,
api_key=OPENROUTER_DEFAULT_API_KEY,
default_model_name=default_model_name,
model_configurations=model_configurations,
api_key_changed=True,
is_auto_mode=True,
)
_upsert(openrouter_provider, default_model_name)
_upsert(openrouter_provider)
else:
logger.info(
"OPENROUTER_DEFAULT_API_KEY not set, skipping OpenRouter provider configuration"

View File

@@ -22,7 +22,6 @@ from onyx.document_index.vespa_constants import HIDDEN
from onyx.document_index.vespa_constants import IMAGE_FILE_NAME
from onyx.document_index.vespa_constants import METADATA_LIST
from onyx.document_index.vespa_constants import METADATA_SUFFIX
from onyx.document_index.vespa_constants import PERSONAS
from onyx.document_index.vespa_constants import PRIMARY_OWNERS
from onyx.document_index.vespa_constants import SECONDARY_OWNERS
from onyx.document_index.vespa_constants import SEMANTIC_IDENTIFIER
@@ -59,7 +58,6 @@ FIELDS_NEEDED_FOR_TRANSFORMATION: list[str] = [
METADATA_SUFFIX,
DOCUMENT_SETS,
USER_PROJECT,
PERSONAS,
PRIMARY_OWNERS,
SECONDARY_OWNERS,
ACCESS_CONTROL_LIST,
@@ -278,7 +276,6 @@ def transform_vespa_chunks_to_opensearch_chunks(
)
)
user_projects: list[int] | None = vespa_chunk.get(USER_PROJECT)
personas: list[int] | None = vespa_chunk.get(PERSONAS)
primary_owners: list[str] | None = vespa_chunk.get(PRIMARY_OWNERS)
secondary_owners: list[str] | None = vespa_chunk.get(SECONDARY_OWNERS)
@@ -328,7 +325,6 @@ def transform_vespa_chunks_to_opensearch_chunks(
metadata_suffix=metadata_suffix,
document_sets=document_sets,
user_projects=user_projects,
personas=personas,
primary_owners=primary_owners,
secondary_owners=secondary_owners,
tenant_id=tenant_state,

View File

@@ -5,14 +5,11 @@ from uuid import UUID
import httpx
import sqlalchemy as sa
from celery import Celery
from celery import shared_task
from celery import Task
from redis import Redis
from redis.lock import Lock as RedisLock
from retry import retry
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
@@ -27,14 +24,12 @@ from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
from onyx.configs.constants import USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH
from onyx.connectors.file.connector import LocalFileConnector
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
@@ -80,58 +75,10 @@ def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
def _user_file_project_sync_queued_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_QUEUED_PREFIX}:{user_file_id}"
def _user_file_delete_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_DELETE_LOCK_PREFIX}:{user_file_id}"
def get_user_file_project_sync_queue_depth(celery_app: Celery) -> int:
redis_celery: Redis = celery_app.broker_connection().channel().client # type: ignore
return celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC, redis_celery
)
def enqueue_user_file_project_sync_task(
*,
celery_app: Celery,
redis_client: Redis,
user_file_id: str | UUID,
tenant_id: str,
priority: OnyxCeleryPriority = OnyxCeleryPriority.HIGH,
) -> bool:
"""Enqueue a project-sync task if no matching queued task already exists."""
queued_key = _user_file_project_sync_queued_key(user_file_id)
# NX+EX gives us atomic dedupe and a self-healing TTL.
queued_guard_set = redis_client.set(
queued_key,
1,
nx=True,
ex=CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES,
)
if not queued_guard_set:
return False
try:
celery_app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
priority=priority,
expires=CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES,
)
except Exception:
# Roll back the queued guard if task publish fails.
redis_client.delete(queued_key)
raise
return True
@retry(tries=3, delay=1, backoff=2, jitter=(0.0, 1.0))
def _visit_chunks(
*,
@@ -685,8 +632,8 @@ def process_single_user_file_delete(
ignore_result=True,
)
def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
"""Scan for user files needing project sync and enqueue per-file tasks."""
task_logger.info("Starting")
"""Scan for user files with PROJECT_SYNC status and enqueue per-file tasks."""
task_logger.info("check_for_user_file_project_sync - Starting")
redis_client = get_redis_client(tenant_id=tenant_id)
lock: RedisLock = redis_client.lock(
@@ -698,25 +645,13 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
return None
enqueued = 0
skipped_guard = 0
try:
queue_depth = get_user_file_project_sync_queue_depth(self.app)
if queue_depth > USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH:
task_logger.warning(
f"Queue depth {queue_depth} exceeds "
f"{USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH}, skipping enqueue for tenant={tenant_id}"
)
return None
with get_session_with_current_tenant() as db_session:
user_file_ids = (
db_session.execute(
select(UserFile.id).where(
sa.and_(
sa.or_(
UserFile.needs_project_sync.is_(True),
UserFile.needs_persona_sync.is_(True),
),
UserFile.needs_project_sync.is_(True),
UserFile.status == UserFileStatus.COMPLETED,
)
)
@@ -726,23 +661,19 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
)
for user_file_id in user_file_ids:
if not enqueue_user_file_project_sync_task(
celery_app=self.app,
redis_client=redis_client,
user_file_id=user_file_id,
tenant_id=tenant_id,
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
priority=OnyxCeleryPriority.HIGH,
):
skipped_guard += 1
continue
)
enqueued += 1
finally:
if lock.owned():
lock.release()
task_logger.info(
f"Enqueued {enqueued} "
f"Skipped guard {skipped_guard} tasks for tenant={tenant_id}"
f"check_for_user_file_project_sync - Enqueued {enqueued} tasks for tenant={tenant_id}"
)
return None
@@ -761,8 +692,6 @@ def process_single_user_file_project_sync(
)
redis_client = get_redis_client(tenant_id=tenant_id)
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
file_lock: RedisLock = redis_client.lock(
_user_file_project_sync_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
@@ -776,11 +705,7 @@ def process_single_user_file_project_sync(
try:
with get_session_with_current_tenant() as db_session:
user_file = db_session.execute(
select(UserFile)
.where(UserFile.id == _as_uuid(user_file_id))
.options(selectinload(UserFile.assistants))
).scalar_one_or_none()
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
if not user_file:
task_logger.info(
f"process_single_user_file_project_sync - User file not found id={user_file_id}"
@@ -808,17 +733,13 @@ def process_single_user_file_project_sync(
]
project_ids = [project.id for project in user_file.projects]
persona_ids = [p.id for p in user_file.assistants if not p.deleted]
for retry_document_index in retry_document_indices:
retry_document_index.update_single(
doc_id=str(user_file.id),
tenant_id=tenant_id,
chunk_count=user_file.chunk_count,
fields=None,
user_fields=VespaDocumentUserFields(
user_projects=project_ids,
personas=persona_ids,
),
user_fields=VespaDocumentUserFields(user_projects=project_ids),
)
task_logger.info(
@@ -826,7 +747,6 @@ def process_single_user_file_project_sync(
)
user_file.needs_project_sync = False
user_file.needs_persona_sync = False
user_file.last_project_sync_at = datetime.datetime.now(
datetime.timezone.utc
)

View File

@@ -58,8 +58,6 @@ from onyx.file_store.document_batch_storage import DocumentBatchStorage
from onyx.file_store.document_batch_storage import get_document_batch_storage
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
from onyx.indexing.postgres_sanitization import sanitize_document_for_postgres
from onyx.indexing.postgres_sanitization import sanitize_hierarchy_nodes_for_postgres
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
from onyx.redis.redis_hierarchy import ensure_source_node_exists
from onyx.redis.redis_hierarchy import get_node_id_from_raw_id
@@ -158,7 +156,36 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
logger.warning(
f"doc {doc.id} too large, Document size: {sys.getsizeof(doc)}"
)
cleaned_batch.append(sanitize_document_for_postgres(doc))
cleaned_doc = doc.model_copy()
# Postgres cannot handle NUL characters in text fields
if "\x00" in cleaned_doc.id:
logger.warning(f"NUL characters found in document ID: {cleaned_doc.id}")
cleaned_doc.id = cleaned_doc.id.replace("\x00", "")
if cleaned_doc.title and "\x00" in cleaned_doc.title:
logger.warning(
f"NUL characters found in document title: {cleaned_doc.title}"
)
cleaned_doc.title = cleaned_doc.title.replace("\x00", "")
if "\x00" in cleaned_doc.semantic_identifier:
logger.warning(
f"NUL characters found in document semantic identifier: {cleaned_doc.semantic_identifier}"
)
cleaned_doc.semantic_identifier = cleaned_doc.semantic_identifier.replace(
"\x00", ""
)
for section in cleaned_doc.sections:
if section.link is not None:
section.link = section.link.replace("\x00", "")
# since text can be longer, just replace to avoid double scan
if isinstance(section, TextSection) and section.text is not None:
section.text = section.text.replace("\x00", "")
cleaned_batch.append(cleaned_doc)
return cleaned_batch
@@ -575,13 +602,10 @@ def connector_document_extraction(
# Process hierarchy nodes batch - upsert to Postgres and cache in Redis
if hierarchy_node_batch:
hierarchy_node_batch_cleaned = (
sanitize_hierarchy_nodes_for_postgres(hierarchy_node_batch)
)
with get_session_with_current_tenant() as db_session:
upserted_nodes = upsert_hierarchy_nodes_batch(
db_session=db_session,
nodes=hierarchy_node_batch_cleaned,
nodes=hierarchy_node_batch,
source=db_connector.source,
commit=True,
is_connector_public=is_connector_public,
@@ -600,7 +624,7 @@ def connector_document_extraction(
)
logger.debug(
f"Persisted and cached {len(hierarchy_node_batch_cleaned)} hierarchy nodes "
f"Persisted and cached {len(hierarchy_node_batch)} hierarchy nodes "
f"for attempt={index_attempt_id}"
)

View File

@@ -30,7 +30,6 @@ from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MessageType
from onyx.context.search.models import SearchDoc
from onyx.context.search.models import SearchDocsResponse
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.memory import add_memory
from onyx.db.memory import update_memory_at_index
from onyx.db.memory import UserMemoryContext
@@ -657,12 +656,7 @@ def run_llm_loop(
fallback_extraction_attempted: bool = False
citation_mapping: dict[int, str] = {} # Maps citation_num -> document_id/URL
# Fetch this in a short-lived session so the long-running stream loop does
# not pin a connection just to keep read state alive.
with get_session_with_current_tenant() as prompt_db_session:
default_base_system_prompt: str = get_default_base_system_prompt(
prompt_db_session
)
default_base_system_prompt: str = get_default_base_system_prompt(db_session)
system_prompt = None
custom_agent_prompt_msg = None

View File

@@ -856,11 +856,6 @@ def handle_stream_message_objects(
reserved_tokens=reserved_token_count,
)
# Release any read transaction before entering the long-running LLM stream.
# Without this, the request-scoped session can keep a connection checked out
# for the full stream duration.
db_session.commit()
# The stream generator can resume on a different worker thread after early yields.
# Set this right before launching the LLM loop so run_in_background copies the right context.
if new_msg_req.mock_llm_response is not None:

View File

@@ -210,10 +210,10 @@ AUTH_COOKIE_EXPIRE_TIME_SECONDS = int(
REQUIRE_EMAIL_VERIFICATION = (
os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true"
)
SMTP_SERVER = os.environ.get("SMTP_SERVER") or ""
SMTP_SERVER = os.environ.get("SMTP_SERVER") or "smtp.gmail.com"
SMTP_PORT = int(os.environ.get("SMTP_PORT") or "587")
SMTP_USER = os.environ.get("SMTP_USER") or ""
SMTP_PASS = os.environ.get("SMTP_PASS") or ""
SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com")
SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
EMAIL_FROM = os.environ.get("EMAIL_FROM") or SMTP_USER
SENDGRID_API_KEY = os.environ.get("SENDGRID_API_KEY") or ""

View File

@@ -167,14 +167,6 @@ CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
# beat generator stops adding more. Prevents unbounded queue growth when workers
# fall behind.
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
# How long a queued user-file-project-sync task remains valid.
# Should be short enough to discard stale queue entries under load while still
# allowing workers enough time to pick up new tasks.
CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES = 60 # 1 minute (in seconds)
# Max queue depth before user-file-project-sync producers stop enqueuing.
# This applies backpressure when workers are falling behind.
USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH = 500
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
@@ -467,7 +459,6 @@ class OnyxRedisLocks:
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
USER_FILE_PROJECT_SYNC_QUEUED_PREFIX = "da_lock:user_file_project_sync_queued"
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"
USER_FILE_DELETE_LOCK_PREFIX = "da_lock:user_file_delete"

View File

@@ -16,22 +16,6 @@ from onyx.utils.retry_wrapper import retry_builder
logger = setup_logger()
_RATE_LIMIT_REASONS = {"userRateLimitExceeded", "rateLimitExceeded"}
def _is_rate_limit_error(error: HttpError) -> bool:
"""Google sometimes returns rate-limit errors as 403 with reason
'userRateLimitExceeded' instead of 429. This helper detects both."""
if error.resp.status == 429:
return True
if error.resp.status != 403:
return False
error_details = getattr(error, "error_details", None) or []
for detail in error_details:
if isinstance(detail, dict) and detail.get("reason") in _RATE_LIMIT_REASONS:
return True
return "userRateLimitExceeded" in str(error) or "rateLimitExceeded" in str(error)
# Google Drive APIs are quite flakey and may 500 for an
# extended period of time. This is now addressed by checkpointing.
@@ -73,7 +57,7 @@ def _execute_with_retry(request: Any) -> Any:
except HttpError as error:
attempt += 1
if _is_rate_limit_error(error):
if error.resp.status == 429:
# Attempt to get 'Retry-After' from headers
retry_after = error.resp.get("Retry-After")
if retry_after:
@@ -156,16 +140,16 @@ def _execute_single_retrieval(
)
logger.error(f"Error executing request: {e}")
raise e
elif _is_rate_limit_error(e):
results = _execute_with_retry(
lambda: retrieval_function(**request_kwargs).execute()
)
elif e.resp.status == 404 or e.resp.status == 403:
if continue_on_404_or_403:
logger.debug(f"Error executing request: {e}")
results = {}
else:
raise e
elif e.resp.status == 429:
results = _execute_with_retry(
lambda: retrieval_function(**request_kwargs).execute()
)
else:
logger.exception("Error executing request:")
raise e

View File

@@ -1,96 +0,0 @@
"""Inverse mapping from user-facing Microsoft host URLs to the SDK's AzureEnvironment.
The office365 library's GraphClient requires an ``AzureEnvironment`` string
(e.g. ``"Global"``, ``"GCC High"``) to route requests to the correct national
cloud. Our connectors instead expose free-text ``authority_host`` and
``graph_api_host`` fields so the frontend doesn't need to know about SDK
internals.
This module bridges the gap: given the two host URLs the user configured, it
resolves the matching ``AzureEnvironment`` value (and the implied SharePoint
domain suffix) so callers can pass ``environment=…`` to ``GraphClient``.
"""
from office365.graph_client import AzureEnvironment # type: ignore[import-untyped]
from pydantic import BaseModel
from onyx.connectors.exceptions import ConnectorValidationError
class MicrosoftGraphEnvironment(BaseModel):
"""One row of the inverse mapping."""
environment: str
graph_host: str
authority_host: str
sharepoint_domain_suffix: str
_ENVIRONMENTS: list[MicrosoftGraphEnvironment] = [
MicrosoftGraphEnvironment(
environment=AzureEnvironment.Global,
graph_host="https://graph.microsoft.com",
authority_host="https://login.microsoftonline.com",
sharepoint_domain_suffix="sharepoint.com",
),
MicrosoftGraphEnvironment(
environment=AzureEnvironment.USGovernmentHigh,
graph_host="https://graph.microsoft.us",
authority_host="https://login.microsoftonline.us",
sharepoint_domain_suffix="sharepoint.us",
),
MicrosoftGraphEnvironment(
environment=AzureEnvironment.USGovernmentDoD,
graph_host="https://dod-graph.microsoft.us",
authority_host="https://login.microsoftonline.us",
sharepoint_domain_suffix="sharepoint.us",
),
MicrosoftGraphEnvironment(
environment=AzureEnvironment.China,
graph_host="https://microsoftgraph.chinacloudapi.cn",
authority_host="https://login.chinacloudapi.cn",
sharepoint_domain_suffix="sharepoint.cn",
),
MicrosoftGraphEnvironment(
environment=AzureEnvironment.Germany,
graph_host="https://graph.microsoft.de",
authority_host="https://login.microsoftonline.de",
sharepoint_domain_suffix="sharepoint.de",
),
]
_GRAPH_HOST_INDEX: dict[str, MicrosoftGraphEnvironment] = {
env.graph_host: env for env in _ENVIRONMENTS
}
def resolve_microsoft_environment(
graph_api_host: str,
authority_host: str,
) -> MicrosoftGraphEnvironment:
"""Return the ``MicrosoftGraphEnvironment`` that matches the supplied hosts.
Raises ``ConnectorValidationError`` when the combination is unknown or
internally inconsistent (e.g. a GCC-High graph host paired with a
commercial authority host).
"""
graph_api_host = graph_api_host.rstrip("/")
authority_host = authority_host.rstrip("/")
env = _GRAPH_HOST_INDEX.get(graph_api_host)
if env is None:
known = ", ".join(sorted(_GRAPH_HOST_INDEX))
raise ConnectorValidationError(
f"Unsupported Microsoft Graph API host '{graph_api_host}'. "
f"Recognised hosts: {known}"
)
if env.authority_host != authority_host:
raise ConnectorValidationError(
f"Authority host '{authority_host}' is inconsistent with "
f"graph API host '{graph_api_host}'. "
f"Expected authority host '{env.authority_host}' "
f"for the {env.environment} environment."
)
return env

View File

@@ -6,7 +6,6 @@ from typing import cast
from pydantic import BaseModel
from pydantic import Field
from pydantic import field_validator
from pydantic import model_validator
from onyx.access.models import ExternalAccess
@@ -168,14 +167,6 @@ class DocumentBase(BaseModel):
# list of strings.
metadata: dict[str, str | list[str]]
@field_validator("metadata", mode="before")
@classmethod
def _coerce_metadata_values(cls, v: dict[str, Any]) -> dict[str, str | list[str]]:
return {
key: [str(item) for item in val] if isinstance(val, list) else str(val)
for key, val in v.items()
}
# UTC time
doc_updated_at: datetime | None = None
chunk_count: int | None = None

View File

@@ -47,7 +47,6 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import IndexingHeartbeatInterface
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.microsoft_graph_env import resolve_microsoft_environment
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
@@ -147,9 +146,7 @@ class DriveItemData(BaseModel):
self.id,
ResourcePath("items", ResourcePath(self.drive_id, ResourcePath("drives"))),
)
item = DriveItem(graph_client, path)
item.set_property("id", self.id)
return item
return DriveItem(graph_client, path)
# The office365 library's ClientContext caches the access token from its
@@ -840,20 +837,10 @@ class SharepointConnector(
self._cached_rest_ctx: ClientContext | None = None
self._cached_rest_ctx_url: str | None = None
self._cached_rest_ctx_created_at: float = 0.0
resolved_env = resolve_microsoft_environment(graph_api_host, authority_host)
self._azure_environment = resolved_env.environment
self.authority_host = resolved_env.authority_host
self.graph_api_host = resolved_env.graph_host
self.authority_host = authority_host.rstrip("/")
self.graph_api_host = graph_api_host.rstrip("/")
self.graph_api_base = f"{self.graph_api_host}/v1.0"
self.sharepoint_domain_suffix = resolved_env.sharepoint_domain_suffix
if sharepoint_domain_suffix != resolved_env.sharepoint_domain_suffix:
logger.warning(
f"Configured sharepoint_domain_suffix '{sharepoint_domain_suffix}' "
f"differs from the expected suffix '{resolved_env.sharepoint_domain_suffix}' "
f"for the {resolved_env.environment} environment. "
f"Using '{resolved_env.sharepoint_domain_suffix}'."
)
self.sharepoint_domain_suffix = sharepoint_domain_suffix
def validate_connector_settings(self) -> None:
# Validate that at least one content type is enabled
@@ -1605,7 +1592,6 @@ class SharepointConnector(
if certificate_data is None:
raise RuntimeError("Failed to load certificate")
logger.info(f"Creating MSAL app with authority url {authority_url}")
self.msal_app = msal.ConfidentialClientApplication(
authority=authority_url,
client_id=sp_client_id,
@@ -1637,9 +1623,7 @@ class SharepointConnector(
raise ConnectorValidationError("Failed to acquire token for graph")
return token
self._graph_client = GraphClient(
_acquire_token_for_graph, environment=self._azure_environment
)
self._graph_client = GraphClient(_acquire_token_for_graph)
if auth_method == SharepointAuthMethod.CERTIFICATE.value:
org = self.graph_client.organization.get().execute_query()
if not org or len(org) == 0:

View File

@@ -11,7 +11,6 @@ from dateutil import parser
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
@@ -259,21 +258,3 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
slim_doc_batch = []
if slim_doc_batch:
yield slim_doc_batch
def validate_connector_settings(self) -> None:
"""
Very basic validation, we could do more here
"""
if not self.base_url.startswith("https://") and not self.base_url.startswith(
"http://"
):
raise ConnectorValidationError(
"Base URL must start with https:// or http://"
)
try:
get_all_post_ids(self.slab_bot_token)
except ConnectorMissingCredentialError:
raise
except Exception as e:
raise ConnectorValidationError(f"Failed to fetch posts from Slab: {e}")

View File

@@ -23,7 +23,6 @@ from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.microsoft_graph_env import resolve_microsoft_environment
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
@@ -74,11 +73,8 @@ class TeamsConnector(
self.msal_app: msal.ConfidentialClientApplication | None = None
self.max_workers = max_workers
self.requested_team_list: list[str] = teams
resolved_env = resolve_microsoft_environment(graph_api_host, authority_host)
self._azure_environment = resolved_env.environment
self.authority_host = resolved_env.authority_host
self.graph_api_host = resolved_env.graph_host
self.authority_host = authority_host.rstrip("/")
self.graph_api_host = graph_api_host.rstrip("/")
# impls for BaseConnector
@@ -110,9 +106,7 @@ class TeamsConnector(
return token
self.graph_client = GraphClient(
_acquire_token_func, environment=self._azure_environment
)
self.graph_client = GraphClient(_acquire_token_func)
return None
def validate_connector_settings(self) -> None:

View File

@@ -1,21 +0,0 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.db.models import CodeInterpreterServer
def fetch_code_interpreter_server(
db_session: Session,
) -> CodeInterpreterServer:
server = db_session.scalars(select(CodeInterpreterServer)).one()
return server
def update_code_interpreter_server_enabled(
db_session: Session,
enabled: bool,
) -> CodeInterpreterServer:
server = db_session.scalars(select(CodeInterpreterServer)).one()
server.server_enabled = enabled
db_session.commit()
return server

View File

@@ -213,12 +213,8 @@ def upsert_llm_provider(
llm_provider_upsert_request: LLMProviderUpsertRequest,
db_session: Session,
) -> LLMProviderView:
existing_llm_provider = (
fetch_existing_llm_provider_by_id(
id=llm_provider_upsert_request.id, db_session=db_session
)
if llm_provider_upsert_request.id
else None
existing_llm_provider = fetch_existing_llm_provider(
name=llm_provider_upsert_request.name, db_session=db_session
)
if not existing_llm_provider:
@@ -242,6 +238,11 @@ def upsert_llm_provider(
existing_llm_provider.api_base = api_base
existing_llm_provider.api_version = llm_provider_upsert_request.api_version
existing_llm_provider.custom_config = custom_config
# TODO: Remove default model name on api change
# Needed due to /provider/{id}/default endpoint not disclosing the default model name
existing_llm_provider.default_model_name = (
llm_provider_upsert_request.default_model_name
)
existing_llm_provider.is_public = llm_provider_upsert_request.is_public
existing_llm_provider.is_auto_mode = llm_provider_upsert_request.is_auto_mode
existing_llm_provider.deployment_name = llm_provider_upsert_request.deployment_name
@@ -250,10 +251,6 @@ def upsert_llm_provider(
# If its not already in the db, we need to generate an ID by flushing
db_session.flush()
models_to_exist = {
mc.name for mc in llm_provider_upsert_request.model_configurations
}
# Build a lookup of existing model configurations by name (single iteration)
existing_by_name = {
mc.name: mc for mc in existing_llm_provider.model_configurations
@@ -309,6 +306,15 @@ def upsert_llm_provider(
display_name=model_config.display_name,
)
default_model = fetch_default_model(db_session, LLMModelFlowType.CHAT)
if default_model and default_model.llm_provider_id == existing_llm_provider.id:
_update_default_model(
db_session=db_session,
provider_id=existing_llm_provider.id,
model=existing_llm_provider.default_model_name,
flow_type=LLMModelFlowType.CHAT,
)
# Make sure the relationship table stays up to date
update_group_llm_provider_relationships__no_commit(
llm_provider_id=existing_llm_provider.id,
@@ -482,22 +488,6 @@ def fetch_existing_llm_provider(
return provider_model
def fetch_existing_llm_provider_by_id(
id: int, db_session: Session
) -> LLMProviderModel | None:
provider_model = db_session.scalar(
select(LLMProviderModel)
.where(LLMProviderModel.id == id)
.options(
selectinload(LLMProviderModel.model_configurations),
selectinload(LLMProviderModel.groups),
selectinload(LLMProviderModel.personas),
)
)
return provider_model
def fetch_embedding_provider(
db_session: Session, provider_type: EmbeddingProvider
) -> CloudEmbeddingProviderModel | None:
@@ -614,13 +604,22 @@ def remove_llm_provider__no_commit(db_session: Session, provider_id: int) -> Non
db_session.flush()
def update_default_provider(
provider_id: int, model_name: str, db_session: Session
) -> None:
def update_default_provider(provider_id: int, db_session: Session) -> None:
# Attempt to get the default_model_name from the provider first
# TODO: Remove default_model_name check
provider = db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.id == provider_id,
)
)
if provider is None:
raise ValueError(f"LLM Provider with id={provider_id} does not exist")
_update_default_model(
db_session,
provider_id,
model_name,
provider.default_model_name,
LLMModelFlowType.CHAT,
)
@@ -806,6 +805,12 @@ def sync_auto_mode_models(
)
changes += 1
# In Auto mode, default model is always set from GitHub config
default_model = llm_recommendations.get_default_model(provider.provider)
if default_model and provider.default_model_name != default_model.name:
provider.default_model_name = default_model.name
changes += 1
db_session.commit()
return changes
@@ -861,6 +866,7 @@ def insert_new_model_configuration__no_commit(
is_visible=is_visible,
max_input_tokens=max_input_tokens,
display_name=display_name,
supports_image_input=LLMModelFlowType.VISION in supported_flows,
)
.on_conflict_do_nothing()
.returning(ModelConfiguration.id)
@@ -895,6 +901,7 @@ def update_model_configuration__no_commit(
is_visible=is_visible,
max_input_tokens=max_input_tokens,
display_name=display_name,
supports_image_input=LLMModelFlowType.VISION in supported_flows,
)
.where(ModelConfiguration.id == model_configuration_id)
.returning(ModelConfiguration)

View File

@@ -2822,9 +2822,14 @@ class LLMProvider(Base):
custom_config: Mapped[dict[str, str] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
default_model_name: Mapped[str] = mapped_column(String)
deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)
# should only be set for a single provider
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
is_default_vision_provider: Mapped[bool | None] = mapped_column(Boolean)
default_vision_model: Mapped[str | None] = mapped_column(String, nullable=True)
# EE only
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
# Auto mode: models, visibility, and defaults are managed by GitHub config
@@ -2874,6 +2879,8 @@ class ModelConfiguration(Base):
# - The end-user is configuring a model and chooses not to set a max-input-tokens limit.
max_input_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
supports_image_input: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
# Human-readable display name for the model.
# For dynamic providers (OpenRouter, Bedrock, Ollama), this comes from the source API.
# For static providers (OpenAI, Anthropic), this may be null and will fall back to LiteLLM.
@@ -4263,9 +4270,6 @@ class UserFile(Base):
needs_project_sync: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False
)
needs_persona_sync: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False
)
last_project_sync_at: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
@@ -4936,11 +4940,6 @@ class ScimUserMapping(Base):
ForeignKey("user.id", ondelete="CASCADE"), unique=True, nullable=False
)
scim_username: Mapped[str | None] = mapped_column(String, nullable=True)
department: Mapped[str | None] = mapped_column(String, nullable=True)
manager: Mapped[str | None] = mapped_column(String, nullable=True)
given_name: Mapped[str | None] = mapped_column(String, nullable=True)
family_name: Mapped[str | None] = mapped_column(String, nullable=True)
scim_emails_json: Mapped[str | None] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False

View File

@@ -765,9 +765,6 @@ def mark_persona_as_deleted(
) -> None:
persona = get_persona_by_id(persona_id=persona_id, user=user, db_session=db_session)
persona.deleted = True
affected_file_ids = [uf.id for uf in persona.user_files]
if affected_file_ids:
_mark_files_need_persona_sync(db_session, affected_file_ids)
db_session.commit()
@@ -779,13 +776,11 @@ def mark_persona_as_not_deleted(
persona = get_persona_by_id(
persona_id=persona_id, user=user, db_session=db_session, include_deleted=True
)
if not persona.deleted:
if persona.deleted:
persona.deleted = False
db_session.commit()
else:
raise ValueError(f"Persona with ID {persona_id} is not deleted.")
persona.deleted = False
affected_file_ids = [uf.id for uf in persona.user_files]
if affected_file_ids:
_mark_files_need_persona_sync(db_session, affected_file_ids)
db_session.commit()
def mark_delete_persona_by_name(
@@ -851,20 +846,6 @@ def update_personas_display_priority(
db_session.commit()
def _mark_files_need_persona_sync(
db_session: Session,
user_file_ids: list[UUID],
) -> None:
"""Flag the given UserFile rows so the background sync task picks them up
and updates their persona metadata in the vector DB."""
if not user_file_ids:
return
db_session.query(UserFile).filter(UserFile.id.in_(user_file_ids)).update(
{UserFile.needs_persona_sync: True},
synchronize_session=False,
)
def upsert_persona(
user: User | None,
name: str,
@@ -1053,13 +1034,8 @@ def upsert_persona(
existing_persona.tools = tools or []
if user_file_ids is not None:
old_file_ids = {uf.id for uf in existing_persona.user_files}
new_file_ids = {uf.id for uf in (user_files or [])}
affected_file_ids = old_file_ids | new_file_ids
existing_persona.user_files.clear()
existing_persona.user_files = user_files or []
if affected_file_ids:
_mark_files_need_persona_sync(db_session, list(affected_file_ids))
if hierarchy_node_ids is not None:
existing_persona.hierarchy_nodes.clear()
@@ -1113,8 +1089,6 @@ def upsert_persona(
attached_documents=attached_documents or [],
)
db_session.add(new_persona)
if user_files:
_mark_files_need_persona_sync(db_session, [uf.id for uf in user_files])
persona = new_persona
if commit:
db_session.commit()

View File

@@ -2,7 +2,6 @@ import random
from datetime import datetime
from datetime import timedelta
from logging import getLogger
from uuid import UUID
from onyx.configs.constants import MessageType
from onyx.db.chat import create_chat_session
@@ -14,26 +13,18 @@ from onyx.db.models import ChatSession
logger = getLogger(__name__)
def seed_chat_history(
num_sessions: int,
num_messages: int,
days: int,
user_id: UUID | None = None,
persona_id: int | None = None,
) -> None:
def seed_chat_history(num_sessions: int, num_messages: int, days: int) -> None:
"""Utility function to seed chat history for testing.
num_sessions: the number of sessions to seed
num_messages: the number of messages to seed per sessions
days: the number of days looking backwards from the current time over which to randomize
the times.
user_id: optional user to associate with sessions
persona_id: optional persona/assistant to associate with sessions
"""
with get_session_with_current_tenant() as db_session:
logger.info(f"Seeding {num_sessions} sessions.")
for y in range(0, num_sessions):
create_chat_session(db_session, f"pytest_session_{y}", user_id, persona_id)
create_chat_session(db_session, f"pytest_session_{y}", None, None)
# randomize all session times
logger.info(f"Seeding {num_messages} messages per session.")

View File

@@ -3,7 +3,6 @@ from uuid import UUID
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.db.models import UserFile
@@ -65,23 +64,6 @@ def fetch_user_project_ids_for_user_files(
}
def fetch_persona_ids_for_user_files(
user_file_ids: list[str],
db_session: Session,
) -> dict[str, list[int]]:
"""Fetch persona (assistant) ids for specified user files."""
stmt = (
select(UserFile)
.where(UserFile.id.in_(user_file_ids))
.options(selectinload(UserFile.assistants))
)
results = db_session.execute(stmt).scalars().all()
return {
str(user_file.id): [persona.id for persona in user_file.assistants]
for user_file in results
}
def update_last_accessed_at_for_user_files(
user_file_ids: list[UUID],
db_session: Session,

View File

@@ -121,7 +121,6 @@ class VespaDocumentUserFields:
"""
user_projects: list[int] | None = None
personas: list[int] | None = None
@dataclass

View File

@@ -148,7 +148,6 @@ class MetadataUpdateRequest(BaseModel):
hidden: bool | None = None
secondary_index_updated: bool | None = None
project_ids: set[int] | None = None
persona_ids: set[int] | None = None
class IndexRetrievalFilters(BaseModel):

View File

@@ -50,7 +50,6 @@ from onyx.document_index.opensearch.schema import DocumentSchema
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
from onyx.document_index.opensearch.schema import GLOBAL_BOOST_FIELD_NAME
from onyx.document_index.opensearch.schema import HIDDEN_FIELD_NAME
from onyx.document_index.opensearch.schema import PERSONAS_FIELD_NAME
from onyx.document_index.opensearch.schema import USER_PROJECTS_FIELD_NAME
from onyx.document_index.opensearch.search import DocumentQuery
from onyx.document_index.opensearch.search import (
@@ -216,7 +215,6 @@ def _convert_onyx_chunk_to_opensearch_document(
# OpenSearch and it will not store any data at all for this field, which
# is different from supplying an empty list.
user_projects=chunk.user_project or None,
personas=chunk.personas or None,
primary_owners=get_experts_stores_representations(
chunk.source_document.primary_owners
),
@@ -364,11 +362,6 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
if user_fields and user_fields.user_projects
else None
),
persona_ids=(
set(user_fields.personas)
if user_fields and user_fields.personas
else None
),
)
try:
@@ -716,10 +709,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
properties_to_update[USER_PROJECTS_FIELD_NAME] = list(
update_request.project_ids
)
if update_request.persona_ids is not None:
properties_to_update[PERSONAS_FIELD_NAME] = list(
update_request.persona_ids
)
if not properties_to_update:
if len(update_request.document_ids) > 1:

View File

@@ -41,7 +41,6 @@ IMAGE_FILE_ID_FIELD_NAME = "image_file_id"
SOURCE_LINKS_FIELD_NAME = "source_links"
DOCUMENT_SETS_FIELD_NAME = "document_sets"
USER_PROJECTS_FIELD_NAME = "user_projects"
PERSONAS_FIELD_NAME = "personas"
DOCUMENT_ID_FIELD_NAME = "document_id"
CHUNK_INDEX_FIELD_NAME = "chunk_index"
MAX_CHUNK_SIZE_FIELD_NAME = "max_chunk_size"
@@ -157,7 +156,6 @@ class DocumentChunk(BaseModel):
document_sets: list[str] | None = None
user_projects: list[int] | None = None
personas: list[int] | None = None
primary_owners: list[str] | None = None
secondary_owners: list[str] | None = None
@@ -487,7 +485,6 @@ class DocumentSchema:
# Product-specific fields.
DOCUMENT_SETS_FIELD_NAME: {"type": "keyword"},
USER_PROJECTS_FIELD_NAME: {"type": "integer"},
PERSONAS_FIELD_NAME: {"type": "integer"},
PRIMARY_OWNERS_FIELD_NAME: {"type": "keyword"},
SECONDARY_OWNERS_FIELD_NAME: {"type": "keyword"},
# OpenSearch metadata fields.

View File

@@ -181,11 +181,6 @@ schema {{ schema_name }} {
rank: filter
attribute: fast-search
}
field personas type array<int> {
indexing: summary | attribute
rank: filter
attribute: fast-search
}
}
# If using different tokenization settings, the fieldset has to be removed, and the field must

View File

@@ -689,9 +689,6 @@ class VespaIndex(DocumentIndex):
project_ids: set[int] | None = None
if user_fields is not None and user_fields.user_projects is not None:
project_ids = set(user_fields.user_projects)
persona_ids: set[int] | None = None
if user_fields is not None and user_fields.personas is not None:
persona_ids = set(user_fields.personas)
update_request = MetadataUpdateRequest(
document_ids=[doc_id],
doc_id_to_chunk_cnt={
@@ -702,7 +699,6 @@ class VespaIndex(DocumentIndex):
boost=fields.boost if fields is not None else None,
hidden=fields.hidden if fields is not None else None,
project_ids=project_ids,
persona_ids=persona_ids,
)
vespa_document_index.update([update_request])

View File

@@ -46,7 +46,6 @@ from onyx.document_index.vespa_constants import METADATA
from onyx.document_index.vespa_constants import METADATA_LIST
from onyx.document_index.vespa_constants import METADATA_SUFFIX
from onyx.document_index.vespa_constants import NUM_THREADS
from onyx.document_index.vespa_constants import PERSONAS
from onyx.document_index.vespa_constants import PRIMARY_OWNERS
from onyx.document_index.vespa_constants import SECONDARY_OWNERS
from onyx.document_index.vespa_constants import SECTION_CONTINUATION
@@ -219,7 +218,6 @@ def _index_vespa_chunk(
# still called `image_file_name` in Vespa for backwards compatibility
IMAGE_FILE_NAME: chunk.image_file_id,
USER_PROJECT: chunk.user_project if chunk.user_project is not None else [],
PERSONAS: chunk.personas if chunk.personas is not None else [],
BOOST: chunk.boost,
AGGREGATED_CHUNK_BOOST_FACTOR: chunk.aggregated_chunk_boost_factor,
}

View File

@@ -183,10 +183,6 @@ def _update_single_chunk(
model_config = {"frozen": True}
assign: list[int]
class _Personas(BaseModel):
model_config = {"frozen": True}
assign: list[int]
class _VespaPutFields(BaseModel):
model_config = {"frozen": True}
# The names of these fields are based the Vespa schema. Changes to the
@@ -197,7 +193,6 @@ def _update_single_chunk(
access_control_list: _AccessControl | None = None
hidden: _Hidden | None = None
user_project: _UserProjects | None = None
personas: _Personas | None = None
class _VespaPutRequest(BaseModel):
model_config = {"frozen": True}
@@ -232,11 +227,6 @@ def _update_single_chunk(
if update_request.project_ids is not None
else None
)
personas_update: _Personas | None = (
_Personas(assign=list(update_request.persona_ids))
if update_request.persona_ids is not None
else None
)
vespa_put_fields = _VespaPutFields(
boost=boost_update,
@@ -244,7 +234,6 @@ def _update_single_chunk(
access_control_list=access_update,
hidden=hidden_update,
user_project=user_projects_update,
personas=personas_update,
)
vespa_put_request = _VespaPutRequest(

View File

@@ -58,7 +58,6 @@ DOCUMENT_SETS = "document_sets"
USER_FILE = "user_file"
USER_FOLDER = "user_folder"
USER_PROJECT = "user_project"
PERSONAS = "personas"
LARGE_CHUNK_REFERENCE_IDS = "large_chunk_reference_ids"
METADATA = "metadata"
METADATA_LIST = "metadata_list"

View File

@@ -12,9 +12,6 @@ if TYPE_CHECKING:
class AzureImageGenerationProvider(ImageGenerationProvider):
_GPT_IMAGE_MODEL_PREFIX = "gpt-image-"
_DALL_E_2_MODEL_NAME = "dall-e-2"
def __init__(
self,
api_key: str,
@@ -56,25 +53,6 @@ class AzureImageGenerationProvider(ImageGenerationProvider):
deployment_name=credentials.deployment_name,
)
@property
def supports_reference_images(self) -> bool:
return True
@property
def max_reference_images(self) -> int:
# Azure GPT image models support up to 16 input images for edits.
return 16
def _normalize_model_name(self, model: str) -> str:
return model.rsplit("/", 1)[-1]
def _model_supports_image_edits(self, model: str) -> bool:
normalized_model = self._normalize_model_name(model)
return (
normalized_model.startswith(self._GPT_IMAGE_MODEL_PREFIX)
or normalized_model == self._DALL_E_2_MODEL_NAME
)
def generate_image(
self,
prompt: str,
@@ -82,44 +60,14 @@ class AzureImageGenerationProvider(ImageGenerationProvider):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None,
reference_images: list[ReferenceImage] | None = None, # noqa: ARG002
**kwargs: Any,
) -> ImageGenerationResponse:
from litellm import image_generation
deployment = self._deployment_name or model
model_name = f"azure/{deployment}"
if reference_images:
if not self._model_supports_image_edits(model):
raise ValueError(
f"Model '{model}' does not support image edits with reference images."
)
normalized_model = self._normalize_model_name(model)
if (
normalized_model == self._DALL_E_2_MODEL_NAME
and len(reference_images) > 1
):
raise ValueError(
"Model 'dall-e-2' only supports a single reference image for edits."
)
from litellm import image_edit
return image_edit(
image=[image.data for image in reference_images],
prompt=prompt,
model=model_name,
api_key=self._api_key,
api_base=self._api_base,
api_version=self._api_version,
size=size,
n=n,
quality=quality,
**kwargs,
)
from litellm import image_generation
return image_generation(
prompt=prompt,
model=model_name,

View File

@@ -12,9 +12,6 @@ if TYPE_CHECKING:
class OpenAIImageGenerationProvider(ImageGenerationProvider):
_GPT_IMAGE_MODEL_PREFIX = "gpt-image-"
_DALL_E_2_MODEL_NAME = "dall-e-2"
def __init__(
self,
api_key: str,
@@ -42,25 +39,6 @@ class OpenAIImageGenerationProvider(ImageGenerationProvider):
api_base=credentials.api_base,
)
@property
def supports_reference_images(self) -> bool:
return True
@property
def max_reference_images(self) -> int:
# GPT image models support up to 16 input images for edits.
return 16
def _normalize_model_name(self, model: str) -> str:
return model.rsplit("/", 1)[-1]
def _model_supports_image_edits(self, model: str) -> bool:
normalized_model = self._normalize_model_name(model)
return (
normalized_model.startswith(self._GPT_IMAGE_MODEL_PREFIX)
or normalized_model == self._DALL_E_2_MODEL_NAME
)
def generate_image(
self,
prompt: str,
@@ -68,38 +46,9 @@ class OpenAIImageGenerationProvider(ImageGenerationProvider):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None,
reference_images: list[ReferenceImage] | None = None, # noqa: ARG002
**kwargs: Any,
) -> ImageGenerationResponse:
if reference_images:
if not self._model_supports_image_edits(model):
raise ValueError(
f"Model '{model}' does not support image edits with reference images."
)
normalized_model = self._normalize_model_name(model)
if (
normalized_model == self._DALL_E_2_MODEL_NAME
and len(reference_images) > 1
):
raise ValueError(
"Model 'dall-e-2' only supports a single reference image for edits."
)
from litellm import image_edit
return image_edit(
image=[image.data for image in reference_images],
prompt=prompt,
model=model,
api_key=self._api_key,
api_base=self._api_base,
size=size,
n=n,
quality=quality,
**kwargs,
)
from litellm import image_generation
return image_generation(

View File

@@ -146,7 +146,6 @@ class DocumentIndexingBatchAdapter:
doc_id_to_document_set.get(chunk.source_document.id, [])
),
user_project=[],
personas=[],
boost=(
context.id_to_boost_map[chunk.source_document.id]
if chunk.source_document.id in context.id_to_boost_map

View File

@@ -20,7 +20,6 @@ from onyx.db.models import Persona
from onyx.db.models import UserFile
from onyx.db.notification import create_notification
from onyx.db.user_file import fetch_chunk_counts_for_user_files
from onyx.db.user_file import fetch_persona_ids_for_user_files
from onyx.db.user_file import fetch_user_project_ids_for_user_files
from onyx.file_store.utils import store_user_file_plaintext
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
@@ -120,10 +119,6 @@ class UserFileIndexingAdapter:
user_file_ids=updatable_ids,
db_session=self.db_session,
)
user_file_id_to_persona_ids = fetch_persona_ids_for_user_files(
user_file_ids=updatable_ids,
db_session=self.db_session,
)
user_file_id_to_access: dict[str, DocumentAccess] = get_access_for_user_files(
user_file_ids=updatable_ids,
db_session=self.db_session,
@@ -187,7 +182,7 @@ class UserFileIndexingAdapter:
user_project=user_file_id_to_project_ids.get(
chunk.source_document.id, []
),
personas=user_file_id_to_persona_ids.get(chunk.source_document.id, []),
# we are going to index userfiles only once, so we just set the boost to the default
boost=DEFAULT_BOOST,
tenant_id=tenant_id,
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],

View File

@@ -49,7 +49,6 @@ from onyx.indexing.embedder import IndexingEmbedder
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import IndexingBatchAdapter
from onyx.indexing.models import UpdatableChunkData
from onyx.indexing.postgres_sanitization import sanitize_documents_for_postgres
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
from onyx.llm.factory import get_default_llm_with_vision
from onyx.llm.factory import get_llm_for_contextual_rag
@@ -229,8 +228,6 @@ def index_doc_batch_prepare(
) -> DocumentBatchPrepareContext | None:
"""Sets up the documents in the relational DB (source of truth) for permissions, metadata, etc.
This preceeds indexing it into the actual document index."""
documents = sanitize_documents_for_postgres(documents)
# Create a trimmed list of docs that don't have a newer updated at
# Shortcuts the time-consuming flow on connector index retries
document_ids: list[str] = [document.id for document in documents]

View File

@@ -112,7 +112,6 @@ class DocMetadataAwareIndexChunk(IndexChunk):
access: "DocumentAccess"
document_sets: set[str]
user_project: list[int]
personas: list[int]
boost: int
aggregated_chunk_boost_factor: float
# Full ancestor path from root hierarchy node to document's parent.
@@ -127,7 +126,6 @@ class DocMetadataAwareIndexChunk(IndexChunk):
access: "DocumentAccess",
document_sets: set[str],
user_project: list[int],
personas: list[int],
boost: int,
aggregated_chunk_boost_factor: float,
tenant_id: str,
@@ -139,7 +137,6 @@ class DocMetadataAwareIndexChunk(IndexChunk):
access=access,
document_sets=document_sets,
user_project=user_project,
personas=personas,
boost=boost,
aggregated_chunk_boost_factor=aggregated_chunk_boost_factor,
tenant_id=tenant_id,

View File

@@ -1,150 +0,0 @@
from typing import Any
from onyx.access.models import ExternalAccess
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
def _sanitize_string(value: str) -> str:
return value.replace("\x00", "")
def _sanitize_json_like(value: Any) -> Any:
if isinstance(value, str):
return _sanitize_string(value)
if isinstance(value, list):
return [_sanitize_json_like(item) for item in value]
if isinstance(value, tuple):
return tuple(_sanitize_json_like(item) for item in value)
if isinstance(value, dict):
sanitized: dict[Any, Any] = {}
for key, nested_value in value.items():
cleaned_key = _sanitize_string(key) if isinstance(key, str) else key
sanitized[cleaned_key] = _sanitize_json_like(nested_value)
return sanitized
return value
def _sanitize_expert_info(expert: BasicExpertInfo) -> BasicExpertInfo:
return expert.model_copy(
update={
"display_name": (
_sanitize_string(expert.display_name)
if expert.display_name is not None
else None
),
"first_name": (
_sanitize_string(expert.first_name)
if expert.first_name is not None
else None
),
"middle_initial": (
_sanitize_string(expert.middle_initial)
if expert.middle_initial is not None
else None
),
"last_name": (
_sanitize_string(expert.last_name)
if expert.last_name is not None
else None
),
"email": (
_sanitize_string(expert.email) if expert.email is not None else None
),
}
)
def _sanitize_external_access(external_access: ExternalAccess) -> ExternalAccess:
return ExternalAccess(
external_user_emails={
_sanitize_string(email) for email in external_access.external_user_emails
},
external_user_group_ids={
_sanitize_string(group_id)
for group_id in external_access.external_user_group_ids
},
is_public=external_access.is_public,
)
def sanitize_document_for_postgres(document: Document) -> Document:
cleaned_doc = document.model_copy(deep=True)
cleaned_doc.id = _sanitize_string(cleaned_doc.id)
cleaned_doc.semantic_identifier = _sanitize_string(cleaned_doc.semantic_identifier)
if cleaned_doc.title is not None:
cleaned_doc.title = _sanitize_string(cleaned_doc.title)
if cleaned_doc.parent_hierarchy_raw_node_id is not None:
cleaned_doc.parent_hierarchy_raw_node_id = _sanitize_string(
cleaned_doc.parent_hierarchy_raw_node_id
)
cleaned_doc.metadata = {
_sanitize_string(key): (
[_sanitize_string(item) for item in value]
if isinstance(value, list)
else _sanitize_string(value)
)
for key, value in cleaned_doc.metadata.items()
}
if cleaned_doc.doc_metadata is not None:
cleaned_doc.doc_metadata = _sanitize_json_like(cleaned_doc.doc_metadata)
if cleaned_doc.primary_owners is not None:
cleaned_doc.primary_owners = [
_sanitize_expert_info(expert) for expert in cleaned_doc.primary_owners
]
if cleaned_doc.secondary_owners is not None:
cleaned_doc.secondary_owners = [
_sanitize_expert_info(expert) for expert in cleaned_doc.secondary_owners
]
if cleaned_doc.external_access is not None:
cleaned_doc.external_access = _sanitize_external_access(
cleaned_doc.external_access
)
for section in cleaned_doc.sections:
if section.link is not None:
section.link = _sanitize_string(section.link)
if section.text is not None:
section.text = _sanitize_string(section.text)
if section.image_file_id is not None:
section.image_file_id = _sanitize_string(section.image_file_id)
return cleaned_doc
def sanitize_documents_for_postgres(documents: list[Document]) -> list[Document]:
return [sanitize_document_for_postgres(document) for document in documents]
def sanitize_hierarchy_node_for_postgres(node: HierarchyNode) -> HierarchyNode:
cleaned_node = node.model_copy(deep=True)
cleaned_node.raw_node_id = _sanitize_string(cleaned_node.raw_node_id)
cleaned_node.display_name = _sanitize_string(cleaned_node.display_name)
if cleaned_node.raw_parent_id is not None:
cleaned_node.raw_parent_id = _sanitize_string(cleaned_node.raw_parent_id)
if cleaned_node.link is not None:
cleaned_node.link = _sanitize_string(cleaned_node.link)
if cleaned_node.external_access is not None:
cleaned_node.external_access = _sanitize_external_access(
cleaned_node.external_access
)
return cleaned_node
def sanitize_hierarchy_nodes_for_postgres(
nodes: list[HierarchyNode],
) -> list[HierarchyNode]:
return [sanitize_hierarchy_node_for_postgres(node) for node in nodes]

View File

@@ -97,9 +97,6 @@ from onyx.server.features.web_search.api import router as web_search_router
from onyx.server.federated.api import router as federated_router
from onyx.server.kg.api import admin_router as kg_admin_router
from onyx.server.manage.administrative import router as admin_router
from onyx.server.manage.code_interpreter.api import (
admin_router as code_interpreter_admin_router,
)
from onyx.server.manage.discord_bot.api import router as discord_bot_router
from onyx.server.manage.embedding.api import admin_router as embedding_admin_router
from onyx.server.manage.embedding.api import basic_router as embedding_router
@@ -424,9 +421,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
include_router_with_global_prefix_prepended(application, llm_admin_router)
include_router_with_global_prefix_prepended(application, kg_admin_router)
include_router_with_global_prefix_prepended(application, llm_router)
include_router_with_global_prefix_prepended(
application, code_interpreter_admin_router
)
include_router_with_global_prefix_prepended(
application, image_generation_admin_router
)

View File

@@ -1,68 +1,14 @@
import re
from typing import Any
from mistune import create_markdown
from mistune import HTMLRenderer
_CITATION_LINK_PATTERN = re.compile(r"\[\[\d+\]\]\(")
def _extract_link_destination(message: str, start_idx: int) -> tuple[str, int | None]:
"""Extract markdown link destination, allowing nested parentheses in the URL."""
depth = 0
i = start_idx
while i < len(message):
curr = message[i]
if curr == "\\":
i += 2
continue
if curr == "(":
depth += 1
elif curr == ")":
if depth == 0:
return message[start_idx:i], i
depth -= 1
i += 1
return message[start_idx:], None
def _normalize_citation_link_destinations(message: str) -> str:
"""Wrap citation URLs in angle brackets so markdown parsers handle parentheses safely."""
if "[[" not in message:
return message
normalized_parts: list[str] = []
cursor = 0
while match := _CITATION_LINK_PATTERN.search(message, cursor):
normalized_parts.append(message[cursor : match.end()])
destination_start = match.end()
destination, end_idx = _extract_link_destination(message, destination_start)
if end_idx is None:
normalized_parts.append(message[destination_start:])
return "".join(normalized_parts)
already_wrapped = destination.startswith("<") and destination.endswith(">")
if destination and not already_wrapped:
destination = f"<{destination}>"
normalized_parts.append(destination)
normalized_parts.append(")")
cursor = end_idx + 1
normalized_parts.append(message[cursor:])
return "".join(normalized_parts)
def format_slack_message(message: str | None) -> str:
if message is None:
return ""
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough"])
normalized_message = _normalize_citation_link_destinations(message)
result = md(normalized_message)
result = md(message)
# With HTMLRenderer, result is always str (not AST list)
assert isinstance(result, str)
return result

View File

@@ -762,43 +762,6 @@ def download_webapp(
)
@router.get("/{session_id}/download-directory/{path:path}")
def download_directory(
session_id: UUID,
path: str,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> Response:
"""
Download a directory as a zip file.
Returns the specified directory as a zip archive.
"""
user_id: UUID = user.id
session_manager = SessionManager(db_session)
try:
result = session_manager.download_directory(session_id, user_id, path)
except ValueError as e:
error_message = str(e)
if "path traversal" in error_message.lower():
raise HTTPException(status_code=403, detail="Access denied")
raise HTTPException(status_code=400, detail=error_message)
if result is None:
raise HTTPException(status_code=404, detail="Directory not found")
zip_bytes, filename = result
return Response(
content=zip_bytes,
media_type="application/zip",
headers={
"Content-Disposition": f'attachment; filename="{filename}"',
},
)
@router.post("/{session_id}/upload", response_model=UploadResponse)
def upload_file_endpoint(
session_id: UUID,

View File

@@ -107,23 +107,27 @@ def get_or_create_craft_connector(db_session: Session, user: User) -> tuple[int,
)
for cc_pair in cc_pairs:
if (
cc_pair.connector.source == DocumentSource.CRAFT_FILE
and cc_pair.creator_id == user.id
):
if cc_pair.connector.source == DocumentSource.CRAFT_FILE:
return cc_pair.connector.id, cc_pair.credential.id
# No cc_pair for this user — find or create the shared CRAFT_FILE connector
# Check for orphaned connector (created but cc_pair creation failed previously)
existing_connectors = fetch_connectors(
db_session, sources=[DocumentSource.CRAFT_FILE]
)
connector_id: int | None = None
orphaned_connector = None
for conn in existing_connectors:
if conn.name == USER_LIBRARY_CONNECTOR_NAME:
connector_id = conn.id
if conn.name != USER_LIBRARY_CONNECTOR_NAME:
continue
if not conn.credentials:
orphaned_connector = conn
break
if connector_id is None:
if orphaned_connector:
connector_id = orphaned_connector.id
logger.info(
f"Found orphaned User Library connector {connector_id}, completing setup"
)
else:
connector_data = ConnectorBase(
name=USER_LIBRARY_CONNECTOR_NAME,
source=DocumentSource.CRAFT_FILE,

View File

@@ -1,19 +1,15 @@
#!/usr/bin/env python3
"""Generate AGENTS.md by scanning the files directory and populating the template.
This script runs during session setup, AFTER files have been synced from S3
and the files symlink has been created. It reads an existing AGENTS.md (which
contains the {{KNOWLEDGE_SOURCES_SECTION}} placeholder), replaces the
placeholder by scanning the knowledge source directory, and writes it back.
This script runs at container startup, AFTER the init container has synced files
from S3. It scans the /workspace/files directory to discover what knowledge sources
are available and generates appropriate documentation.
Usage:
python3 generate_agents_md.py <agents_md_path> <files_path>
Arguments:
agents_md_path: Path to the AGENTS.md file to update in place
files_path: Path to the files directory to scan for knowledge sources
Environment variables:
- AGENT_INSTRUCTIONS: The template content with placeholders to replace
"""
import os
import sys
from pathlib import Path
@@ -193,39 +189,49 @@ def build_knowledge_sources_section(files_path: Path) -> str:
def main() -> None:
"""Main entry point for container startup script.
Reads an existing AGENTS.md, replaces the {{KNOWLEDGE_SOURCES_SECTION}}
placeholder by scanning the files directory, and writes it back.
Usage:
python3 generate_agents_md.py <agents_md_path> <files_path>
Is called by the container startup script to scan /workspace/files and populate
the knowledge sources section.
"""
if len(sys.argv) != 3:
print(
f"Usage: {sys.argv[0]} <agents_md_path> <files_path>",
file=sys.stderr,
)
sys.exit(1)
# Read template from environment variable
template = os.environ.get("AGENT_INSTRUCTIONS", "")
if not template:
print("Warning: No AGENT_INSTRUCTIONS template provided", file=sys.stderr)
template = "# Agent Instructions\n\nNo instructions provided."
agents_md_path = Path(sys.argv[1])
files_path = Path(sys.argv[2])
# Scan files directory - check /workspace/files first, then /workspace/demo_data
files_path = Path("/workspace/files")
demo_data_path = Path("/workspace/demo_data")
if not agents_md_path.exists():
print(f"Error: {agents_md_path} not found", file=sys.stderr)
sys.exit(1)
# Use demo_data if files doesn't exist or is empty
if not files_path.exists() or not any(files_path.iterdir()):
if demo_data_path.exists():
files_path = demo_data_path
template = agents_md_path.read_text()
knowledge_sources_section = build_knowledge_sources_section(files_path)
# Resolve symlinks (handles both direct symlinks and dirs containing symlinks)
resolved_files_path = files_path.resolve()
knowledge_sources_section = build_knowledge_sources_section(resolved_files_path)
# Replace placeholder and write back
content = template.replace(
# Replace placeholders
content = template
content = content.replace(
"{{KNOWLEDGE_SOURCES_SECTION}}", knowledge_sources_section
)
agents_md_path.write_text(content)
print(f"Populated knowledge sources in {agents_md_path}")
# Write AGENTS.md
output_path = Path("/workspace/AGENTS.md")
output_path.write_text(content)
# Log result
source_count = 0
if files_path.exists():
source_count = len(
[
d
for d in files_path.iterdir()
if d.is_dir() and not d.name.startswith(".")
]
)
print(
f"Generated AGENTS.md with {source_count} knowledge sources from {files_path}"
)
if __name__ == "__main__":

View File

@@ -1352,9 +1352,6 @@ fi
echo "Writing AGENTS.md"
printf '%s' '{agent_instructions_escaped}' > {session_path}/AGENTS.md
# Populate knowledge sources by scanning the files directory
python3 /usr/local/bin/generate_agents_md.py {session_path}/AGENTS.md {session_path}/files || true
# Write opencode config
echo "Writing opencode.json"
printf '%s' '{opencode_json_escaped}' > {session_path}/opencode.json
@@ -1783,9 +1780,6 @@ ln -sf {symlink_target} {session_path}/files
echo "Writing AGENTS.md"
printf '%s' '{agent_instructions_escaped}' > {session_path}/AGENTS.md
# Populate knowledge sources by scanning the files directory
python3 /usr/local/bin/generate_agents_md.py {session_path}/AGENTS.md {session_path}/files || true
# Write opencode config
echo "Writing opencode.json"
printf '%s' '{opencode_json_escaped}' > {session_path}/opencode.json

View File

@@ -68,7 +68,6 @@ from onyx.server.features.build.db.sandbox import create_sandbox__no_commit
from onyx.server.features.build.db.sandbox import get_running_sandbox_count_by_tenant
from onyx.server.features.build.db.sandbox import get_sandbox_by_session_id
from onyx.server.features.build.db.sandbox import get_sandbox_by_user_id
from onyx.server.features.build.db.sandbox import get_snapshots_for_session
from onyx.server.features.build.db.sandbox import update_sandbox_heartbeat
from onyx.server.features.build.db.sandbox import update_sandbox_status__no_commit
from onyx.server.features.build.sandbox import get_sandbox_manager
@@ -647,30 +646,16 @@ class SessionManager:
if sandbox and sandbox.status.is_active():
# Quick health check to verify sandbox is actually responsive
# AND verify the session workspace still exists on disk
# (it may have been wiped if the sandbox was re-provisioned)
is_healthy = self._sandbox_manager.health_check(sandbox.id, timeout=5.0)
workspace_exists = (
is_healthy
and self._sandbox_manager.session_workspace_exists(
sandbox.id, existing.id
)
)
if is_healthy and workspace_exists:
if self._sandbox_manager.health_check(sandbox.id, timeout=5.0):
logger.info(
f"Returning existing empty session {existing.id} for user {user_id}"
)
return existing
elif not is_healthy:
else:
logger.warning(
f"Empty session {existing.id} has unhealthy sandbox {sandbox.id}. "
f"Deleting and creating fresh session."
)
else:
logger.warning(
f"Empty session {existing.id} workspace missing in sandbox "
f"{sandbox.id}. Deleting and creating fresh session."
)
else:
logger.warning(
f"Empty session {existing.id} has no active sandbox "
@@ -1050,23 +1035,6 @@ class SessionManager:
# workspace cleanup fails (e.g., if pod is already terminated)
logger.warning(f"Failed to cleanup session workspace {session_id}: {e}")
# Delete snapshot files from S3 before removing DB records
snapshots = get_snapshots_for_session(self._db_session, session_id)
if snapshots:
from onyx.file_store.file_store import get_default_file_store
from onyx.server.features.build.sandbox.manager.snapshot_manager import (
SnapshotManager,
)
snapshot_manager = SnapshotManager(get_default_file_store())
for snapshot in snapshots:
try:
snapshot_manager.delete_snapshot(snapshot.storage_path)
except Exception as e:
logger.warning(
f"Failed to delete snapshot file {snapshot.storage_path}: {e}"
)
# Delete session (uses flush, caller commits)
return delete_build_session__no_commit(session_id, user_id, self._db_session)
@@ -1935,94 +1903,6 @@ class SessionManager:
return zip_buffer.getvalue(), filename
def download_directory(
self,
session_id: UUID,
user_id: UUID,
path: str,
) -> tuple[bytes, str] | None:
"""
Create a zip file of an arbitrary directory in the session workspace.
Args:
session_id: The session UUID
user_id: The user ID to verify ownership
path: Relative path to the directory (within session workspace)
Returns:
Tuple of (zip_bytes, filename) or None if session not found
Raises:
ValueError: If path traversal attempted or path is not a directory
"""
# Verify session ownership
session = get_build_session(session_id, user_id, self._db_session)
if session is None:
return None
sandbox = get_sandbox_by_user_id(self._db_session, user_id)
if sandbox is None:
return None
# Check if directory exists
try:
self._sandbox_manager.list_directory(
sandbox_id=sandbox.id,
session_id=session_id,
path=path,
)
except ValueError:
return None
# Recursively collect all files
def collect_files(dir_path: str) -> list[tuple[str, str]]:
"""Collect all files recursively, returning (full_path, arcname) tuples."""
files: list[tuple[str, str]] = []
try:
entries = self._sandbox_manager.list_directory(
sandbox_id=sandbox.id,
session_id=session_id,
path=dir_path,
)
for entry in entries:
if entry.is_directory:
files.extend(collect_files(entry.path))
else:
# arcname is relative to the target directory
prefix_len = len(path) + 1 # +1 for trailing slash
arcname = entry.path[prefix_len:]
files.append((entry.path, arcname))
except ValueError:
pass
return files
file_list = collect_files(path)
# Create zip file in memory
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file:
for full_path, arcname in file_list:
try:
content = self._sandbox_manager.read_file(
sandbox_id=sandbox.id,
session_id=session_id,
path=full_path,
)
zip_file.writestr(arcname, content)
except ValueError:
pass
zip_buffer.seek(0)
# Use the directory name for the zip filename
dir_name = Path(path).name
safe_name = "".join(
c if c.isalnum() or c in ("-", "_", ".") else "_" for c in dir_name
)
filename = f"{safe_name}.zip"
return zip_buffer.getvalue(), filename
# =========================================================================
# File System Operations
# =========================================================================
@@ -2057,18 +1937,11 @@ class SessionManager:
return None
# Use sandbox manager to list directory (works for both local and K8s)
# If the directory doesn't exist (e.g., session workspace not yet loaded),
# return an empty listing rather than erroring out.
try:
raw_entries = self._sandbox_manager.list_directory(
sandbox_id=sandbox.id,
session_id=session_id,
path=path,
)
except ValueError as e:
if "path traversal" in str(e).lower():
raise
return DirectoryListing(path=path, entries=[])
raw_entries = self._sandbox_manager.list_directory(
sandbox_id=sandbox.id,
session_id=session_id,
path=path,
)
# Filter hidden files and directories
entries: list[FileSystemEntry] = [

View File

@@ -12,18 +12,11 @@ from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.auth.users import current_user
from onyx.background.celery.tasks.user_file_processing.tasks import (
enqueue_user_file_project_sync_task,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
get_user_file_project_sync_queue_depth,
)
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import PUBLIC_API_TAGS
from onyx.configs.constants import USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import UserFileStatus
from onyx.db.models import ChatSession
@@ -34,7 +27,6 @@ from onyx.db.models import UserProject
from onyx.db.persona import get_personas_by_ids
from onyx.db.projects import get_project_token_count
from onyx.db.projects import upload_files_to_user_files_with_indexing
from onyx.redis.redis_pool import get_redis_client
from onyx.server.features.projects.models import CategorizedFilesSnapshot
from onyx.server.features.projects.models import ChatSessionRequest
from onyx.server.features.projects.models import TokenCountResponse
@@ -55,33 +47,6 @@ class UserFileDeleteResult(BaseModel):
assistant_names: list[str] = []
def _trigger_user_file_project_sync(user_file_id: UUID, tenant_id: str) -> None:
queue_depth = get_user_file_project_sync_queue_depth(client_app)
if queue_depth > USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH:
logger.warning(
f"Skipping immediate project sync for user_file_id={user_file_id} due to "
f"queue depth {queue_depth}>{USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH}. "
"It will be picked up by beat later."
)
return
redis_client = get_redis_client(tenant_id=tenant_id)
enqueued = enqueue_user_file_project_sync_task(
celery_app=client_app,
redis_client=redis_client,
user_file_id=user_file_id,
tenant_id=tenant_id,
priority=OnyxCeleryPriority.HIGHEST,
)
if not enqueued:
logger.info(
f"Skipped duplicate project sync enqueue for user_file_id={user_file_id}"
)
return
logger.info(f"Triggered project sync for user_file_id={user_file_id}")
@router.get("", tags=PUBLIC_API_TAGS)
def get_projects(
user: User = Depends(current_user),
@@ -224,7 +189,15 @@ def unlink_user_file_from_project(
db_session.commit()
tenant_id = get_current_tenant_id()
_trigger_user_file_project_sync(user_file.id, tenant_id)
task = client_app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
priority=OnyxCeleryPriority.HIGHEST,
)
logger.info(
f"Triggered project sync for user_file_id={user_file.id} with task_id={task.id}"
)
return Response(status_code=204)
@@ -268,7 +241,15 @@ def link_user_file_to_project(
db_session.commit()
tenant_id = get_current_tenant_id()
_trigger_user_file_project_sync(user_file.id, tenant_id)
task = client_app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
priority=OnyxCeleryPriority.HIGHEST,
)
logger.info(
f"Triggered project sync for user_file_id={user_file.id} with task_id={task.id}"
)
return UserFileSnapshot.from_model(user_file)

View File

@@ -1,47 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.db.code_interpreter import fetch_code_interpreter_server
from onyx.db.code_interpreter import update_code_interpreter_server_enabled
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.server.manage.code_interpreter.models import CodeInterpreterServer
from onyx.server.manage.code_interpreter.models import CodeInterpreterServerHealth
from onyx.tools.tool_implementations.python.code_interpreter_client import (
CodeInterpreterClient,
)
admin_router = APIRouter(prefix="/admin/code-interpreter")
@admin_router.get("/health")
def get_code_interpreter_health(
_: User = Depends(current_admin_user),
) -> CodeInterpreterServerHealth:
try:
client = CodeInterpreterClient()
return CodeInterpreterServerHealth(healthy=client.health())
except ValueError:
return CodeInterpreterServerHealth(healthy=False)
@admin_router.get("")
def get_code_interpreter(
_: User = Depends(current_admin_user), db_session: Session = Depends(get_session)
) -> CodeInterpreterServer:
ci_server = fetch_code_interpreter_server(db_session)
return CodeInterpreterServer(enabled=ci_server.server_enabled)
@admin_router.put("")
def update_code_interpreter(
update: CodeInterpreterServer,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
update_code_interpreter_server_enabled(
db_session=db_session,
enabled=update.enabled,
)

View File

@@ -1,9 +0,0 @@
from pydantic import BaseModel
class CodeInterpreterServer(BaseModel):
enabled: bool
class CodeInterpreterServerHealth(BaseModel):
healthy: bool

View File

@@ -97,6 +97,7 @@ def _build_llm_provider_request(
), # Only this from source
api_base=api_base, # From request
api_version=api_version, # From request
default_model_name=model_name,
deployment_name=deployment_name, # From request
is_public=True,
groups=[],
@@ -135,6 +136,7 @@ def _build_llm_provider_request(
api_key=api_key,
api_base=api_base,
api_version=api_version,
default_model_name=model_name,
deployment_name=deployment_name,
is_public=True,
groups=[],
@@ -166,6 +168,7 @@ def _create_image_gen_llm_provider__no_commit(
api_key=provider_request.api_key,
api_base=provider_request.api_base,
api_version=provider_request.api_version,
default_model_name=provider_request.default_model_name,
deployment_name=provider_request.deployment_name,
is_public=provider_request.is_public,
custom_config=provider_request.custom_config,

View File

@@ -22,10 +22,7 @@ from onyx.auth.users import current_chat_accessible_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import LLMModelFlowType
from onyx.db.llm import can_user_access_llm_provider
from onyx.db.llm import fetch_default_llm_model
from onyx.db.llm import fetch_default_vision_model
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import fetch_existing_llm_provider_by_id
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_existing_models
from onyx.db.llm import fetch_persona_with_groups
@@ -55,12 +52,11 @@ from onyx.llm.well_known_providers.llm_provider_options import (
)
from onyx.server.manage.llm.models import BedrockFinalModelResponse
from onyx.server.manage.llm.models import BedrockModelsRequest
from onyx.server.manage.llm.models import DefaultModel
from onyx.server.manage.llm.models import LLMCost
from onyx.server.manage.llm.models import LLMProviderDescriptor
from onyx.server.manage.llm.models import LLMProviderResponse
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from onyx.server.manage.llm.models import OllamaFinalModelResponse
from onyx.server.manage.llm.models import OllamaModelDetails
from onyx.server.manage.llm.models import OllamaModelsRequest
@@ -237,12 +233,12 @@ def test_llm_configuration(
test_api_key = test_llm_request.api_key
test_custom_config = test_llm_request.custom_config
if test_llm_request.id:
if test_llm_request.name:
# NOTE: we are querying by name. we probably should be querying by an invariant id, but
# as it turns out the name is not editable in the UI and other code also keys off name,
# so we won't rock the boat just yet.
existing_provider = fetch_existing_llm_provider_by_id(
id=test_llm_request.id, db_session=db_session
existing_provider = fetch_existing_llm_provider(
name=test_llm_request.name, db_session=db_session
)
if existing_provider:
test_custom_config = _restore_masked_custom_config_values(
@@ -272,7 +268,7 @@ def test_llm_configuration(
llm = get_llm(
provider=test_llm_request.provider,
model=test_llm_request.model,
model=test_llm_request.default_model_name,
api_key=test_api_key,
api_base=test_llm_request.api_base,
api_version=test_llm_request.api_version,
@@ -307,7 +303,7 @@ def list_llm_providers(
include_image_gen: bool = Query(False),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> LLMProviderResponse[LLMProviderView]:
) -> list[LLMProviderView]:
start_time = datetime.now(timezone.utc)
logger.debug("Starting to fetch LLM providers")
@@ -332,25 +328,7 @@ def list_llm_providers(
duration = (end_time - start_time).total_seconds()
logger.debug(f"Completed fetching LLM providers in {duration:.2f} seconds")
default_model = None
if model_config := fetch_default_llm_model(db_session):
default_model = DefaultModel(
provider_id=model_config.llm_provider.id,
model_name=model_config.name,
)
default_vision_model = None
if model_config := fetch_default_vision_model(db_session):
default_vision_model = DefaultModel(
provider_id=model_config.llm_provider.id,
model_name=model_config.name,
)
return LLMProviderResponse[LLMProviderView].from_models(
providers=llm_provider_list,
default_text=default_model,
default_vision=default_vision_model,
)
return llm_provider_list
@admin_router.put("/provider")
@@ -363,29 +341,21 @@ def put_llm_provider(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> LLMProviderView:
# NOTE: Name updating functionality currently not supported. There are many places that still
# rely on immutable names, so this will be a larger change
# validate request (e.g. if we're intending to create but the name already exists we should throw an error)
# NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache
# the result
existing_provider = None
if llm_provider_upsert_request.id:
existing_provider = fetch_existing_llm_provider_by_id(
id=llm_provider_upsert_request.id, db_session=db_session
)
existing_provider = fetch_existing_llm_provider(
name=llm_provider_upsert_request.name, db_session=db_session
)
if existing_provider and is_creation:
raise HTTPException(
status_code=400,
detail=f"LLM Provider with name {llm_provider_upsert_request.name} and \
id={llm_provider_upsert_request.id} already exists",
detail=f"LLM Provider with name {llm_provider_upsert_request.name} already exists",
)
elif not existing_provider and not is_creation:
raise HTTPException(
status_code=400,
detail=f"LLM Provider with name {llm_provider_upsert_request.name} and \
id={llm_provider_upsert_request.id} does not exist",
detail=f"LLM Provider with name {llm_provider_upsert_request.name} does not exist",
)
# SSRF Protection: Validate api_base and custom_config match stored values
@@ -423,6 +393,22 @@ def put_llm_provider(
deduplicated_personas.append(persona_id)
llm_provider_upsert_request.personas = deduplicated_personas
default_model_found = False
for model_configuration in llm_provider_upsert_request.model_configurations:
if model_configuration.name == llm_provider_upsert_request.default_model_name:
model_configuration.is_visible = True
default_model_found = True
# TODO: Remove this logic on api change
# Believed to be a dead pathway but we want to be safe for now
if not default_model_found:
llm_provider_upsert_request.model_configurations.append(
ModelConfigurationUpsertRequest(
name=llm_provider_upsert_request.default_model_name, is_visible=True
)
)
# the llm api key is sanitized when returned to clients, so the only time we
# should get a real key is when it is explicitly changed
if existing_provider and not llm_provider_upsert_request.api_key_changed:
@@ -452,8 +438,8 @@ def put_llm_provider(
config = fetch_llm_recommendations_from_github()
if config and llm_provider_upsert_request.provider in config.providers:
# Refetch the provider to get the updated model
updated_provider = fetch_existing_llm_provider_by_id(
id=result.id, db_session=db_session
updated_provider = fetch_existing_llm_provider(
name=llm_provider_upsert_request.name, db_session=db_session
)
if updated_provider:
sync_auto_mode_models(
@@ -483,29 +469,28 @@ def delete_llm_provider(
raise HTTPException(status_code=404, detail=str(e))
@admin_router.post("/default")
@admin_router.post("/provider/{provider_id}/default")
def set_provider_as_default(
default_model_request: DefaultModel,
provider_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
update_default_provider(
provider_id=default_model_request.provider_id,
model_name=default_model_request.model_name,
db_session=db_session,
)
update_default_provider(provider_id=provider_id, db_session=db_session)
@admin_router.post("/default-vision")
@admin_router.post("/provider/{provider_id}/default-vision")
def set_provider_as_default_vision(
default_model_request: DefaultModel,
provider_id: int,
vision_model: str | None = Query(
None, description="The default vision model to use"
),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
if vision_model is None:
raise HTTPException(status_code=404, detail="Vision model not provided")
update_default_vision_provider(
provider_id=default_model_request.provider_id,
vision_model=default_model_request.model_name,
db_session=db_session,
provider_id=provider_id, vision_model=vision_model, db_session=db_session
)
@@ -531,7 +516,7 @@ def get_auto_config(
def get_vision_capable_providers(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> LLMProviderResponse[VisionProviderResponse]:
) -> list[VisionProviderResponse]:
"""Return a list of LLM providers and their models that support image input"""
vision_models = fetch_existing_models(
db_session=db_session, flow_types=[LLMModelFlowType.VISION]
@@ -560,18 +545,7 @@ def get_vision_capable_providers(
]
logger.debug(f"Found {len(vision_provider_response)} vision-capable providers")
default_vision_model = None
if model_config := fetch_default_vision_model(db_session):
default_vision_model = DefaultModel(
provider_id=model_config.llm_provider.id,
model_name=model_config.name,
)
return LLMProviderResponse[VisionProviderResponse].from_models(
providers=vision_provider_response,
default_vision=default_vision_model,
)
return vision_provider_response
"""Endpoints for all"""
@@ -581,7 +555,7 @@ def get_vision_capable_providers(
def list_llm_provider_basics(
user: User = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> LLMProviderResponse[LLMProviderDescriptor]:
) -> list[LLMProviderDescriptor]:
"""Get LLM providers accessible to the current user.
Returns:
@@ -618,25 +592,7 @@ def list_llm_provider_basics(
f"Completed fetching {len(accessible_providers)} user-accessible providers in {duration:.2f} seconds"
)
default_model = None
if model_config := fetch_default_llm_model(db_session):
default_model = DefaultModel(
provider_id=model_config.llm_provider.id,
model_name=model_config.name,
)
default_vision_model = None
if model_config := fetch_default_vision_model(db_session):
default_vision_model = DefaultModel(
provider_id=model_config.llm_provider.id,
model_name=model_config.name,
)
return LLMProviderResponse[LLMProviderDescriptor].from_models(
providers=accessible_providers,
default_text=default_model,
default_vision=default_vision_model,
)
return accessible_providers
def get_valid_model_names_for_persona(
@@ -679,7 +635,7 @@ def list_llm_providers_for_persona(
persona_id: int,
user: User = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> LLMProviderResponse[LLMProviderDescriptor]:
) -> list[LLMProviderDescriptor]:
"""Get LLM providers for a specific persona.
Returns providers that the user can access when using this persona:
@@ -726,63 +682,7 @@ def list_llm_providers_for_persona(
f"Completed fetching {len(llm_provider_list)} LLM providers for persona {persona_id} in {duration:.2f} seconds"
)
# Get the default model and vision model for the persona
# NOTE: This should be ported over to use id as it is blocking on name mutability
persona_default_provider = persona.llm_model_provider_override
persona_default_model = persona.llm_model_version_override
default_text_model = fetch_default_llm_model(db_session)
default_vision_model = fetch_default_vision_model(db_session)
# Build default_text and default_vision using persona overrides when available,
# falling back to the global defaults.
default_text: DefaultModel | None = (
DefaultModel(
provider_id=default_text_model.llm_provider.id,
model_name=default_text_model.name,
)
if default_text_model
else None
)
default_vision: DefaultModel | None = (
DefaultModel(
provider_id=default_vision_model.llm_provider.id,
model_name=default_vision_model.name,
)
if default_vision_model
else None
)
if persona_default_provider:
provider = fetch_existing_llm_provider(persona_default_provider, db_session)
if provider:
if persona_default_model:
# Persona specifies both provider and model — use them directly
default_text = DefaultModel(
provider_id=provider.id,
model_name=persona_default_model,
)
else:
# Persona specifies only the provider — pick a visible (public) model,
# falling back to any model on this provider
visible_model = next(
(mc for mc in provider.model_configurations if mc.is_visible),
None,
)
fallback_model = visible_model or next(
iter(provider.model_configurations), None
)
if fallback_model:
default_text = DefaultModel(
provider_id=provider.id,
model_name=fallback_model.name,
)
return LLMProviderResponse[LLMProviderDescriptor].from_models(
providers=llm_provider_list,
default_text=default_text,
default_vision=default_vision,
)
return llm_provider_list
@admin_router.get("/provider-contextual-cost")

View File

@@ -1,7 +1,5 @@
from typing import Any
from typing import Generic
from typing import TYPE_CHECKING
from typing import TypeVar
from pydantic import BaseModel
from pydantic import Field
@@ -23,8 +21,6 @@ if TYPE_CHECKING:
ModelConfiguration as ModelConfigurationModel,
)
T = TypeVar("T", bound="LLMProviderDescriptor | LLMProviderView")
# TODO: Clear this up on api refactor
# There is still logic that requires sending each providers default model name
@@ -56,18 +52,19 @@ def get_default_vision_model_name(llm_provider_model: "LLMProviderModel") -> str
class TestLLMRequest(BaseModel):
# provider level
id: int | None = None
name: str | None = None
provider: str
model: str
api_key: str | None = None
api_base: str | None = None
api_version: str | None = None
custom_config: dict[str, str] | None = None
# model level
default_model_name: str
deployment_name: str | None = None
model_configurations: list["ModelConfigurationUpsertRequest"]
# if try and use the existing API/custom config key
api_key_changed: bool
custom_config_changed: bool
@@ -83,10 +80,13 @@ class LLMProviderDescriptor(BaseModel):
"""A descriptor for an LLM provider that can be safely viewed by
non-admin users. Used when giving a list of available LLMs."""
id: int
name: str
provider: str
provider_display_name: str # Human-friendly name like "Claude (Anthropic)"
default_model_name: str
is_default_provider: bool | None
is_default_vision_provider: bool | None
default_vision_model: str | None
model_configurations: list["ModelConfigurationView"]
@classmethod
@@ -99,12 +99,22 @@ class LLMProviderDescriptor(BaseModel):
)
provider = llm_provider_model.provider
default_model_name = get_default_llm_model_name(llm_provider_model)
default_vision_model = get_default_vision_model_name(llm_provider_model)
is_default_provider = bool(default_model_name)
is_default_vision_provider = default_vision_model is not None
default_model_name = default_model_name or llm_provider_model.default_model_name
return cls(
id=llm_provider_model.id,
name=llm_provider_model.name,
provider=provider,
provider_display_name=get_provider_display_name(provider),
default_model_name=default_model_name,
is_default_provider=is_default_provider,
is_default_vision_provider=is_default_vision_provider,
default_vision_model=default_vision_model,
model_configurations=filter_model_configurations(
llm_provider_model.model_configurations, provider
),
@@ -118,17 +128,18 @@ class LLMProvider(BaseModel):
api_base: str | None = None
api_version: str | None = None
custom_config: dict[str, str] | None = None
default_model_name: str
is_public: bool = True
is_auto_mode: bool = False
groups: list[int] = Field(default_factory=list)
personas: list[int] = Field(default_factory=list)
deployment_name: str | None = None
default_vision_model: str | None = None
class LLMProviderUpsertRequest(LLMProvider):
# should only be used for a "custom" provider
# for default providers, the built-in model names are used
id: int | None = None
api_key_changed: bool = False
custom_config_changed: bool = False
model_configurations: list["ModelConfigurationUpsertRequest"] = []
@@ -144,6 +155,8 @@ class LLMProviderView(LLMProvider):
"""Stripped down representation of LLMProvider for display / limited access info only"""
id: int
is_default_provider: bool | None = None
is_default_vision_provider: bool | None = None
model_configurations: list["ModelConfigurationView"]
@classmethod
@@ -165,6 +178,14 @@ class LLMProviderView(LLMProvider):
provider = llm_provider_model.provider
default_model_name = get_default_llm_model_name(llm_provider_model)
default_vision_model = get_default_vision_model_name(llm_provider_model)
is_default_provider = bool(default_model_name)
is_default_vision_provider = default_vision_model is not None
default_model_name = default_model_name or llm_provider_model.default_model_name
return cls(
id=llm_provider_model.id,
name=llm_provider_model.name,
@@ -177,6 +198,10 @@ class LLMProviderView(LLMProvider):
api_base=llm_provider_model.api_base,
api_version=llm_provider_model.api_version,
custom_config=llm_provider_model.custom_config,
default_model_name=default_model_name,
is_default_provider=is_default_provider,
is_default_vision_provider=is_default_vision_provider,
default_vision_model=default_vision_model,
is_public=llm_provider_model.is_public,
is_auto_mode=llm_provider_model.is_auto_mode,
groups=groups,
@@ -203,8 +228,7 @@ class ModelConfigurationUpsertRequest(BaseModel):
name=model_configuration_model.name,
is_visible=model_configuration_model.is_visible,
max_input_tokens=model_configuration_model.max_input_tokens,
supports_image_input=LLMModelFlowType.VISION
in model_configuration_model.llm_model_flow_types,
supports_image_input=model_configuration_model.supports_image_input,
display_name=model_configuration_model.display_name,
)
@@ -397,27 +421,3 @@ class OpenRouterFinalModelResponse(BaseModel):
int | None
) # From OpenRouter API context_length (may be missing for some models)
supports_image_input: bool
class DefaultModel(BaseModel):
provider_id: int
model_name: str
class LLMProviderResponse(BaseModel, Generic[T]):
providers: list[T]
default_text: DefaultModel | None = None
default_vision: DefaultModel | None = None
@classmethod
def from_models(
cls,
providers: list[T],
default_text: DefaultModel | None = None,
default_vision: DefaultModel | None = None,
) -> "LLMProviderResponse[T]":
return cls(
providers=providers,
default_text=default_text,
default_vision=default_vision,
)

View File

@@ -35,18 +35,6 @@ if TYPE_CHECKING:
pass
class EmailInviteStatus(str, Enum):
SENT = "SENT"
NOT_CONFIGURED = "NOT_CONFIGURED"
SEND_FAILED = "SEND_FAILED"
DISABLED = "DISABLED"
class BulkInviteResponse(BaseModel):
invited_count: int
email_invite_status: EmailInviteStatus
class VersionResponse(BaseModel):
backend_version: str

View File

@@ -36,7 +36,6 @@ from onyx.configs.app_configs import AUTH_BACKEND
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import AuthBackend
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import EMAIL_CONFIGURED
from onyx.configs.app_configs import ENABLE_EMAIL_INVITES
from onyx.configs.app_configs import NUM_FREE_TRIAL_USER_INVITES
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
@@ -79,10 +78,8 @@ from onyx.server.documents.models import PaginatedReturn
from onyx.server.features.projects.models import UserFileSnapshot
from onyx.server.manage.models import AllUsersResponse
from onyx.server.manage.models import AutoScrollRequest
from onyx.server.manage.models import BulkInviteResponse
from onyx.server.manage.models import ChatBackgroundRequest
from onyx.server.manage.models import DefaultAppModeRequest
from onyx.server.manage.models import EmailInviteStatus
from onyx.server.manage.models import MemoryItem
from onyx.server.manage.models import PersonalizationUpdateRequest
from onyx.server.manage.models import TenantInfo
@@ -371,7 +368,7 @@ def bulk_invite_users(
emails: list[str] = Body(..., embed=True),
current_user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> BulkInviteResponse:
) -> int:
"""emails are string validated. If any email fails validation, no emails are
invited and an exception is raised."""
tenant_id = get_current_tenant_id()
@@ -430,41 +427,34 @@ def bulk_invite_users(
number_of_invited_users = write_invited_users(all_emails)
# send out email invitations only to new users (not already invited or existing)
if not ENABLE_EMAIL_INVITES:
email_invite_status = EmailInviteStatus.DISABLED
elif not EMAIL_CONFIGURED:
email_invite_status = EmailInviteStatus.NOT_CONFIGURED
else:
if ENABLE_EMAIL_INVITES:
try:
for email in emails_needing_seats:
send_user_email_invite(email, current_user, AUTH_TYPE)
email_invite_status = EmailInviteStatus.SENT
except Exception as e:
logger.error(f"Error sending email invite to invited users: {e}")
email_invite_status = EmailInviteStatus.SEND_FAILED
if MULTI_TENANT and not DEV_MODE:
# for billing purposes, write to the control plane about the number of new users
try:
logger.info("Registering tenant users")
fetch_ee_implementation_or_noop(
"onyx.server.tenants.billing", "register_tenant_users", None
)(tenant_id, get_live_users_count(db_session))
except Exception as e:
logger.error(f"Failed to register tenant users: {str(e)}")
logger.info(
"Reverting changes: removing users from tenant and resetting invited users"
)
write_invited_users(initial_invited_users) # Reset to original state
fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
)(new_invited_emails, tenant_id)
raise e
if not MULTI_TENANT or DEV_MODE:
return number_of_invited_users
return BulkInviteResponse(
invited_count=number_of_invited_users,
email_invite_status=email_invite_status,
)
# for billing purposes, write to the control plane about the number of new users
try:
logger.info("Registering tenant users")
fetch_ee_implementation_or_noop(
"onyx.server.tenants.billing", "register_tenant_users", None
)(tenant_id, get_live_users_count(db_session))
return number_of_invited_users
except Exception as e:
logger.error(f"Failed to register tenant users: {str(e)}")
logger.info(
"Reverting changes: removing users from tenant and resetting invited users"
)
write_invited_users(initial_invited_users) # Reset to original state
fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
)(new_invited_emails, tenant_id)
raise e
@router.patch("/manage/admin/remove-invited-user", tags=PUBLIC_API_TAGS)

View File

@@ -587,7 +587,6 @@ def handle_send_chat_message(
request.headers
),
mcp_headers=chat_message_req.mcp_headers,
additional_context=chat_message_req.additional_context,
external_state_container=state_container,
)
result = gather_stream_full(packets, state_container)
@@ -610,7 +609,6 @@ def handle_send_chat_message(
request.headers
),
mcp_headers=chat_message_req.mcp_headers,
additional_context=chat_message_req.additional_context,
external_state_container=state_container,
):
yield get_json_line(obj.model_dump())

View File

@@ -125,11 +125,6 @@ class SendMessageRequest(BaseModel):
# - No CitationInfo packets are emitted during streaming
include_citations: bool = True
# Additional context injected into the LLM call but NOT stored in the DB
# (not shown in chat history). Used e.g. by the Chrome extension to pass
# the current tab URL when "Read this tab" is enabled.
additional_context: str | None = None
@model_validator(mode="after")
def check_chat_session_id_or_info(self) -> "SendMessageRequest":
# If neither is provided, default to creating a new chat session using the

View File

@@ -245,11 +245,7 @@ def setup_postgres(db_session: Session) -> None:
create_initial_default_connector(db_session)
associate_default_cc_pair(db_session)
if (
GEN_AI_API_KEY
and fetch_default_llm_model(db_session) is None
and not INTEGRATION_TESTS_MODE
):
if GEN_AI_API_KEY and fetch_default_llm_model(db_session) is None:
# Only for dev flows
logger.notice("Setting up default OpenAI LLM for dev.")
@@ -261,6 +257,7 @@ def setup_postgres(db_session: Session) -> None:
api_base=None,
api_version=None,
custom_config=None,
default_model_name=llm_model,
is_public=True,
groups=[],
model_configurations=[
@@ -272,9 +269,7 @@ def setup_postgres(db_session: Session) -> None:
new_llm_provider = upsert_llm_provider(
llm_provider_upsert_request=model_req, db_session=db_session
)
update_default_provider(
provider_id=new_llm_provider.id, model_name=llm_model, db_session=db_session
)
update_default_provider(provider_id=new_llm_provider.id, db_session=db_session)
def update_default_multipass_indexing(db_session: Session) -> None:

View File

@@ -1,8 +1,5 @@
import json
from collections.abc import Generator
from typing import Literal
from typing import TypedDict
from typing import Union
import requests
from pydantic import BaseModel
@@ -39,39 +36,6 @@ class ExecuteResponse(BaseModel):
files: list[WorkspaceFile]
class StreamOutputEvent(BaseModel):
"""SSE 'output' event: a chunk of stdout or stderr"""
stream: Literal["stdout", "stderr"]
data: str
class StreamResultEvent(BaseModel):
"""SSE 'result' event: final execution result"""
exit_code: int | None
timed_out: bool
duration_ms: int
files: list[WorkspaceFile]
class StreamErrorEvent(BaseModel):
"""SSE 'error' event: execution-level error"""
message: str
StreamEvent = Union[StreamOutputEvent, StreamResultEvent, StreamErrorEvent]
_SSE_EVENT_MAP: dict[
str, type[StreamOutputEvent | StreamResultEvent | StreamErrorEvent]
] = {
"output": StreamOutputEvent,
"result": StreamResultEvent,
"error": StreamErrorEvent,
}
class CodeInterpreterClient:
"""Client for Code Interpreter service"""
@@ -81,34 +45,6 @@ class CodeInterpreterClient:
self.base_url = base_url.rstrip("/")
self.session = requests.Session()
def _build_payload(
self,
code: str,
stdin: str | None,
timeout_ms: int,
files: list[FileInput] | None,
) -> dict:
payload: dict = {
"code": code,
"timeout_ms": timeout_ms,
}
if stdin is not None:
payload["stdin"] = stdin
if files:
payload["files"] = files
return payload
def health(self) -> bool:
"""Check if the Code Interpreter service is healthy"""
url = f"{self.base_url}/health"
try:
response = self.session.get(url, timeout=5)
response.raise_for_status()
return response.json().get("status") == "ok"
except Exception as e:
logger.warning(f"Exception caught when checking health, e={e}")
return False
def execute(
self,
code: str,
@@ -116,110 +52,25 @@ class CodeInterpreterClient:
timeout_ms: int = 30000,
files: list[FileInput] | None = None,
) -> ExecuteResponse:
"""Execute Python code (batch)"""
"""Execute Python code"""
url = f"{self.base_url}/v1/execute"
payload = self._build_payload(code, stdin, timeout_ms, files)
payload = {
"code": code,
"timeout_ms": timeout_ms,
}
if stdin is not None:
payload["stdin"] = stdin
if files:
payload["files"] = files
response = self.session.post(url, json=payload, timeout=timeout_ms / 1000 + 10)
response.raise_for_status()
return ExecuteResponse(**response.json())
def execute_streaming(
self,
code: str,
stdin: str | None = None,
timeout_ms: int = 30000,
files: list[FileInput] | None = None,
) -> Generator[StreamEvent, None, None]:
"""Execute Python code with streaming SSE output.
Yields StreamEvent objects (StreamOutputEvent, StreamResultEvent,
StreamErrorEvent) as execution progresses. Falls back to batch
execution if the streaming endpoint is not available (older
code-interpreter versions).
"""
url = f"{self.base_url}/v1/execute/stream"
payload = self._build_payload(code, stdin, timeout_ms, files)
response = self.session.post(
url,
json=payload,
stream=True,
timeout=timeout_ms / 1000 + 10,
)
if response.status_code == 404:
logger.info(
"Streaming endpoint not available, " "falling back to batch execution"
)
response.close()
yield from self._batch_as_stream(code, stdin, timeout_ms, files)
return
response.raise_for_status()
yield from self._parse_sse(response)
def _parse_sse(
self, response: requests.Response
) -> Generator[StreamEvent, None, None]:
"""Parse SSE streaming response into StreamEvent objects.
Expected format per event:
event: <type>
data: <json>
<blank line>
"""
event_type: str | None = None
data_lines: list[str] = []
for line in response.iter_lines(decode_unicode=True):
if line is None:
continue
if line == "":
# Blank line marks end of an SSE event
if event_type is not None and data_lines:
data = "\n".join(data_lines)
model_cls = _SSE_EVENT_MAP.get(event_type)
if model_cls is not None:
yield model_cls(**json.loads(data))
else:
logger.warning(f"Unknown SSE event type: {event_type}")
event_type = None
data_lines = []
elif line.startswith("event:"):
event_type = line[len("event:") :].strip()
elif line.startswith("data:"):
data_lines.append(line[len("data:") :].strip())
if event_type is not None or data_lines:
logger.warning(
f"SSE stream ended with incomplete event: "
f"event_type={event_type}, data_lines={data_lines}"
)
def _batch_as_stream(
self,
code: str,
stdin: str | None,
timeout_ms: int,
files: list[FileInput] | None,
) -> Generator[StreamEvent, None, None]:
"""Execute via batch endpoint and yield results as stream events."""
result = self.execute(code, stdin, timeout_ms, files)
if result.stdout:
yield StreamOutputEvent(stream="stdout", data=result.stdout)
if result.stderr:
yield StreamOutputEvent(stream="stderr", data=result.stderr)
yield StreamResultEvent(
exit_code=result.exit_code,
timed_out=result.timed_out,
duration_ms=result.duration_ms,
files=result.files,
)
def upload_file(self, file_content: bytes, filename: str) -> str:
"""Upload file to Code Interpreter and return file_id"""
url = f"{self.base_url}/v1/files"

View File

@@ -12,7 +12,6 @@ from onyx.configs.app_configs import CODE_INTERPRETER_BASE_URL
from onyx.configs.app_configs import CODE_INTERPRETER_DEFAULT_TIMEOUT_MS
from onyx.configs.app_configs import CODE_INTERPRETER_MAX_OUTPUT_LENGTH
from onyx.configs.constants import FileOrigin
from onyx.db.code_interpreter import fetch_code_interpreter_server
from onyx.file_store.utils import build_full_frontend_file_url
from onyx.file_store.utils import get_default_file_store
from onyx.server.query_and_chat.placement import Placement
@@ -29,15 +28,6 @@ from onyx.tools.tool_implementations.python.code_interpreter_client import (
CodeInterpreterClient,
)
from onyx.tools.tool_implementations.python.code_interpreter_client import FileInput
from onyx.tools.tool_implementations.python.code_interpreter_client import (
StreamErrorEvent,
)
from onyx.tools.tool_implementations.python.code_interpreter_client import (
StreamOutputEvent,
)
from onyx.tools.tool_implementations.python.code_interpreter_client import (
StreamResultEvent,
)
from onyx.utils.logger import setup_logger
@@ -104,10 +94,8 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
@override
@classmethod
def is_available(cls, db_session: Session) -> bool:
if not CODE_INTERPRETER_BASE_URL:
return False
server = fetch_code_interpreter_server(db_session)
return server.server_enabled
is_available = bool(CODE_INTERPRETER_BASE_URL)
return is_available
def tool_definition(self) -> dict:
return {
@@ -193,50 +181,19 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
try:
logger.debug(f"Executing code: {code}")
# Execute code with streaming (falls back to batch if unavailable)
stdout_parts: list[str] = []
stderr_parts: list[str] = []
result_event: StreamResultEvent | None = None
for event in client.execute_streaming(
# Execute code with timeout
response = client.execute(
code=code,
timeout_ms=CODE_INTERPRETER_DEFAULT_TIMEOUT_MS,
files=files_to_stage or None,
):
if isinstance(event, StreamOutputEvent):
if event.stream == "stdout":
stdout_parts.append(event.data)
else:
stderr_parts.append(event.data)
# Emit incremental delta to frontend
self.emitter.emit(
Packet(
placement=placement,
obj=PythonToolDelta(
stdout=event.data if event.stream == "stdout" else "",
stderr=event.data if event.stream == "stderr" else "",
),
)
)
elif isinstance(event, StreamResultEvent):
result_event = event
elif isinstance(event, StreamErrorEvent):
raise RuntimeError(f"Code interpreter error: {event.message}")
if result_event is None:
raise RuntimeError(
"Code interpreter stream ended without a result event"
)
full_stdout = "".join(stdout_parts)
full_stderr = "".join(stderr_parts)
)
# Truncate output for LLM consumption
truncated_stdout = _truncate_output(
full_stdout, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stdout"
response.stdout, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stdout"
)
truncated_stderr = _truncate_output(
full_stderr, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stderr"
response.stderr, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stderr"
)
# Handle generated files
@@ -245,7 +202,7 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
file_ids_to_cleanup: list[str] = []
file_store = get_default_file_store()
for workspace_file in result_event.files:
for workspace_file in response.files:
if workspace_file.kind != "file" or not workspace_file.file_id:
continue
@@ -301,23 +258,26 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
f"Failed to delete Code Interpreter staged file {file_mapping['file_id']}: {e}"
)
# Emit file_ids once files are processed
if generated_file_ids:
self.emitter.emit(
Packet(
placement=placement,
obj=PythonToolDelta(file_ids=generated_file_ids),
)
# Emit delta with stdout/stderr and generated files
self.emitter.emit(
Packet(
placement=placement,
obj=PythonToolDelta(
stdout=truncated_stdout,
stderr=truncated_stderr,
file_ids=generated_file_ids,
),
)
)
# Build result
result = LlmPythonExecutionResult(
stdout=truncated_stdout,
stderr=truncated_stderr,
exit_code=result_event.exit_code,
timed_out=result_event.timed_out,
exit_code=response.exit_code,
timed_out=response.timed_out,
generated_files=generated_files,
error=None if result_event.exit_code == 0 else truncated_stderr,
error=None if response.exit_code == 0 else truncated_stderr,
)
# Serialize result for LLM

View File

@@ -6,8 +6,6 @@ aioboto3==15.1.0
# via onyx
aiobotocore==2.24.0
# via aioboto3
aiofile==3.9.0
# via py-key-value-aio
aiofiles==25.1.0
# via
# aioboto3
@@ -42,10 +40,8 @@ anyio==4.11.0
# httpx
# mcp
# openai
# py-key-value-aio
# sse-starlette
# starlette
# watchfiles
argon2-cffi==23.1.0
# via pwdlib
argon2-cffi-bindings==25.1.0
@@ -78,7 +74,9 @@ backports-tarfile==1.2.0 ; python_full_version < '3.12'
bcrypt==4.3.0
# via pwdlib
beartype==0.22.6
# via py-key-value-aio
# via
# py-key-value-aio
# py-key-value-shared
beautifulsoup4==4.12.3
# via
# atlassian-python-api
@@ -112,8 +110,6 @@ cachetools==6.2.2
# via
# google-auth
# py-key-value-aio
caio==0.9.25
# via aiofile
celery==5.5.1
# via onyx
certifi==2025.11.12
@@ -174,6 +170,7 @@ cloudpickle==3.1.2
# via
# dask
# distributed
# pydocket
cobble==0.1.4
# via mammoth
cohere==5.6.1
@@ -221,6 +218,8 @@ deprecated==1.3.1
# pygithub
discord-py==2.4.0
# via onyx
diskcache==5.6.3
# via py-key-value-aio
distributed==2026.1.1
# via onyx
distro==1.9.0
@@ -257,6 +256,8 @@ exceptiongroup==1.3.0
# via
# braintrust
# fastmcp
fakeredis==2.33.0
# via pydocket
fastapi==0.128.0
# via
# fastapi-limiter
@@ -272,7 +273,7 @@ fastapi-users-db-sqlalchemy==7.0.0
# via onyx
fastavro==1.12.1
# via cohere
fastmcp==3.0.2
fastmcp==2.14.2
# via onyx
fastuuid==0.14.0
# via litellm
@@ -477,9 +478,7 @@ jsonpatch==1.33
jsonpointer==3.0.0
# via jsonpatch
jsonref==1.1.0
# via
# fastmcp
# onyx
# via onyx
jsonschema==4.25.1
# via
# litellm
@@ -514,6 +513,8 @@ locket==1.0.0
# via
# distributed
# partd
lupa==2.6
# via fakeredis
lxml==5.3.0
# via
# htmldate
@@ -555,7 +556,7 @@ marshmallow==3.26.2
# via dataclasses-json
matrix-client==0.3.2
# via zulip
mcp==1.26.0
mcp==1.25.0
# via
# claude-agent-sdk
# fastmcp
@@ -612,7 +613,7 @@ oauthlib==3.2.2
# kubernetes
# onyx
# requests-oauthlib
office365-rest-python-client==2.6.2
office365-rest-python-client==2.5.9
# via onyx
olefile==0.47
# via
@@ -641,16 +642,22 @@ opensearch-py==3.0.0
opentelemetry-api==1.39.1
# via
# ddtrace
# fastmcp
# langfuse
# openinference-instrumentation
# opentelemetry-exporter-otlp-proto-http
# opentelemetry-exporter-prometheus
# opentelemetry-instrumentation
# opentelemetry-sdk
# opentelemetry-semantic-conventions
# pydocket
opentelemetry-exporter-otlp-proto-common==1.39.1
# via opentelemetry-exporter-otlp-proto-http
opentelemetry-exporter-otlp-proto-http==1.39.1
# via langfuse
opentelemetry-exporter-prometheus==0.60b1
# via pydocket
opentelemetry-instrumentation==0.60b1
# via pydocket
opentelemetry-proto==1.39.1
# via
# onyx
@@ -661,15 +668,17 @@ opentelemetry-sdk==1.39.1
# langfuse
# openinference-instrumentation
# opentelemetry-exporter-otlp-proto-http
# opentelemetry-exporter-prometheus
opentelemetry-semantic-conventions==0.60b1
# via opentelemetry-sdk
# via
# opentelemetry-instrumentation
# opentelemetry-sdk
orjson==3.11.4 ; platform_python_implementation != 'PyPy'
# via langsmith
packaging==24.2
# via
# dask
# distributed
# fastmcp
# google-cloud-aiplatform
# google-cloud-bigquery
# huggingface-hub
@@ -680,6 +689,7 @@ packaging==24.2
# langsmith
# marshmallow
# onnxruntime
# opentelemetry-instrumentation
# pytest
# pywikibot
pandas==2.3.3
@@ -692,6 +702,8 @@ passlib==1.7.4
# via onyx
pathable==0.4.4
# via jsonschema-path
pathvalidate==3.3.1
# via py-key-value-aio
pdfminer-six==20251107
# via markitdown
pillow==12.1.1
@@ -711,7 +723,9 @@ ply==3.11
prometheus-client==0.23.1
# via
# onyx
# opentelemetry-exporter-prometheus
# prometheus-fastapi-instrumentator
# pydocket
prometheus-fastapi-instrumentator==7.1.0
# via onyx
prompt-toolkit==3.0.52
@@ -750,8 +764,12 @@ pwdlib==0.3.0
# via fastapi-users
py==1.11.0
# via retry
py-key-value-aio==0.4.4
# via fastmcp
py-key-value-aio==0.3.0
# via
# fastmcp
# pydocket
py-key-value-shared==0.3.0
# via py-key-value-aio
pyairtable==3.0.1
# via onyx
pyasn1==0.6.2
@@ -788,6 +806,8 @@ pydantic-core==2.33.2
# via pydantic
pydantic-settings==2.12.0
# via mcp
pydocket==0.16.3
# via fastmcp
pyee==13.0.0
# via playwright
pygithub==2.5.0
@@ -859,6 +879,8 @@ python-http-client==3.3.7
# via sendgrid
python-iso639==2025.11.16
# via unstructured
python-json-logger==4.0.0
# via pydocket
python-magic==0.4.27
# via unstructured
python-multipart==0.0.22
@@ -896,7 +918,6 @@ pyyaml==6.0.3
# via
# dask
# distributed
# fastmcp
# huggingface-hub
# jsonschema-path
# kubernetes
@@ -907,8 +928,11 @@ rapidfuzz==3.13.0
# unstructured
redis==5.0.8
# via
# fakeredis
# fastapi-limiter
# onyx
# py-key-value-aio
# pydocket
referencing==0.36.2
# via
# jsonschema
@@ -983,6 +1007,7 @@ rich==14.2.0
# via
# cyclopts
# fastmcp
# pydocket
# rich-rst
# typer
rich-rst==1.3.2
@@ -1031,7 +1056,9 @@ sniffio==1.3.1
# anyio
# openai
sortedcontainers==2.4.0
# via distributed
# via
# distributed
# fakeredis
soupsieve==2.8
# via beautifulsoup4
sqlalchemy==2.0.15
@@ -1097,7 +1124,9 @@ tqdm==4.67.1
trafilatura==1.12.2
# via onyx
typer==0.20.0
# via mcp
# via
# mcp
# pydocket
types-awscrt==0.28.4
# via botocore-stubs
types-openpyxl==3.0.4.7
@@ -1133,10 +1162,11 @@ typing-extensions==4.15.0
# opentelemetry-exporter-otlp-proto-http
# opentelemetry-sdk
# opentelemetry-semantic-conventions
# py-key-value-aio
# py-key-value-shared
# pyairtable
# pydantic
# pydantic-core
# pydocket
# pyee
# pygithub
# python-docx
@@ -1204,8 +1234,6 @@ vine==5.1.0
# kombu
voyageai==0.2.3
# via onyx
watchfiles==1.1.1
# via fastmcp
wcwidth==0.2.14
# via prompt-toolkit
webencodings==0.5.1
@@ -1226,6 +1254,7 @@ wrapt==1.17.3
# deprecated
# langfuse
# openinference-instrumentation
# opentelemetry-instrumentation
# unstructured
xlrd==2.0.2
# via markitdown

View File

@@ -288,7 +288,7 @@ matplotlib-inline==0.2.1
# via
# ipykernel
# ipython
mcp==1.26.0
mcp==1.25.0
# via claude-agent-sdk
multidict==6.7.0
# via
@@ -317,7 +317,7 @@ oauthlib==3.2.2
# via
# kubernetes
# requests-oauthlib
onyx-devtools==0.6.1
onyx-devtools==0.6.0
# via onyx
openai==2.14.0
# via

View File

@@ -211,7 +211,7 @@ litellm==1.81.6
# via onyx
markupsafe==3.0.3
# via jinja2
mcp==1.26.0
mcp==1.25.0
# via claude-agent-sdk
monotonic==1.6
# via posthog

View File

@@ -246,7 +246,7 @@ litellm==1.81.6
# via onyx
markupsafe==3.0.3
# via jinja2
mcp==1.26.0
mcp==1.25.0
# via claude-agent-sdk
mpmath==1.3.0
# via sympy

View File

@@ -95,7 +95,6 @@ def generate_dummy_chunk(
return DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
user_project=[],
personas=[],
access=DocumentAccess.build(
user_emails=user_emails,
user_groups=user_groups,

View File

@@ -3,8 +3,8 @@ set -e
cleanup() {
echo "Error occurred. Cleaning up..."
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
}
# Trap errors and output a message, then cleanup
@@ -20,8 +20,8 @@ MINIO_VOLUME=${4:-""} # Default is empty if not provided
# Stop and remove the existing containers
echo "Stopping and removing existing containers..."
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
# Start the PostgreSQL container with optional volume
echo "Starting PostgreSQL container..."
@@ -55,10 +55,6 @@ else
docker run --detach --name onyx_minio --publish 9004:9000 --publish 9005:9001 -e MINIO_ROOT_USER=minioadmin -e MINIO_ROOT_PASSWORD=minioadmin minio/minio server /data --console-address ":9001"
fi
# Start the Code Interpreter container
echo "Starting Code Interpreter container..."
docker run --detach --name onyx_code_interpreter --publish 8000:8000 --user root -v /var/run/docker.sock:/var/run/docker.sock onyxdotapp/code-interpreter:latest bash ./entrypoint.sh code-interpreter-api
# Ensure alembic runs in the correct directory (backend/)
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
PARENT_DIR="$(dirname "$SCRIPT_DIR")"

View File

@@ -9,7 +9,6 @@ from collections.abc import AsyncGenerator
from collections.abc import Generator
from contextlib import asynccontextmanager
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from dotenv import load_dotenv
@@ -47,15 +46,11 @@ def mock_current_admin_user() -> MagicMock:
@pytest.fixture(scope="function")
def client() -> Generator[TestClient, None, None]:
# Initialize TestClient with the FastAPI app using a no-op test lifespan.
# Patch out prometheus metrics setup to avoid "Duplicated timeseries in
# CollectorRegistry" errors when multiple tests each create a new app
# (prometheus registers metrics globally and rejects duplicate names).
# Initialize TestClient with the FastAPI app using a no-op test lifespan
get_app = fetch_versioned_implementation(
module="onyx.main", attribute="get_application"
)
with patch("onyx.main.setup_prometheus_metrics"):
app: FastAPI = get_app(lifespan_override=test_lifespan)
app: FastAPI = get_app(lifespan_override=test_lifespan)
# Override the database session dependency with a mock
# (these tests don't actually need DB access)

View File

@@ -17,7 +17,7 @@ def test_bedrock_llm_configuration(client: TestClient) -> None:
# Prepare the test request payload
test_request: dict[str, Any] = {
"provider": LlmProviderNames.BEDROCK,
"model": _DEFAULT_BEDROCK_MODEL,
"default_model_name": _DEFAULT_BEDROCK_MODEL,
"api_key": None,
"api_base": None,
"api_version": None,
@@ -26,6 +26,7 @@ def test_bedrock_llm_configuration(client: TestClient) -> None:
"AWS_ACCESS_KEY_ID": os.environ.get("AWS_ACCESS_KEY_ID"),
"AWS_SECRET_ACCESS_KEY": os.environ.get("AWS_SECRET_ACCESS_KEY"),
},
"model_configurations": [{"name": _DEFAULT_BEDROCK_MODEL, "is_visible": True}],
"api_key_changed": True,
"custom_config_changed": True,
}
@@ -43,7 +44,7 @@ def test_bedrock_llm_configuration_invalid_key(client: TestClient) -> None:
# Prepare the test request payload with invalid credentials
test_request: dict[str, Any] = {
"provider": LlmProviderNames.BEDROCK,
"model": _DEFAULT_BEDROCK_MODEL,
"default_model_name": _DEFAULT_BEDROCK_MODEL,
"api_key": None,
"api_base": None,
"api_version": None,
@@ -52,6 +53,7 @@ def test_bedrock_llm_configuration_invalid_key(client: TestClient) -> None:
"AWS_ACCESS_KEY_ID": "invalid_access_key_id",
"AWS_SECRET_ACCESS_KEY": "invalid_secret_access_key",
},
"model_configurations": [{"name": _DEFAULT_BEDROCK_MODEL, "is_visible": True}],
"api_key_changed": True,
"custom_config_changed": True,
}

View File

@@ -28,6 +28,7 @@ def ensure_default_llm_provider(db_session: Session) -> None:
provider=LlmProviderNames.OPENAI,
api_key=os.environ.get("OPENAI_API_KEY", "test"),
is_public=True,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini",
@@ -40,7 +41,7 @@ def ensure_default_llm_provider(db_session: Session) -> None:
llm_provider_upsert_request=llm_provider_request,
db_session=db_session,
)
update_default_provider(provider.id, "gpt-4o-mini", db_session)
update_default_provider(provider.id, db_session)
except Exception as exc: # pragma: no cover - only hits on duplicate setup issues
# Rollback to clear the pending transaction state
db_session.rollback()

View File

@@ -47,6 +47,7 @@ def test_answer_with_only_anthropic_provider(
name=provider_name,
provider=LlmProviderNames.ANTHROPIC,
api_key=anthropic_api_key,
default_model_name=anthropic_model,
is_public=True,
groups=[],
model_configurations=[
@@ -58,7 +59,7 @@ def test_answer_with_only_anthropic_provider(
)
try:
update_default_provider(anthropic_provider.id, anthropic_model, db_session)
update_default_provider(anthropic_provider.id, db_session)
test_user = create_test_user(db_session, email_prefix="anthropic_only")
chat_session = create_chat_session(

View File

@@ -29,7 +29,6 @@ from onyx.server.manage.llm.api import (
test_llm_configuration as run_test_llm_configuration,
)
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from onyx.server.manage.llm.models import TestLLMRequest as LLMTestRequest
@@ -45,14 +44,15 @@ def _create_test_provider(
db_session: Session,
name: str,
api_key: str = "sk-test-key-00000000000000000000000000000000000",
) -> LLMProviderView:
) -> None:
"""Helper to create a test LLM provider in the database."""
return upsert_llm_provider(
upsert_llm_provider(
LLMProviderUpsertRequest(
name=name,
provider=LlmProviderNames.OPENAI,
api_key=api_key,
api_key_changed=True,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(name="gpt-4o-mini", is_visible=True)
],
@@ -107,7 +107,12 @@ class TestLLMConfigurationEndpoint:
api_key="sk-new-test-key-0000000000000000000000000000",
api_key_changed=True,
custom_config_changed=False,
model="gpt-4o-mini",
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
),
_=_create_mock_admin(),
db_session=db_session,
@@ -152,7 +157,12 @@ class TestLLMConfigurationEndpoint:
api_key="sk-invalid-key-00000000000000000000000000",
api_key_changed=True,
custom_config_changed=False,
model="gpt-4o-mini",
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
),
_=_create_mock_admin(),
db_session=db_session,
@@ -184,9 +194,7 @@ class TestLLMConfigurationEndpoint:
try:
# First, create the provider in the database
provider = _create_test_provider(
db_session, provider_name, api_key=original_api_key
)
_create_test_provider(db_session, provider_name, api_key=original_api_key)
with patch(
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
@@ -194,13 +202,17 @@ class TestLLMConfigurationEndpoint:
# Test with api_key_changed=False - should use stored key
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
id=provider.id,
name=provider_name, # Existing provider
provider=LlmProviderNames.OPENAI,
api_key=None, # Not providing a new key
api_key_changed=False, # Using existing key
custom_config_changed=False,
model="gpt-4o-mini",
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
),
_=_create_mock_admin(),
db_session=db_session,
@@ -247,7 +259,12 @@ class TestLLMConfigurationEndpoint:
api_key=new_api_key, # Providing a new key
api_key_changed=True, # Key is being changed
custom_config_changed=False,
model="gpt-4o-mini",
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
),
_=_create_mock_admin(),
db_session=db_session,
@@ -280,7 +297,7 @@ class TestLLMConfigurationEndpoint:
try:
# First, create the provider in the database with custom_config
provider = upsert_llm_provider(
upsert_llm_provider(
LLMProviderUpsertRequest(
name=provider_name,
provider=LlmProviderNames.OPENAI,
@@ -288,6 +305,12 @@ class TestLLMConfigurationEndpoint:
api_key_changed=True,
custom_config=original_custom_config,
custom_config_changed=True,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
),
db_session=db_session,
)
@@ -298,14 +321,18 @@ class TestLLMConfigurationEndpoint:
# Test with custom_config_changed=False - should use stored config
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key=None,
api_key_changed=False,
custom_config=None, # Not providing new config
custom_config_changed=False, # Using existing config
model="gpt-4o-mini",
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
),
_=_create_mock_admin(),
db_session=db_session,
@@ -346,7 +373,12 @@ class TestLLMConfigurationEndpoint:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
custom_config_changed=False,
model=model_name,
default_model_name=model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=model_name, is_visible=True
)
],
),
_=_create_mock_admin(),
db_session=db_session,
@@ -410,6 +442,7 @@ class TestDefaultProviderEndpoint:
provider=LlmProviderNames.OPENAI,
api_key=provider_1_api_key,
api_key_changed=True,
default_model_name=provider_1_initial_model,
model_configurations=[
ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True),
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
@@ -419,7 +452,7 @@ class TestDefaultProviderEndpoint:
)
# Set provider 1 as the default provider explicitly
update_default_provider(provider_1.id, provider_1_initial_model, db_session)
update_default_provider(provider_1.id, db_session)
# Step 2: Call run_test_default_provider - should use provider 1's default model
with patch(
@@ -439,6 +472,7 @@ class TestDefaultProviderEndpoint:
provider=LlmProviderNames.OPENAI,
api_key=provider_2_api_key,
api_key_changed=True,
default_model_name=provider_2_default_model,
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
@@ -465,11 +499,11 @@ class TestDefaultProviderEndpoint:
# Step 5: Update provider 1's default model
upsert_llm_provider(
LLMProviderUpsertRequest(
id=provider_1.id,
name=provider_1_name,
provider=LlmProviderNames.OPENAI,
api_key=provider_1_api_key,
api_key_changed=True,
default_model_name=provider_1_updated_model, # Changed
model_configurations=[
ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True),
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
@@ -478,9 +512,6 @@ class TestDefaultProviderEndpoint:
db_session=db_session,
)
# Set provider 1's default model to the updated model
update_default_provider(provider_1.id, provider_1_updated_model, db_session)
# Step 6: Call run_test_default_provider - should use new model on provider 1
with patch(
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
@@ -493,7 +524,7 @@ class TestDefaultProviderEndpoint:
captured_llms.clear()
# Step 7: Change the default provider to provider 2
update_default_provider(provider_2.id, provider_2_default_model, db_session)
update_default_provider(provider_2.id, db_session)
# Step 8: Call run_test_default_provider - should use provider 2
with patch(
@@ -565,6 +596,7 @@ class TestDefaultProviderEndpoint:
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
@@ -573,7 +605,7 @@ class TestDefaultProviderEndpoint:
),
db_session=db_session,
)
update_default_provider(provider.id, "gpt-4o-mini", db_session)
update_default_provider(provider.id, db_session)
# Test should fail
with patch(

View File

@@ -20,7 +20,6 @@ from fastapi import HTTPException
from sqlalchemy.orm import Session
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import fetch_llm_provider_view
from onyx.db.llm import remove_llm_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import UserRole
@@ -50,6 +49,7 @@ def _create_test_provider(
api_key_changed=True,
api_base=api_base,
custom_config=custom_config,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(name="gpt-4o-mini", is_visible=True)
],
@@ -91,14 +91,14 @@ class TestLLMProviderChanges:
the API key should be blocked.
"""
try:
provider = _create_test_provider(db_session, provider_name)
_create_test_provider(db_session, provider_name)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_base="https://attacker.example.com",
default_model_name="gpt-4o-mini",
)
with pytest.raises(HTTPException) as exc_info:
@@ -125,16 +125,16 @@ class TestLLMProviderChanges:
Changing api_base IS allowed when the API key is also being changed.
"""
try:
provider = _create_test_provider(db_session, provider_name)
_create_test_provider(db_session, provider_name)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-new-key-00000000000000000000000000000000000",
api_key_changed=True,
api_base="https://custom-endpoint.example.com/v1",
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -159,16 +159,14 @@ class TestLLMProviderChanges:
original_api_base = "https://original.example.com/v1"
try:
provider = _create_test_provider(
db_session, provider_name, api_base=original_api_base
)
_create_test_provider(db_session, provider_name, api_base=original_api_base)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_base=original_api_base,
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -192,14 +190,14 @@ class TestLLMProviderChanges:
changes. This allows model-only updates when provider has no custom base URL.
"""
try:
view = _create_test_provider(db_session, provider_name, api_base=None)
_create_test_provider(db_session, provider_name, api_base=None)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=view.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_base="",
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -225,16 +223,14 @@ class TestLLMProviderChanges:
original_api_base = "https://original.example.com/v1"
try:
provider = _create_test_provider(
db_session, provider_name, api_base=original_api_base
)
_create_test_provider(db_session, provider_name, api_base=original_api_base)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_base=None,
default_model_name="gpt-4o-mini",
)
with pytest.raises(HTTPException) as exc_info:
@@ -263,14 +259,14 @@ class TestLLMProviderChanges:
users have full control over their deployment.
"""
try:
provider = _create_test_provider(db_session, provider_name)
_create_test_provider(db_session, provider_name)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_base="https://custom.example.com/v1",
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -301,6 +297,7 @@ class TestLLMProviderChanges:
api_key="sk-new-key-00000000000000000000000000000000000",
api_key_changed=True,
api_base="https://custom.example.com/v1",
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -325,7 +322,7 @@ class TestLLMProviderChanges:
redirect LLM API requests).
"""
try:
provider = _create_test_provider(
_create_test_provider(
db_session,
provider_name,
custom_config={"SOME_CONFIG": "original_value"},
@@ -333,11 +330,11 @@ class TestLLMProviderChanges:
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
custom_config={"OPENAI_API_BASE": "https://attacker.example.com"},
custom_config_changed=True,
default_model_name="gpt-4o-mini",
)
with pytest.raises(HTTPException) as exc_info:
@@ -365,15 +362,15 @@ class TestLLMProviderChanges:
without changing the API key.
"""
try:
provider = _create_test_provider(db_session, provider_name)
_create_test_provider(db_session, provider_name)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
custom_config={"OPENAI_API_BASE": "https://attacker.example.com"},
custom_config_changed=True,
default_model_name="gpt-4o-mini",
)
with pytest.raises(HTTPException) as exc_info:
@@ -402,7 +399,7 @@ class TestLLMProviderChanges:
new_config = {"AWS_REGION_NAME": "us-west-2"}
try:
provider = _create_test_provider(
_create_test_provider(
db_session,
provider_name,
custom_config={"AWS_REGION_NAME": "us-east-1"},
@@ -410,13 +407,13 @@ class TestLLMProviderChanges:
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-new-key-00000000000000000000000000000000000",
api_key_changed=True,
custom_config_changed=True,
custom_config=new_config,
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -441,17 +438,17 @@ class TestLLMProviderChanges:
original_config = {"AWS_REGION_NAME": "us-east-1"}
try:
provider = _create_test_provider(
_create_test_provider(
db_session, provider_name, custom_config=original_config
)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
custom_config=original_config,
custom_config_changed=True,
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -477,7 +474,7 @@ class TestLLMProviderChanges:
new_config = {"AWS_REGION_NAME": "eu-west-1"}
try:
provider = _create_test_provider(
_create_test_provider(
db_session,
provider_name,
custom_config={"AWS_REGION_NAME": "us-east-1"},
@@ -485,10 +482,10 @@ class TestLLMProviderChanges:
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
custom_config=new_config,
default_model_name="gpt-4o-mini",
custom_config_changed=True,
)
@@ -535,7 +532,12 @@ def test_upload_with_custom_config_then_change(
LLMTestRequest(
name=name,
provider=provider_name,
model=default_model_name,
default_model_name=default_model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name, is_visible=True
)
],
api_key_changed=False,
custom_config_changed=True,
custom_config=custom_config,
@@ -544,10 +546,11 @@ def test_upload_with_custom_config_then_change(
db_session=db_session,
)
provider = put_llm_provider(
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=name,
provider=provider_name,
default_model_name=default_model_name,
custom_config=custom_config,
model_configurations=[
ModelConfigurationUpsertRequest(
@@ -566,10 +569,14 @@ def test_upload_with_custom_config_then_change(
# Turn auto mode off
run_llm_config_test(
LLMTestRequest(
id=provider.id,
name=name,
provider=provider_name,
model=default_model_name,
default_model_name=default_model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name, is_visible=True
)
],
api_key_changed=False,
custom_config_changed=False,
),
@@ -579,9 +586,9 @@ def test_upload_with_custom_config_then_change(
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
id=provider.id,
name=name,
provider=provider_name,
default_model_name=default_model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name, is_visible=True
@@ -609,9 +616,7 @@ def test_upload_with_custom_config_then_change(
)
# Check inside the database and check that custom_config is the same as the original
provider = fetch_llm_provider_view(
db_session=db_session, provider_name=name
)
provider = fetch_existing_llm_provider(name=name, db_session=db_session)
if not provider:
assert False, "Provider not found in the database"
@@ -637,10 +642,11 @@ def test_preserves_masked_sensitive_custom_config_on_provider_update(
}
try:
view = put_llm_provider(
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=name,
provider=provider,
default_model_name=default_model_name,
custom_config=original_custom_config,
model_configurations=[
ModelConfigurationUpsertRequest(
@@ -659,9 +665,9 @@ def test_preserves_masked_sensitive_custom_config_on_provider_update(
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
id=view.id,
name=name,
provider=provider,
default_model_name=default_model_name,
custom_config={
"vertex_credentials": _mask_string(
original_custom_config["vertex_credentials"]
@@ -713,10 +719,11 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
return ""
try:
view = put_llm_provider(
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=name,
provider=provider,
default_model_name=default_model_name,
custom_config=original_custom_config,
model_configurations=[
ModelConfigurationUpsertRequest(
@@ -735,10 +742,14 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
with patch("onyx.server.manage.llm.api.test_llm", side_effect=capture_test_llm):
run_llm_config_test(
LLMTestRequest(
id=view.id,
name=name,
provider=provider,
model=default_model_name,
default_model_name=default_model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name, is_visible=True
)
],
api_key_changed=False,
custom_config_changed=True,
custom_config={

View File

@@ -18,7 +18,6 @@ from onyx.db.enums import LLMModelFlowType
from onyx.db.llm import fetch_default_llm_model
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_llm_provider_view
from onyx.db.llm import remove_llm_provider
from onyx.db.llm import sync_auto_mode_models
from onyx.db.llm import update_default_provider
@@ -136,6 +135,7 @@ class TestAutoModeSyncFeature:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
default_model_name=expected_default_model,
model_configurations=[], # No model configs provided
),
is_creation=True,
@@ -163,8 +163,13 @@ class TestAutoModeSyncFeature:
if mc.name in all_expected_models:
assert mc.is_visible is True, f"Model '{mc.name}' should be visible"
# Verify the default model was set correctly
assert (
provider.default_model_name == expected_default_model
), f"Default model should be '{expected_default_model}'"
# Step 4: Set the provider as default
update_default_provider(provider.id, expected_default_model, db_session)
update_default_provider(provider.id, db_session)
# Step 5: Fetch the default provider and verify
default_model = fetch_default_llm_model(db_session)
@@ -233,6 +238,7 @@ class TestAutoModeSyncFeature:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
default_model_name="gpt-4o",
model_configurations=[],
),
is_creation=True,
@@ -304,13 +310,14 @@ class TestAutoModeSyncFeature:
try:
# Step 1: Upload provider WITHOUT auto mode, with initial models
provider = put_llm_provider(
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=False, # Not in auto mode initially
default_model_name="gpt-4",
model_configurations=initial_models,
),
is_creation=True,
@@ -337,12 +344,12 @@ class TestAutoModeSyncFeature:
):
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key=None, # Not changing API key
api_key_changed=False,
is_auto_mode=True, # Now enabling auto mode
default_model_name=auto_mode_default,
model_configurations=[], # Auto mode will sync from config
),
is_creation=False, # This is an update
@@ -353,8 +360,8 @@ class TestAutoModeSyncFeature:
# Step 3: Verify model visibility after auto mode transition
# Expire session cache to force fresh fetch after sync_auto_mode_models committed
db_session.expire_all()
provider = fetch_llm_provider_view(
db_session=db_session, provider_name=provider_name
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
assert provider.is_auto_mode is True
@@ -381,6 +388,9 @@ class TestAutoModeSyncFeature:
model_visibility[model_name] is False
), f"Model '{model_name}' not in auto config should NOT be visible"
# Verify the default model was updated
assert provider.default_model_name == auto_mode_default
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)
@@ -422,12 +432,8 @@ class TestAutoModeSyncFeature:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o",
is_visible=True,
)
],
default_model_name="gpt-4o",
model_configurations=[],
),
is_creation=True,
_=_create_mock_admin(),
@@ -529,6 +535,7 @@ class TestAutoModeSyncFeature:
api_key=provider_1_api_key,
api_key_changed=True,
is_auto_mode=True,
default_model_name=provider_1_default_model,
model_configurations=[],
),
is_creation=True,
@@ -542,7 +549,7 @@ class TestAutoModeSyncFeature:
name=provider_1_name, db_session=db_session
)
assert provider_1 is not None
update_default_provider(provider_1.id, provider_1_default_model, db_session)
update_default_provider(provider_1.id, db_session)
with patch(
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
@@ -556,6 +563,7 @@ class TestAutoModeSyncFeature:
api_key=provider_2_api_key,
api_key_changed=True,
is_auto_mode=True,
default_model_name=provider_2_default_model,
model_configurations=[],
),
is_creation=True,
@@ -576,7 +584,7 @@ class TestAutoModeSyncFeature:
name=provider_2_name, db_session=db_session
)
assert provider_2 is not None
update_default_provider(provider_2.id, provider_2_default_model, db_session)
update_default_provider(provider_2.id, db_session)
# Step 5: Verify provider 2 is now the default
db_session.expire_all()

View File

@@ -64,6 +64,7 @@ def _create_provider(
name=name,
provider=provider,
api_key="sk-ant-api03-...",
default_model_name="claude-3-5-sonnet-20240620",
is_public=is_public,
model_configurations=[
ModelConfigurationUpsertRequest(
@@ -153,9 +154,7 @@ def test_user_sends_message_to_private_provider(
)
_create_provider(db_session, LlmProviderNames.GOOGLE, "private-provider", False)
update_default_provider(
public_provider_id, "claude-3-5-sonnet-20240620", db_session
)
update_default_provider(public_provider_id, db_session)
try:
# Create chat session

View File

@@ -434,6 +434,7 @@ class TestSlackBotFederatedSearch:
name=f"test-llm-provider-{uuid4().hex[:8]}",
provider=LlmProviderNames.OPENAI,
api_key=api_key,
default_model_name="gpt-4o",
is_public=True,
model_configurations=[
ModelConfigurationUpsertRequest(
@@ -447,7 +448,7 @@ class TestSlackBotFederatedSearch:
db_session=db_session,
)
update_default_provider(provider_view.id, "gpt-4o", db_session)
update_default_provider(provider_view.id, db_session)
def _teardown_common_mocks(self, patches: list) -> None:
"""Stop all patches"""

View File

@@ -990,27 +990,6 @@ class _MockCIHandler(BaseHTTPRequestHandler):
self._respond_json(
200, {"file_id": f"mock-ci-file-{self.server._file_counter}"}
)
elif self.path == "/v1/execute/stream":
if self.server.streaming_enabled:
self._respond_sse(
[
(
"output",
{"stream": "stdout", "data": "mock output\n"},
),
(
"result",
{
"exit_code": 0,
"timed_out": False,
"duration_ms": 50,
"files": [],
},
),
]
)
else:
self._respond_json(404, {"error": "not found"})
elif self.path == "/v1/execute":
self._respond_json(
200,
@@ -1048,17 +1027,6 @@ class _MockCIHandler(BaseHTTPRequestHandler):
self.end_headers()
self.wfile.write(payload)
def _respond_sse(self, events: list[tuple[str, dict[str, Any]]]) -> None:
frames = []
for event_type, data in events:
frames.append(f"event: {event_type}\ndata: {json.dumps(data)}\n\n")
payload = "".join(frames).encode()
self.send_response(200)
self.send_header("Content-Type", "text/event-stream")
self.send_header("Content-Length", str(len(payload)))
self.end_headers()
self.wfile.write(payload)
def log_message(self, format: str, *args: Any) -> None: # noqa: A002
pass
@@ -1070,7 +1038,6 @@ class MockCodeInterpreterServer(HTTPServer):
super().__init__(("localhost", 0), _MockCIHandler)
self.captured_requests: list[CapturedRequest] = []
self._file_counter = 0
self.streaming_enabled: bool = True
@property
def url(self) -> str:
@@ -1201,19 +1168,17 @@ def test_code_interpreter_receives_chat_files(
finally:
ci_mod.CodeInterpreterClient.__init__.__defaults__ = original_defaults
# Verify: file uploaded, code executed via streaming, staged file cleaned up
# Verify: file uploaded, code executed, staged file cleaned up
assert len(mock_ci_server.get_requests(method="POST", path="/v1/files")) == 1
assert (
len(mock_ci_server.get_requests(method="POST", path="/v1/execute/stream")) == 1
)
assert len(mock_ci_server.get_requests(method="POST", path="/v1/execute")) == 1
delete_requests = mock_ci_server.get_requests(method="DELETE")
assert len(delete_requests) == 1
assert delete_requests[0].path.startswith("/v1/files/")
execute_body = mock_ci_server.get_requests(
method="POST", path="/v1/execute/stream"
)[0].json_body()
execute_body = mock_ci_server.get_requests(method="POST", path="/v1/execute")[
0
].json_body()
assert execute_body["code"] == code
assert len(execute_body["files"]) == 1
assert execute_body["files"][0]["path"] == "data.csv"
@@ -1319,9 +1284,7 @@ def test_code_interpreter_replay_packets_include_code_and_output(
db_session=db_session,
)
assert (
len(mock_ci_server.get_requests(method="POST", path="/v1/execute/stream")) == 1
)
assert len(mock_ci_server.get_requests(method="POST", path="/v1/execute")) == 1
# The response contains `packets` — a list of packet-lists, one per
# assistant message. We should have exactly one assistant message.
@@ -1350,76 +1313,3 @@ def test_code_interpreter_replay_packets_include_code_and_output(
delta_obj = delta_packets[0].obj
assert isinstance(delta_obj, PythonToolDelta)
assert "mock output" in delta_obj.stdout
def test_code_interpreter_streaming_fallback_to_batch(
db_session: Session,
mock_ci_server: MockCodeInterpreterServer,
_attach_python_tool_to_default_persona: None,
initialize_file_store: None, # noqa: ARG001
) -> None:
"""When the streaming endpoint is not available (older code-interpreter),
execute_streaming should fall back to the batch /v1/execute endpoint."""
mock_ci_server.captured_requests.clear()
mock_ci_server._file_counter = 0
mock_ci_server.streaming_enabled = False
mock_url = mock_ci_server.url
user = create_test_user(db_session, "ci_fallback_test")
chat_session = create_chat_session(db_session=db_session, user=user)
code = 'print("fallback test")'
msg_req = SendMessageRequest(
message="Print fallback test",
chat_session_id=chat_session.id,
stream=True,
)
original_defaults = ci_mod.CodeInterpreterClient.__init__.__defaults__
with (
use_mock_llm() as mock_llm,
patch(
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
mock_url,
),
patch(
"onyx.tools.tool_implementations.python.code_interpreter_client.CODE_INTERPRETER_BASE_URL",
mock_url,
),
):
mock_llm.add_response(
LLMToolCallResponse(
tool_name="python",
tool_call_id="call_fallback",
tool_call_argument_tokens=[json.dumps({"code": code})],
)
)
mock_llm.forward_till_end()
ci_mod.CodeInterpreterClient.__init__.__defaults__ = (mock_url,)
try:
packets = list(
handle_stream_message_objects(
new_msg_req=msg_req, user=user, db_session=db_session
)
)
finally:
ci_mod.CodeInterpreterClient.__init__.__defaults__ = original_defaults
mock_ci_server.streaming_enabled = True
# Streaming was attempted first (returned 404), then fell back to batch
assert (
len(mock_ci_server.get_requests(method="POST", path="/v1/execute/stream")) == 1
)
assert len(mock_ci_server.get_requests(method="POST", path="/v1/execute")) == 1
# Verify output still made it through
delta_packets = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, PythonToolDelta)
]
assert len(delta_packets) >= 1
first_delta = delta_packets[0].obj
assert isinstance(first_delta, PythonToolDelta)
assert "mock output" in first_delta.stdout

View File

@@ -1,53 +0,0 @@
"""Tests that PythonTool.is_available() respects the server_enabled DB flag.
Uses a real DB session with CODE_INTERPRETER_BASE_URL mocked so the
environment-variable check passes and the DB flag is the deciding factor.
"""
from unittest.mock import patch
from sqlalchemy.orm import Session
from onyx.db.code_interpreter import fetch_code_interpreter_server
from onyx.db.code_interpreter import update_code_interpreter_server_enabled
from onyx.tools.tool_implementations.python.python_tool import PythonTool
def test_python_tool_unavailable_when_server_disabled(
db_session: Session,
) -> None:
"""With a valid base URL, the tool should be unavailable when
server_enabled is False in the DB."""
server = fetch_code_interpreter_server(db_session)
initial_enabled = server.server_enabled
try:
update_code_interpreter_server_enabled(db_session, enabled=False)
with patch(
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
"http://fake:8888",
):
assert PythonTool.is_available(db_session) is False
finally:
update_code_interpreter_server_enabled(db_session, enabled=initial_enabled)
def test_python_tool_available_when_server_enabled(
db_session: Session,
) -> None:
"""With a valid base URL, the tool should be available when
server_enabled is True in the DB."""
server = fetch_code_interpreter_server(db_session)
initial_enabled = server.server_enabled
try:
update_code_interpreter_server_enabled(db_session, enabled=True)
with patch(
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
"http://fake:8888",
):
assert PythonTool.is_available(db_session) is True
finally:
update_code_interpreter_server_enabled(db_session, enabled=initial_enabled)

View File

@@ -38,5 +38,5 @@ COPY --from=openapi-client /local/onyx_openapi_client /app/generated/onyx_openap
ENV PYTHONPATH=/app
ENTRYPOINT ["pytest", "-s", "-rs"]
ENTRYPOINT ["pytest", "-s"]
CMD ["/app/tests/integration", "--ignore=/app/tests/integration/multitenant_tests"]

View File

@@ -4,12 +4,10 @@ from uuid import uuid4
import requests
from onyx.llm.constants import LlmProviderNames
from onyx.server.manage.llm.models import DefaultModel
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestUser
@@ -34,6 +32,7 @@ class LLMProviderManager:
llm_provider = LLMProviderUpsertRequest(
name=name or f"test-provider-{uuid4()}",
provider=provider or LlmProviderNames.OPENAI,
default_model_name=default_model_name or "gpt-4o-mini",
api_key=api_key or os.environ["OPENAI_API_KEY"],
api_base=api_base,
api_version=api_version,
@@ -66,6 +65,7 @@ class LLMProviderManager:
name=response_data["name"],
provider=response_data["provider"],
api_key=response_data["api_key"],
default_model_name=response_data["default_model_name"],
is_public=response_data["is_public"],
is_auto_mode=response_data.get("is_auto_mode", False),
groups=response_data["groups"],
@@ -75,20 +75,9 @@ class LLMProviderManager:
)
if set_as_default:
if default_model_name is None:
default_model_name = "gpt-4o-mini"
set_default_response = requests.post(
f"{API_SERVER_URL}/admin/llm/default",
json={
"provider_id": response_data["id"],
"model_name": default_model_name,
},
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
f"{API_SERVER_URL}/admin/llm/provider/{llm_response.json()['id']}/default",
headers=user_performing_action.headers,
)
set_default_response.raise_for_status()
@@ -124,12 +113,7 @@ class LLMProviderManager:
verify_deleted: bool = False,
) -> None:
all_llm_providers = LLMProviderManager.get_all(user_performing_action)
default_model = LLMProviderManager.get_default_model(user_performing_action)
for fetched_llm_provider in all_llm_providers:
model_names = [
model.name for model in fetched_llm_provider.model_configurations
]
if llm_provider.id == fetched_llm_provider.id:
if verify_deleted:
raise ValueError(
@@ -142,25 +126,11 @@ class LLMProviderManager:
if (
fetched_llm_groups == llm_provider_groups
and llm_provider.provider == fetched_llm_provider.provider
and default_model.model_name in model_names
and llm_provider.default_model_name
== fetched_llm_provider.default_model_name
and llm_provider.is_public == fetched_llm_provider.is_public
and set(fetched_llm_provider.personas) == set(llm_provider.personas)
):
return
if not verify_deleted:
raise ValueError(f"LLM Provider {llm_provider.id} not found")
@staticmethod
def get_default_model(
user_performing_action: DATestUser | None = None,
) -> DefaultModel:
response = requests.get(
f"{API_SERVER_URL}/admin/llm/default",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return DefaultModel(**response.json())

View File

@@ -1,4 +1,3 @@
import time
from datetime import datetime
from urllib.parse import urlencode
from uuid import UUID
@@ -9,10 +8,8 @@ from requests.models import CaseInsensitiveDict
from ee.onyx.server.query_history.models import ChatSessionMinimal
from ee.onyx.server.query_history.models import ChatSessionSnapshot
from onyx.configs.constants import QAFeedbackType
from onyx.db.enums import TaskStatus
from onyx.server.documents.models import PaginatedReturn
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import MAX_DELAY
from tests.integration.common_utils.test_models import DATestUser
@@ -72,42 +69,9 @@ class QueryHistoryManager:
if end_time:
query_params["end"] = end_time.isoformat()
start_response = requests.post(
url=f"{API_SERVER_URL}/admin/query-history/start-export?{urlencode(query_params, doseq=True)}",
response = requests.get(
url=f"{API_SERVER_URL}/admin/query-history-csv?{urlencode(query_params, doseq=True)}",
headers=user_performing_action.headers,
)
start_response.raise_for_status()
request_id = start_response.json()["request_id"]
deadline = time.time() + MAX_DELAY
while time.time() < deadline:
status_response = requests.get(
url=f"{API_SERVER_URL}/admin/query-history/export-status",
params={"request_id": request_id},
headers=user_performing_action.headers,
)
status_response.raise_for_status()
status = status_response.json()["status"]
if status == TaskStatus.SUCCESS:
break
if status == TaskStatus.FAILURE:
raise RuntimeError("Query history export task failed")
time.sleep(2)
else:
raise TimeoutError(
f"Query history export not completed within {MAX_DELAY} seconds"
)
download_response = requests.get(
url=f"{API_SERVER_URL}/admin/query-history/download",
params={"request_id": request_id},
headers=user_performing_action.headers,
)
download_response.raise_for_status()
if not download_response.content:
raise RuntimeError(
"Query history CSV download returned zero-length content"
)
return download_response.headers, download_response.content.decode()
response.raise_for_status()
return response.headers, response.content.decode()

View File

@@ -116,6 +116,7 @@ class DATestLLMProvider(BaseModel):
name: str
provider: str
api_key: str
default_model_name: str
is_public: bool
is_auto_mode: bool = False
groups: list[int]

View File

@@ -6,26 +6,16 @@ import pytest
from onyx.connectors.slack.models import ChannelType
from tests.integration.connector_job_tests.slack.slack_api_utils import SlackManager
SLACK_ADMIN_EMAIL = os.environ.get("SLACK_ADMIN_EMAIL", "evan@onyx.app")
SLACK_TEST_USER_1_EMAIL = os.environ.get("SLACK_TEST_USER_1_EMAIL", "evan+1@onyx.app")
SLACK_TEST_USER_2_EMAIL = os.environ.get("SLACK_TEST_USER_2_EMAIL", "justin@onyx.app")
# from tests.load_env_vars import load_env_vars
# load_env_vars()
def _provision_slack_channels(
bot_token: str,
) -> Generator[tuple[ChannelType, ChannelType], None, None]:
slack_client = SlackManager.get_slack_client(bot_token)
auth_info = slack_client.auth_test()
print(f"\nSlack workspace: {auth_info.get('team')} ({auth_info.get('url')})")
@pytest.fixture()
def slack_test_setup() -> Generator[tuple[ChannelType, ChannelType], None, None]:
slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"])
user_map = SlackManager.build_slack_user_email_id_map(slack_client)
if SLACK_ADMIN_EMAIL not in user_map:
raise KeyError(
f"'{SLACK_ADMIN_EMAIL}' not found in Slack workspace. "
f"Available emails: {sorted(user_map.keys())}"
)
admin_user_id = user_map[SLACK_ADMIN_EMAIL]
admin_user_id = user_map["admin@example.com"]
(
public_channel,
@@ -37,16 +27,5 @@ def _provision_slack_channels(
yield public_channel, private_channel
# This part will always run after the test, even if it fails
SlackManager.cleanup_after_test(slack_client=slack_client, test_id=run_id)
@pytest.fixture()
def slack_test_setup() -> Generator[tuple[ChannelType, ChannelType], None, None]:
yield from _provision_slack_channels(os.environ["SLACK_BOT_TOKEN"])
@pytest.fixture()
def slack_perm_sync_test_setup() -> (
Generator[tuple[ChannelType, ChannelType], None, None]
):
yield from _provision_slack_channels(os.environ["SLACK_BOT_TOKEN_TEST_SPACE"])

Some files were not shown because too many files have changed in this diff Show More