mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-26 12:15:48 +00:00
Compare commits
15 Commits
refactor/l
...
csv_render
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
80cf389774 | ||
|
|
e775aaacb7 | ||
|
|
e5b08b3d92 | ||
|
|
7c91304ba2 | ||
|
|
68a292b500 | ||
|
|
e553b80030 | ||
|
|
f3949f8e09 | ||
|
|
c7c064e296 | ||
|
|
68b91a8862 | ||
|
|
c23e5a196d | ||
|
|
093223c6c4 | ||
|
|
89517111d4 | ||
|
|
883d4b4ceb | ||
|
|
f3672b6819 | ||
|
|
921f5d9e96 |
73
.github/actions/build-backend-image/action.yml
vendored
73
.github/actions/build-backend-image/action.yml
vendored
@@ -1,73 +0,0 @@
|
||||
name: "Build Backend Image"
|
||||
description: "Builds and pushes the backend Docker image with cache reuse"
|
||||
inputs:
|
||||
runs-on-ecr-cache:
|
||||
description: "ECR cache registry from runs-on/action"
|
||||
required: true
|
||||
ref-name:
|
||||
description: "Git ref name used for cache suffix fallback"
|
||||
required: true
|
||||
pr-number:
|
||||
description: "Optional PR number for cache suffix"
|
||||
required: false
|
||||
default: ""
|
||||
github-sha:
|
||||
description: "Commit SHA used for cache keys"
|
||||
required: true
|
||||
run-id:
|
||||
description: "GitHub run ID used in output image tag"
|
||||
required: true
|
||||
docker-username:
|
||||
description: "Docker Hub username"
|
||||
required: true
|
||||
docker-token:
|
||||
description: "Docker Hub token"
|
||||
required: true
|
||||
docker-no-cache:
|
||||
description: "Set to 'true' to disable docker build cache"
|
||||
required: false
|
||||
default: "false"
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
shell: bash
|
||||
env:
|
||||
PR_NUMBER: ${{ inputs.pr-number }}
|
||||
REF_NAME: ${{ inputs.ref-name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ inputs.docker-username }}
|
||||
password: ${{ inputs.docker-token }}
|
||||
|
||||
- name: Build and push Backend Docker image
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
push: true
|
||||
tags: ${{ inputs.runs-on-ecr-cache }}:nightly-llm-it-backend-${{ inputs.run-id }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ inputs.github-sha }}
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }}
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache
|
||||
type=registry,ref=onyxdotapp/onyx-backend:latest
|
||||
cache-to: |
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ inputs.github-sha }},mode=max
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache,mode=max
|
||||
no-cache: ${{ inputs.docker-no-cache == 'true' }}
|
||||
@@ -1,75 +0,0 @@
|
||||
name: "Build Integration Image"
|
||||
description: "Builds and pushes the integration test image with docker bake"
|
||||
inputs:
|
||||
runs-on-ecr-cache:
|
||||
description: "ECR cache registry from runs-on/action"
|
||||
required: true
|
||||
ref-name:
|
||||
description: "Git ref name used for cache suffix fallback"
|
||||
required: true
|
||||
pr-number:
|
||||
description: "Optional PR number for cache suffix"
|
||||
required: false
|
||||
default: ""
|
||||
github-sha:
|
||||
description: "Commit SHA used for cache keys"
|
||||
required: true
|
||||
run-id:
|
||||
description: "GitHub run ID used in output image tag"
|
||||
required: true
|
||||
docker-username:
|
||||
description: "Docker Hub username"
|
||||
required: true
|
||||
docker-token:
|
||||
description: "Docker Hub token"
|
||||
required: true
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ inputs.docker-username }}
|
||||
password: ${{ inputs.docker-token }}
|
||||
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
shell: bash
|
||||
env:
|
||||
PR_NUMBER: ${{ inputs.pr-number }}
|
||||
REF_NAME: ${{ inputs.ref-name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Build and push integration test image with Docker Bake
|
||||
shell: bash
|
||||
env:
|
||||
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
|
||||
TAG: nightly-llm-it-${{ inputs.run-id }}
|
||||
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
|
||||
HEAD_SHA: ${{ inputs.github-sha }}
|
||||
run: |
|
||||
docker buildx bake --push \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA} \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX} \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
|
||||
--set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA},mode=max \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX},mode=max \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA} \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX} \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA},mode=max \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX},mode=max \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \
|
||||
integration
|
||||
@@ -1,68 +0,0 @@
|
||||
name: "Build Model Server Image"
|
||||
description: "Builds and pushes the model server Docker image with cache reuse"
|
||||
inputs:
|
||||
runs-on-ecr-cache:
|
||||
description: "ECR cache registry from runs-on/action"
|
||||
required: true
|
||||
ref-name:
|
||||
description: "Git ref name used for cache suffix fallback"
|
||||
required: true
|
||||
pr-number:
|
||||
description: "Optional PR number for cache suffix"
|
||||
required: false
|
||||
default: ""
|
||||
github-sha:
|
||||
description: "Commit SHA used for cache keys"
|
||||
required: true
|
||||
run-id:
|
||||
description: "GitHub run ID used in output image tag"
|
||||
required: true
|
||||
docker-username:
|
||||
description: "Docker Hub username"
|
||||
required: true
|
||||
docker-token:
|
||||
description: "Docker Hub token"
|
||||
required: true
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
shell: bash
|
||||
env:
|
||||
PR_NUMBER: ${{ inputs.pr-number }}
|
||||
REF_NAME: ${{ inputs.ref-name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ inputs.docker-username }}
|
||||
password: ${{ inputs.docker-token }}
|
||||
|
||||
- name: Build and push Model Server Docker image
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
push: true
|
||||
tags: ${{ inputs.runs-on-ecr-cache }}:nightly-llm-it-model-server-${{ inputs.run-id }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ inputs.github-sha }}
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }}
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache
|
||||
type=registry,ref=onyxdotapp/onyx-model-server:latest
|
||||
cache-to: |
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ inputs.github-sha }},mode=max
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache,mode=max
|
||||
@@ -1,120 +0,0 @@
|
||||
name: "Run Nightly Provider Chat Test"
|
||||
description: "Starts required compose services and runs nightly provider integration test"
|
||||
inputs:
|
||||
provider:
|
||||
description: "Provider slug for NIGHTLY_LLM_PROVIDER"
|
||||
required: true
|
||||
models:
|
||||
description: "Comma-separated model list for NIGHTLY_LLM_MODELS"
|
||||
required: true
|
||||
provider-api-key:
|
||||
description: "API key for NIGHTLY_LLM_API_KEY"
|
||||
required: true
|
||||
strict:
|
||||
description: "String true/false for NIGHTLY_LLM_STRICT"
|
||||
required: true
|
||||
api-base:
|
||||
description: "Optional NIGHTLY_LLM_API_BASE"
|
||||
required: false
|
||||
default: ""
|
||||
custom-config-json:
|
||||
description: "Optional NIGHTLY_LLM_CUSTOM_CONFIG_JSON"
|
||||
required: false
|
||||
default: ""
|
||||
runs-on-ecr-cache:
|
||||
description: "ECR cache registry from runs-on/action"
|
||||
required: true
|
||||
run-id:
|
||||
description: "GitHub run ID used in image tags"
|
||||
required: true
|
||||
docker-username:
|
||||
description: "Docker Hub username"
|
||||
required: true
|
||||
docker-token:
|
||||
description: "Docker Hub token"
|
||||
required: true
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ inputs.docker-username }}
|
||||
password: ${{ inputs.docker-token }}
|
||||
|
||||
- name: Create .env file for Docker Compose
|
||||
shell: bash
|
||||
env:
|
||||
ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
|
||||
RUN_ID: ${{ inputs.run-id }}
|
||||
run: |
|
||||
cat <<EOF2 > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:nightly-llm-it-backend-${RUN_ID}
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:nightly-llm-it-model-server-${RUN_ID}
|
||||
EOF2
|
||||
|
||||
- name: Start Docker containers
|
||||
shell: bash
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d --wait \
|
||||
relational_db \
|
||||
index \
|
||||
cache \
|
||||
minio \
|
||||
api_server \
|
||||
inference_model_server
|
||||
|
||||
- name: Run nightly provider integration test
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
env:
|
||||
MODELS: ${{ inputs.models }}
|
||||
NIGHTLY_LLM_PROVIDER: ${{ inputs.provider }}
|
||||
NIGHTLY_LLM_API_KEY: ${{ inputs.provider-api-key }}
|
||||
NIGHTLY_LLM_API_BASE: ${{ inputs.api-base }}
|
||||
NIGHTLY_LLM_CUSTOM_CONFIG_JSON: ${{ inputs.custom-config-json }}
|
||||
NIGHTLY_LLM_STRICT: ${{ inputs.strict }}
|
||||
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
|
||||
RUN_ID: ${{ inputs.run-id }}
|
||||
with:
|
||||
timeout_minutes: 20
|
||||
max_attempts: 2
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
if [ -z "${MODELS}" ]; then
|
||||
echo "Input 'models' must be non-empty for provider '${NIGHTLY_LLM_PROVIDER}'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e NIGHTLY_LLM_PROVIDER="${NIGHTLY_LLM_PROVIDER}" \
|
||||
-e NIGHTLY_LLM_MODELS="${MODELS}" \
|
||||
-e NIGHTLY_LLM_API_KEY="${NIGHTLY_LLM_API_KEY}" \
|
||||
-e NIGHTLY_LLM_API_BASE="${NIGHTLY_LLM_API_BASE}" \
|
||||
-e NIGHTLY_LLM_CUSTOM_CONFIG_JSON="${NIGHTLY_LLM_CUSTOM_CONFIG_JSON}" \
|
||||
-e NIGHTLY_LLM_STRICT="${NIGHTLY_LLM_STRICT}" \
|
||||
${RUNS_ON_ECR_CACHE}:nightly-llm-it-${RUN_ID} \
|
||||
/app/tests/integration/tests/llm_workflows/test_nightly_provider_chat_workflow.py
|
||||
@@ -1,44 +0,0 @@
|
||||
name: Nightly LLM Provider Chat Tests (OpenAI)
|
||||
concurrency:
|
||||
group: Nightly-LLM-Provider-Chat-OpenAI-${{ github.workflow }}-${{ github.ref_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
schedule:
|
||||
# Runs daily at 10:30 UTC (2:30 AM PST / 3:30 AM PDT)
|
||||
- cron: "30 10 * * *"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
openai-provider-chat-test:
|
||||
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
|
||||
with:
|
||||
provider: openai
|
||||
models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
|
||||
strict: true
|
||||
secrets:
|
||||
provider_api_key: ${{ secrets.OPENAI_API_KEY }}
|
||||
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs: [openai-provider-chat-test]
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 5
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.SLACK_WEBHOOK }}
|
||||
failed-jobs: openai-provider-chat-test
|
||||
title: "🚨 Scheduled OpenAI Provider Chat Tests failed!"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
@@ -11,11 +11,6 @@ permissions:
|
||||
|
||||
jobs:
|
||||
cherry-pick-to-latest-release:
|
||||
outputs:
|
||||
should_cherrypick: ${{ steps.gate.outputs.should_cherrypick }}
|
||||
pr_number: ${{ steps.gate.outputs.pr_number }}
|
||||
cherry_pick_reason: ${{ steps.run_cherry_pick.outputs.reason }}
|
||||
cherry_pick_details: ${{ steps.run_cherry_pick.outputs.details }}
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
@@ -41,13 +36,9 @@ jobs:
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Read the PR once so we can gate behavior and infer preferred actor.
|
||||
pr_json="$(gh api "repos/${GITHUB_REPOSITORY}/pulls/${pr_number}")"
|
||||
pr_body="$(printf '%s' "$pr_json" | jq -r '.body // ""')"
|
||||
merged_by="$(printf '%s' "$pr_json" | jq -r '.merged_by.login // ""')"
|
||||
|
||||
# Read the PR body and check whether the helper checkbox is checked.
|
||||
pr_body="$(gh api "repos/${GITHUB_REPOSITORY}/pulls/${pr_number}" --jq '.body // ""')"
|
||||
echo "pr_number=$pr_number" >> "$GITHUB_OUTPUT"
|
||||
echo "merged_by=$merged_by" >> "$GITHUB_OUTPUT"
|
||||
|
||||
if echo "$pr_body" | grep -qiE "\\[x\\][[:space:]]*(\\[[^]]+\\][[:space:]]*)?Please cherry-pick this PR to the latest release version"; then
|
||||
echo "should_cherrypick=true" >> "$GITHUB_OUTPUT"
|
||||
@@ -80,84 +71,9 @@ jobs:
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
- name: Create cherry-pick PR to latest release
|
||||
id: run_cherry_pick
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
continue-on-error: true
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
GITHUB_TOKEN: ${{ github.token }}
|
||||
CHERRY_PICK_ASSIGNEE: ${{ steps.gate.outputs.merged_by }}
|
||||
run: |
|
||||
set -o pipefail
|
||||
output_file="$(mktemp)"
|
||||
uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify 2>&1 | tee "$output_file"
|
||||
exit_code="${PIPESTATUS[0]}"
|
||||
|
||||
if [ "${exit_code}" -eq 0 ]; then
|
||||
echo "status=success" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "status=failure" >> "$GITHUB_OUTPUT"
|
||||
|
||||
reason="command-failed"
|
||||
if grep -qiE "merge conflict during cherry-pick|CONFLICT|could not apply|cherry-pick in progress with staged changes" "$output_file"; then
|
||||
reason="merge-conflict"
|
||||
fi
|
||||
echo "reason=${reason}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
{
|
||||
echo "details<<EOF"
|
||||
tail -n 40 "$output_file"
|
||||
echo "EOF"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Mark workflow as failed if cherry-pick failed
|
||||
if: steps.gate.outputs.should_cherrypick == 'true' && steps.run_cherry_pick.outputs.status == 'failure'
|
||||
env:
|
||||
CHERRY_PICK_REASON: ${{ steps.run_cherry_pick.outputs.reason }}
|
||||
run: |
|
||||
echo "::error::Automated cherry-pick failed (${CHERRY_PICK_REASON})."
|
||||
exit 1
|
||||
|
||||
notify-slack-on-cherry-pick-failure:
|
||||
needs:
|
||||
- cherry-pick-to-latest-release
|
||||
if: always() && needs.cherry-pick-to-latest-release.outputs.should_cherrypick == 'true' && needs.cherry-pick-to-latest-release.result != 'success'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Build cherry-pick failure summary
|
||||
id: failure-summary
|
||||
env:
|
||||
SOURCE_PR_NUMBER: ${{ needs.cherry-pick-to-latest-release.outputs.pr_number }}
|
||||
CHERRY_PICK_REASON: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_reason }}
|
||||
CHERRY_PICK_DETAILS: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_details }}
|
||||
run: |
|
||||
source_pr_url="https://github.com/${GITHUB_REPOSITORY}/pull/${SOURCE_PR_NUMBER}"
|
||||
|
||||
reason_text="cherry-pick command failed"
|
||||
if [ "${CHERRY_PICK_REASON}" = "merge-conflict" ]; then
|
||||
reason_text="merge conflict during cherry-pick"
|
||||
fi
|
||||
|
||||
details_excerpt="$(printf '%s' "${CHERRY_PICK_DETAILS}" | tail -n 8 | tr '\n' ' ' | sed "s/[[:space:]]\\+/ /g" | sed "s/\"/'/g" | cut -c1-350)"
|
||||
failed_jobs="• cherry-pick-to-latest-release\\n• source PR: ${source_pr_url}\\n• reason: ${reason_text}"
|
||||
if [ -n "${details_excerpt}" ]; then
|
||||
failed_jobs="${failed_jobs}\\n• excerpt: ${details_excerpt}"
|
||||
fi
|
||||
|
||||
echo "jobs=${failed_jobs}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Notify #cherry-pick-prs about cherry-pick failure
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }}
|
||||
failed-jobs: ${{ steps.failure-summary.outputs.jobs }}
|
||||
title: "🚨 Automated Cherry-Pick Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify
|
||||
|
||||
@@ -116,6 +116,7 @@ jobs:
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore,opensearch-enabled
|
||||
CODE_INTERPRETER_BETA_ENABLED=true
|
||||
DISABLE_TELEMETRY=true
|
||||
OPENSEARCH_FOR_ONYX_ENABLED=true
|
||||
EOF
|
||||
|
||||
4
.github/workflows/pr-integration-tests.yml
vendored
4
.github/workflows/pr-integration-tests.yml
vendored
@@ -20,7 +20,6 @@ env:
|
||||
# Test Environment Variables
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
SLACK_BOT_TOKEN_TEST_SPACE: ${{ secrets.SLACK_BOT_TOKEN_TEST_SPACE }}
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
@@ -424,7 +423,6 @@ jobs:
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e EXA_API_KEY=${EXA_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e SLACK_BOT_TOKEN_TEST_SPACE=${SLACK_BOT_TOKEN_TEST_SPACE} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
@@ -445,7 +443,6 @@ jobs:
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
-e ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${{ matrix.edition == 'ee' && 'true' || 'false' }} \
|
||||
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
|
||||
/app/tests/integration/${{ matrix.test-dir.path }}
|
||||
|
||||
@@ -704,7 +701,6 @@ jobs:
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e EXA_API_KEY=${EXA_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e SLACK_BOT_TOKEN_TEST_SPACE=${SLACK_BOT_TOKEN_TEST_SPACE} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e AUTH_TYPE=cloud \
|
||||
-e MULTI_TENANT=true \
|
||||
|
||||
@@ -1,206 +0,0 @@
|
||||
name: Reusable Nightly LLM Provider Chat Tests
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
provider:
|
||||
description: "Provider slug passed to NIGHTLY_LLM_PROVIDER (e.g. openai, anthropic)"
|
||||
required: true
|
||||
type: string
|
||||
models:
|
||||
description: "Comma-separated model list passed to NIGHTLY_LLM_MODELS"
|
||||
required: true
|
||||
type: string
|
||||
strict:
|
||||
description: "Pass-through value for NIGHTLY_LLM_STRICT"
|
||||
required: false
|
||||
default: true
|
||||
type: boolean
|
||||
api_base:
|
||||
description: "Optional NIGHTLY_LLM_API_BASE override"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
custom_config_json:
|
||||
description: "Optional NIGHTLY_LLM_CUSTOM_CONFIG_JSON override"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
secrets:
|
||||
provider_api_key:
|
||||
description: "Provider API key passed to NIGHTLY_LLM_API_KEY"
|
||||
required: true
|
||||
DOCKER_USERNAME:
|
||||
required: true
|
||||
DOCKER_TOKEN:
|
||||
required: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
NIGHTLY_LLM_PROVIDER: ${{ inputs.provider }}
|
||||
NIGHTLY_LLM_MODELS: ${{ inputs.models }}
|
||||
NIGHTLY_LLM_API_KEY: ${{ secrets.provider_api_key }}
|
||||
NIGHTLY_LLM_API_BASE: ${{ inputs.api_base }}
|
||||
NIGHTLY_LLM_CUSTOM_CONFIG_JSON: ${{ inputs.custom_config_json }}
|
||||
NIGHTLY_LLM_STRICT: ${{ inputs.strict && 'true' || 'false' }}
|
||||
|
||||
jobs:
|
||||
validate-inputs:
|
||||
# NOTE: Keep this cheap and fail before image builds if required inputs are missing.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 5
|
||||
steps:
|
||||
- name: Validate required nightly provider inputs
|
||||
run: |
|
||||
if [ -z "${NIGHTLY_LLM_MODELS}" ]; then
|
||||
echo "Input 'models' must be non-empty for provider '${NIGHTLY_LLM_PROVIDER}'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
build-backend-image:
|
||||
needs: [validate-inputs]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-backend-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Build backend image
|
||||
uses: ./.github/actions/build-backend-image
|
||||
with:
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
ref-name: ${{ github.ref_name }}
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
docker-no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' && 'true' || 'false' }}
|
||||
|
||||
build-model-server-image:
|
||||
needs: [validate-inputs]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-model-server-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Build model server image
|
||||
uses: ./.github/actions/build-model-server-image
|
||||
with:
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
ref-name: ${{ github.ref_name }}
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
build-integration-image:
|
||||
needs: [validate-inputs]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=2cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-integration-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Build integration image
|
||||
uses: ./.github/actions/build-integration-image
|
||||
with:
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
ref-name: ${{ github.ref_name }}
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
provider-chat-test:
|
||||
needs:
|
||||
[build-backend-image, build-model-server-image, build-integration-image]
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- "run-id=${{ github.run_id }}-nightly-${{ inputs.provider }}-provider-chat-test"
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Run nightly provider chat test
|
||||
uses: ./.github/actions/run-nightly-provider-chat-test
|
||||
with:
|
||||
provider: ${{ env.NIGHTLY_LLM_PROVIDER }}
|
||||
models: ${{ env.NIGHTLY_LLM_MODELS }}
|
||||
provider-api-key: ${{ secrets.provider_api_key }}
|
||||
strict: ${{ env.NIGHTLY_LLM_STRICT }}
|
||||
api-base: ${{ env.NIGHTLY_LLM_API_BASE }}
|
||||
custom-config-json: ${{ env.NIGHTLY_LLM_CUSTOM_CONFIG_JSON }}
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Dump API server logs
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
|
||||
- name: Dump all-container logs
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-nightly-${{ inputs.provider }}-llm-provider
|
||||
path: |
|
||||
${{ github.workspace }}/api_server.log
|
||||
${{ github.workspace }}/docker-compose.log
|
||||
|
||||
- name: Stop Docker containers
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose down -v
|
||||
@@ -548,7 +548,7 @@ class in the utils over directly calling the APIs with a library like `requests`
|
||||
calling the utilities directly (e.g. do NOT create admin users with
|
||||
`admin_user = UserManager.create(name="admin_user")`, instead use the `admin_user` fixture).
|
||||
|
||||
A great example of this type of test is `backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py`.
|
||||
A great example of this type of test is `backend/tests/integration/dev_apis/test_simple_chat_api.py`.
|
||||
|
||||
To run them:
|
||||
|
||||
@@ -616,9 +616,3 @@ This is a minimal list - feel free to include more. Do NOT write code as part of
|
||||
Keep it high level. You can reference certain files or functions though.
|
||||
|
||||
Before writing your plan, make sure to do research. Explore the relevant sections in the codebase.
|
||||
|
||||
## Best Practices
|
||||
|
||||
In addition to the other content in this file, best practices for contributing
|
||||
to the codebase can be found at `contributing_guides/best_practices.md`.
|
||||
Understand its contents and follow them.
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
"""LLMProvider deprecated fields are nullable
|
||||
|
||||
Revision ID: 001984c88745
|
||||
Revises: 7616121f6e97
|
||||
Create Date: 2026-02-01 22:24:34.171100
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "001984c88745"
|
||||
down_revision = "7616121f6e97"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Make default_model_name nullable (was NOT NULL)
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"default_model_name",
|
||||
existing_type=sa.String(),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Remove server_default from is_default_vision_provider (was server_default=false())
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"is_default_vision_provider",
|
||||
existing_type=sa.Boolean(),
|
||||
server_default=None,
|
||||
)
|
||||
|
||||
# is_default_provider and default_vision_model are already nullable with no server_default
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Restore default_model_name to NOT NULL (set empty string for any NULLs first)
|
||||
op.execute(
|
||||
"UPDATE llm_provider SET default_model_name = '' WHERE default_model_name IS NULL"
|
||||
)
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"default_model_name",
|
||||
existing_type=sa.String(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Restore server_default for is_default_vision_provider
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"is_default_vision_provider",
|
||||
existing_type=sa.Boolean(),
|
||||
server_default=sa.false(),
|
||||
)
|
||||
@@ -1,29 +0,0 @@
|
||||
"""code interpreter seed
|
||||
|
||||
Revision ID: 07b98176f1de
|
||||
Revises: 7cb492013621
|
||||
Create Date: 2026-02-23 15:55:07.606784
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "07b98176f1de"
|
||||
down_revision = "7cb492013621"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Seed the single instance of code_interpreter_server
|
||||
# NOTE: There should only exist at most and at minimum 1 code_interpreter_server row
|
||||
op.execute(
|
||||
sa.text("INSERT INTO code_interpreter_server (server_enabled) VALUES (true)")
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(sa.text("DELETE FROM code_interpreter_server"))
|
||||
@@ -1,48 +0,0 @@
|
||||
"""add enterprise and name fields to scim_user_mapping
|
||||
|
||||
Revision ID: 7616121f6e97
|
||||
Revises: 07b98176f1de
|
||||
Create Date: 2026-02-23 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7616121f6e97"
|
||||
down_revision = "07b98176f1de"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"scim_user_mapping",
|
||||
sa.Column("department", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"scim_user_mapping",
|
||||
sa.Column("manager", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"scim_user_mapping",
|
||||
sa.Column("given_name", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"scim_user_mapping",
|
||||
sa.Column("family_name", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"scim_user_mapping",
|
||||
sa.Column("scim_emails_json", sa.Text(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("scim_user_mapping", "scim_emails_json")
|
||||
op.drop_column("scim_user_mapping", "family_name")
|
||||
op.drop_column("scim_user_mapping", "given_name")
|
||||
op.drop_column("scim_user_mapping", "manager")
|
||||
op.drop_column("scim_user_mapping", "department")
|
||||
@@ -1,33 +0,0 @@
|
||||
"""add needs_persona_sync to user_file
|
||||
|
||||
Revision ID: 8ffcc2bcfc11
|
||||
Revises: 7616121f6e97
|
||||
Create Date: 2026-02-23 10:48:48.343826
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8ffcc2bcfc11"
|
||||
down_revision = "7616121f6e97"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user_file",
|
||||
sa.Column(
|
||||
"needs_persona_sync",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user_file", "needs_persona_sync")
|
||||
@@ -26,13 +26,14 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.db.scim import ScimDAL
|
||||
from ee.onyx.server.scim.auth import verify_scim_token
|
||||
from ee.onyx.server.scim.filtering import parse_scim_filter
|
||||
from ee.onyx.server.scim.models import SCIM_LIST_RESPONSE_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimError
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimListResponse
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimPatchRequest
|
||||
from ee.onyx.server.scim.models import ScimResourceType
|
||||
from ee.onyx.server.scim.models import ScimSchemaDefinition
|
||||
from ee.onyx.server.scim.models import ScimServiceProviderConfig
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.patch import apply_group_patch
|
||||
@@ -50,24 +51,12 @@ from onyx.db.models import ScimToken
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class ScimJSONResponse(JSONResponse):
|
||||
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
|
||||
|
||||
media_type = "application/scim+json"
|
||||
|
||||
|
||||
# NOTE: All URL paths in this router (/ServiceProviderConfig, /ResourceTypes,
|
||||
# /Schemas, /Users, /Groups) are mandated by the SCIM spec (RFC 7643/7644).
|
||||
# IdPs like Okta and Azure AD hardcode these exact paths, so they cannot be
|
||||
# changed to kebab-case.
|
||||
|
||||
|
||||
scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
|
||||
|
||||
_pw_helper = PasswordHelper()
|
||||
@@ -97,39 +86,15 @@ def get_service_provider_config() -> ScimServiceProviderConfig:
|
||||
|
||||
|
||||
@scim_router.get("/ResourceTypes")
|
||||
def get_resource_types() -> ScimJSONResponse:
|
||||
"""List available SCIM resource types (RFC 7643 §6).
|
||||
|
||||
Wrapped in a ListResponse envelope (RFC 7644 §3.4.2) because IdPs
|
||||
like Entra ID expect a JSON object, not a bare array.
|
||||
"""
|
||||
resources = [USER_RESOURCE_TYPE, GROUP_RESOURCE_TYPE]
|
||||
return ScimJSONResponse(
|
||||
content={
|
||||
"schemas": [SCIM_LIST_RESPONSE_SCHEMA],
|
||||
"totalResults": len(resources),
|
||||
"Resources": [
|
||||
r.model_dump(exclude_none=True, by_alias=True) for r in resources
|
||||
],
|
||||
}
|
||||
)
|
||||
def get_resource_types() -> list[ScimResourceType]:
|
||||
"""List available SCIM resource types (RFC 7643 §6)."""
|
||||
return [USER_RESOURCE_TYPE, GROUP_RESOURCE_TYPE]
|
||||
|
||||
|
||||
@scim_router.get("/Schemas")
|
||||
def get_schemas() -> ScimJSONResponse:
|
||||
"""Return SCIM schema definitions (RFC 7643 §7).
|
||||
|
||||
Wrapped in a ListResponse envelope (RFC 7644 §3.4.2) because IdPs
|
||||
like Entra ID expect a JSON object, not a bare array.
|
||||
"""
|
||||
schemas = [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF]
|
||||
return ScimJSONResponse(
|
||||
content={
|
||||
"schemas": [SCIM_LIST_RESPONSE_SCHEMA],
|
||||
"totalResults": len(schemas),
|
||||
"Resources": [s.model_dump(exclude_none=True) for s in schemas],
|
||||
}
|
||||
)
|
||||
def get_schemas() -> list[ScimSchemaDefinition]:
|
||||
"""Return SCIM schema definitions (RFC 7643 §7)."""
|
||||
return [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -137,45 +102,15 @@ def get_schemas() -> ScimJSONResponse:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _scim_error_response(status: int, detail: str) -> ScimJSONResponse:
|
||||
def _scim_error_response(status: int, detail: str) -> JSONResponse:
|
||||
"""Build a SCIM-compliant error response (RFC 7644 §3.12)."""
|
||||
logger.warning("SCIM error response: status=%s detail=%s", status, detail)
|
||||
body = ScimError(status=str(status), detail=detail)
|
||||
return ScimJSONResponse(
|
||||
return JSONResponse(
|
||||
status_code=status,
|
||||
content=body.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
def _parse_excluded_attributes(raw: str | None) -> set[str]:
|
||||
"""Parse the ``excludedAttributes`` query parameter (RFC 7644 §3.4.2.5).
|
||||
|
||||
Returns a set of lowercased attribute names to omit from responses.
|
||||
"""
|
||||
if not raw:
|
||||
return set()
|
||||
return {attr.strip().lower() for attr in raw.split(",") if attr.strip()}
|
||||
|
||||
|
||||
def _apply_exclusions(
|
||||
resource: ScimUserResource | ScimGroupResource,
|
||||
excluded: set[str],
|
||||
) -> dict:
|
||||
"""Serialize a SCIM resource, omitting attributes the IdP excluded.
|
||||
|
||||
RFC 7644 §3.4.2.5 lets the IdP pass ``?excludedAttributes=groups,emails``
|
||||
to reduce response payload size. We strip those fields after serialization
|
||||
so the rest of the pipeline doesn't need to know about them.
|
||||
"""
|
||||
data = resource.model_dump(exclude_none=True, by_alias=True)
|
||||
for attr in excluded:
|
||||
# Match case-insensitively against the camelCase field names
|
||||
keys_to_remove = [k for k in data if k.lower() == attr]
|
||||
for k in keys_to_remove:
|
||||
del data[k]
|
||||
return data
|
||||
|
||||
|
||||
def _check_seat_availability(dal: ScimDAL) -> str | None:
|
||||
"""Return an error message if seat limit is reached, else None."""
|
||||
check_fn = fetch_ee_implementation_or_noop(
|
||||
@@ -189,7 +124,7 @@ def _check_seat_availability(dal: ScimDAL) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_user_or_404(user_id: str, dal: ScimDAL) -> User | ScimJSONResponse:
|
||||
def _fetch_user_or_404(user_id: str, dal: ScimDAL) -> User | JSONResponse:
|
||||
"""Parse *user_id* as UUID, look up the user, or return a 404 error."""
|
||||
try:
|
||||
uid = UUID(user_id)
|
||||
@@ -209,56 +144,10 @@ def _scim_name_to_str(name: ScimName | None) -> str | None:
|
||||
"""
|
||||
if not name:
|
||||
return None
|
||||
# If the client explicitly provides ``formatted``, prefer it — the client
|
||||
# knows what display string it wants. Otherwise build from components.
|
||||
if name.formatted:
|
||||
return name.formatted
|
||||
# Build from givenName/familyName first — IdPs like Okta may send a stale
|
||||
# ``formatted`` value while updating the individual name components.
|
||||
parts = " ".join(part for part in [name.givenName, name.familyName] if part)
|
||||
return parts or None
|
||||
|
||||
|
||||
def _scim_resource_response(
|
||||
resource: ScimUserResource | ScimGroupResource | ScimListResponse,
|
||||
status_code: int = 200,
|
||||
) -> ScimJSONResponse:
|
||||
"""Serialize a SCIM resource as ``application/scim+json``."""
|
||||
content = resource.model_dump(exclude_none=True, by_alias=True)
|
||||
return ScimJSONResponse(
|
||||
status_code=status_code,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
def _build_list_response(
|
||||
resources: list[ScimUserResource | ScimGroupResource],
|
||||
total: int,
|
||||
start_index: int,
|
||||
count: int,
|
||||
excluded: set[str] | None = None,
|
||||
) -> ScimListResponse | ScimJSONResponse:
|
||||
"""Build a SCIM list response, optionally applying attribute exclusions.
|
||||
|
||||
RFC 7644 §3.4.2.5 — IdPs may request certain attributes be omitted via
|
||||
the ``excludedAttributes`` query parameter.
|
||||
"""
|
||||
if excluded:
|
||||
envelope = ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=start_index,
|
||||
itemsPerPage=count,
|
||||
)
|
||||
data = envelope.model_dump(exclude_none=True)
|
||||
data["Resources"] = [_apply_exclusions(r, excluded) for r in resources]
|
||||
return ScimJSONResponse(content=data)
|
||||
|
||||
return _scim_resource_response(
|
||||
ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=start_index,
|
||||
itemsPerPage=count,
|
||||
Resources=resources,
|
||||
)
|
||||
)
|
||||
return parts or name.formatted
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -269,13 +158,12 @@ def _build_list_response(
|
||||
@scim_router.get("/Users", response_model=None)
|
||||
def list_users(
|
||||
filter: str | None = Query(None),
|
||||
excludedAttributes: str | None = None,
|
||||
startIndex: int = Query(1, ge=1),
|
||||
count: int = Query(100, ge=0, le=500),
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimListResponse | ScimJSONResponse:
|
||||
) -> ScimListResponse | JSONResponse:
|
||||
"""List users with optional SCIM filter and pagination."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
@@ -301,48 +189,38 @@ def list_users(
|
||||
for user, mapping in users_with_mappings
|
||||
]
|
||||
|
||||
return _build_list_response(
|
||||
resources,
|
||||
total,
|
||||
startIndex,
|
||||
count,
|
||||
excluded=_parse_excluded_attributes(excludedAttributes),
|
||||
return ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=startIndex,
|
||||
itemsPerPage=count,
|
||||
Resources=resources,
|
||||
)
|
||||
|
||||
|
||||
@scim_router.get("/Users/{user_id}", response_model=None)
|
||||
def get_user(
|
||||
user_id: str,
|
||||
excludedAttributes: str | None = None,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | ScimJSONResponse:
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Get a single user by ID."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
|
||||
resource = provider.build_user_resource(
|
||||
return provider.build_user_resource(
|
||||
user,
|
||||
mapping.external_id if mapping else None,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=mapping.scim_username if mapping else None,
|
||||
)
|
||||
|
||||
# RFC 7644 §3.4.2.5 — IdP may request certain attributes be omitted
|
||||
excluded = _parse_excluded_attributes(excludedAttributes)
|
||||
if excluded:
|
||||
return ScimJSONResponse(content=_apply_exclusions(resource, excluded))
|
||||
|
||||
return _scim_resource_response(resource)
|
||||
|
||||
|
||||
@scim_router.post("/Users", status_code=201, response_model=None)
|
||||
def create_user(
|
||||
@@ -350,7 +228,7 @@ def create_user(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | ScimJSONResponse:
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Create a new user from a SCIM provisioning request."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
@@ -398,10 +276,7 @@ def create_user(
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(user, external_id, scim_username=scim_username),
|
||||
status_code=201,
|
||||
)
|
||||
return provider.build_user_resource(user, external_id, scim_username=scim_username)
|
||||
|
||||
|
||||
@scim_router.put("/Users/{user_id}", response_model=None)
|
||||
@@ -411,13 +286,13 @@ def replace_user(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | ScimJSONResponse:
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Replace a user entirely (RFC 7644 §3.5.1)."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
@@ -442,13 +317,11 @@ def replace_user(
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
new_external_id,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=scim_username,
|
||||
)
|
||||
return provider.build_user_resource(
|
||||
user,
|
||||
new_external_id,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=scim_username,
|
||||
)
|
||||
|
||||
|
||||
@@ -459,7 +332,7 @@ def patch_user(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | ScimJSONResponse:
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Partially update a user (RFC 7644 §3.5.2).
|
||||
|
||||
This is the primary endpoint for user deprovisioning — Okta sends
|
||||
@@ -469,7 +342,7 @@ def patch_user(
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
@@ -526,13 +399,11 @@ def patch_user(
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
patched.externalId,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=new_scim_username,
|
||||
)
|
||||
return provider.build_user_resource(
|
||||
user,
|
||||
patched.externalId,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=new_scim_username,
|
||||
)
|
||||
|
||||
|
||||
@@ -541,29 +412,25 @@ def delete_user(
|
||||
user_id: str,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response | ScimJSONResponse:
|
||||
) -> Response | JSONResponse:
|
||||
"""Delete a user (RFC 7644 §3.6).
|
||||
|
||||
Deactivates the user and removes the SCIM mapping. Note that Okta
|
||||
typically uses PATCH active=false instead of DELETE.
|
||||
A second DELETE returns 404 per RFC 7644 §3.6.
|
||||
"""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
# If no SCIM mapping exists, the user was already deleted from
|
||||
# SCIM's perspective — return 404 per RFC 7644 §3.6.
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
if not mapping:
|
||||
return _scim_error_response(404, f"User {user_id} not found")
|
||||
|
||||
dal.deactivate_user(user)
|
||||
dal.delete_user_mapping(mapping.id)
|
||||
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
if mapping:
|
||||
dal.delete_user_mapping(mapping.id)
|
||||
|
||||
dal.commit()
|
||||
|
||||
@@ -575,7 +442,7 @@ def delete_user(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | ScimJSONResponse:
|
||||
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | JSONResponse:
|
||||
"""Parse *group_id* as int, look up the group, or return a 404 error."""
|
||||
try:
|
||||
gid = int(group_id)
|
||||
@@ -630,13 +497,12 @@ def _validate_and_parse_members(
|
||||
@scim_router.get("/Groups", response_model=None)
|
||||
def list_groups(
|
||||
filter: str | None = Query(None),
|
||||
excludedAttributes: str | None = None,
|
||||
startIndex: int = Query(1, ge=1),
|
||||
count: int = Query(100, ge=0, le=500),
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimListResponse | ScimJSONResponse:
|
||||
) -> ScimListResponse | JSONResponse:
|
||||
"""List groups with optional SCIM filter and pagination."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
@@ -656,46 +522,37 @@ def list_groups(
|
||||
for group, ext_id in groups_with_ext_ids
|
||||
]
|
||||
|
||||
return _build_list_response(
|
||||
resources,
|
||||
total,
|
||||
startIndex,
|
||||
count,
|
||||
excluded=_parse_excluded_attributes(excludedAttributes),
|
||||
return ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=startIndex,
|
||||
itemsPerPage=count,
|
||||
Resources=resources,
|
||||
)
|
||||
|
||||
|
||||
@scim_router.get("/Groups/{group_id}", response_model=None)
|
||||
def get_group(
|
||||
group_id: str,
|
||||
excludedAttributes: str | None = None,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | ScimJSONResponse:
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Get a single group by ID."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
mapping = dal.get_group_mapping_by_group_id(group.id)
|
||||
members = dal.get_group_members(group.id)
|
||||
|
||||
resource = provider.build_group_resource(
|
||||
return provider.build_group_resource(
|
||||
group, members, mapping.external_id if mapping else None
|
||||
)
|
||||
|
||||
# RFC 7644 §3.4.2.5 — IdP may request certain attributes be omitted
|
||||
excluded = _parse_excluded_attributes(excludedAttributes)
|
||||
if excluded:
|
||||
return ScimJSONResponse(content=_apply_exclusions(resource, excluded))
|
||||
|
||||
return _scim_resource_response(resource)
|
||||
|
||||
|
||||
@scim_router.post("/Groups", status_code=201, response_model=None)
|
||||
def create_group(
|
||||
@@ -703,7 +560,7 @@ def create_group(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | ScimJSONResponse:
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Create a new group from a SCIM provisioning request."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
@@ -739,10 +596,7 @@ def create_group(
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(db_group.id)
|
||||
return _scim_resource_response(
|
||||
provider.build_group_resource(db_group, members, external_id),
|
||||
status_code=201,
|
||||
)
|
||||
return provider.build_group_resource(db_group, members, external_id)
|
||||
|
||||
|
||||
@scim_router.put("/Groups/{group_id}", response_model=None)
|
||||
@@ -752,13 +606,13 @@ def replace_group(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | ScimJSONResponse:
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Replace a group entirely (RFC 7644 §3.5.1)."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
@@ -773,9 +627,7 @@ def replace_group(
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(group.id)
|
||||
return _scim_resource_response(
|
||||
provider.build_group_resource(group, members, group_resource.externalId)
|
||||
)
|
||||
return provider.build_group_resource(group, members, group_resource.externalId)
|
||||
|
||||
|
||||
@scim_router.patch("/Groups/{group_id}", response_model=None)
|
||||
@@ -785,7 +637,7 @@ def patch_group(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | ScimJSONResponse:
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Partially update a group (RFC 7644 §3.5.2).
|
||||
|
||||
Handles member add/remove operations from Okta and Azure AD.
|
||||
@@ -794,7 +646,7 @@ def patch_group(
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
@@ -833,9 +685,7 @@ def patch_group(
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(group.id)
|
||||
return _scim_resource_response(
|
||||
provider.build_group_resource(group, members, patched.externalId)
|
||||
)
|
||||
return provider.build_group_resource(group, members, patched.externalId)
|
||||
|
||||
|
||||
@scim_router.delete("/Groups/{group_id}", status_code=204, response_model=None)
|
||||
@@ -843,13 +693,13 @@ def delete_group(
|
||||
group_id: str,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response | ScimJSONResponse:
|
||||
) -> Response | JSONResponse:
|
||||
"""Delete a group (RFC 7644 §3.6)."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ from enum import Enum
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -166,19 +165,6 @@ class ScimPatchOperation(BaseModel):
|
||||
path: str | None = None
|
||||
value: ScimPatchValue = None
|
||||
|
||||
@field_validator("op", mode="before")
|
||||
@classmethod
|
||||
def normalize_operation(cls, v: object) -> object:
|
||||
"""Normalize op to lowercase for case-insensitive matching.
|
||||
|
||||
Some IdPs (e.g. Entra ID) send capitalized ops like ``"Replace"``
|
||||
instead of ``"replace"``. This is safe for all providers since the
|
||||
enum values are lowercase. If a future provider requires other
|
||||
pre-processing quirks, move patch deserialization into the provider
|
||||
subclass instead of adding more special cases here.
|
||||
"""
|
||||
return v.lower() if isinstance(v, str) else v
|
||||
|
||||
|
||||
class ScimPatchRequest(BaseModel):
|
||||
"""PATCH request body (RFC 7644 §3.5.2).
|
||||
|
||||
@@ -15,8 +15,6 @@ responsible for persisting changes.
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
@@ -26,50 +24,6 @@ from ee.onyx.server.scim.models import ScimPatchResourceValue
|
||||
from ee.onyx.server.scim.models import ScimPatchValue
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
|
||||
# Pattern for email filter paths, e.g.:
|
||||
# emails[primary eq true].value (Okta)
|
||||
# emails[type eq "work"].value (Azure AD / Entra ID)
|
||||
_EMAIL_FILTER_RE = re.compile(
|
||||
r"^emails\[.+\]\.value$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Pattern for member removal path: members[value eq "user-id"]
|
||||
_MEMBER_FILTER_RE = re.compile(
|
||||
r'^members\[value\s+eq\s+"([^"]+)"\]$',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dispatch tables for user PATCH paths
|
||||
#
|
||||
# Maps lowercased SCIM path → (camelCase key, target dict name).
|
||||
# "data" writes to the top-level resource dict, "name" writes to the
|
||||
# name sub-object dict. This replaces the elif chains for simple fields.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_USER_REPLACE_PATHS: dict[str, tuple[str, str]] = {
|
||||
"active": ("active", "data"),
|
||||
"username": ("userName", "data"),
|
||||
"externalid": ("externalId", "data"),
|
||||
"name.givenname": ("givenName", "name"),
|
||||
"name.familyname": ("familyName", "name"),
|
||||
"name.formatted": ("formatted", "name"),
|
||||
}
|
||||
|
||||
_USER_REMOVE_PATHS: dict[str, tuple[str, str]] = {
|
||||
"externalid": ("externalId", "data"),
|
||||
"name.givenname": ("givenName", "name"),
|
||||
"name.familyname": ("familyName", "name"),
|
||||
"name.formatted": ("formatted", "name"),
|
||||
"displayname": ("displayName", "data"),
|
||||
}
|
||||
|
||||
_GROUP_REPLACE_PATHS: dict[str, tuple[str, str]] = {
|
||||
"displayname": ("displayName", "data"),
|
||||
"externalid": ("externalId", "data"),
|
||||
}
|
||||
|
||||
|
||||
class ScimPatchError(Exception):
|
||||
"""Raised when a PATCH operation cannot be applied."""
|
||||
@@ -80,17 +34,11 @@ class ScimPatchError(Exception):
|
||||
super().__init__(detail)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _UserPatchCtx:
|
||||
"""Bundles the mutable state for user PATCH operations."""
|
||||
|
||||
data: dict[str, Any]
|
||||
name_data: dict[str, Any]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User PATCH
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pattern for member removal path: members[value eq "user-id"]
|
||||
_MEMBER_FILTER_RE = re.compile(
|
||||
r'^members\[value\s+eq\s+"([^"]+)"\]$',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def apply_user_patch(
|
||||
@@ -112,126 +60,72 @@ def apply_user_patch(
|
||||
ScimPatchError: If an operation targets an unsupported path.
|
||||
"""
|
||||
data = current.model_dump()
|
||||
ctx = _UserPatchCtx(data=data, name_data=data.get("name") or {})
|
||||
name_data = data.get("name") or {}
|
||||
|
||||
for op in operations:
|
||||
if op.op in (ScimPatchOperationType.REPLACE, ScimPatchOperationType.ADD):
|
||||
_apply_user_replace(op, ctx, ignored_paths)
|
||||
elif op.op == ScimPatchOperationType.REMOVE:
|
||||
_apply_user_remove(op, ctx, ignored_paths)
|
||||
if op.op == ScimPatchOperationType.REPLACE:
|
||||
_apply_user_replace(op, data, name_data, ignored_paths)
|
||||
elif op.op == ScimPatchOperationType.ADD:
|
||||
_apply_user_replace(op, data, name_data, ignored_paths)
|
||||
else:
|
||||
raise ScimPatchError(
|
||||
f"Unsupported operation '{op.op.value}' on User resource"
|
||||
)
|
||||
|
||||
ctx.data["name"] = ctx.name_data
|
||||
return ScimUserResource.model_validate(ctx.data)
|
||||
data["name"] = name_data
|
||||
return ScimUserResource.model_validate(data)
|
||||
|
||||
|
||||
def _apply_user_replace(
|
||||
op: ScimPatchOperation,
|
||||
ctx: _UserPatchCtx,
|
||||
data: dict,
|
||||
name_data: dict,
|
||||
ignored_paths: frozenset[str],
|
||||
) -> None:
|
||||
"""Apply a replace/add operation to user data."""
|
||||
path = (op.path or "").lower()
|
||||
|
||||
if not path:
|
||||
# No path — value is a resource dict of top-level attributes to set.
|
||||
# No path — value is a resource dict of top-level attributes to set
|
||||
if isinstance(op.value, ScimPatchResourceValue):
|
||||
for key, val in op.value.model_dump(exclude_unset=True).items():
|
||||
_set_user_field(key.lower(), val, ctx, ignored_paths, strict=False)
|
||||
_set_user_field(key.lower(), val, data, name_data, ignored_paths)
|
||||
else:
|
||||
raise ScimPatchError("Replace without path requires a dict value")
|
||||
return
|
||||
|
||||
_set_user_field(path, op.value, ctx, ignored_paths)
|
||||
|
||||
|
||||
def _apply_user_remove(
|
||||
op: ScimPatchOperation,
|
||||
ctx: _UserPatchCtx,
|
||||
ignored_paths: frozenset[str],
|
||||
) -> None:
|
||||
"""Apply a remove operation to user data — clears the target field."""
|
||||
path = (op.path or "").lower()
|
||||
if not path:
|
||||
raise ScimPatchError("Remove operation requires a path")
|
||||
|
||||
if path in ignored_paths:
|
||||
return
|
||||
|
||||
entry = _USER_REMOVE_PATHS.get(path)
|
||||
if entry:
|
||||
key, target = entry
|
||||
target_dict = ctx.data if target == "data" else ctx.name_data
|
||||
target_dict[key] = None
|
||||
return
|
||||
|
||||
raise ScimPatchError(f"Unsupported remove path '{path}' for User PATCH")
|
||||
_set_user_field(path, op.value, data, name_data, ignored_paths)
|
||||
|
||||
|
||||
def _set_user_field(
|
||||
path: str,
|
||||
value: ScimPatchValue,
|
||||
ctx: _UserPatchCtx,
|
||||
data: dict,
|
||||
name_data: dict,
|
||||
ignored_paths: frozenset[str],
|
||||
*,
|
||||
strict: bool = True,
|
||||
) -> None:
|
||||
"""Set a single field on user data by SCIM path.
|
||||
|
||||
Args:
|
||||
strict: When ``False`` (path-less replace), unknown attributes are
|
||||
silently skipped. When ``True`` (explicit path), they raise.
|
||||
"""
|
||||
"""Set a single field on user data by SCIM path."""
|
||||
if path in ignored_paths:
|
||||
return
|
||||
|
||||
# Simple field writes handled by the dispatch table
|
||||
entry = _USER_REPLACE_PATHS.get(path)
|
||||
if entry:
|
||||
key, target = entry
|
||||
target_dict = ctx.data if target == "data" else ctx.name_data
|
||||
target_dict[key] = value
|
||||
return
|
||||
|
||||
# displayName sets both the top-level field and the name.formatted sub-field
|
||||
if path == "displayname":
|
||||
ctx.data["displayName"] = value
|
||||
ctx.name_data["formatted"] = value
|
||||
elif path == "name":
|
||||
if isinstance(value, dict):
|
||||
for k, v in value.items():
|
||||
ctx.name_data[k] = v
|
||||
elif path == "emails":
|
||||
if isinstance(value, list):
|
||||
ctx.data["emails"] = value
|
||||
elif _EMAIL_FILTER_RE.match(path):
|
||||
_update_primary_email(ctx.data, value)
|
||||
elif not strict:
|
||||
return
|
||||
elif path == "active":
|
||||
data["active"] = value
|
||||
elif path == "username":
|
||||
data["userName"] = value
|
||||
elif path == "externalid":
|
||||
data["externalId"] = value
|
||||
elif path == "name.givenname":
|
||||
name_data["givenName"] = value
|
||||
elif path == "name.familyname":
|
||||
name_data["familyName"] = value
|
||||
elif path == "name.formatted":
|
||||
name_data["formatted"] = value
|
||||
elif path == "displayname":
|
||||
data["displayName"] = value
|
||||
name_data["formatted"] = value
|
||||
else:
|
||||
raise ScimPatchError(f"Unsupported path '{path}' for User PATCH")
|
||||
|
||||
|
||||
def _update_primary_email(data: dict[str, Any], value: ScimPatchValue) -> None:
|
||||
"""Update the primary email entry via an email filter path."""
|
||||
emails: list[dict] = data.get("emails") or []
|
||||
for email_entry in emails:
|
||||
if email_entry.get("primary"):
|
||||
email_entry["value"] = value
|
||||
break
|
||||
else:
|
||||
emails.append({"value": value, "type": "work", "primary": True})
|
||||
data["emails"] = emails
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Group PATCH
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def apply_group_patch(
|
||||
operations: list[ScimPatchOperation],
|
||||
current: ScimGroupResource,
|
||||
@@ -341,14 +235,12 @@ def _set_group_field(
|
||||
"""Set a single field on group data by SCIM path."""
|
||||
if path in ignored_paths:
|
||||
return
|
||||
|
||||
entry = _GROUP_REPLACE_PATHS.get(path)
|
||||
if entry:
|
||||
key, _ = entry
|
||||
data[key] = value
|
||||
return
|
||||
|
||||
raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH")
|
||||
elif path == "displayname":
|
||||
data["displayName"] = value
|
||||
elif path == "externalid":
|
||||
data["externalId"] = value
|
||||
else:
|
||||
raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH")
|
||||
|
||||
|
||||
def _apply_group_add(
|
||||
|
||||
@@ -16,14 +16,6 @@ from ee.onyx.server.scim.models import ScimUserResource
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
|
||||
COMMON_IGNORED_PATCH_PATHS: frozenset[str] = frozenset(
|
||||
{
|
||||
"id",
|
||||
"schemas",
|
||||
"meta",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ScimProvider(ABC):
|
||||
"""Base class for provider-specific SCIM behavior.
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
|
||||
|
||||
@@ -23,4 +22,4 @@ class OktaProvider(ScimProvider):
|
||||
|
||||
@property
|
||||
def ignored_patch_paths(self) -> frozenset[str]:
|
||||
return COMMON_IGNORED_PATCH_PATHS
|
||||
return frozenset({"id", "schemas", "meta"})
|
||||
|
||||
@@ -123,21 +123,9 @@ def _seed_llms(
|
||||
upsert_llm_provider(llm_upsert_request, db_session)
|
||||
for llm_upsert_request in llm_upsert_requests
|
||||
]
|
||||
|
||||
if len(seeded_providers[0].model_configurations) > 0:
|
||||
default_model = next(
|
||||
(
|
||||
mc
|
||||
for mc in seeded_providers[0].model_configurations
|
||||
if mc.is_visible
|
||||
),
|
||||
seeded_providers[0].model_configurations[0],
|
||||
).name
|
||||
update_default_provider(
|
||||
provider_id=seeded_providers[0].id,
|
||||
model_name=default_model,
|
||||
db_session=db_session,
|
||||
)
|
||||
update_default_provider(
|
||||
provider_id=seeded_providers[0].id, db_session=db_session
|
||||
)
|
||||
|
||||
|
||||
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:
|
||||
|
||||
@@ -302,12 +302,12 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
|
||||
has_set_default_provider = False
|
||||
|
||||
def _upsert(request: LLMProviderUpsertRequest, default_model: str) -> None:
|
||||
def _upsert(request: LLMProviderUpsertRequest) -> None:
|
||||
nonlocal has_set_default_provider
|
||||
try:
|
||||
provider = upsert_llm_provider(request, db_session)
|
||||
if not has_set_default_provider:
|
||||
update_default_provider(provider.id, default_model, db_session)
|
||||
update_default_provider(provider.id, db_session)
|
||||
has_set_default_provider = True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure {request.provider} provider: {e}")
|
||||
@@ -325,13 +325,14 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
OPENAI_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(openai_provider, default_model_name)
|
||||
_upsert(openai_provider)
|
||||
|
||||
# Create default image generation config using the OpenAI API key
|
||||
try:
|
||||
@@ -360,13 +361,14 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="Anthropic",
|
||||
provider=ANTHROPIC_PROVIDER_NAME,
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
ANTHROPIC_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(anthropic_provider, default_model_name)
|
||||
_upsert(anthropic_provider)
|
||||
else:
|
||||
logger.info(
|
||||
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
|
||||
@@ -391,13 +393,14 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="Google Vertex AI",
|
||||
provider=VERTEXAI_PROVIDER_NAME,
|
||||
custom_config=custom_config,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
VERTEXAI_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(vertexai_provider, default_model_name)
|
||||
_upsert(vertexai_provider)
|
||||
else:
|
||||
logger.info(
|
||||
"VERTEXAI_DEFAULT_CREDENTIALS not set, skipping Vertex AI provider configuration"
|
||||
@@ -429,11 +432,12 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="OpenRouter",
|
||||
provider=OPENROUTER_PROVIDER_NAME,
|
||||
api_key=OPENROUTER_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=model_configurations,
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(openrouter_provider, default_model_name)
|
||||
_upsert(openrouter_provider)
|
||||
else:
|
||||
logger.info(
|
||||
"OPENROUTER_DEFAULT_API_KEY not set, skipping OpenRouter provider configuration"
|
||||
|
||||
@@ -22,7 +22,6 @@ from onyx.document_index.vespa_constants import HIDDEN
|
||||
from onyx.document_index.vespa_constants import IMAGE_FILE_NAME
|
||||
from onyx.document_index.vespa_constants import METADATA_LIST
|
||||
from onyx.document_index.vespa_constants import METADATA_SUFFIX
|
||||
from onyx.document_index.vespa_constants import PERSONAS
|
||||
from onyx.document_index.vespa_constants import PRIMARY_OWNERS
|
||||
from onyx.document_index.vespa_constants import SECONDARY_OWNERS
|
||||
from onyx.document_index.vespa_constants import SEMANTIC_IDENTIFIER
|
||||
@@ -59,7 +58,6 @@ FIELDS_NEEDED_FOR_TRANSFORMATION: list[str] = [
|
||||
METADATA_SUFFIX,
|
||||
DOCUMENT_SETS,
|
||||
USER_PROJECT,
|
||||
PERSONAS,
|
||||
PRIMARY_OWNERS,
|
||||
SECONDARY_OWNERS,
|
||||
ACCESS_CONTROL_LIST,
|
||||
@@ -278,7 +276,6 @@ def transform_vespa_chunks_to_opensearch_chunks(
|
||||
)
|
||||
)
|
||||
user_projects: list[int] | None = vespa_chunk.get(USER_PROJECT)
|
||||
personas: list[int] | None = vespa_chunk.get(PERSONAS)
|
||||
primary_owners: list[str] | None = vespa_chunk.get(PRIMARY_OWNERS)
|
||||
secondary_owners: list[str] | None = vespa_chunk.get(SECONDARY_OWNERS)
|
||||
|
||||
@@ -328,7 +325,6 @@ def transform_vespa_chunks_to_opensearch_chunks(
|
||||
metadata_suffix=metadata_suffix,
|
||||
document_sets=document_sets,
|
||||
user_projects=user_projects,
|
||||
personas=personas,
|
||||
primary_owners=primary_owners,
|
||||
secondary_owners=secondary_owners,
|
||||
tenant_id=tenant_state,
|
||||
|
||||
@@ -5,14 +5,11 @@ from uuid import UUID
|
||||
|
||||
import httpx
|
||||
import sqlalchemy as sa
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from retry import retry
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
@@ -27,14 +24,12 @@ from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
|
||||
from onyx.configs.constants import USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH
|
||||
from onyx.connectors.file.connector import LocalFileConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
@@ -80,58 +75,10 @@ def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_project_sync_queued_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_QUEUED_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_delete_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_DELETE_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def get_user_file_project_sync_queue_depth(celery_app: Celery) -> int:
|
||||
redis_celery: Redis = celery_app.broker_connection().channel().client # type: ignore
|
||||
return celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC, redis_celery
|
||||
)
|
||||
|
||||
|
||||
def enqueue_user_file_project_sync_task(
|
||||
*,
|
||||
celery_app: Celery,
|
||||
redis_client: Redis,
|
||||
user_file_id: str | UUID,
|
||||
tenant_id: str,
|
||||
priority: OnyxCeleryPriority = OnyxCeleryPriority.HIGH,
|
||||
) -> bool:
|
||||
"""Enqueue a project-sync task if no matching queued task already exists."""
|
||||
queued_key = _user_file_project_sync_queued_key(user_file_id)
|
||||
|
||||
# NX+EX gives us atomic dedupe and a self-healing TTL.
|
||||
queued_guard_set = redis_client.set(
|
||||
queued_key,
|
||||
1,
|
||||
nx=True,
|
||||
ex=CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES,
|
||||
)
|
||||
if not queued_guard_set:
|
||||
return False
|
||||
|
||||
try:
|
||||
celery_app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
|
||||
priority=priority,
|
||||
expires=CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES,
|
||||
)
|
||||
except Exception:
|
||||
# Roll back the queued guard if task publish fails.
|
||||
redis_client.delete(queued_key)
|
||||
raise
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2, jitter=(0.0, 1.0))
|
||||
def _visit_chunks(
|
||||
*,
|
||||
@@ -685,8 +632,8 @@ def process_single_user_file_delete(
|
||||
ignore_result=True,
|
||||
)
|
||||
def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
|
||||
"""Scan for user files needing project sync and enqueue per-file tasks."""
|
||||
task_logger.info("Starting")
|
||||
"""Scan for user files with PROJECT_SYNC status and enqueue per-file tasks."""
|
||||
task_logger.info("check_for_user_file_project_sync - Starting")
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
lock: RedisLock = redis_client.lock(
|
||||
@@ -698,25 +645,13 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
enqueued = 0
|
||||
skipped_guard = 0
|
||||
try:
|
||||
queue_depth = get_user_file_project_sync_queue_depth(self.app)
|
||||
if queue_depth > USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH:
|
||||
task_logger.warning(
|
||||
f"Queue depth {queue_depth} exceeds "
|
||||
f"{USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH}, skipping enqueue for tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file_ids = (
|
||||
db_session.execute(
|
||||
select(UserFile.id).where(
|
||||
sa.and_(
|
||||
sa.or_(
|
||||
UserFile.needs_project_sync.is_(True),
|
||||
UserFile.needs_persona_sync.is_(True),
|
||||
),
|
||||
UserFile.needs_project_sync.is_(True),
|
||||
UserFile.status == UserFileStatus.COMPLETED,
|
||||
)
|
||||
)
|
||||
@@ -726,23 +661,19 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
|
||||
)
|
||||
|
||||
for user_file_id in user_file_ids:
|
||||
if not enqueue_user_file_project_sync_task(
|
||||
celery_app=self.app,
|
||||
redis_client=redis_client,
|
||||
user_file_id=user_file_id,
|
||||
tenant_id=tenant_id,
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
):
|
||||
skipped_guard += 1
|
||||
continue
|
||||
)
|
||||
enqueued += 1
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"Enqueued {enqueued} "
|
||||
f"Skipped guard {skipped_guard} tasks for tenant={tenant_id}"
|
||||
f"check_for_user_file_project_sync - Enqueued {enqueued} tasks for tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -761,8 +692,6 @@ def process_single_user_file_project_sync(
|
||||
)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_project_sync_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
|
||||
@@ -776,11 +705,7 @@ def process_single_user_file_project_sync(
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file = db_session.execute(
|
||||
select(UserFile)
|
||||
.where(UserFile.id == _as_uuid(user_file_id))
|
||||
.options(selectinload(UserFile.assistants))
|
||||
).scalar_one_or_none()
|
||||
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if not user_file:
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - User file not found id={user_file_id}"
|
||||
@@ -808,17 +733,13 @@ def process_single_user_file_project_sync(
|
||||
]
|
||||
|
||||
project_ids = [project.id for project in user_file.projects]
|
||||
persona_ids = [p.id for p in user_file.assistants if not p.deleted]
|
||||
for retry_document_index in retry_document_indices:
|
||||
retry_document_index.update_single(
|
||||
doc_id=str(user_file.id),
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=user_file.chunk_count,
|
||||
fields=None,
|
||||
user_fields=VespaDocumentUserFields(
|
||||
user_projects=project_ids,
|
||||
personas=persona_ids,
|
||||
),
|
||||
user_fields=VespaDocumentUserFields(user_projects=project_ids),
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
@@ -826,7 +747,6 @@ def process_single_user_file_project_sync(
|
||||
)
|
||||
|
||||
user_file.needs_project_sync = False
|
||||
user_file.needs_persona_sync = False
|
||||
user_file.last_project_sync_at = datetime.datetime.now(
|
||||
datetime.timezone.utc
|
||||
)
|
||||
|
||||
@@ -58,8 +58,6 @@ from onyx.file_store.document_batch_storage import DocumentBatchStorage
|
||||
from onyx.file_store.document_batch_storage import get_document_batch_storage
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
|
||||
from onyx.indexing.postgres_sanitization import sanitize_document_for_postgres
|
||||
from onyx.indexing.postgres_sanitization import sanitize_hierarchy_nodes_for_postgres
|
||||
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
|
||||
from onyx.redis.redis_hierarchy import ensure_source_node_exists
|
||||
from onyx.redis.redis_hierarchy import get_node_id_from_raw_id
|
||||
@@ -158,7 +156,36 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
|
||||
logger.warning(
|
||||
f"doc {doc.id} too large, Document size: {sys.getsizeof(doc)}"
|
||||
)
|
||||
cleaned_batch.append(sanitize_document_for_postgres(doc))
|
||||
cleaned_doc = doc.model_copy()
|
||||
|
||||
# Postgres cannot handle NUL characters in text fields
|
||||
if "\x00" in cleaned_doc.id:
|
||||
logger.warning(f"NUL characters found in document ID: {cleaned_doc.id}")
|
||||
cleaned_doc.id = cleaned_doc.id.replace("\x00", "")
|
||||
|
||||
if cleaned_doc.title and "\x00" in cleaned_doc.title:
|
||||
logger.warning(
|
||||
f"NUL characters found in document title: {cleaned_doc.title}"
|
||||
)
|
||||
cleaned_doc.title = cleaned_doc.title.replace("\x00", "")
|
||||
|
||||
if "\x00" in cleaned_doc.semantic_identifier:
|
||||
logger.warning(
|
||||
f"NUL characters found in document semantic identifier: {cleaned_doc.semantic_identifier}"
|
||||
)
|
||||
cleaned_doc.semantic_identifier = cleaned_doc.semantic_identifier.replace(
|
||||
"\x00", ""
|
||||
)
|
||||
|
||||
for section in cleaned_doc.sections:
|
||||
if section.link is not None:
|
||||
section.link = section.link.replace("\x00", "")
|
||||
|
||||
# since text can be longer, just replace to avoid double scan
|
||||
if isinstance(section, TextSection) and section.text is not None:
|
||||
section.text = section.text.replace("\x00", "")
|
||||
|
||||
cleaned_batch.append(cleaned_doc)
|
||||
|
||||
return cleaned_batch
|
||||
|
||||
@@ -575,13 +602,10 @@ def connector_document_extraction(
|
||||
|
||||
# Process hierarchy nodes batch - upsert to Postgres and cache in Redis
|
||||
if hierarchy_node_batch:
|
||||
hierarchy_node_batch_cleaned = (
|
||||
sanitize_hierarchy_nodes_for_postgres(hierarchy_node_batch)
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
upserted_nodes = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=hierarchy_node_batch_cleaned,
|
||||
nodes=hierarchy_node_batch,
|
||||
source=db_connector.source,
|
||||
commit=True,
|
||||
is_connector_public=is_connector_public,
|
||||
@@ -600,7 +624,7 @@ def connector_document_extraction(
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Persisted and cached {len(hierarchy_node_batch_cleaned)} hierarchy nodes "
|
||||
f"Persisted and cached {len(hierarchy_node_batch)} hierarchy nodes "
|
||||
f"for attempt={index_attempt_id}"
|
||||
)
|
||||
|
||||
|
||||
@@ -30,7 +30,6 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.context.search.models import SearchDocsResponse
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.memory import add_memory
|
||||
from onyx.db.memory import update_memory_at_index
|
||||
from onyx.db.memory import UserMemoryContext
|
||||
@@ -657,12 +656,7 @@ def run_llm_loop(
|
||||
fallback_extraction_attempted: bool = False
|
||||
citation_mapping: dict[int, str] = {} # Maps citation_num -> document_id/URL
|
||||
|
||||
# Fetch this in a short-lived session so the long-running stream loop does
|
||||
# not pin a connection just to keep read state alive.
|
||||
with get_session_with_current_tenant() as prompt_db_session:
|
||||
default_base_system_prompt: str = get_default_base_system_prompt(
|
||||
prompt_db_session
|
||||
)
|
||||
default_base_system_prompt: str = get_default_base_system_prompt(db_session)
|
||||
system_prompt = None
|
||||
custom_agent_prompt_msg = None
|
||||
|
||||
|
||||
@@ -856,11 +856,6 @@ def handle_stream_message_objects(
|
||||
reserved_tokens=reserved_token_count,
|
||||
)
|
||||
|
||||
# Release any read transaction before entering the long-running LLM stream.
|
||||
# Without this, the request-scoped session can keep a connection checked out
|
||||
# for the full stream duration.
|
||||
db_session.commit()
|
||||
|
||||
# The stream generator can resume on a different worker thread after early yields.
|
||||
# Set this right before launching the LLM loop so run_in_background copies the right context.
|
||||
if new_msg_req.mock_llm_response is not None:
|
||||
|
||||
@@ -210,10 +210,10 @@ AUTH_COOKIE_EXPIRE_TIME_SECONDS = int(
|
||||
REQUIRE_EMAIL_VERIFICATION = (
|
||||
os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true"
|
||||
)
|
||||
SMTP_SERVER = os.environ.get("SMTP_SERVER") or ""
|
||||
SMTP_SERVER = os.environ.get("SMTP_SERVER") or "smtp.gmail.com"
|
||||
SMTP_PORT = int(os.environ.get("SMTP_PORT") or "587")
|
||||
SMTP_USER = os.environ.get("SMTP_USER") or ""
|
||||
SMTP_PASS = os.environ.get("SMTP_PASS") or ""
|
||||
SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com")
|
||||
SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
|
||||
EMAIL_FROM = os.environ.get("EMAIL_FROM") or SMTP_USER
|
||||
|
||||
SENDGRID_API_KEY = os.environ.get("SENDGRID_API_KEY") or ""
|
||||
|
||||
@@ -167,14 +167,6 @@ CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
|
||||
# beat generator stops adding more. Prevents unbounded queue growth when workers
|
||||
# fall behind.
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
|
||||
# How long a queued user-file-project-sync task remains valid.
|
||||
# Should be short enough to discard stale queue entries under load while still
|
||||
# allowing workers enough time to pick up new tasks.
|
||||
CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES = 60 # 1 minute (in seconds)
|
||||
|
||||
# Max queue depth before user-file-project-sync producers stop enqueuing.
|
||||
# This applies backpressure when workers are falling behind.
|
||||
USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH = 500
|
||||
|
||||
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
|
||||
@@ -467,7 +459,6 @@ class OnyxRedisLocks:
|
||||
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
|
||||
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
|
||||
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
|
||||
USER_FILE_PROJECT_SYNC_QUEUED_PREFIX = "da_lock:user_file_project_sync_queued"
|
||||
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"
|
||||
USER_FILE_DELETE_LOCK_PREFIX = "da_lock:user_file_delete"
|
||||
|
||||
|
||||
@@ -16,22 +16,6 @@ from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_RATE_LIMIT_REASONS = {"userRateLimitExceeded", "rateLimitExceeded"}
|
||||
|
||||
|
||||
def _is_rate_limit_error(error: HttpError) -> bool:
|
||||
"""Google sometimes returns rate-limit errors as 403 with reason
|
||||
'userRateLimitExceeded' instead of 429. This helper detects both."""
|
||||
if error.resp.status == 429:
|
||||
return True
|
||||
if error.resp.status != 403:
|
||||
return False
|
||||
error_details = getattr(error, "error_details", None) or []
|
||||
for detail in error_details:
|
||||
if isinstance(detail, dict) and detail.get("reason") in _RATE_LIMIT_REASONS:
|
||||
return True
|
||||
return "userRateLimitExceeded" in str(error) or "rateLimitExceeded" in str(error)
|
||||
|
||||
|
||||
# Google Drive APIs are quite flakey and may 500 for an
|
||||
# extended period of time. This is now addressed by checkpointing.
|
||||
@@ -73,7 +57,7 @@ def _execute_with_retry(request: Any) -> Any:
|
||||
except HttpError as error:
|
||||
attempt += 1
|
||||
|
||||
if _is_rate_limit_error(error):
|
||||
if error.resp.status == 429:
|
||||
# Attempt to get 'Retry-After' from headers
|
||||
retry_after = error.resp.get("Retry-After")
|
||||
if retry_after:
|
||||
@@ -156,16 +140,16 @@ def _execute_single_retrieval(
|
||||
)
|
||||
logger.error(f"Error executing request: {e}")
|
||||
raise e
|
||||
elif _is_rate_limit_error(e):
|
||||
results = _execute_with_retry(
|
||||
lambda: retrieval_function(**request_kwargs).execute()
|
||||
)
|
||||
elif e.resp.status == 404 or e.resp.status == 403:
|
||||
if continue_on_404_or_403:
|
||||
logger.debug(f"Error executing request: {e}")
|
||||
results = {}
|
||||
else:
|
||||
raise e
|
||||
elif e.resp.status == 429:
|
||||
results = _execute_with_retry(
|
||||
lambda: retrieval_function(**request_kwargs).execute()
|
||||
)
|
||||
else:
|
||||
logger.exception("Error executing request:")
|
||||
raise e
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
"""Inverse mapping from user-facing Microsoft host URLs to the SDK's AzureEnvironment.
|
||||
|
||||
The office365 library's GraphClient requires an ``AzureEnvironment`` string
|
||||
(e.g. ``"Global"``, ``"GCC High"``) to route requests to the correct national
|
||||
cloud. Our connectors instead expose free-text ``authority_host`` and
|
||||
``graph_api_host`` fields so the frontend doesn't need to know about SDK
|
||||
internals.
|
||||
|
||||
This module bridges the gap: given the two host URLs the user configured, it
|
||||
resolves the matching ``AzureEnvironment`` value (and the implied SharePoint
|
||||
domain suffix) so callers can pass ``environment=…`` to ``GraphClient``.
|
||||
"""
|
||||
|
||||
from office365.graph_client import AzureEnvironment # type: ignore[import-untyped]
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
|
||||
|
||||
class MicrosoftGraphEnvironment(BaseModel):
|
||||
"""One row of the inverse mapping."""
|
||||
|
||||
environment: str
|
||||
graph_host: str
|
||||
authority_host: str
|
||||
sharepoint_domain_suffix: str
|
||||
|
||||
|
||||
_ENVIRONMENTS: list[MicrosoftGraphEnvironment] = [
|
||||
MicrosoftGraphEnvironment(
|
||||
environment=AzureEnvironment.Global,
|
||||
graph_host="https://graph.microsoft.com",
|
||||
authority_host="https://login.microsoftonline.com",
|
||||
sharepoint_domain_suffix="sharepoint.com",
|
||||
),
|
||||
MicrosoftGraphEnvironment(
|
||||
environment=AzureEnvironment.USGovernmentHigh,
|
||||
graph_host="https://graph.microsoft.us",
|
||||
authority_host="https://login.microsoftonline.us",
|
||||
sharepoint_domain_suffix="sharepoint.us",
|
||||
),
|
||||
MicrosoftGraphEnvironment(
|
||||
environment=AzureEnvironment.USGovernmentDoD,
|
||||
graph_host="https://dod-graph.microsoft.us",
|
||||
authority_host="https://login.microsoftonline.us",
|
||||
sharepoint_domain_suffix="sharepoint.us",
|
||||
),
|
||||
MicrosoftGraphEnvironment(
|
||||
environment=AzureEnvironment.China,
|
||||
graph_host="https://microsoftgraph.chinacloudapi.cn",
|
||||
authority_host="https://login.chinacloudapi.cn",
|
||||
sharepoint_domain_suffix="sharepoint.cn",
|
||||
),
|
||||
MicrosoftGraphEnvironment(
|
||||
environment=AzureEnvironment.Germany,
|
||||
graph_host="https://graph.microsoft.de",
|
||||
authority_host="https://login.microsoftonline.de",
|
||||
sharepoint_domain_suffix="sharepoint.de",
|
||||
),
|
||||
]
|
||||
|
||||
_GRAPH_HOST_INDEX: dict[str, MicrosoftGraphEnvironment] = {
|
||||
env.graph_host: env for env in _ENVIRONMENTS
|
||||
}
|
||||
|
||||
|
||||
def resolve_microsoft_environment(
|
||||
graph_api_host: str,
|
||||
authority_host: str,
|
||||
) -> MicrosoftGraphEnvironment:
|
||||
"""Return the ``MicrosoftGraphEnvironment`` that matches the supplied hosts.
|
||||
|
||||
Raises ``ConnectorValidationError`` when the combination is unknown or
|
||||
internally inconsistent (e.g. a GCC-High graph host paired with a
|
||||
commercial authority host).
|
||||
"""
|
||||
graph_api_host = graph_api_host.rstrip("/")
|
||||
authority_host = authority_host.rstrip("/")
|
||||
|
||||
env = _GRAPH_HOST_INDEX.get(graph_api_host)
|
||||
if env is None:
|
||||
known = ", ".join(sorted(_GRAPH_HOST_INDEX))
|
||||
raise ConnectorValidationError(
|
||||
f"Unsupported Microsoft Graph API host '{graph_api_host}'. "
|
||||
f"Recognised hosts: {known}"
|
||||
)
|
||||
|
||||
if env.authority_host != authority_host:
|
||||
raise ConnectorValidationError(
|
||||
f"Authority host '{authority_host}' is inconsistent with "
|
||||
f"graph API host '{graph_api_host}'. "
|
||||
f"Expected authority host '{env.authority_host}' "
|
||||
f"for the {env.environment} environment."
|
||||
)
|
||||
|
||||
return env
|
||||
@@ -6,7 +6,6 @@ from typing import cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
from pydantic import model_validator
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
@@ -168,14 +167,6 @@ class DocumentBase(BaseModel):
|
||||
# list of strings.
|
||||
metadata: dict[str, str | list[str]]
|
||||
|
||||
@field_validator("metadata", mode="before")
|
||||
@classmethod
|
||||
def _coerce_metadata_values(cls, v: dict[str, Any]) -> dict[str, str | list[str]]:
|
||||
return {
|
||||
key: [str(item) for item in val] if isinstance(val, list) else str(val)
|
||||
for key, val in v.items()
|
||||
}
|
||||
|
||||
# UTC time
|
||||
doc_updated_at: datetime | None = None
|
||||
chunk_count: int | None = None
|
||||
|
||||
@@ -47,7 +47,6 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import IndexingHeartbeatInterface
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.microsoft_graph_env import resolve_microsoft_environment
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
@@ -147,9 +146,7 @@ class DriveItemData(BaseModel):
|
||||
self.id,
|
||||
ResourcePath("items", ResourcePath(self.drive_id, ResourcePath("drives"))),
|
||||
)
|
||||
item = DriveItem(graph_client, path)
|
||||
item.set_property("id", self.id)
|
||||
return item
|
||||
return DriveItem(graph_client, path)
|
||||
|
||||
|
||||
# The office365 library's ClientContext caches the access token from its
|
||||
@@ -840,20 +837,10 @@ class SharepointConnector(
|
||||
self._cached_rest_ctx: ClientContext | None = None
|
||||
self._cached_rest_ctx_url: str | None = None
|
||||
self._cached_rest_ctx_created_at: float = 0.0
|
||||
|
||||
resolved_env = resolve_microsoft_environment(graph_api_host, authority_host)
|
||||
self._azure_environment = resolved_env.environment
|
||||
self.authority_host = resolved_env.authority_host
|
||||
self.graph_api_host = resolved_env.graph_host
|
||||
self.authority_host = authority_host.rstrip("/")
|
||||
self.graph_api_host = graph_api_host.rstrip("/")
|
||||
self.graph_api_base = f"{self.graph_api_host}/v1.0"
|
||||
self.sharepoint_domain_suffix = resolved_env.sharepoint_domain_suffix
|
||||
if sharepoint_domain_suffix != resolved_env.sharepoint_domain_suffix:
|
||||
logger.warning(
|
||||
f"Configured sharepoint_domain_suffix '{sharepoint_domain_suffix}' "
|
||||
f"differs from the expected suffix '{resolved_env.sharepoint_domain_suffix}' "
|
||||
f"for the {resolved_env.environment} environment. "
|
||||
f"Using '{resolved_env.sharepoint_domain_suffix}'."
|
||||
)
|
||||
self.sharepoint_domain_suffix = sharepoint_domain_suffix
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
# Validate that at least one content type is enabled
|
||||
@@ -1605,7 +1592,6 @@ class SharepointConnector(
|
||||
if certificate_data is None:
|
||||
raise RuntimeError("Failed to load certificate")
|
||||
|
||||
logger.info(f"Creating MSAL app with authority url {authority_url}")
|
||||
self.msal_app = msal.ConfidentialClientApplication(
|
||||
authority=authority_url,
|
||||
client_id=sp_client_id,
|
||||
@@ -1637,9 +1623,7 @@ class SharepointConnector(
|
||||
raise ConnectorValidationError("Failed to acquire token for graph")
|
||||
return token
|
||||
|
||||
self._graph_client = GraphClient(
|
||||
_acquire_token_for_graph, environment=self._azure_environment
|
||||
)
|
||||
self._graph_client = GraphClient(_acquire_token_for_graph)
|
||||
if auth_method == SharepointAuthMethod.CERTIFICATE.value:
|
||||
org = self.graph_client.organization.get().execute_query()
|
||||
if not org or len(org) == 0:
|
||||
|
||||
@@ -11,7 +11,6 @@ from dateutil import parser
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
@@ -259,21 +258,3 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
slim_doc_batch = []
|
||||
if slim_doc_batch:
|
||||
yield slim_doc_batch
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""
|
||||
Very basic validation, we could do more here
|
||||
"""
|
||||
if not self.base_url.startswith("https://") and not self.base_url.startswith(
|
||||
"http://"
|
||||
):
|
||||
raise ConnectorValidationError(
|
||||
"Base URL must start with https:// or http://"
|
||||
)
|
||||
|
||||
try:
|
||||
get_all_post_ids(self.slab_bot_token)
|
||||
except ConnectorMissingCredentialError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise ConnectorValidationError(f"Failed to fetch posts from Slab: {e}")
|
||||
|
||||
@@ -23,7 +23,6 @@ from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.microsoft_graph_env import resolve_microsoft_environment
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
@@ -74,11 +73,8 @@ class TeamsConnector(
|
||||
self.msal_app: msal.ConfidentialClientApplication | None = None
|
||||
self.max_workers = max_workers
|
||||
self.requested_team_list: list[str] = teams
|
||||
|
||||
resolved_env = resolve_microsoft_environment(graph_api_host, authority_host)
|
||||
self._azure_environment = resolved_env.environment
|
||||
self.authority_host = resolved_env.authority_host
|
||||
self.graph_api_host = resolved_env.graph_host
|
||||
self.authority_host = authority_host.rstrip("/")
|
||||
self.graph_api_host = graph_api_host.rstrip("/")
|
||||
|
||||
# impls for BaseConnector
|
||||
|
||||
@@ -110,9 +106,7 @@ class TeamsConnector(
|
||||
|
||||
return token
|
||||
|
||||
self.graph_client = GraphClient(
|
||||
_acquire_token_func, environment=self._azure_environment
|
||||
)
|
||||
self.graph_client = GraphClient(_acquire_token_func)
|
||||
return None
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import CodeInterpreterServer
|
||||
|
||||
|
||||
def fetch_code_interpreter_server(
|
||||
db_session: Session,
|
||||
) -> CodeInterpreterServer:
|
||||
server = db_session.scalars(select(CodeInterpreterServer)).one()
|
||||
return server
|
||||
|
||||
|
||||
def update_code_interpreter_server_enabled(
|
||||
db_session: Session,
|
||||
enabled: bool,
|
||||
) -> CodeInterpreterServer:
|
||||
server = db_session.scalars(select(CodeInterpreterServer)).one()
|
||||
server.server_enabled = enabled
|
||||
db_session.commit()
|
||||
return server
|
||||
@@ -213,12 +213,8 @@ def upsert_llm_provider(
|
||||
llm_provider_upsert_request: LLMProviderUpsertRequest,
|
||||
db_session: Session,
|
||||
) -> LLMProviderView:
|
||||
existing_llm_provider = (
|
||||
fetch_existing_llm_provider_by_id(
|
||||
id=llm_provider_upsert_request.id, db_session=db_session
|
||||
)
|
||||
if llm_provider_upsert_request.id
|
||||
else None
|
||||
existing_llm_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
)
|
||||
|
||||
if not existing_llm_provider:
|
||||
@@ -242,6 +238,11 @@ def upsert_llm_provider(
|
||||
existing_llm_provider.api_base = api_base
|
||||
existing_llm_provider.api_version = llm_provider_upsert_request.api_version
|
||||
existing_llm_provider.custom_config = custom_config
|
||||
# TODO: Remove default model name on api change
|
||||
# Needed due to /provider/{id}/default endpoint not disclosing the default model name
|
||||
existing_llm_provider.default_model_name = (
|
||||
llm_provider_upsert_request.default_model_name
|
||||
)
|
||||
existing_llm_provider.is_public = llm_provider_upsert_request.is_public
|
||||
existing_llm_provider.is_auto_mode = llm_provider_upsert_request.is_auto_mode
|
||||
existing_llm_provider.deployment_name = llm_provider_upsert_request.deployment_name
|
||||
@@ -250,10 +251,6 @@ def upsert_llm_provider(
|
||||
# If its not already in the db, we need to generate an ID by flushing
|
||||
db_session.flush()
|
||||
|
||||
models_to_exist = {
|
||||
mc.name for mc in llm_provider_upsert_request.model_configurations
|
||||
}
|
||||
|
||||
# Build a lookup of existing model configurations by name (single iteration)
|
||||
existing_by_name = {
|
||||
mc.name: mc for mc in existing_llm_provider.model_configurations
|
||||
@@ -309,6 +306,15 @@ def upsert_llm_provider(
|
||||
display_name=model_config.display_name,
|
||||
)
|
||||
|
||||
default_model = fetch_default_model(db_session, LLMModelFlowType.CHAT)
|
||||
if default_model and default_model.llm_provider_id == existing_llm_provider.id:
|
||||
_update_default_model(
|
||||
db_session=db_session,
|
||||
provider_id=existing_llm_provider.id,
|
||||
model=existing_llm_provider.default_model_name,
|
||||
flow_type=LLMModelFlowType.CHAT,
|
||||
)
|
||||
|
||||
# Make sure the relationship table stays up to date
|
||||
update_group_llm_provider_relationships__no_commit(
|
||||
llm_provider_id=existing_llm_provider.id,
|
||||
@@ -482,22 +488,6 @@ def fetch_existing_llm_provider(
|
||||
return provider_model
|
||||
|
||||
|
||||
def fetch_existing_llm_provider_by_id(
|
||||
id: int, db_session: Session
|
||||
) -> LLMProviderModel | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel)
|
||||
.where(LLMProviderModel.id == id)
|
||||
.options(
|
||||
selectinload(LLMProviderModel.model_configurations),
|
||||
selectinload(LLMProviderModel.groups),
|
||||
selectinload(LLMProviderModel.personas),
|
||||
)
|
||||
)
|
||||
|
||||
return provider_model
|
||||
|
||||
|
||||
def fetch_embedding_provider(
|
||||
db_session: Session, provider_type: EmbeddingProvider
|
||||
) -> CloudEmbeddingProviderModel | None:
|
||||
@@ -614,13 +604,22 @@ def remove_llm_provider__no_commit(db_session: Session, provider_id: int) -> Non
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def update_default_provider(
|
||||
provider_id: int, model_name: str, db_session: Session
|
||||
) -> None:
|
||||
def update_default_provider(provider_id: int, db_session: Session) -> None:
|
||||
# Attempt to get the default_model_name from the provider first
|
||||
# TODO: Remove default_model_name check
|
||||
provider = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.id == provider_id,
|
||||
)
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f"LLM Provider with id={provider_id} does not exist")
|
||||
|
||||
_update_default_model(
|
||||
db_session,
|
||||
provider_id,
|
||||
model_name,
|
||||
provider.default_model_name,
|
||||
LLMModelFlowType.CHAT,
|
||||
)
|
||||
|
||||
@@ -806,6 +805,12 @@ def sync_auto_mode_models(
|
||||
)
|
||||
changes += 1
|
||||
|
||||
# In Auto mode, default model is always set from GitHub config
|
||||
default_model = llm_recommendations.get_default_model(provider.provider)
|
||||
if default_model and provider.default_model_name != default_model.name:
|
||||
provider.default_model_name = default_model.name
|
||||
changes += 1
|
||||
|
||||
db_session.commit()
|
||||
return changes
|
||||
|
||||
@@ -861,6 +866,7 @@ def insert_new_model_configuration__no_commit(
|
||||
is_visible=is_visible,
|
||||
max_input_tokens=max_input_tokens,
|
||||
display_name=display_name,
|
||||
supports_image_input=LLMModelFlowType.VISION in supported_flows,
|
||||
)
|
||||
.on_conflict_do_nothing()
|
||||
.returning(ModelConfiguration.id)
|
||||
@@ -895,6 +901,7 @@ def update_model_configuration__no_commit(
|
||||
is_visible=is_visible,
|
||||
max_input_tokens=max_input_tokens,
|
||||
display_name=display_name,
|
||||
supports_image_input=LLMModelFlowType.VISION in supported_flows,
|
||||
)
|
||||
.where(ModelConfiguration.id == model_configuration_id)
|
||||
.returning(ModelConfiguration)
|
||||
|
||||
@@ -2822,9 +2822,14 @@ class LLMProvider(Base):
|
||||
custom_config: Mapped[dict[str, str] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
default_model_name: Mapped[str] = mapped_column(String)
|
||||
|
||||
deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
# should only be set for a single provider
|
||||
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
|
||||
is_default_vision_provider: Mapped[bool | None] = mapped_column(Boolean)
|
||||
default_vision_model: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
# EE only
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
# Auto mode: models, visibility, and defaults are managed by GitHub config
|
||||
@@ -2874,6 +2879,8 @@ class ModelConfiguration(Base):
|
||||
# - The end-user is configuring a model and chooses not to set a max-input-tokens limit.
|
||||
max_input_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
supports_image_input: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
|
||||
# Human-readable display name for the model.
|
||||
# For dynamic providers (OpenRouter, Bedrock, Ollama), this comes from the source API.
|
||||
# For static providers (OpenAI, Anthropic), this may be null and will fall back to LiteLLM.
|
||||
@@ -4263,9 +4270,6 @@ class UserFile(Base):
|
||||
needs_project_sync: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
needs_persona_sync: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
last_project_sync_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
@@ -4936,11 +4940,6 @@ class ScimUserMapping(Base):
|
||||
ForeignKey("user.id", ondelete="CASCADE"), unique=True, nullable=False
|
||||
)
|
||||
scim_username: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
department: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
manager: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
given_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
family_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
scim_emails_json: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
|
||||
@@ -765,9 +765,6 @@ def mark_persona_as_deleted(
|
||||
) -> None:
|
||||
persona = get_persona_by_id(persona_id=persona_id, user=user, db_session=db_session)
|
||||
persona.deleted = True
|
||||
affected_file_ids = [uf.id for uf in persona.user_files]
|
||||
if affected_file_ids:
|
||||
_mark_files_need_persona_sync(db_session, affected_file_ids)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -779,13 +776,11 @@ def mark_persona_as_not_deleted(
|
||||
persona = get_persona_by_id(
|
||||
persona_id=persona_id, user=user, db_session=db_session, include_deleted=True
|
||||
)
|
||||
if not persona.deleted:
|
||||
if persona.deleted:
|
||||
persona.deleted = False
|
||||
db_session.commit()
|
||||
else:
|
||||
raise ValueError(f"Persona with ID {persona_id} is not deleted.")
|
||||
persona.deleted = False
|
||||
affected_file_ids = [uf.id for uf in persona.user_files]
|
||||
if affected_file_ids:
|
||||
_mark_files_need_persona_sync(db_session, affected_file_ids)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def mark_delete_persona_by_name(
|
||||
@@ -851,20 +846,6 @@ def update_personas_display_priority(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _mark_files_need_persona_sync(
|
||||
db_session: Session,
|
||||
user_file_ids: list[UUID],
|
||||
) -> None:
|
||||
"""Flag the given UserFile rows so the background sync task picks them up
|
||||
and updates their persona metadata in the vector DB."""
|
||||
if not user_file_ids:
|
||||
return
|
||||
db_session.query(UserFile).filter(UserFile.id.in_(user_file_ids)).update(
|
||||
{UserFile.needs_persona_sync: True},
|
||||
synchronize_session=False,
|
||||
)
|
||||
|
||||
|
||||
def upsert_persona(
|
||||
user: User | None,
|
||||
name: str,
|
||||
@@ -1053,13 +1034,8 @@ def upsert_persona(
|
||||
existing_persona.tools = tools or []
|
||||
|
||||
if user_file_ids is not None:
|
||||
old_file_ids = {uf.id for uf in existing_persona.user_files}
|
||||
new_file_ids = {uf.id for uf in (user_files or [])}
|
||||
affected_file_ids = old_file_ids | new_file_ids
|
||||
existing_persona.user_files.clear()
|
||||
existing_persona.user_files = user_files or []
|
||||
if affected_file_ids:
|
||||
_mark_files_need_persona_sync(db_session, list(affected_file_ids))
|
||||
|
||||
if hierarchy_node_ids is not None:
|
||||
existing_persona.hierarchy_nodes.clear()
|
||||
@@ -1113,8 +1089,6 @@ def upsert_persona(
|
||||
attached_documents=attached_documents or [],
|
||||
)
|
||||
db_session.add(new_persona)
|
||||
if user_files:
|
||||
_mark_files_need_persona_sync(db_session, [uf.id for uf in user_files])
|
||||
persona = new_persona
|
||||
if commit:
|
||||
db_session.commit()
|
||||
|
||||
@@ -2,7 +2,6 @@ import random
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from logging import getLogger
|
||||
from uuid import UUID
|
||||
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.chat import create_chat_session
|
||||
@@ -14,26 +13,18 @@ from onyx.db.models import ChatSession
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def seed_chat_history(
|
||||
num_sessions: int,
|
||||
num_messages: int,
|
||||
days: int,
|
||||
user_id: UUID | None = None,
|
||||
persona_id: int | None = None,
|
||||
) -> None:
|
||||
def seed_chat_history(num_sessions: int, num_messages: int, days: int) -> None:
|
||||
"""Utility function to seed chat history for testing.
|
||||
|
||||
num_sessions: the number of sessions to seed
|
||||
num_messages: the number of messages to seed per sessions
|
||||
days: the number of days looking backwards from the current time over which to randomize
|
||||
the times.
|
||||
user_id: optional user to associate with sessions
|
||||
persona_id: optional persona/assistant to associate with sessions
|
||||
"""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
logger.info(f"Seeding {num_sessions} sessions.")
|
||||
for y in range(0, num_sessions):
|
||||
create_chat_session(db_session, f"pytest_session_{y}", user_id, persona_id)
|
||||
create_chat_session(db_session, f"pytest_session_{y}", None, None)
|
||||
|
||||
# randomize all session times
|
||||
logger.info(f"Seeding {num_messages} messages per session.")
|
||||
|
||||
@@ -3,7 +3,6 @@ from uuid import UUID
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import UserFile
|
||||
@@ -65,23 +64,6 @@ def fetch_user_project_ids_for_user_files(
|
||||
}
|
||||
|
||||
|
||||
def fetch_persona_ids_for_user_files(
|
||||
user_file_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, list[int]]:
|
||||
"""Fetch persona (assistant) ids for specified user files."""
|
||||
stmt = (
|
||||
select(UserFile)
|
||||
.where(UserFile.id.in_(user_file_ids))
|
||||
.options(selectinload(UserFile.assistants))
|
||||
)
|
||||
results = db_session.execute(stmt).scalars().all()
|
||||
return {
|
||||
str(user_file.id): [persona.id for persona in user_file.assistants]
|
||||
for user_file in results
|
||||
}
|
||||
|
||||
|
||||
def update_last_accessed_at_for_user_files(
|
||||
user_file_ids: list[UUID],
|
||||
db_session: Session,
|
||||
|
||||
@@ -121,7 +121,6 @@ class VespaDocumentUserFields:
|
||||
"""
|
||||
|
||||
user_projects: list[int] | None = None
|
||||
personas: list[int] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -148,7 +148,6 @@ class MetadataUpdateRequest(BaseModel):
|
||||
hidden: bool | None = None
|
||||
secondary_index_updated: bool | None = None
|
||||
project_ids: set[int] | None = None
|
||||
persona_ids: set[int] | None = None
|
||||
|
||||
|
||||
class IndexRetrievalFilters(BaseModel):
|
||||
|
||||
@@ -50,7 +50,6 @@ from onyx.document_index.opensearch.schema import DocumentSchema
|
||||
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
|
||||
from onyx.document_index.opensearch.schema import GLOBAL_BOOST_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import HIDDEN_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import PERSONAS_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import USER_PROJECTS_FIELD_NAME
|
||||
from onyx.document_index.opensearch.search import DocumentQuery
|
||||
from onyx.document_index.opensearch.search import (
|
||||
@@ -216,7 +215,6 @@ def _convert_onyx_chunk_to_opensearch_document(
|
||||
# OpenSearch and it will not store any data at all for this field, which
|
||||
# is different from supplying an empty list.
|
||||
user_projects=chunk.user_project or None,
|
||||
personas=chunk.personas or None,
|
||||
primary_owners=get_experts_stores_representations(
|
||||
chunk.source_document.primary_owners
|
||||
),
|
||||
@@ -364,11 +362,6 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
if user_fields and user_fields.user_projects
|
||||
else None
|
||||
),
|
||||
persona_ids=(
|
||||
set(user_fields.personas)
|
||||
if user_fields and user_fields.personas
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -716,10 +709,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
properties_to_update[USER_PROJECTS_FIELD_NAME] = list(
|
||||
update_request.project_ids
|
||||
)
|
||||
if update_request.persona_ids is not None:
|
||||
properties_to_update[PERSONAS_FIELD_NAME] = list(
|
||||
update_request.persona_ids
|
||||
)
|
||||
|
||||
if not properties_to_update:
|
||||
if len(update_request.document_ids) > 1:
|
||||
|
||||
@@ -41,7 +41,6 @@ IMAGE_FILE_ID_FIELD_NAME = "image_file_id"
|
||||
SOURCE_LINKS_FIELD_NAME = "source_links"
|
||||
DOCUMENT_SETS_FIELD_NAME = "document_sets"
|
||||
USER_PROJECTS_FIELD_NAME = "user_projects"
|
||||
PERSONAS_FIELD_NAME = "personas"
|
||||
DOCUMENT_ID_FIELD_NAME = "document_id"
|
||||
CHUNK_INDEX_FIELD_NAME = "chunk_index"
|
||||
MAX_CHUNK_SIZE_FIELD_NAME = "max_chunk_size"
|
||||
@@ -157,7 +156,6 @@ class DocumentChunk(BaseModel):
|
||||
|
||||
document_sets: list[str] | None = None
|
||||
user_projects: list[int] | None = None
|
||||
personas: list[int] | None = None
|
||||
primary_owners: list[str] | None = None
|
||||
secondary_owners: list[str] | None = None
|
||||
|
||||
@@ -487,7 +485,6 @@ class DocumentSchema:
|
||||
# Product-specific fields.
|
||||
DOCUMENT_SETS_FIELD_NAME: {"type": "keyword"},
|
||||
USER_PROJECTS_FIELD_NAME: {"type": "integer"},
|
||||
PERSONAS_FIELD_NAME: {"type": "integer"},
|
||||
PRIMARY_OWNERS_FIELD_NAME: {"type": "keyword"},
|
||||
SECONDARY_OWNERS_FIELD_NAME: {"type": "keyword"},
|
||||
# OpenSearch metadata fields.
|
||||
|
||||
@@ -181,11 +181,6 @@ schema {{ schema_name }} {
|
||||
rank: filter
|
||||
attribute: fast-search
|
||||
}
|
||||
field personas type array<int> {
|
||||
indexing: summary | attribute
|
||||
rank: filter
|
||||
attribute: fast-search
|
||||
}
|
||||
}
|
||||
|
||||
# If using different tokenization settings, the fieldset has to be removed, and the field must
|
||||
|
||||
@@ -689,9 +689,6 @@ class VespaIndex(DocumentIndex):
|
||||
project_ids: set[int] | None = None
|
||||
if user_fields is not None and user_fields.user_projects is not None:
|
||||
project_ids = set(user_fields.user_projects)
|
||||
persona_ids: set[int] | None = None
|
||||
if user_fields is not None and user_fields.personas is not None:
|
||||
persona_ids = set(user_fields.personas)
|
||||
update_request = MetadataUpdateRequest(
|
||||
document_ids=[doc_id],
|
||||
doc_id_to_chunk_cnt={
|
||||
@@ -702,7 +699,6 @@ class VespaIndex(DocumentIndex):
|
||||
boost=fields.boost if fields is not None else None,
|
||||
hidden=fields.hidden if fields is not None else None,
|
||||
project_ids=project_ids,
|
||||
persona_ids=persona_ids,
|
||||
)
|
||||
|
||||
vespa_document_index.update([update_request])
|
||||
|
||||
@@ -46,7 +46,6 @@ from onyx.document_index.vespa_constants import METADATA
|
||||
from onyx.document_index.vespa_constants import METADATA_LIST
|
||||
from onyx.document_index.vespa_constants import METADATA_SUFFIX
|
||||
from onyx.document_index.vespa_constants import NUM_THREADS
|
||||
from onyx.document_index.vespa_constants import PERSONAS
|
||||
from onyx.document_index.vespa_constants import PRIMARY_OWNERS
|
||||
from onyx.document_index.vespa_constants import SECONDARY_OWNERS
|
||||
from onyx.document_index.vespa_constants import SECTION_CONTINUATION
|
||||
@@ -219,7 +218,6 @@ def _index_vespa_chunk(
|
||||
# still called `image_file_name` in Vespa for backwards compatibility
|
||||
IMAGE_FILE_NAME: chunk.image_file_id,
|
||||
USER_PROJECT: chunk.user_project if chunk.user_project is not None else [],
|
||||
PERSONAS: chunk.personas if chunk.personas is not None else [],
|
||||
BOOST: chunk.boost,
|
||||
AGGREGATED_CHUNK_BOOST_FACTOR: chunk.aggregated_chunk_boost_factor,
|
||||
}
|
||||
|
||||
@@ -183,10 +183,6 @@ def _update_single_chunk(
|
||||
model_config = {"frozen": True}
|
||||
assign: list[int]
|
||||
|
||||
class _Personas(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
assign: list[int]
|
||||
|
||||
class _VespaPutFields(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
# The names of these fields are based the Vespa schema. Changes to the
|
||||
@@ -197,7 +193,6 @@ def _update_single_chunk(
|
||||
access_control_list: _AccessControl | None = None
|
||||
hidden: _Hidden | None = None
|
||||
user_project: _UserProjects | None = None
|
||||
personas: _Personas | None = None
|
||||
|
||||
class _VespaPutRequest(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
@@ -232,11 +227,6 @@ def _update_single_chunk(
|
||||
if update_request.project_ids is not None
|
||||
else None
|
||||
)
|
||||
personas_update: _Personas | None = (
|
||||
_Personas(assign=list(update_request.persona_ids))
|
||||
if update_request.persona_ids is not None
|
||||
else None
|
||||
)
|
||||
|
||||
vespa_put_fields = _VespaPutFields(
|
||||
boost=boost_update,
|
||||
@@ -244,7 +234,6 @@ def _update_single_chunk(
|
||||
access_control_list=access_update,
|
||||
hidden=hidden_update,
|
||||
user_project=user_projects_update,
|
||||
personas=personas_update,
|
||||
)
|
||||
|
||||
vespa_put_request = _VespaPutRequest(
|
||||
|
||||
@@ -58,7 +58,6 @@ DOCUMENT_SETS = "document_sets"
|
||||
USER_FILE = "user_file"
|
||||
USER_FOLDER = "user_folder"
|
||||
USER_PROJECT = "user_project"
|
||||
PERSONAS = "personas"
|
||||
LARGE_CHUNK_REFERENCE_IDS = "large_chunk_reference_ids"
|
||||
METADATA = "metadata"
|
||||
METADATA_LIST = "metadata_list"
|
||||
|
||||
@@ -12,9 +12,6 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class AzureImageGenerationProvider(ImageGenerationProvider):
|
||||
_GPT_IMAGE_MODEL_PREFIX = "gpt-image-"
|
||||
_DALL_E_2_MODEL_NAME = "dall-e-2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
@@ -56,25 +53,6 @@ class AzureImageGenerationProvider(ImageGenerationProvider):
|
||||
deployment_name=credentials.deployment_name,
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_reference_images(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def max_reference_images(self) -> int:
|
||||
# Azure GPT image models support up to 16 input images for edits.
|
||||
return 16
|
||||
|
||||
def _normalize_model_name(self, model: str) -> str:
|
||||
return model.rsplit("/", 1)[-1]
|
||||
|
||||
def _model_supports_image_edits(self, model: str) -> bool:
|
||||
normalized_model = self._normalize_model_name(model)
|
||||
return (
|
||||
normalized_model.startswith(self._GPT_IMAGE_MODEL_PREFIX)
|
||||
or normalized_model == self._DALL_E_2_MODEL_NAME
|
||||
)
|
||||
|
||||
def generate_image(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -82,44 +60,14 @@ class AzureImageGenerationProvider(ImageGenerationProvider):
|
||||
size: str,
|
||||
n: int,
|
||||
quality: str | None = None,
|
||||
reference_images: list[ReferenceImage] | None = None,
|
||||
reference_images: list[ReferenceImage] | None = None, # noqa: ARG002
|
||||
**kwargs: Any,
|
||||
) -> ImageGenerationResponse:
|
||||
from litellm import image_generation
|
||||
|
||||
deployment = self._deployment_name or model
|
||||
model_name = f"azure/{deployment}"
|
||||
|
||||
if reference_images:
|
||||
if not self._model_supports_image_edits(model):
|
||||
raise ValueError(
|
||||
f"Model '{model}' does not support image edits with reference images."
|
||||
)
|
||||
|
||||
normalized_model = self._normalize_model_name(model)
|
||||
if (
|
||||
normalized_model == self._DALL_E_2_MODEL_NAME
|
||||
and len(reference_images) > 1
|
||||
):
|
||||
raise ValueError(
|
||||
"Model 'dall-e-2' only supports a single reference image for edits."
|
||||
)
|
||||
|
||||
from litellm import image_edit
|
||||
|
||||
return image_edit(
|
||||
image=[image.data for image in reference_images],
|
||||
prompt=prompt,
|
||||
model=model_name,
|
||||
api_key=self._api_key,
|
||||
api_base=self._api_base,
|
||||
api_version=self._api_version,
|
||||
size=size,
|
||||
n=n,
|
||||
quality=quality,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
from litellm import image_generation
|
||||
|
||||
return image_generation(
|
||||
prompt=prompt,
|
||||
model=model_name,
|
||||
|
||||
@@ -12,9 +12,6 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class OpenAIImageGenerationProvider(ImageGenerationProvider):
|
||||
_GPT_IMAGE_MODEL_PREFIX = "gpt-image-"
|
||||
_DALL_E_2_MODEL_NAME = "dall-e-2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
@@ -42,25 +39,6 @@ class OpenAIImageGenerationProvider(ImageGenerationProvider):
|
||||
api_base=credentials.api_base,
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_reference_images(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def max_reference_images(self) -> int:
|
||||
# GPT image models support up to 16 input images for edits.
|
||||
return 16
|
||||
|
||||
def _normalize_model_name(self, model: str) -> str:
|
||||
return model.rsplit("/", 1)[-1]
|
||||
|
||||
def _model_supports_image_edits(self, model: str) -> bool:
|
||||
normalized_model = self._normalize_model_name(model)
|
||||
return (
|
||||
normalized_model.startswith(self._GPT_IMAGE_MODEL_PREFIX)
|
||||
or normalized_model == self._DALL_E_2_MODEL_NAME
|
||||
)
|
||||
|
||||
def generate_image(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -68,38 +46,9 @@ class OpenAIImageGenerationProvider(ImageGenerationProvider):
|
||||
size: str,
|
||||
n: int,
|
||||
quality: str | None = None,
|
||||
reference_images: list[ReferenceImage] | None = None,
|
||||
reference_images: list[ReferenceImage] | None = None, # noqa: ARG002
|
||||
**kwargs: Any,
|
||||
) -> ImageGenerationResponse:
|
||||
if reference_images:
|
||||
if not self._model_supports_image_edits(model):
|
||||
raise ValueError(
|
||||
f"Model '{model}' does not support image edits with reference images."
|
||||
)
|
||||
|
||||
normalized_model = self._normalize_model_name(model)
|
||||
if (
|
||||
normalized_model == self._DALL_E_2_MODEL_NAME
|
||||
and len(reference_images) > 1
|
||||
):
|
||||
raise ValueError(
|
||||
"Model 'dall-e-2' only supports a single reference image for edits."
|
||||
)
|
||||
|
||||
from litellm import image_edit
|
||||
|
||||
return image_edit(
|
||||
image=[image.data for image in reference_images],
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
api_key=self._api_key,
|
||||
api_base=self._api_base,
|
||||
size=size,
|
||||
n=n,
|
||||
quality=quality,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
from litellm import image_generation
|
||||
|
||||
return image_generation(
|
||||
|
||||
@@ -146,7 +146,6 @@ class DocumentIndexingBatchAdapter:
|
||||
doc_id_to_document_set.get(chunk.source_document.id, [])
|
||||
),
|
||||
user_project=[],
|
||||
personas=[],
|
||||
boost=(
|
||||
context.id_to_boost_map[chunk.source_document.id]
|
||||
if chunk.source_document.id in context.id_to_boost_map
|
||||
|
||||
@@ -20,7 +20,6 @@ from onyx.db.models import Persona
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.db.user_file import fetch_chunk_counts_for_user_files
|
||||
from onyx.db.user_file import fetch_persona_ids_for_user_files
|
||||
from onyx.db.user_file import fetch_user_project_ids_for_user_files
|
||||
from onyx.file_store.utils import store_user_file_plaintext
|
||||
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
|
||||
@@ -120,10 +119,6 @@ class UserFileIndexingAdapter:
|
||||
user_file_ids=updatable_ids,
|
||||
db_session=self.db_session,
|
||||
)
|
||||
user_file_id_to_persona_ids = fetch_persona_ids_for_user_files(
|
||||
user_file_ids=updatable_ids,
|
||||
db_session=self.db_session,
|
||||
)
|
||||
user_file_id_to_access: dict[str, DocumentAccess] = get_access_for_user_files(
|
||||
user_file_ids=updatable_ids,
|
||||
db_session=self.db_session,
|
||||
@@ -187,7 +182,7 @@ class UserFileIndexingAdapter:
|
||||
user_project=user_file_id_to_project_ids.get(
|
||||
chunk.source_document.id, []
|
||||
),
|
||||
personas=user_file_id_to_persona_ids.get(chunk.source_document.id, []),
|
||||
# we are going to index userfiles only once, so we just set the boost to the default
|
||||
boost=DEFAULT_BOOST,
|
||||
tenant_id=tenant_id,
|
||||
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],
|
||||
|
||||
@@ -49,7 +49,6 @@ from onyx.indexing.embedder import IndexingEmbedder
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import IndexingBatchAdapter
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
from onyx.indexing.postgres_sanitization import sanitize_documents_for_postgres
|
||||
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
|
||||
from onyx.llm.factory import get_default_llm_with_vision
|
||||
from onyx.llm.factory import get_llm_for_contextual_rag
|
||||
@@ -229,8 +228,6 @@ def index_doc_batch_prepare(
|
||||
) -> DocumentBatchPrepareContext | None:
|
||||
"""Sets up the documents in the relational DB (source of truth) for permissions, metadata, etc.
|
||||
This preceeds indexing it into the actual document index."""
|
||||
documents = sanitize_documents_for_postgres(documents)
|
||||
|
||||
# Create a trimmed list of docs that don't have a newer updated at
|
||||
# Shortcuts the time-consuming flow on connector index retries
|
||||
document_ids: list[str] = [document.id for document in documents]
|
||||
|
||||
@@ -112,7 +112,6 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
access: "DocumentAccess"
|
||||
document_sets: set[str]
|
||||
user_project: list[int]
|
||||
personas: list[int]
|
||||
boost: int
|
||||
aggregated_chunk_boost_factor: float
|
||||
# Full ancestor path from root hierarchy node to document's parent.
|
||||
@@ -127,7 +126,6 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
access: "DocumentAccess",
|
||||
document_sets: set[str],
|
||||
user_project: list[int],
|
||||
personas: list[int],
|
||||
boost: int,
|
||||
aggregated_chunk_boost_factor: float,
|
||||
tenant_id: str,
|
||||
@@ -139,7 +137,6 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
access=access,
|
||||
document_sets=document_sets,
|
||||
user_project=user_project,
|
||||
personas=personas,
|
||||
boost=boost,
|
||||
aggregated_chunk_boost_factor=aggregated_chunk_boost_factor,
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@@ -1,150 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
|
||||
|
||||
def _sanitize_string(value: str) -> str:
|
||||
return value.replace("\x00", "")
|
||||
|
||||
|
||||
def _sanitize_json_like(value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
return _sanitize_string(value)
|
||||
|
||||
if isinstance(value, list):
|
||||
return [_sanitize_json_like(item) for item in value]
|
||||
|
||||
if isinstance(value, tuple):
|
||||
return tuple(_sanitize_json_like(item) for item in value)
|
||||
|
||||
if isinstance(value, dict):
|
||||
sanitized: dict[Any, Any] = {}
|
||||
for key, nested_value in value.items():
|
||||
cleaned_key = _sanitize_string(key) if isinstance(key, str) else key
|
||||
sanitized[cleaned_key] = _sanitize_json_like(nested_value)
|
||||
return sanitized
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def _sanitize_expert_info(expert: BasicExpertInfo) -> BasicExpertInfo:
|
||||
return expert.model_copy(
|
||||
update={
|
||||
"display_name": (
|
||||
_sanitize_string(expert.display_name)
|
||||
if expert.display_name is not None
|
||||
else None
|
||||
),
|
||||
"first_name": (
|
||||
_sanitize_string(expert.first_name)
|
||||
if expert.first_name is not None
|
||||
else None
|
||||
),
|
||||
"middle_initial": (
|
||||
_sanitize_string(expert.middle_initial)
|
||||
if expert.middle_initial is not None
|
||||
else None
|
||||
),
|
||||
"last_name": (
|
||||
_sanitize_string(expert.last_name)
|
||||
if expert.last_name is not None
|
||||
else None
|
||||
),
|
||||
"email": (
|
||||
_sanitize_string(expert.email) if expert.email is not None else None
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_external_access(external_access: ExternalAccess) -> ExternalAccess:
|
||||
return ExternalAccess(
|
||||
external_user_emails={
|
||||
_sanitize_string(email) for email in external_access.external_user_emails
|
||||
},
|
||||
external_user_group_ids={
|
||||
_sanitize_string(group_id)
|
||||
for group_id in external_access.external_user_group_ids
|
||||
},
|
||||
is_public=external_access.is_public,
|
||||
)
|
||||
|
||||
|
||||
def sanitize_document_for_postgres(document: Document) -> Document:
|
||||
cleaned_doc = document.model_copy(deep=True)
|
||||
|
||||
cleaned_doc.id = _sanitize_string(cleaned_doc.id)
|
||||
cleaned_doc.semantic_identifier = _sanitize_string(cleaned_doc.semantic_identifier)
|
||||
if cleaned_doc.title is not None:
|
||||
cleaned_doc.title = _sanitize_string(cleaned_doc.title)
|
||||
if cleaned_doc.parent_hierarchy_raw_node_id is not None:
|
||||
cleaned_doc.parent_hierarchy_raw_node_id = _sanitize_string(
|
||||
cleaned_doc.parent_hierarchy_raw_node_id
|
||||
)
|
||||
|
||||
cleaned_doc.metadata = {
|
||||
_sanitize_string(key): (
|
||||
[_sanitize_string(item) for item in value]
|
||||
if isinstance(value, list)
|
||||
else _sanitize_string(value)
|
||||
)
|
||||
for key, value in cleaned_doc.metadata.items()
|
||||
}
|
||||
|
||||
if cleaned_doc.doc_metadata is not None:
|
||||
cleaned_doc.doc_metadata = _sanitize_json_like(cleaned_doc.doc_metadata)
|
||||
|
||||
if cleaned_doc.primary_owners is not None:
|
||||
cleaned_doc.primary_owners = [
|
||||
_sanitize_expert_info(expert) for expert in cleaned_doc.primary_owners
|
||||
]
|
||||
if cleaned_doc.secondary_owners is not None:
|
||||
cleaned_doc.secondary_owners = [
|
||||
_sanitize_expert_info(expert) for expert in cleaned_doc.secondary_owners
|
||||
]
|
||||
|
||||
if cleaned_doc.external_access is not None:
|
||||
cleaned_doc.external_access = _sanitize_external_access(
|
||||
cleaned_doc.external_access
|
||||
)
|
||||
|
||||
for section in cleaned_doc.sections:
|
||||
if section.link is not None:
|
||||
section.link = _sanitize_string(section.link)
|
||||
if section.text is not None:
|
||||
section.text = _sanitize_string(section.text)
|
||||
if section.image_file_id is not None:
|
||||
section.image_file_id = _sanitize_string(section.image_file_id)
|
||||
|
||||
return cleaned_doc
|
||||
|
||||
|
||||
def sanitize_documents_for_postgres(documents: list[Document]) -> list[Document]:
|
||||
return [sanitize_document_for_postgres(document) for document in documents]
|
||||
|
||||
|
||||
def sanitize_hierarchy_node_for_postgres(node: HierarchyNode) -> HierarchyNode:
|
||||
cleaned_node = node.model_copy(deep=True)
|
||||
|
||||
cleaned_node.raw_node_id = _sanitize_string(cleaned_node.raw_node_id)
|
||||
cleaned_node.display_name = _sanitize_string(cleaned_node.display_name)
|
||||
if cleaned_node.raw_parent_id is not None:
|
||||
cleaned_node.raw_parent_id = _sanitize_string(cleaned_node.raw_parent_id)
|
||||
if cleaned_node.link is not None:
|
||||
cleaned_node.link = _sanitize_string(cleaned_node.link)
|
||||
|
||||
if cleaned_node.external_access is not None:
|
||||
cleaned_node.external_access = _sanitize_external_access(
|
||||
cleaned_node.external_access
|
||||
)
|
||||
|
||||
return cleaned_node
|
||||
|
||||
|
||||
def sanitize_hierarchy_nodes_for_postgres(
|
||||
nodes: list[HierarchyNode],
|
||||
) -> list[HierarchyNode]:
|
||||
return [sanitize_hierarchy_node_for_postgres(node) for node in nodes]
|
||||
@@ -97,9 +97,6 @@ from onyx.server.features.web_search.api import router as web_search_router
|
||||
from onyx.server.federated.api import router as federated_router
|
||||
from onyx.server.kg.api import admin_router as kg_admin_router
|
||||
from onyx.server.manage.administrative import router as admin_router
|
||||
from onyx.server.manage.code_interpreter.api import (
|
||||
admin_router as code_interpreter_admin_router,
|
||||
)
|
||||
from onyx.server.manage.discord_bot.api import router as discord_bot_router
|
||||
from onyx.server.manage.embedding.api import admin_router as embedding_admin_router
|
||||
from onyx.server.manage.embedding.api import basic_router as embedding_router
|
||||
@@ -424,9 +421,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, llm_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, kg_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, llm_router)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, code_interpreter_admin_router
|
||||
)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, image_generation_admin_router
|
||||
)
|
||||
|
||||
@@ -1,68 +1,14 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from mistune import create_markdown
|
||||
from mistune import HTMLRenderer
|
||||
|
||||
_CITATION_LINK_PATTERN = re.compile(r"\[\[\d+\]\]\(")
|
||||
|
||||
|
||||
def _extract_link_destination(message: str, start_idx: int) -> tuple[str, int | None]:
|
||||
"""Extract markdown link destination, allowing nested parentheses in the URL."""
|
||||
depth = 0
|
||||
i = start_idx
|
||||
|
||||
while i < len(message):
|
||||
curr = message[i]
|
||||
if curr == "\\":
|
||||
i += 2
|
||||
continue
|
||||
|
||||
if curr == "(":
|
||||
depth += 1
|
||||
elif curr == ")":
|
||||
if depth == 0:
|
||||
return message[start_idx:i], i
|
||||
depth -= 1
|
||||
i += 1
|
||||
|
||||
return message[start_idx:], None
|
||||
|
||||
|
||||
def _normalize_citation_link_destinations(message: str) -> str:
|
||||
"""Wrap citation URLs in angle brackets so markdown parsers handle parentheses safely."""
|
||||
if "[[" not in message:
|
||||
return message
|
||||
|
||||
normalized_parts: list[str] = []
|
||||
cursor = 0
|
||||
|
||||
while match := _CITATION_LINK_PATTERN.search(message, cursor):
|
||||
normalized_parts.append(message[cursor : match.end()])
|
||||
destination_start = match.end()
|
||||
destination, end_idx = _extract_link_destination(message, destination_start)
|
||||
if end_idx is None:
|
||||
normalized_parts.append(message[destination_start:])
|
||||
return "".join(normalized_parts)
|
||||
|
||||
already_wrapped = destination.startswith("<") and destination.endswith(">")
|
||||
if destination and not already_wrapped:
|
||||
destination = f"<{destination}>"
|
||||
|
||||
normalized_parts.append(destination)
|
||||
normalized_parts.append(")")
|
||||
cursor = end_idx + 1
|
||||
|
||||
normalized_parts.append(message[cursor:])
|
||||
return "".join(normalized_parts)
|
||||
|
||||
|
||||
def format_slack_message(message: str | None) -> str:
|
||||
if message is None:
|
||||
return ""
|
||||
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough"])
|
||||
normalized_message = _normalize_citation_link_destinations(message)
|
||||
result = md(normalized_message)
|
||||
result = md(message)
|
||||
# With HTMLRenderer, result is always str (not AST list)
|
||||
assert isinstance(result, str)
|
||||
return result
|
||||
|
||||
@@ -762,43 +762,6 @@ def download_webapp(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{session_id}/download-directory/{path:path}")
|
||||
def download_directory(
|
||||
session_id: UUID,
|
||||
path: str,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
"""
|
||||
Download a directory as a zip file.
|
||||
|
||||
Returns the specified directory as a zip archive.
|
||||
"""
|
||||
user_id: UUID = user.id
|
||||
session_manager = SessionManager(db_session)
|
||||
|
||||
try:
|
||||
result = session_manager.download_directory(session_id, user_id, path)
|
||||
except ValueError as e:
|
||||
error_message = str(e)
|
||||
if "path traversal" in error_message.lower():
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
raise HTTPException(status_code=400, detail=error_message)
|
||||
|
||||
if result is None:
|
||||
raise HTTPException(status_code=404, detail="Directory not found")
|
||||
|
||||
zip_bytes, filename = result
|
||||
|
||||
return Response(
|
||||
content=zip_bytes,
|
||||
media_type="application/zip",
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="{filename}"',
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{session_id}/upload", response_model=UploadResponse)
|
||||
def upload_file_endpoint(
|
||||
session_id: UUID,
|
||||
|
||||
@@ -107,23 +107,27 @@ def get_or_create_craft_connector(db_session: Session, user: User) -> tuple[int,
|
||||
)
|
||||
|
||||
for cc_pair in cc_pairs:
|
||||
if (
|
||||
cc_pair.connector.source == DocumentSource.CRAFT_FILE
|
||||
and cc_pair.creator_id == user.id
|
||||
):
|
||||
if cc_pair.connector.source == DocumentSource.CRAFT_FILE:
|
||||
return cc_pair.connector.id, cc_pair.credential.id
|
||||
|
||||
# No cc_pair for this user — find or create the shared CRAFT_FILE connector
|
||||
# Check for orphaned connector (created but cc_pair creation failed previously)
|
||||
existing_connectors = fetch_connectors(
|
||||
db_session, sources=[DocumentSource.CRAFT_FILE]
|
||||
)
|
||||
connector_id: int | None = None
|
||||
orphaned_connector = None
|
||||
for conn in existing_connectors:
|
||||
if conn.name == USER_LIBRARY_CONNECTOR_NAME:
|
||||
connector_id = conn.id
|
||||
if conn.name != USER_LIBRARY_CONNECTOR_NAME:
|
||||
continue
|
||||
if not conn.credentials:
|
||||
orphaned_connector = conn
|
||||
break
|
||||
|
||||
if connector_id is None:
|
||||
if orphaned_connector:
|
||||
connector_id = orphaned_connector.id
|
||||
logger.info(
|
||||
f"Found orphaned User Library connector {connector_id}, completing setup"
|
||||
)
|
||||
else:
|
||||
connector_data = ConnectorBase(
|
||||
name=USER_LIBRARY_CONNECTOR_NAME,
|
||||
source=DocumentSource.CRAFT_FILE,
|
||||
|
||||
Binary file not shown.
@@ -1,19 +1,15 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate AGENTS.md by scanning the files directory and populating the template.
|
||||
|
||||
This script runs during session setup, AFTER files have been synced from S3
|
||||
and the files symlink has been created. It reads an existing AGENTS.md (which
|
||||
contains the {{KNOWLEDGE_SOURCES_SECTION}} placeholder), replaces the
|
||||
placeholder by scanning the knowledge source directory, and writes it back.
|
||||
This script runs at container startup, AFTER the init container has synced files
|
||||
from S3. It scans the /workspace/files directory to discover what knowledge sources
|
||||
are available and generates appropriate documentation.
|
||||
|
||||
Usage:
|
||||
python3 generate_agents_md.py <agents_md_path> <files_path>
|
||||
|
||||
Arguments:
|
||||
agents_md_path: Path to the AGENTS.md file to update in place
|
||||
files_path: Path to the files directory to scan for knowledge sources
|
||||
Environment variables:
|
||||
- AGENT_INSTRUCTIONS: The template content with placeholders to replace
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
@@ -193,39 +189,49 @@ def build_knowledge_sources_section(files_path: Path) -> str:
|
||||
def main() -> None:
|
||||
"""Main entry point for container startup script.
|
||||
|
||||
Reads an existing AGENTS.md, replaces the {{KNOWLEDGE_SOURCES_SECTION}}
|
||||
placeholder by scanning the files directory, and writes it back.
|
||||
|
||||
Usage:
|
||||
python3 generate_agents_md.py <agents_md_path> <files_path>
|
||||
Is called by the container startup script to scan /workspace/files and populate
|
||||
the knowledge sources section.
|
||||
"""
|
||||
if len(sys.argv) != 3:
|
||||
print(
|
||||
f"Usage: {sys.argv[0]} <agents_md_path> <files_path>",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
# Read template from environment variable
|
||||
template = os.environ.get("AGENT_INSTRUCTIONS", "")
|
||||
if not template:
|
||||
print("Warning: No AGENT_INSTRUCTIONS template provided", file=sys.stderr)
|
||||
template = "# Agent Instructions\n\nNo instructions provided."
|
||||
|
||||
agents_md_path = Path(sys.argv[1])
|
||||
files_path = Path(sys.argv[2])
|
||||
# Scan files directory - check /workspace/files first, then /workspace/demo_data
|
||||
files_path = Path("/workspace/files")
|
||||
demo_data_path = Path("/workspace/demo_data")
|
||||
|
||||
if not agents_md_path.exists():
|
||||
print(f"Error: {agents_md_path} not found", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
# Use demo_data if files doesn't exist or is empty
|
||||
if not files_path.exists() or not any(files_path.iterdir()):
|
||||
if demo_data_path.exists():
|
||||
files_path = demo_data_path
|
||||
|
||||
template = agents_md_path.read_text()
|
||||
knowledge_sources_section = build_knowledge_sources_section(files_path)
|
||||
|
||||
# Resolve symlinks (handles both direct symlinks and dirs containing symlinks)
|
||||
resolved_files_path = files_path.resolve()
|
||||
|
||||
knowledge_sources_section = build_knowledge_sources_section(resolved_files_path)
|
||||
|
||||
# Replace placeholder and write back
|
||||
content = template.replace(
|
||||
# Replace placeholders
|
||||
content = template
|
||||
content = content.replace(
|
||||
"{{KNOWLEDGE_SOURCES_SECTION}}", knowledge_sources_section
|
||||
)
|
||||
agents_md_path.write_text(content)
|
||||
print(f"Populated knowledge sources in {agents_md_path}")
|
||||
|
||||
# Write AGENTS.md
|
||||
output_path = Path("/workspace/AGENTS.md")
|
||||
output_path.write_text(content)
|
||||
|
||||
# Log result
|
||||
source_count = 0
|
||||
if files_path.exists():
|
||||
source_count = len(
|
||||
[
|
||||
d
|
||||
for d in files_path.iterdir()
|
||||
if d.is_dir() and not d.name.startswith(".")
|
||||
]
|
||||
)
|
||||
print(
|
||||
f"Generated AGENTS.md with {source_count} knowledge sources from {files_path}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1352,9 +1352,6 @@ fi
|
||||
echo "Writing AGENTS.md"
|
||||
printf '%s' '{agent_instructions_escaped}' > {session_path}/AGENTS.md
|
||||
|
||||
# Populate knowledge sources by scanning the files directory
|
||||
python3 /usr/local/bin/generate_agents_md.py {session_path}/AGENTS.md {session_path}/files || true
|
||||
|
||||
# Write opencode config
|
||||
echo "Writing opencode.json"
|
||||
printf '%s' '{opencode_json_escaped}' > {session_path}/opencode.json
|
||||
@@ -1783,9 +1780,6 @@ ln -sf {symlink_target} {session_path}/files
|
||||
echo "Writing AGENTS.md"
|
||||
printf '%s' '{agent_instructions_escaped}' > {session_path}/AGENTS.md
|
||||
|
||||
# Populate knowledge sources by scanning the files directory
|
||||
python3 /usr/local/bin/generate_agents_md.py {session_path}/AGENTS.md {session_path}/files || true
|
||||
|
||||
# Write opencode config
|
||||
echo "Writing opencode.json"
|
||||
printf '%s' '{opencode_json_escaped}' > {session_path}/opencode.json
|
||||
|
||||
@@ -68,7 +68,6 @@ from onyx.server.features.build.db.sandbox import create_sandbox__no_commit
|
||||
from onyx.server.features.build.db.sandbox import get_running_sandbox_count_by_tenant
|
||||
from onyx.server.features.build.db.sandbox import get_sandbox_by_session_id
|
||||
from onyx.server.features.build.db.sandbox import get_sandbox_by_user_id
|
||||
from onyx.server.features.build.db.sandbox import get_snapshots_for_session
|
||||
from onyx.server.features.build.db.sandbox import update_sandbox_heartbeat
|
||||
from onyx.server.features.build.db.sandbox import update_sandbox_status__no_commit
|
||||
from onyx.server.features.build.sandbox import get_sandbox_manager
|
||||
@@ -647,30 +646,16 @@ class SessionManager:
|
||||
|
||||
if sandbox and sandbox.status.is_active():
|
||||
# Quick health check to verify sandbox is actually responsive
|
||||
# AND verify the session workspace still exists on disk
|
||||
# (it may have been wiped if the sandbox was re-provisioned)
|
||||
is_healthy = self._sandbox_manager.health_check(sandbox.id, timeout=5.0)
|
||||
workspace_exists = (
|
||||
is_healthy
|
||||
and self._sandbox_manager.session_workspace_exists(
|
||||
sandbox.id, existing.id
|
||||
)
|
||||
)
|
||||
if is_healthy and workspace_exists:
|
||||
if self._sandbox_manager.health_check(sandbox.id, timeout=5.0):
|
||||
logger.info(
|
||||
f"Returning existing empty session {existing.id} for user {user_id}"
|
||||
)
|
||||
return existing
|
||||
elif not is_healthy:
|
||||
else:
|
||||
logger.warning(
|
||||
f"Empty session {existing.id} has unhealthy sandbox {sandbox.id}. "
|
||||
f"Deleting and creating fresh session."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Empty session {existing.id} workspace missing in sandbox "
|
||||
f"{sandbox.id}. Deleting and creating fresh session."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Empty session {existing.id} has no active sandbox "
|
||||
@@ -1050,23 +1035,6 @@ class SessionManager:
|
||||
# workspace cleanup fails (e.g., if pod is already terminated)
|
||||
logger.warning(f"Failed to cleanup session workspace {session_id}: {e}")
|
||||
|
||||
# Delete snapshot files from S3 before removing DB records
|
||||
snapshots = get_snapshots_for_session(self._db_session, session_id)
|
||||
if snapshots:
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.server.features.build.sandbox.manager.snapshot_manager import (
|
||||
SnapshotManager,
|
||||
)
|
||||
|
||||
snapshot_manager = SnapshotManager(get_default_file_store())
|
||||
for snapshot in snapshots:
|
||||
try:
|
||||
snapshot_manager.delete_snapshot(snapshot.storage_path)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to delete snapshot file {snapshot.storage_path}: {e}"
|
||||
)
|
||||
|
||||
# Delete session (uses flush, caller commits)
|
||||
return delete_build_session__no_commit(session_id, user_id, self._db_session)
|
||||
|
||||
@@ -1935,94 +1903,6 @@ class SessionManager:
|
||||
|
||||
return zip_buffer.getvalue(), filename
|
||||
|
||||
def download_directory(
|
||||
self,
|
||||
session_id: UUID,
|
||||
user_id: UUID,
|
||||
path: str,
|
||||
) -> tuple[bytes, str] | None:
|
||||
"""
|
||||
Create a zip file of an arbitrary directory in the session workspace.
|
||||
|
||||
Args:
|
||||
session_id: The session UUID
|
||||
user_id: The user ID to verify ownership
|
||||
path: Relative path to the directory (within session workspace)
|
||||
|
||||
Returns:
|
||||
Tuple of (zip_bytes, filename) or None if session not found
|
||||
|
||||
Raises:
|
||||
ValueError: If path traversal attempted or path is not a directory
|
||||
"""
|
||||
# Verify session ownership
|
||||
session = get_build_session(session_id, user_id, self._db_session)
|
||||
if session is None:
|
||||
return None
|
||||
|
||||
sandbox = get_sandbox_by_user_id(self._db_session, user_id)
|
||||
if sandbox is None:
|
||||
return None
|
||||
|
||||
# Check if directory exists
|
||||
try:
|
||||
self._sandbox_manager.list_directory(
|
||||
sandbox_id=sandbox.id,
|
||||
session_id=session_id,
|
||||
path=path,
|
||||
)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
# Recursively collect all files
|
||||
def collect_files(dir_path: str) -> list[tuple[str, str]]:
|
||||
"""Collect all files recursively, returning (full_path, arcname) tuples."""
|
||||
files: list[tuple[str, str]] = []
|
||||
try:
|
||||
entries = self._sandbox_manager.list_directory(
|
||||
sandbox_id=sandbox.id,
|
||||
session_id=session_id,
|
||||
path=dir_path,
|
||||
)
|
||||
for entry in entries:
|
||||
if entry.is_directory:
|
||||
files.extend(collect_files(entry.path))
|
||||
else:
|
||||
# arcname is relative to the target directory
|
||||
prefix_len = len(path) + 1 # +1 for trailing slash
|
||||
arcname = entry.path[prefix_len:]
|
||||
files.append((entry.path, arcname))
|
||||
except ValueError:
|
||||
pass
|
||||
return files
|
||||
|
||||
file_list = collect_files(path)
|
||||
|
||||
# Create zip file in memory
|
||||
zip_buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file:
|
||||
for full_path, arcname in file_list:
|
||||
try:
|
||||
content = self._sandbox_manager.read_file(
|
||||
sandbox_id=sandbox.id,
|
||||
session_id=session_id,
|
||||
path=full_path,
|
||||
)
|
||||
zip_file.writestr(arcname, content)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
zip_buffer.seek(0)
|
||||
|
||||
# Use the directory name for the zip filename
|
||||
dir_name = Path(path).name
|
||||
safe_name = "".join(
|
||||
c if c.isalnum() or c in ("-", "_", ".") else "_" for c in dir_name
|
||||
)
|
||||
filename = f"{safe_name}.zip"
|
||||
|
||||
return zip_buffer.getvalue(), filename
|
||||
|
||||
# =========================================================================
|
||||
# File System Operations
|
||||
# =========================================================================
|
||||
@@ -2057,18 +1937,11 @@ class SessionManager:
|
||||
return None
|
||||
|
||||
# Use sandbox manager to list directory (works for both local and K8s)
|
||||
# If the directory doesn't exist (e.g., session workspace not yet loaded),
|
||||
# return an empty listing rather than erroring out.
|
||||
try:
|
||||
raw_entries = self._sandbox_manager.list_directory(
|
||||
sandbox_id=sandbox.id,
|
||||
session_id=session_id,
|
||||
path=path,
|
||||
)
|
||||
except ValueError as e:
|
||||
if "path traversal" in str(e).lower():
|
||||
raise
|
||||
return DirectoryListing(path=path, entries=[])
|
||||
raw_entries = self._sandbox_manager.list_directory(
|
||||
sandbox_id=sandbox.id,
|
||||
session_id=session_id,
|
||||
path=path,
|
||||
)
|
||||
|
||||
# Filter hidden files and directories
|
||||
entries: list[FileSystemEntry] = [
|
||||
|
||||
@@ -12,18 +12,11 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
enqueue_user_file_project_sync_task,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
get_user_file_project_sync_queue_depth,
|
||||
)
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.configs.constants import USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import ChatSession
|
||||
@@ -34,7 +27,6 @@ from onyx.db.models import UserProject
|
||||
from onyx.db.persona import get_personas_by_ids
|
||||
from onyx.db.projects import get_project_token_count
|
||||
from onyx.db.projects import upload_files_to_user_files_with_indexing
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.features.projects.models import CategorizedFilesSnapshot
|
||||
from onyx.server.features.projects.models import ChatSessionRequest
|
||||
from onyx.server.features.projects.models import TokenCountResponse
|
||||
@@ -55,33 +47,6 @@ class UserFileDeleteResult(BaseModel):
|
||||
assistant_names: list[str] = []
|
||||
|
||||
|
||||
def _trigger_user_file_project_sync(user_file_id: UUID, tenant_id: str) -> None:
|
||||
queue_depth = get_user_file_project_sync_queue_depth(client_app)
|
||||
if queue_depth > USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH:
|
||||
logger.warning(
|
||||
f"Skipping immediate project sync for user_file_id={user_file_id} due to "
|
||||
f"queue depth {queue_depth}>{USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH}. "
|
||||
"It will be picked up by beat later."
|
||||
)
|
||||
return
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
enqueued = enqueue_user_file_project_sync_task(
|
||||
celery_app=client_app,
|
||||
redis_client=redis_client,
|
||||
user_file_id=user_file_id,
|
||||
tenant_id=tenant_id,
|
||||
priority=OnyxCeleryPriority.HIGHEST,
|
||||
)
|
||||
if not enqueued:
|
||||
logger.info(
|
||||
f"Skipped duplicate project sync enqueue for user_file_id={user_file_id}"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(f"Triggered project sync for user_file_id={user_file_id}")
|
||||
|
||||
|
||||
@router.get("", tags=PUBLIC_API_TAGS)
|
||||
def get_projects(
|
||||
user: User = Depends(current_user),
|
||||
@@ -224,7 +189,15 @@ def unlink_user_file_from_project(
|
||||
db_session.commit()
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
_trigger_user_file_project_sync(user_file.id, tenant_id)
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
|
||||
priority=OnyxCeleryPriority.HIGHEST,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered project sync for user_file_id={user_file.id} with task_id={task.id}"
|
||||
)
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
@@ -268,7 +241,15 @@ def link_user_file_to_project(
|
||||
db_session.commit()
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
_trigger_user_file_project_sync(user_file.id, tenant_id)
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
|
||||
priority=OnyxCeleryPriority.HIGHEST,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered project sync for user_file_id={user_file.id} with task_id={task.id}"
|
||||
)
|
||||
|
||||
return UserFileSnapshot.from_model(user_file)
|
||||
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.db.code_interpreter import fetch_code_interpreter_server
|
||||
from onyx.db.code_interpreter import update_code_interpreter_server_enabled
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.server.manage.code_interpreter.models import CodeInterpreterServer
|
||||
from onyx.server.manage.code_interpreter.models import CodeInterpreterServerHealth
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
CodeInterpreterClient,
|
||||
)
|
||||
|
||||
admin_router = APIRouter(prefix="/admin/code-interpreter")
|
||||
|
||||
|
||||
@admin_router.get("/health")
|
||||
def get_code_interpreter_health(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> CodeInterpreterServerHealth:
|
||||
try:
|
||||
client = CodeInterpreterClient()
|
||||
return CodeInterpreterServerHealth(healthy=client.health())
|
||||
except ValueError:
|
||||
return CodeInterpreterServerHealth(healthy=False)
|
||||
|
||||
|
||||
@admin_router.get("")
|
||||
def get_code_interpreter(
|
||||
_: User = Depends(current_admin_user), db_session: Session = Depends(get_session)
|
||||
) -> CodeInterpreterServer:
|
||||
ci_server = fetch_code_interpreter_server(db_session)
|
||||
return CodeInterpreterServer(enabled=ci_server.server_enabled)
|
||||
|
||||
|
||||
@admin_router.put("")
|
||||
def update_code_interpreter(
|
||||
update: CodeInterpreterServer,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
update_code_interpreter_server_enabled(
|
||||
db_session=db_session,
|
||||
enabled=update.enabled,
|
||||
)
|
||||
@@ -1,9 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CodeInterpreterServer(BaseModel):
|
||||
enabled: bool
|
||||
|
||||
|
||||
class CodeInterpreterServerHealth(BaseModel):
|
||||
healthy: bool
|
||||
@@ -97,6 +97,7 @@ def _build_llm_provider_request(
|
||||
), # Only this from source
|
||||
api_base=api_base, # From request
|
||||
api_version=api_version, # From request
|
||||
default_model_name=model_name,
|
||||
deployment_name=deployment_name, # From request
|
||||
is_public=True,
|
||||
groups=[],
|
||||
@@ -135,6 +136,7 @@ def _build_llm_provider_request(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
default_model_name=model_name,
|
||||
deployment_name=deployment_name,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
@@ -166,6 +168,7 @@ def _create_image_gen_llm_provider__no_commit(
|
||||
api_key=provider_request.api_key,
|
||||
api_base=provider_request.api_base,
|
||||
api_version=provider_request.api_version,
|
||||
default_model_name=provider_request.default_model_name,
|
||||
deployment_name=provider_request.deployment_name,
|
||||
is_public=provider_request.is_public,
|
||||
custom_config=provider_request.custom_config,
|
||||
|
||||
@@ -22,10 +22,7 @@ from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.llm import can_user_access_llm_provider
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.db.llm import fetch_default_vision_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import fetch_existing_llm_provider_by_id
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_existing_models
|
||||
from onyx.db.llm import fetch_persona_with_groups
|
||||
@@ -55,12 +52,11 @@ from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
)
|
||||
from onyx.server.manage.llm.models import BedrockFinalModelResponse
|
||||
from onyx.server.manage.llm.models import BedrockModelsRequest
|
||||
from onyx.server.manage.llm.models import DefaultModel
|
||||
from onyx.server.manage.llm.models import LLMCost
|
||||
from onyx.server.manage.llm.models import LLMProviderDescriptor
|
||||
from onyx.server.manage.llm.models import LLMProviderResponse
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.manage.llm.models import OllamaFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OllamaModelDetails
|
||||
from onyx.server.manage.llm.models import OllamaModelsRequest
|
||||
@@ -237,12 +233,12 @@ def test_llm_configuration(
|
||||
|
||||
test_api_key = test_llm_request.api_key
|
||||
test_custom_config = test_llm_request.custom_config
|
||||
if test_llm_request.id:
|
||||
if test_llm_request.name:
|
||||
# NOTE: we are querying by name. we probably should be querying by an invariant id, but
|
||||
# as it turns out the name is not editable in the UI and other code also keys off name,
|
||||
# so we won't rock the boat just yet.
|
||||
existing_provider = fetch_existing_llm_provider_by_id(
|
||||
id=test_llm_request.id, db_session=db_session
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=test_llm_request.name, db_session=db_session
|
||||
)
|
||||
if existing_provider:
|
||||
test_custom_config = _restore_masked_custom_config_values(
|
||||
@@ -272,7 +268,7 @@ def test_llm_configuration(
|
||||
|
||||
llm = get_llm(
|
||||
provider=test_llm_request.provider,
|
||||
model=test_llm_request.model,
|
||||
model=test_llm_request.default_model_name,
|
||||
api_key=test_api_key,
|
||||
api_base=test_llm_request.api_base,
|
||||
api_version=test_llm_request.api_version,
|
||||
@@ -307,7 +303,7 @@ def list_llm_providers(
|
||||
include_image_gen: bool = Query(False),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LLMProviderResponse[LLMProviderView]:
|
||||
) -> list[LLMProviderView]:
|
||||
start_time = datetime.now(timezone.utc)
|
||||
logger.debug("Starting to fetch LLM providers")
|
||||
|
||||
@@ -332,25 +328,7 @@ def list_llm_providers(
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
logger.debug(f"Completed fetching LLM providers in {duration:.2f} seconds")
|
||||
|
||||
default_model = None
|
||||
if model_config := fetch_default_llm_model(db_session):
|
||||
default_model = DefaultModel(
|
||||
provider_id=model_config.llm_provider.id,
|
||||
model_name=model_config.name,
|
||||
)
|
||||
|
||||
default_vision_model = None
|
||||
if model_config := fetch_default_vision_model(db_session):
|
||||
default_vision_model = DefaultModel(
|
||||
provider_id=model_config.llm_provider.id,
|
||||
model_name=model_config.name,
|
||||
)
|
||||
|
||||
return LLMProviderResponse[LLMProviderView].from_models(
|
||||
providers=llm_provider_list,
|
||||
default_text=default_model,
|
||||
default_vision=default_vision_model,
|
||||
)
|
||||
return llm_provider_list
|
||||
|
||||
|
||||
@admin_router.put("/provider")
|
||||
@@ -363,29 +341,21 @@ def put_llm_provider(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LLMProviderView:
|
||||
# NOTE: Name updating functionality currently not supported. There are many places that still
|
||||
# rely on immutable names, so this will be a larger change
|
||||
|
||||
# validate request (e.g. if we're intending to create but the name already exists we should throw an error)
|
||||
# NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache
|
||||
# the result
|
||||
existing_provider = None
|
||||
if llm_provider_upsert_request.id:
|
||||
existing_provider = fetch_existing_llm_provider_by_id(
|
||||
id=llm_provider_upsert_request.id, db_session=db_session
|
||||
)
|
||||
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
)
|
||||
if existing_provider and is_creation:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"LLM Provider with name {llm_provider_upsert_request.name} and \
|
||||
id={llm_provider_upsert_request.id} already exists",
|
||||
detail=f"LLM Provider with name {llm_provider_upsert_request.name} already exists",
|
||||
)
|
||||
elif not existing_provider and not is_creation:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"LLM Provider with name {llm_provider_upsert_request.name} and \
|
||||
id={llm_provider_upsert_request.id} does not exist",
|
||||
detail=f"LLM Provider with name {llm_provider_upsert_request.name} does not exist",
|
||||
)
|
||||
|
||||
# SSRF Protection: Validate api_base and custom_config match stored values
|
||||
@@ -423,6 +393,22 @@ def put_llm_provider(
|
||||
deduplicated_personas.append(persona_id)
|
||||
llm_provider_upsert_request.personas = deduplicated_personas
|
||||
|
||||
default_model_found = False
|
||||
|
||||
for model_configuration in llm_provider_upsert_request.model_configurations:
|
||||
if model_configuration.name == llm_provider_upsert_request.default_model_name:
|
||||
model_configuration.is_visible = True
|
||||
default_model_found = True
|
||||
|
||||
# TODO: Remove this logic on api change
|
||||
# Believed to be a dead pathway but we want to be safe for now
|
||||
if not default_model_found:
|
||||
llm_provider_upsert_request.model_configurations.append(
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=llm_provider_upsert_request.default_model_name, is_visible=True
|
||||
)
|
||||
)
|
||||
|
||||
# the llm api key is sanitized when returned to clients, so the only time we
|
||||
# should get a real key is when it is explicitly changed
|
||||
if existing_provider and not llm_provider_upsert_request.api_key_changed:
|
||||
@@ -452,8 +438,8 @@ def put_llm_provider(
|
||||
config = fetch_llm_recommendations_from_github()
|
||||
if config and llm_provider_upsert_request.provider in config.providers:
|
||||
# Refetch the provider to get the updated model
|
||||
updated_provider = fetch_existing_llm_provider_by_id(
|
||||
id=result.id, db_session=db_session
|
||||
updated_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
)
|
||||
if updated_provider:
|
||||
sync_auto_mode_models(
|
||||
@@ -483,29 +469,28 @@ def delete_llm_provider(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.post("/default")
|
||||
@admin_router.post("/provider/{provider_id}/default")
|
||||
def set_provider_as_default(
|
||||
default_model_request: DefaultModel,
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
update_default_provider(
|
||||
provider_id=default_model_request.provider_id,
|
||||
model_name=default_model_request.model_name,
|
||||
db_session=db_session,
|
||||
)
|
||||
update_default_provider(provider_id=provider_id, db_session=db_session)
|
||||
|
||||
|
||||
@admin_router.post("/default-vision")
|
||||
@admin_router.post("/provider/{provider_id}/default-vision")
|
||||
def set_provider_as_default_vision(
|
||||
default_model_request: DefaultModel,
|
||||
provider_id: int,
|
||||
vision_model: str | None = Query(
|
||||
None, description="The default vision model to use"
|
||||
),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
if vision_model is None:
|
||||
raise HTTPException(status_code=404, detail="Vision model not provided")
|
||||
update_default_vision_provider(
|
||||
provider_id=default_model_request.provider_id,
|
||||
vision_model=default_model_request.model_name,
|
||||
db_session=db_session,
|
||||
provider_id=provider_id, vision_model=vision_model, db_session=db_session
|
||||
)
|
||||
|
||||
|
||||
@@ -531,7 +516,7 @@ def get_auto_config(
|
||||
def get_vision_capable_providers(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LLMProviderResponse[VisionProviderResponse]:
|
||||
) -> list[VisionProviderResponse]:
|
||||
"""Return a list of LLM providers and their models that support image input"""
|
||||
vision_models = fetch_existing_models(
|
||||
db_session=db_session, flow_types=[LLMModelFlowType.VISION]
|
||||
@@ -560,18 +545,7 @@ def get_vision_capable_providers(
|
||||
]
|
||||
|
||||
logger.debug(f"Found {len(vision_provider_response)} vision-capable providers")
|
||||
|
||||
default_vision_model = None
|
||||
if model_config := fetch_default_vision_model(db_session):
|
||||
default_vision_model = DefaultModel(
|
||||
provider_id=model_config.llm_provider.id,
|
||||
model_name=model_config.name,
|
||||
)
|
||||
|
||||
return LLMProviderResponse[VisionProviderResponse].from_models(
|
||||
providers=vision_provider_response,
|
||||
default_vision=default_vision_model,
|
||||
)
|
||||
return vision_provider_response
|
||||
|
||||
|
||||
"""Endpoints for all"""
|
||||
@@ -581,7 +555,7 @@ def get_vision_capable_providers(
|
||||
def list_llm_provider_basics(
|
||||
user: User = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LLMProviderResponse[LLMProviderDescriptor]:
|
||||
) -> list[LLMProviderDescriptor]:
|
||||
"""Get LLM providers accessible to the current user.
|
||||
|
||||
Returns:
|
||||
@@ -618,25 +592,7 @@ def list_llm_provider_basics(
|
||||
f"Completed fetching {len(accessible_providers)} user-accessible providers in {duration:.2f} seconds"
|
||||
)
|
||||
|
||||
default_model = None
|
||||
if model_config := fetch_default_llm_model(db_session):
|
||||
default_model = DefaultModel(
|
||||
provider_id=model_config.llm_provider.id,
|
||||
model_name=model_config.name,
|
||||
)
|
||||
|
||||
default_vision_model = None
|
||||
if model_config := fetch_default_vision_model(db_session):
|
||||
default_vision_model = DefaultModel(
|
||||
provider_id=model_config.llm_provider.id,
|
||||
model_name=model_config.name,
|
||||
)
|
||||
|
||||
return LLMProviderResponse[LLMProviderDescriptor].from_models(
|
||||
providers=accessible_providers,
|
||||
default_text=default_model,
|
||||
default_vision=default_vision_model,
|
||||
)
|
||||
return accessible_providers
|
||||
|
||||
|
||||
def get_valid_model_names_for_persona(
|
||||
@@ -679,7 +635,7 @@ def list_llm_providers_for_persona(
|
||||
persona_id: int,
|
||||
user: User = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LLMProviderResponse[LLMProviderDescriptor]:
|
||||
) -> list[LLMProviderDescriptor]:
|
||||
"""Get LLM providers for a specific persona.
|
||||
|
||||
Returns providers that the user can access when using this persona:
|
||||
@@ -726,63 +682,7 @@ def list_llm_providers_for_persona(
|
||||
f"Completed fetching {len(llm_provider_list)} LLM providers for persona {persona_id} in {duration:.2f} seconds"
|
||||
)
|
||||
|
||||
# Get the default model and vision model for the persona
|
||||
# NOTE: This should be ported over to use id as it is blocking on name mutability
|
||||
persona_default_provider = persona.llm_model_provider_override
|
||||
persona_default_model = persona.llm_model_version_override
|
||||
|
||||
default_text_model = fetch_default_llm_model(db_session)
|
||||
default_vision_model = fetch_default_vision_model(db_session)
|
||||
|
||||
# Build default_text and default_vision using persona overrides when available,
|
||||
# falling back to the global defaults.
|
||||
default_text: DefaultModel | None = (
|
||||
DefaultModel(
|
||||
provider_id=default_text_model.llm_provider.id,
|
||||
model_name=default_text_model.name,
|
||||
)
|
||||
if default_text_model
|
||||
else None
|
||||
)
|
||||
default_vision: DefaultModel | None = (
|
||||
DefaultModel(
|
||||
provider_id=default_vision_model.llm_provider.id,
|
||||
model_name=default_vision_model.name,
|
||||
)
|
||||
if default_vision_model
|
||||
else None
|
||||
)
|
||||
|
||||
if persona_default_provider:
|
||||
provider = fetch_existing_llm_provider(persona_default_provider, db_session)
|
||||
if provider:
|
||||
if persona_default_model:
|
||||
# Persona specifies both provider and model — use them directly
|
||||
default_text = DefaultModel(
|
||||
provider_id=provider.id,
|
||||
model_name=persona_default_model,
|
||||
)
|
||||
else:
|
||||
# Persona specifies only the provider — pick a visible (public) model,
|
||||
# falling back to any model on this provider
|
||||
visible_model = next(
|
||||
(mc for mc in provider.model_configurations if mc.is_visible),
|
||||
None,
|
||||
)
|
||||
fallback_model = visible_model or next(
|
||||
iter(provider.model_configurations), None
|
||||
)
|
||||
if fallback_model:
|
||||
default_text = DefaultModel(
|
||||
provider_id=provider.id,
|
||||
model_name=fallback_model.name,
|
||||
)
|
||||
|
||||
return LLMProviderResponse[LLMProviderDescriptor].from_models(
|
||||
providers=llm_provider_list,
|
||||
default_text=default_text,
|
||||
default_vision=default_vision,
|
||||
)
|
||||
return llm_provider_list
|
||||
|
||||
|
||||
@admin_router.get("/provider-contextual-cost")
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
@@ -23,8 +21,6 @@ if TYPE_CHECKING:
|
||||
ModelConfiguration as ModelConfigurationModel,
|
||||
)
|
||||
|
||||
T = TypeVar("T", bound="LLMProviderDescriptor | LLMProviderView")
|
||||
|
||||
|
||||
# TODO: Clear this up on api refactor
|
||||
# There is still logic that requires sending each providers default model name
|
||||
@@ -56,18 +52,19 @@ def get_default_vision_model_name(llm_provider_model: "LLMProviderModel") -> str
|
||||
|
||||
class TestLLMRequest(BaseModel):
|
||||
# provider level
|
||||
id: int | None = None
|
||||
name: str | None = None
|
||||
provider: str
|
||||
model: str
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
custom_config: dict[str, str] | None = None
|
||||
|
||||
# model level
|
||||
default_model_name: str
|
||||
deployment_name: str | None = None
|
||||
|
||||
model_configurations: list["ModelConfigurationUpsertRequest"]
|
||||
|
||||
# if try and use the existing API/custom config key
|
||||
api_key_changed: bool
|
||||
custom_config_changed: bool
|
||||
@@ -83,10 +80,13 @@ class LLMProviderDescriptor(BaseModel):
|
||||
"""A descriptor for an LLM provider that can be safely viewed by
|
||||
non-admin users. Used when giving a list of available LLMs."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
provider: str
|
||||
provider_display_name: str # Human-friendly name like "Claude (Anthropic)"
|
||||
default_model_name: str
|
||||
is_default_provider: bool | None
|
||||
is_default_vision_provider: bool | None
|
||||
default_vision_model: str | None
|
||||
model_configurations: list["ModelConfigurationView"]
|
||||
|
||||
@classmethod
|
||||
@@ -99,12 +99,22 @@ class LLMProviderDescriptor(BaseModel):
|
||||
)
|
||||
|
||||
provider = llm_provider_model.provider
|
||||
default_model_name = get_default_llm_model_name(llm_provider_model)
|
||||
default_vision_model = get_default_vision_model_name(llm_provider_model)
|
||||
|
||||
is_default_provider = bool(default_model_name)
|
||||
is_default_vision_provider = default_vision_model is not None
|
||||
|
||||
default_model_name = default_model_name or llm_provider_model.default_model_name
|
||||
|
||||
return cls(
|
||||
id=llm_provider_model.id,
|
||||
name=llm_provider_model.name,
|
||||
provider=provider,
|
||||
provider_display_name=get_provider_display_name(provider),
|
||||
default_model_name=default_model_name,
|
||||
is_default_provider=is_default_provider,
|
||||
is_default_vision_provider=is_default_vision_provider,
|
||||
default_vision_model=default_vision_model,
|
||||
model_configurations=filter_model_configurations(
|
||||
llm_provider_model.model_configurations, provider
|
||||
),
|
||||
@@ -118,17 +128,18 @@ class LLMProvider(BaseModel):
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
custom_config: dict[str, str] | None = None
|
||||
default_model_name: str
|
||||
is_public: bool = True
|
||||
is_auto_mode: bool = False
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
personas: list[int] = Field(default_factory=list)
|
||||
deployment_name: str | None = None
|
||||
default_vision_model: str | None = None
|
||||
|
||||
|
||||
class LLMProviderUpsertRequest(LLMProvider):
|
||||
# should only be used for a "custom" provider
|
||||
# for default providers, the built-in model names are used
|
||||
id: int | None = None
|
||||
api_key_changed: bool = False
|
||||
custom_config_changed: bool = False
|
||||
model_configurations: list["ModelConfigurationUpsertRequest"] = []
|
||||
@@ -144,6 +155,8 @@ class LLMProviderView(LLMProvider):
|
||||
"""Stripped down representation of LLMProvider for display / limited access info only"""
|
||||
|
||||
id: int
|
||||
is_default_provider: bool | None = None
|
||||
is_default_vision_provider: bool | None = None
|
||||
model_configurations: list["ModelConfigurationView"]
|
||||
|
||||
@classmethod
|
||||
@@ -165,6 +178,14 @@ class LLMProviderView(LLMProvider):
|
||||
|
||||
provider = llm_provider_model.provider
|
||||
|
||||
default_model_name = get_default_llm_model_name(llm_provider_model)
|
||||
default_vision_model = get_default_vision_model_name(llm_provider_model)
|
||||
|
||||
is_default_provider = bool(default_model_name)
|
||||
is_default_vision_provider = default_vision_model is not None
|
||||
|
||||
default_model_name = default_model_name or llm_provider_model.default_model_name
|
||||
|
||||
return cls(
|
||||
id=llm_provider_model.id,
|
||||
name=llm_provider_model.name,
|
||||
@@ -177,6 +198,10 @@ class LLMProviderView(LLMProvider):
|
||||
api_base=llm_provider_model.api_base,
|
||||
api_version=llm_provider_model.api_version,
|
||||
custom_config=llm_provider_model.custom_config,
|
||||
default_model_name=default_model_name,
|
||||
is_default_provider=is_default_provider,
|
||||
is_default_vision_provider=is_default_vision_provider,
|
||||
default_vision_model=default_vision_model,
|
||||
is_public=llm_provider_model.is_public,
|
||||
is_auto_mode=llm_provider_model.is_auto_mode,
|
||||
groups=groups,
|
||||
@@ -203,8 +228,7 @@ class ModelConfigurationUpsertRequest(BaseModel):
|
||||
name=model_configuration_model.name,
|
||||
is_visible=model_configuration_model.is_visible,
|
||||
max_input_tokens=model_configuration_model.max_input_tokens,
|
||||
supports_image_input=LLMModelFlowType.VISION
|
||||
in model_configuration_model.llm_model_flow_types,
|
||||
supports_image_input=model_configuration_model.supports_image_input,
|
||||
display_name=model_configuration_model.display_name,
|
||||
)
|
||||
|
||||
@@ -397,27 +421,3 @@ class OpenRouterFinalModelResponse(BaseModel):
|
||||
int | None
|
||||
) # From OpenRouter API context_length (may be missing for some models)
|
||||
supports_image_input: bool
|
||||
|
||||
|
||||
class DefaultModel(BaseModel):
|
||||
provider_id: int
|
||||
model_name: str
|
||||
|
||||
|
||||
class LLMProviderResponse(BaseModel, Generic[T]):
|
||||
providers: list[T]
|
||||
default_text: DefaultModel | None = None
|
||||
default_vision: DefaultModel | None = None
|
||||
|
||||
@classmethod
|
||||
def from_models(
|
||||
cls,
|
||||
providers: list[T],
|
||||
default_text: DefaultModel | None = None,
|
||||
default_vision: DefaultModel | None = None,
|
||||
) -> "LLMProviderResponse[T]":
|
||||
return cls(
|
||||
providers=providers,
|
||||
default_text=default_text,
|
||||
default_vision=default_vision,
|
||||
)
|
||||
|
||||
@@ -35,18 +35,6 @@ if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class EmailInviteStatus(str, Enum):
|
||||
SENT = "SENT"
|
||||
NOT_CONFIGURED = "NOT_CONFIGURED"
|
||||
SEND_FAILED = "SEND_FAILED"
|
||||
DISABLED = "DISABLED"
|
||||
|
||||
|
||||
class BulkInviteResponse(BaseModel):
|
||||
invited_count: int
|
||||
email_invite_status: EmailInviteStatus
|
||||
|
||||
|
||||
class VersionResponse(BaseModel):
|
||||
backend_version: str
|
||||
|
||||
|
||||
@@ -36,7 +36,6 @@ from onyx.configs.app_configs import AUTH_BACKEND
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import AuthBackend
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import ENABLE_EMAIL_INVITES
|
||||
from onyx.configs.app_configs import NUM_FREE_TRIAL_USER_INVITES
|
||||
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
|
||||
@@ -79,10 +78,8 @@ from onyx.server.documents.models import PaginatedReturn
|
||||
from onyx.server.features.projects.models import UserFileSnapshot
|
||||
from onyx.server.manage.models import AllUsersResponse
|
||||
from onyx.server.manage.models import AutoScrollRequest
|
||||
from onyx.server.manage.models import BulkInviteResponse
|
||||
from onyx.server.manage.models import ChatBackgroundRequest
|
||||
from onyx.server.manage.models import DefaultAppModeRequest
|
||||
from onyx.server.manage.models import EmailInviteStatus
|
||||
from onyx.server.manage.models import MemoryItem
|
||||
from onyx.server.manage.models import PersonalizationUpdateRequest
|
||||
from onyx.server.manage.models import TenantInfo
|
||||
@@ -371,7 +368,7 @@ def bulk_invite_users(
|
||||
emails: list[str] = Body(..., embed=True),
|
||||
current_user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> BulkInviteResponse:
|
||||
) -> int:
|
||||
"""emails are string validated. If any email fails validation, no emails are
|
||||
invited and an exception is raised."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
@@ -430,41 +427,34 @@ def bulk_invite_users(
|
||||
number_of_invited_users = write_invited_users(all_emails)
|
||||
|
||||
# send out email invitations only to new users (not already invited or existing)
|
||||
if not ENABLE_EMAIL_INVITES:
|
||||
email_invite_status = EmailInviteStatus.DISABLED
|
||||
elif not EMAIL_CONFIGURED:
|
||||
email_invite_status = EmailInviteStatus.NOT_CONFIGURED
|
||||
else:
|
||||
if ENABLE_EMAIL_INVITES:
|
||||
try:
|
||||
for email in emails_needing_seats:
|
||||
send_user_email_invite(email, current_user, AUTH_TYPE)
|
||||
email_invite_status = EmailInviteStatus.SENT
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending email invite to invited users: {e}")
|
||||
email_invite_status = EmailInviteStatus.SEND_FAILED
|
||||
|
||||
if MULTI_TENANT and not DEV_MODE:
|
||||
# for billing purposes, write to the control plane about the number of new users
|
||||
try:
|
||||
logger.info("Registering tenant users")
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.billing", "register_tenant_users", None
|
||||
)(tenant_id, get_live_users_count(db_session))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register tenant users: {str(e)}")
|
||||
logger.info(
|
||||
"Reverting changes: removing users from tenant and resetting invited users"
|
||||
)
|
||||
write_invited_users(initial_invited_users) # Reset to original state
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
|
||||
)(new_invited_emails, tenant_id)
|
||||
raise e
|
||||
if not MULTI_TENANT or DEV_MODE:
|
||||
return number_of_invited_users
|
||||
|
||||
return BulkInviteResponse(
|
||||
invited_count=number_of_invited_users,
|
||||
email_invite_status=email_invite_status,
|
||||
)
|
||||
# for billing purposes, write to the control plane about the number of new users
|
||||
try:
|
||||
logger.info("Registering tenant users")
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.billing", "register_tenant_users", None
|
||||
)(tenant_id, get_live_users_count(db_session))
|
||||
|
||||
return number_of_invited_users
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register tenant users: {str(e)}")
|
||||
logger.info(
|
||||
"Reverting changes: removing users from tenant and resetting invited users"
|
||||
)
|
||||
write_invited_users(initial_invited_users) # Reset to original state
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
|
||||
)(new_invited_emails, tenant_id)
|
||||
raise e
|
||||
|
||||
|
||||
@router.patch("/manage/admin/remove-invited-user", tags=PUBLIC_API_TAGS)
|
||||
|
||||
@@ -587,7 +587,6 @@ def handle_send_chat_message(
|
||||
request.headers
|
||||
),
|
||||
mcp_headers=chat_message_req.mcp_headers,
|
||||
additional_context=chat_message_req.additional_context,
|
||||
external_state_container=state_container,
|
||||
)
|
||||
result = gather_stream_full(packets, state_container)
|
||||
@@ -610,7 +609,6 @@ def handle_send_chat_message(
|
||||
request.headers
|
||||
),
|
||||
mcp_headers=chat_message_req.mcp_headers,
|
||||
additional_context=chat_message_req.additional_context,
|
||||
external_state_container=state_container,
|
||||
):
|
||||
yield get_json_line(obj.model_dump())
|
||||
|
||||
@@ -125,11 +125,6 @@ class SendMessageRequest(BaseModel):
|
||||
# - No CitationInfo packets are emitted during streaming
|
||||
include_citations: bool = True
|
||||
|
||||
# Additional context injected into the LLM call but NOT stored in the DB
|
||||
# (not shown in chat history). Used e.g. by the Chrome extension to pass
|
||||
# the current tab URL when "Read this tab" is enabled.
|
||||
additional_context: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_chat_session_id_or_info(self) -> "SendMessageRequest":
|
||||
# If neither is provided, default to creating a new chat session using the
|
||||
|
||||
@@ -245,11 +245,7 @@ def setup_postgres(db_session: Session) -> None:
|
||||
create_initial_default_connector(db_session)
|
||||
associate_default_cc_pair(db_session)
|
||||
|
||||
if (
|
||||
GEN_AI_API_KEY
|
||||
and fetch_default_llm_model(db_session) is None
|
||||
and not INTEGRATION_TESTS_MODE
|
||||
):
|
||||
if GEN_AI_API_KEY and fetch_default_llm_model(db_session) is None:
|
||||
# Only for dev flows
|
||||
logger.notice("Setting up default OpenAI LLM for dev.")
|
||||
|
||||
@@ -261,6 +257,7 @@ def setup_postgres(db_session: Session) -> None:
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
custom_config=None,
|
||||
default_model_name=llm_model,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
model_configurations=[
|
||||
@@ -272,9 +269,7 @@ def setup_postgres(db_session: Session) -> None:
|
||||
new_llm_provider = upsert_llm_provider(
|
||||
llm_provider_upsert_request=model_req, db_session=db_session
|
||||
)
|
||||
update_default_provider(
|
||||
provider_id=new_llm_provider.id, model_name=llm_model, db_session=db_session
|
||||
)
|
||||
update_default_provider(provider_id=new_llm_provider.id, db_session=db_session)
|
||||
|
||||
|
||||
def update_default_multipass_indexing(db_session: Session) -> None:
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Literal
|
||||
from typing import TypedDict
|
||||
from typing import Union
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
@@ -39,39 +36,6 @@ class ExecuteResponse(BaseModel):
|
||||
files: list[WorkspaceFile]
|
||||
|
||||
|
||||
class StreamOutputEvent(BaseModel):
|
||||
"""SSE 'output' event: a chunk of stdout or stderr"""
|
||||
|
||||
stream: Literal["stdout", "stderr"]
|
||||
data: str
|
||||
|
||||
|
||||
class StreamResultEvent(BaseModel):
|
||||
"""SSE 'result' event: final execution result"""
|
||||
|
||||
exit_code: int | None
|
||||
timed_out: bool
|
||||
duration_ms: int
|
||||
files: list[WorkspaceFile]
|
||||
|
||||
|
||||
class StreamErrorEvent(BaseModel):
|
||||
"""SSE 'error' event: execution-level error"""
|
||||
|
||||
message: str
|
||||
|
||||
|
||||
StreamEvent = Union[StreamOutputEvent, StreamResultEvent, StreamErrorEvent]
|
||||
|
||||
_SSE_EVENT_MAP: dict[
|
||||
str, type[StreamOutputEvent | StreamResultEvent | StreamErrorEvent]
|
||||
] = {
|
||||
"output": StreamOutputEvent,
|
||||
"result": StreamResultEvent,
|
||||
"error": StreamErrorEvent,
|
||||
}
|
||||
|
||||
|
||||
class CodeInterpreterClient:
|
||||
"""Client for Code Interpreter service"""
|
||||
|
||||
@@ -81,34 +45,6 @@ class CodeInterpreterClient:
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.session = requests.Session()
|
||||
|
||||
def _build_payload(
|
||||
self,
|
||||
code: str,
|
||||
stdin: str | None,
|
||||
timeout_ms: int,
|
||||
files: list[FileInput] | None,
|
||||
) -> dict:
|
||||
payload: dict = {
|
||||
"code": code,
|
||||
"timeout_ms": timeout_ms,
|
||||
}
|
||||
if stdin is not None:
|
||||
payload["stdin"] = stdin
|
||||
if files:
|
||||
payload["files"] = files
|
||||
return payload
|
||||
|
||||
def health(self) -> bool:
|
||||
"""Check if the Code Interpreter service is healthy"""
|
||||
url = f"{self.base_url}/health"
|
||||
try:
|
||||
response = self.session.get(url, timeout=5)
|
||||
response.raise_for_status()
|
||||
return response.json().get("status") == "ok"
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception caught when checking health, e={e}")
|
||||
return False
|
||||
|
||||
def execute(
|
||||
self,
|
||||
code: str,
|
||||
@@ -116,110 +52,25 @@ class CodeInterpreterClient:
|
||||
timeout_ms: int = 30000,
|
||||
files: list[FileInput] | None = None,
|
||||
) -> ExecuteResponse:
|
||||
"""Execute Python code (batch)"""
|
||||
"""Execute Python code"""
|
||||
url = f"{self.base_url}/v1/execute"
|
||||
payload = self._build_payload(code, stdin, timeout_ms, files)
|
||||
|
||||
payload = {
|
||||
"code": code,
|
||||
"timeout_ms": timeout_ms,
|
||||
}
|
||||
|
||||
if stdin is not None:
|
||||
payload["stdin"] = stdin
|
||||
|
||||
if files:
|
||||
payload["files"] = files
|
||||
|
||||
response = self.session.post(url, json=payload, timeout=timeout_ms / 1000 + 10)
|
||||
response.raise_for_status()
|
||||
|
||||
return ExecuteResponse(**response.json())
|
||||
|
||||
def execute_streaming(
|
||||
self,
|
||||
code: str,
|
||||
stdin: str | None = None,
|
||||
timeout_ms: int = 30000,
|
||||
files: list[FileInput] | None = None,
|
||||
) -> Generator[StreamEvent, None, None]:
|
||||
"""Execute Python code with streaming SSE output.
|
||||
|
||||
Yields StreamEvent objects (StreamOutputEvent, StreamResultEvent,
|
||||
StreamErrorEvent) as execution progresses. Falls back to batch
|
||||
execution if the streaming endpoint is not available (older
|
||||
code-interpreter versions).
|
||||
"""
|
||||
url = f"{self.base_url}/v1/execute/stream"
|
||||
payload = self._build_payload(code, stdin, timeout_ms, files)
|
||||
|
||||
response = self.session.post(
|
||||
url,
|
||||
json=payload,
|
||||
stream=True,
|
||||
timeout=timeout_ms / 1000 + 10,
|
||||
)
|
||||
|
||||
if response.status_code == 404:
|
||||
logger.info(
|
||||
"Streaming endpoint not available, " "falling back to batch execution"
|
||||
)
|
||||
response.close()
|
||||
yield from self._batch_as_stream(code, stdin, timeout_ms, files)
|
||||
return
|
||||
|
||||
response.raise_for_status()
|
||||
yield from self._parse_sse(response)
|
||||
|
||||
def _parse_sse(
|
||||
self, response: requests.Response
|
||||
) -> Generator[StreamEvent, None, None]:
|
||||
"""Parse SSE streaming response into StreamEvent objects.
|
||||
|
||||
Expected format per event:
|
||||
event: <type>
|
||||
data: <json>
|
||||
<blank line>
|
||||
"""
|
||||
event_type: str | None = None
|
||||
data_lines: list[str] = []
|
||||
|
||||
for line in response.iter_lines(decode_unicode=True):
|
||||
if line is None:
|
||||
continue
|
||||
|
||||
if line == "":
|
||||
# Blank line marks end of an SSE event
|
||||
if event_type is not None and data_lines:
|
||||
data = "\n".join(data_lines)
|
||||
model_cls = _SSE_EVENT_MAP.get(event_type)
|
||||
if model_cls is not None:
|
||||
yield model_cls(**json.loads(data))
|
||||
else:
|
||||
logger.warning(f"Unknown SSE event type: {event_type}")
|
||||
event_type = None
|
||||
data_lines = []
|
||||
elif line.startswith("event:"):
|
||||
event_type = line[len("event:") :].strip()
|
||||
elif line.startswith("data:"):
|
||||
data_lines.append(line[len("data:") :].strip())
|
||||
|
||||
if event_type is not None or data_lines:
|
||||
logger.warning(
|
||||
f"SSE stream ended with incomplete event: "
|
||||
f"event_type={event_type}, data_lines={data_lines}"
|
||||
)
|
||||
|
||||
def _batch_as_stream(
|
||||
self,
|
||||
code: str,
|
||||
stdin: str | None,
|
||||
timeout_ms: int,
|
||||
files: list[FileInput] | None,
|
||||
) -> Generator[StreamEvent, None, None]:
|
||||
"""Execute via batch endpoint and yield results as stream events."""
|
||||
result = self.execute(code, stdin, timeout_ms, files)
|
||||
|
||||
if result.stdout:
|
||||
yield StreamOutputEvent(stream="stdout", data=result.stdout)
|
||||
if result.stderr:
|
||||
yield StreamOutputEvent(stream="stderr", data=result.stderr)
|
||||
yield StreamResultEvent(
|
||||
exit_code=result.exit_code,
|
||||
timed_out=result.timed_out,
|
||||
duration_ms=result.duration_ms,
|
||||
files=result.files,
|
||||
)
|
||||
|
||||
def upload_file(self, file_content: bytes, filename: str) -> str:
|
||||
"""Upload file to Code Interpreter and return file_id"""
|
||||
url = f"{self.base_url}/v1/files"
|
||||
|
||||
@@ -12,7 +12,6 @@ from onyx.configs.app_configs import CODE_INTERPRETER_BASE_URL
|
||||
from onyx.configs.app_configs import CODE_INTERPRETER_DEFAULT_TIMEOUT_MS
|
||||
from onyx.configs.app_configs import CODE_INTERPRETER_MAX_OUTPUT_LENGTH
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.db.code_interpreter import fetch_code_interpreter_server
|
||||
from onyx.file_store.utils import build_full_frontend_file_url
|
||||
from onyx.file_store.utils import get_default_file_store
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
@@ -29,15 +28,6 @@ from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
CodeInterpreterClient,
|
||||
)
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import FileInput
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
StreamErrorEvent,
|
||||
)
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
StreamOutputEvent,
|
||||
)
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
StreamResultEvent,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -104,10 +94,8 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
@override
|
||||
@classmethod
|
||||
def is_available(cls, db_session: Session) -> bool:
|
||||
if not CODE_INTERPRETER_BASE_URL:
|
||||
return False
|
||||
server = fetch_code_interpreter_server(db_session)
|
||||
return server.server_enabled
|
||||
is_available = bool(CODE_INTERPRETER_BASE_URL)
|
||||
return is_available
|
||||
|
||||
def tool_definition(self) -> dict:
|
||||
return {
|
||||
@@ -193,50 +181,19 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
try:
|
||||
logger.debug(f"Executing code: {code}")
|
||||
|
||||
# Execute code with streaming (falls back to batch if unavailable)
|
||||
stdout_parts: list[str] = []
|
||||
stderr_parts: list[str] = []
|
||||
result_event: StreamResultEvent | None = None
|
||||
|
||||
for event in client.execute_streaming(
|
||||
# Execute code with timeout
|
||||
response = client.execute(
|
||||
code=code,
|
||||
timeout_ms=CODE_INTERPRETER_DEFAULT_TIMEOUT_MS,
|
||||
files=files_to_stage or None,
|
||||
):
|
||||
if isinstance(event, StreamOutputEvent):
|
||||
if event.stream == "stdout":
|
||||
stdout_parts.append(event.data)
|
||||
else:
|
||||
stderr_parts.append(event.data)
|
||||
# Emit incremental delta to frontend
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
obj=PythonToolDelta(
|
||||
stdout=event.data if event.stream == "stdout" else "",
|
||||
stderr=event.data if event.stream == "stderr" else "",
|
||||
),
|
||||
)
|
||||
)
|
||||
elif isinstance(event, StreamResultEvent):
|
||||
result_event = event
|
||||
elif isinstance(event, StreamErrorEvent):
|
||||
raise RuntimeError(f"Code interpreter error: {event.message}")
|
||||
|
||||
if result_event is None:
|
||||
raise RuntimeError(
|
||||
"Code interpreter stream ended without a result event"
|
||||
)
|
||||
|
||||
full_stdout = "".join(stdout_parts)
|
||||
full_stderr = "".join(stderr_parts)
|
||||
)
|
||||
|
||||
# Truncate output for LLM consumption
|
||||
truncated_stdout = _truncate_output(
|
||||
full_stdout, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stdout"
|
||||
response.stdout, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stdout"
|
||||
)
|
||||
truncated_stderr = _truncate_output(
|
||||
full_stderr, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stderr"
|
||||
response.stderr, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stderr"
|
||||
)
|
||||
|
||||
# Handle generated files
|
||||
@@ -245,7 +202,7 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
file_ids_to_cleanup: list[str] = []
|
||||
file_store = get_default_file_store()
|
||||
|
||||
for workspace_file in result_event.files:
|
||||
for workspace_file in response.files:
|
||||
if workspace_file.kind != "file" or not workspace_file.file_id:
|
||||
continue
|
||||
|
||||
@@ -301,23 +258,26 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
f"Failed to delete Code Interpreter staged file {file_mapping['file_id']}: {e}"
|
||||
)
|
||||
|
||||
# Emit file_ids once files are processed
|
||||
if generated_file_ids:
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
obj=PythonToolDelta(file_ids=generated_file_ids),
|
||||
)
|
||||
# Emit delta with stdout/stderr and generated files
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
obj=PythonToolDelta(
|
||||
stdout=truncated_stdout,
|
||||
stderr=truncated_stderr,
|
||||
file_ids=generated_file_ids,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Build result
|
||||
result = LlmPythonExecutionResult(
|
||||
stdout=truncated_stdout,
|
||||
stderr=truncated_stderr,
|
||||
exit_code=result_event.exit_code,
|
||||
timed_out=result_event.timed_out,
|
||||
exit_code=response.exit_code,
|
||||
timed_out=response.timed_out,
|
||||
generated_files=generated_files,
|
||||
error=None if result_event.exit_code == 0 else truncated_stderr,
|
||||
error=None if response.exit_code == 0 else truncated_stderr,
|
||||
)
|
||||
|
||||
# Serialize result for LLM
|
||||
|
||||
@@ -6,8 +6,6 @@ aioboto3==15.1.0
|
||||
# via onyx
|
||||
aiobotocore==2.24.0
|
||||
# via aioboto3
|
||||
aiofile==3.9.0
|
||||
# via py-key-value-aio
|
||||
aiofiles==25.1.0
|
||||
# via
|
||||
# aioboto3
|
||||
@@ -42,10 +40,8 @@ anyio==4.11.0
|
||||
# httpx
|
||||
# mcp
|
||||
# openai
|
||||
# py-key-value-aio
|
||||
# sse-starlette
|
||||
# starlette
|
||||
# watchfiles
|
||||
argon2-cffi==23.1.0
|
||||
# via pwdlib
|
||||
argon2-cffi-bindings==25.1.0
|
||||
@@ -78,7 +74,9 @@ backports-tarfile==1.2.0 ; python_full_version < '3.12'
|
||||
bcrypt==4.3.0
|
||||
# via pwdlib
|
||||
beartype==0.22.6
|
||||
# via py-key-value-aio
|
||||
# via
|
||||
# py-key-value-aio
|
||||
# py-key-value-shared
|
||||
beautifulsoup4==4.12.3
|
||||
# via
|
||||
# atlassian-python-api
|
||||
@@ -112,8 +110,6 @@ cachetools==6.2.2
|
||||
# via
|
||||
# google-auth
|
||||
# py-key-value-aio
|
||||
caio==0.9.25
|
||||
# via aiofile
|
||||
celery==5.5.1
|
||||
# via onyx
|
||||
certifi==2025.11.12
|
||||
@@ -174,6 +170,7 @@ cloudpickle==3.1.2
|
||||
# via
|
||||
# dask
|
||||
# distributed
|
||||
# pydocket
|
||||
cobble==0.1.4
|
||||
# via mammoth
|
||||
cohere==5.6.1
|
||||
@@ -221,6 +218,8 @@ deprecated==1.3.1
|
||||
# pygithub
|
||||
discord-py==2.4.0
|
||||
# via onyx
|
||||
diskcache==5.6.3
|
||||
# via py-key-value-aio
|
||||
distributed==2026.1.1
|
||||
# via onyx
|
||||
distro==1.9.0
|
||||
@@ -257,6 +256,8 @@ exceptiongroup==1.3.0
|
||||
# via
|
||||
# braintrust
|
||||
# fastmcp
|
||||
fakeredis==2.33.0
|
||||
# via pydocket
|
||||
fastapi==0.128.0
|
||||
# via
|
||||
# fastapi-limiter
|
||||
@@ -272,7 +273,7 @@ fastapi-users-db-sqlalchemy==7.0.0
|
||||
# via onyx
|
||||
fastavro==1.12.1
|
||||
# via cohere
|
||||
fastmcp==3.0.2
|
||||
fastmcp==2.14.2
|
||||
# via onyx
|
||||
fastuuid==0.14.0
|
||||
# via litellm
|
||||
@@ -477,9 +478,7 @@ jsonpatch==1.33
|
||||
jsonpointer==3.0.0
|
||||
# via jsonpatch
|
||||
jsonref==1.1.0
|
||||
# via
|
||||
# fastmcp
|
||||
# onyx
|
||||
# via onyx
|
||||
jsonschema==4.25.1
|
||||
# via
|
||||
# litellm
|
||||
@@ -514,6 +513,8 @@ locket==1.0.0
|
||||
# via
|
||||
# distributed
|
||||
# partd
|
||||
lupa==2.6
|
||||
# via fakeredis
|
||||
lxml==5.3.0
|
||||
# via
|
||||
# htmldate
|
||||
@@ -555,7 +556,7 @@ marshmallow==3.26.2
|
||||
# via dataclasses-json
|
||||
matrix-client==0.3.2
|
||||
# via zulip
|
||||
mcp==1.26.0
|
||||
mcp==1.25.0
|
||||
# via
|
||||
# claude-agent-sdk
|
||||
# fastmcp
|
||||
@@ -612,7 +613,7 @@ oauthlib==3.2.2
|
||||
# kubernetes
|
||||
# onyx
|
||||
# requests-oauthlib
|
||||
office365-rest-python-client==2.6.2
|
||||
office365-rest-python-client==2.5.9
|
||||
# via onyx
|
||||
olefile==0.47
|
||||
# via
|
||||
@@ -641,16 +642,22 @@ opensearch-py==3.0.0
|
||||
opentelemetry-api==1.39.1
|
||||
# via
|
||||
# ddtrace
|
||||
# fastmcp
|
||||
# langfuse
|
||||
# openinference-instrumentation
|
||||
# opentelemetry-exporter-otlp-proto-http
|
||||
# opentelemetry-exporter-prometheus
|
||||
# opentelemetry-instrumentation
|
||||
# opentelemetry-sdk
|
||||
# opentelemetry-semantic-conventions
|
||||
# pydocket
|
||||
opentelemetry-exporter-otlp-proto-common==1.39.1
|
||||
# via opentelemetry-exporter-otlp-proto-http
|
||||
opentelemetry-exporter-otlp-proto-http==1.39.1
|
||||
# via langfuse
|
||||
opentelemetry-exporter-prometheus==0.60b1
|
||||
# via pydocket
|
||||
opentelemetry-instrumentation==0.60b1
|
||||
# via pydocket
|
||||
opentelemetry-proto==1.39.1
|
||||
# via
|
||||
# onyx
|
||||
@@ -661,15 +668,17 @@ opentelemetry-sdk==1.39.1
|
||||
# langfuse
|
||||
# openinference-instrumentation
|
||||
# opentelemetry-exporter-otlp-proto-http
|
||||
# opentelemetry-exporter-prometheus
|
||||
opentelemetry-semantic-conventions==0.60b1
|
||||
# via opentelemetry-sdk
|
||||
# via
|
||||
# opentelemetry-instrumentation
|
||||
# opentelemetry-sdk
|
||||
orjson==3.11.4 ; platform_python_implementation != 'PyPy'
|
||||
# via langsmith
|
||||
packaging==24.2
|
||||
# via
|
||||
# dask
|
||||
# distributed
|
||||
# fastmcp
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# huggingface-hub
|
||||
@@ -680,6 +689,7 @@ packaging==24.2
|
||||
# langsmith
|
||||
# marshmallow
|
||||
# onnxruntime
|
||||
# opentelemetry-instrumentation
|
||||
# pytest
|
||||
# pywikibot
|
||||
pandas==2.3.3
|
||||
@@ -692,6 +702,8 @@ passlib==1.7.4
|
||||
# via onyx
|
||||
pathable==0.4.4
|
||||
# via jsonschema-path
|
||||
pathvalidate==3.3.1
|
||||
# via py-key-value-aio
|
||||
pdfminer-six==20251107
|
||||
# via markitdown
|
||||
pillow==12.1.1
|
||||
@@ -711,7 +723,9 @@ ply==3.11
|
||||
prometheus-client==0.23.1
|
||||
# via
|
||||
# onyx
|
||||
# opentelemetry-exporter-prometheus
|
||||
# prometheus-fastapi-instrumentator
|
||||
# pydocket
|
||||
prometheus-fastapi-instrumentator==7.1.0
|
||||
# via onyx
|
||||
prompt-toolkit==3.0.52
|
||||
@@ -750,8 +764,12 @@ pwdlib==0.3.0
|
||||
# via fastapi-users
|
||||
py==1.11.0
|
||||
# via retry
|
||||
py-key-value-aio==0.4.4
|
||||
# via fastmcp
|
||||
py-key-value-aio==0.3.0
|
||||
# via
|
||||
# fastmcp
|
||||
# pydocket
|
||||
py-key-value-shared==0.3.0
|
||||
# via py-key-value-aio
|
||||
pyairtable==3.0.1
|
||||
# via onyx
|
||||
pyasn1==0.6.2
|
||||
@@ -788,6 +806,8 @@ pydantic-core==2.33.2
|
||||
# via pydantic
|
||||
pydantic-settings==2.12.0
|
||||
# via mcp
|
||||
pydocket==0.16.3
|
||||
# via fastmcp
|
||||
pyee==13.0.0
|
||||
# via playwright
|
||||
pygithub==2.5.0
|
||||
@@ -859,6 +879,8 @@ python-http-client==3.3.7
|
||||
# via sendgrid
|
||||
python-iso639==2025.11.16
|
||||
# via unstructured
|
||||
python-json-logger==4.0.0
|
||||
# via pydocket
|
||||
python-magic==0.4.27
|
||||
# via unstructured
|
||||
python-multipart==0.0.22
|
||||
@@ -896,7 +918,6 @@ pyyaml==6.0.3
|
||||
# via
|
||||
# dask
|
||||
# distributed
|
||||
# fastmcp
|
||||
# huggingface-hub
|
||||
# jsonschema-path
|
||||
# kubernetes
|
||||
@@ -907,8 +928,11 @@ rapidfuzz==3.13.0
|
||||
# unstructured
|
||||
redis==5.0.8
|
||||
# via
|
||||
# fakeredis
|
||||
# fastapi-limiter
|
||||
# onyx
|
||||
# py-key-value-aio
|
||||
# pydocket
|
||||
referencing==0.36.2
|
||||
# via
|
||||
# jsonschema
|
||||
@@ -983,6 +1007,7 @@ rich==14.2.0
|
||||
# via
|
||||
# cyclopts
|
||||
# fastmcp
|
||||
# pydocket
|
||||
# rich-rst
|
||||
# typer
|
||||
rich-rst==1.3.2
|
||||
@@ -1031,7 +1056,9 @@ sniffio==1.3.1
|
||||
# anyio
|
||||
# openai
|
||||
sortedcontainers==2.4.0
|
||||
# via distributed
|
||||
# via
|
||||
# distributed
|
||||
# fakeredis
|
||||
soupsieve==2.8
|
||||
# via beautifulsoup4
|
||||
sqlalchemy==2.0.15
|
||||
@@ -1097,7 +1124,9 @@ tqdm==4.67.1
|
||||
trafilatura==1.12.2
|
||||
# via onyx
|
||||
typer==0.20.0
|
||||
# via mcp
|
||||
# via
|
||||
# mcp
|
||||
# pydocket
|
||||
types-awscrt==0.28.4
|
||||
# via botocore-stubs
|
||||
types-openpyxl==3.0.4.7
|
||||
@@ -1133,10 +1162,11 @@ typing-extensions==4.15.0
|
||||
# opentelemetry-exporter-otlp-proto-http
|
||||
# opentelemetry-sdk
|
||||
# opentelemetry-semantic-conventions
|
||||
# py-key-value-aio
|
||||
# py-key-value-shared
|
||||
# pyairtable
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# pydocket
|
||||
# pyee
|
||||
# pygithub
|
||||
# python-docx
|
||||
@@ -1204,8 +1234,6 @@ vine==5.1.0
|
||||
# kombu
|
||||
voyageai==0.2.3
|
||||
# via onyx
|
||||
watchfiles==1.1.1
|
||||
# via fastmcp
|
||||
wcwidth==0.2.14
|
||||
# via prompt-toolkit
|
||||
webencodings==0.5.1
|
||||
@@ -1226,6 +1254,7 @@ wrapt==1.17.3
|
||||
# deprecated
|
||||
# langfuse
|
||||
# openinference-instrumentation
|
||||
# opentelemetry-instrumentation
|
||||
# unstructured
|
||||
xlrd==2.0.2
|
||||
# via markitdown
|
||||
|
||||
@@ -288,7 +288,7 @@ matplotlib-inline==0.2.1
|
||||
# via
|
||||
# ipykernel
|
||||
# ipython
|
||||
mcp==1.26.0
|
||||
mcp==1.25.0
|
||||
# via claude-agent-sdk
|
||||
multidict==6.7.0
|
||||
# via
|
||||
@@ -317,7 +317,7 @@ oauthlib==3.2.2
|
||||
# via
|
||||
# kubernetes
|
||||
# requests-oauthlib
|
||||
onyx-devtools==0.6.1
|
||||
onyx-devtools==0.6.0
|
||||
# via onyx
|
||||
openai==2.14.0
|
||||
# via
|
||||
|
||||
@@ -211,7 +211,7 @@ litellm==1.81.6
|
||||
# via onyx
|
||||
markupsafe==3.0.3
|
||||
# via jinja2
|
||||
mcp==1.26.0
|
||||
mcp==1.25.0
|
||||
# via claude-agent-sdk
|
||||
monotonic==1.6
|
||||
# via posthog
|
||||
|
||||
@@ -246,7 +246,7 @@ litellm==1.81.6
|
||||
# via onyx
|
||||
markupsafe==3.0.3
|
||||
# via jinja2
|
||||
mcp==1.26.0
|
||||
mcp==1.25.0
|
||||
# via claude-agent-sdk
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
|
||||
@@ -95,7 +95,6 @@ def generate_dummy_chunk(
|
||||
return DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
user_project=[],
|
||||
personas=[],
|
||||
access=DocumentAccess.build(
|
||||
user_emails=user_emails,
|
||||
user_groups=user_groups,
|
||||
|
||||
@@ -3,8 +3,8 @@ set -e
|
||||
|
||||
cleanup() {
|
||||
echo "Error occurred. Cleaning up..."
|
||||
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
|
||||
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
|
||||
}
|
||||
|
||||
# Trap errors and output a message, then cleanup
|
||||
@@ -20,8 +20,8 @@ MINIO_VOLUME=${4:-""} # Default is empty if not provided
|
||||
|
||||
# Stop and remove the existing containers
|
||||
echo "Stopping and removing existing containers..."
|
||||
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
|
||||
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
|
||||
|
||||
# Start the PostgreSQL container with optional volume
|
||||
echo "Starting PostgreSQL container..."
|
||||
@@ -55,10 +55,6 @@ else
|
||||
docker run --detach --name onyx_minio --publish 9004:9000 --publish 9005:9001 -e MINIO_ROOT_USER=minioadmin -e MINIO_ROOT_PASSWORD=minioadmin minio/minio server /data --console-address ":9001"
|
||||
fi
|
||||
|
||||
# Start the Code Interpreter container
|
||||
echo "Starting Code Interpreter container..."
|
||||
docker run --detach --name onyx_code_interpreter --publish 8000:8000 --user root -v /var/run/docker.sock:/var/run/docker.sock onyxdotapp/code-interpreter:latest bash ./entrypoint.sh code-interpreter-api
|
||||
|
||||
# Ensure alembic runs in the correct directory (backend/)
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
|
||||
PARENT_DIR="$(dirname "$SCRIPT_DIR")"
|
||||
|
||||
@@ -9,7 +9,6 @@ from collections.abc import AsyncGenerator
|
||||
from collections.abc import Generator
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
@@ -47,15 +46,11 @@ def mock_current_admin_user() -> MagicMock:
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client() -> Generator[TestClient, None, None]:
|
||||
# Initialize TestClient with the FastAPI app using a no-op test lifespan.
|
||||
# Patch out prometheus metrics setup to avoid "Duplicated timeseries in
|
||||
# CollectorRegistry" errors when multiple tests each create a new app
|
||||
# (prometheus registers metrics globally and rejects duplicate names).
|
||||
# Initialize TestClient with the FastAPI app using a no-op test lifespan
|
||||
get_app = fetch_versioned_implementation(
|
||||
module="onyx.main", attribute="get_application"
|
||||
)
|
||||
with patch("onyx.main.setup_prometheus_metrics"):
|
||||
app: FastAPI = get_app(lifespan_override=test_lifespan)
|
||||
app: FastAPI = get_app(lifespan_override=test_lifespan)
|
||||
|
||||
# Override the database session dependency with a mock
|
||||
# (these tests don't actually need DB access)
|
||||
|
||||
@@ -17,7 +17,7 @@ def test_bedrock_llm_configuration(client: TestClient) -> None:
|
||||
# Prepare the test request payload
|
||||
test_request: dict[str, Any] = {
|
||||
"provider": LlmProviderNames.BEDROCK,
|
||||
"model": _DEFAULT_BEDROCK_MODEL,
|
||||
"default_model_name": _DEFAULT_BEDROCK_MODEL,
|
||||
"api_key": None,
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
@@ -26,6 +26,7 @@ def test_bedrock_llm_configuration(client: TestClient) -> None:
|
||||
"AWS_ACCESS_KEY_ID": os.environ.get("AWS_ACCESS_KEY_ID"),
|
||||
"AWS_SECRET_ACCESS_KEY": os.environ.get("AWS_SECRET_ACCESS_KEY"),
|
||||
},
|
||||
"model_configurations": [{"name": _DEFAULT_BEDROCK_MODEL, "is_visible": True}],
|
||||
"api_key_changed": True,
|
||||
"custom_config_changed": True,
|
||||
}
|
||||
@@ -43,7 +44,7 @@ def test_bedrock_llm_configuration_invalid_key(client: TestClient) -> None:
|
||||
# Prepare the test request payload with invalid credentials
|
||||
test_request: dict[str, Any] = {
|
||||
"provider": LlmProviderNames.BEDROCK,
|
||||
"model": _DEFAULT_BEDROCK_MODEL,
|
||||
"default_model_name": _DEFAULT_BEDROCK_MODEL,
|
||||
"api_key": None,
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
@@ -52,6 +53,7 @@ def test_bedrock_llm_configuration_invalid_key(client: TestClient) -> None:
|
||||
"AWS_ACCESS_KEY_ID": "invalid_access_key_id",
|
||||
"AWS_SECRET_ACCESS_KEY": "invalid_secret_access_key",
|
||||
},
|
||||
"model_configurations": [{"name": _DEFAULT_BEDROCK_MODEL, "is_visible": True}],
|
||||
"api_key_changed": True,
|
||||
"custom_config_changed": True,
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ def ensure_default_llm_provider(db_session: Session) -> None:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=os.environ.get("OPENAI_API_KEY", "test"),
|
||||
is_public=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini",
|
||||
@@ -40,7 +41,7 @@ def ensure_default_llm_provider(db_session: Session) -> None:
|
||||
llm_provider_upsert_request=llm_provider_request,
|
||||
db_session=db_session,
|
||||
)
|
||||
update_default_provider(provider.id, "gpt-4o-mini", db_session)
|
||||
update_default_provider(provider.id, db_session)
|
||||
except Exception as exc: # pragma: no cover - only hits on duplicate setup issues
|
||||
# Rollback to clear the pending transaction state
|
||||
db_session.rollback()
|
||||
|
||||
@@ -47,6 +47,7 @@ def test_answer_with_only_anthropic_provider(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.ANTHROPIC,
|
||||
api_key=anthropic_api_key,
|
||||
default_model_name=anthropic_model,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
model_configurations=[
|
||||
@@ -58,7 +59,7 @@ def test_answer_with_only_anthropic_provider(
|
||||
)
|
||||
|
||||
try:
|
||||
update_default_provider(anthropic_provider.id, anthropic_model, db_session)
|
||||
update_default_provider(anthropic_provider.id, db_session)
|
||||
|
||||
test_user = create_test_user(db_session, email_prefix="anthropic_only")
|
||||
chat_session = create_chat_session(
|
||||
|
||||
@@ -29,7 +29,6 @@ from onyx.server.manage.llm.api import (
|
||||
test_llm_configuration as run_test_llm_configuration,
|
||||
)
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.manage.llm.models import TestLLMRequest as LLMTestRequest
|
||||
|
||||
@@ -45,14 +44,15 @@ def _create_test_provider(
|
||||
db_session: Session,
|
||||
name: str,
|
||||
api_key: str = "sk-test-key-00000000000000000000000000000000000",
|
||||
) -> LLMProviderView:
|
||||
) -> None:
|
||||
"""Helper to create a test LLM provider in the database."""
|
||||
return upsert_llm_provider(
|
||||
upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=api_key,
|
||||
api_key_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o-mini", is_visible=True)
|
||||
],
|
||||
@@ -107,7 +107,12 @@ class TestLLMConfigurationEndpoint:
|
||||
api_key="sk-new-test-key-0000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=False,
|
||||
model="gpt-4o-mini",
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -152,7 +157,12 @@ class TestLLMConfigurationEndpoint:
|
||||
api_key="sk-invalid-key-00000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=False,
|
||||
model="gpt-4o-mini",
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -184,9 +194,7 @@ class TestLLMConfigurationEndpoint:
|
||||
|
||||
try:
|
||||
# First, create the provider in the database
|
||||
provider = _create_test_provider(
|
||||
db_session, provider_name, api_key=original_api_key
|
||||
)
|
||||
_create_test_provider(db_session, provider_name, api_key=original_api_key)
|
||||
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
|
||||
@@ -194,13 +202,17 @@ class TestLLMConfigurationEndpoint:
|
||||
# Test with api_key_changed=False - should use stored key
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
id=provider.id,
|
||||
name=provider_name, # Existing provider
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None, # Not providing a new key
|
||||
api_key_changed=False, # Using existing key
|
||||
custom_config_changed=False,
|
||||
model="gpt-4o-mini",
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -247,7 +259,12 @@ class TestLLMConfigurationEndpoint:
|
||||
api_key=new_api_key, # Providing a new key
|
||||
api_key_changed=True, # Key is being changed
|
||||
custom_config_changed=False,
|
||||
model="gpt-4o-mini",
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -280,7 +297,7 @@ class TestLLMConfigurationEndpoint:
|
||||
|
||||
try:
|
||||
# First, create the provider in the database with custom_config
|
||||
provider = upsert_llm_provider(
|
||||
upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
@@ -288,6 +305,12 @@ class TestLLMConfigurationEndpoint:
|
||||
api_key_changed=True,
|
||||
custom_config=original_custom_config,
|
||||
custom_config_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
@@ -298,14 +321,18 @@ class TestLLMConfigurationEndpoint:
|
||||
# Test with custom_config_changed=False - should use stored config
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None,
|
||||
api_key_changed=False,
|
||||
custom_config=None, # Not providing new config
|
||||
custom_config_changed=False, # Using existing config
|
||||
model="gpt-4o-mini",
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -346,7 +373,12 @@ class TestLLMConfigurationEndpoint:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=False,
|
||||
model=model_name,
|
||||
default_model_name=model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -410,6 +442,7 @@ class TestDefaultProviderEndpoint:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=provider_1_api_key,
|
||||
api_key_changed=True,
|
||||
default_model_name=provider_1_initial_model,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True),
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
|
||||
@@ -419,7 +452,7 @@ class TestDefaultProviderEndpoint:
|
||||
)
|
||||
|
||||
# Set provider 1 as the default provider explicitly
|
||||
update_default_provider(provider_1.id, provider_1_initial_model, db_session)
|
||||
update_default_provider(provider_1.id, db_session)
|
||||
|
||||
# Step 2: Call run_test_default_provider - should use provider 1's default model
|
||||
with patch(
|
||||
@@ -439,6 +472,7 @@ class TestDefaultProviderEndpoint:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=provider_2_api_key,
|
||||
api_key_changed=True,
|
||||
default_model_name=provider_2_default_model,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
@@ -465,11 +499,11 @@ class TestDefaultProviderEndpoint:
|
||||
# Step 5: Update provider 1's default model
|
||||
upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
id=provider_1.id,
|
||||
name=provider_1_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=provider_1_api_key,
|
||||
api_key_changed=True,
|
||||
default_model_name=provider_1_updated_model, # Changed
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True),
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
|
||||
@@ -478,9 +512,6 @@ class TestDefaultProviderEndpoint:
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Set provider 1's default model to the updated model
|
||||
update_default_provider(provider_1.id, provider_1_updated_model, db_session)
|
||||
|
||||
# Step 6: Call run_test_default_provider - should use new model on provider 1
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
|
||||
@@ -493,7 +524,7 @@ class TestDefaultProviderEndpoint:
|
||||
captured_llms.clear()
|
||||
|
||||
# Step 7: Change the default provider to provider 2
|
||||
update_default_provider(provider_2.id, provider_2_default_model, db_session)
|
||||
update_default_provider(provider_2.id, db_session)
|
||||
|
||||
# Step 8: Call run_test_default_provider - should use provider 2
|
||||
with patch(
|
||||
@@ -565,6 +596,7 @@ class TestDefaultProviderEndpoint:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
@@ -573,7 +605,7 @@ class TestDefaultProviderEndpoint:
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
update_default_provider(provider.id, "gpt-4o-mini", db_session)
|
||||
update_default_provider(provider.id, db_session)
|
||||
|
||||
# Test should fail
|
||||
with patch(
|
||||
|
||||
@@ -20,7 +20,6 @@ from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import fetch_llm_provider_view
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import UserRole
|
||||
@@ -50,6 +49,7 @@ def _create_test_provider(
|
||||
api_key_changed=True,
|
||||
api_base=api_base,
|
||||
custom_config=custom_config,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o-mini", is_visible=True)
|
||||
],
|
||||
@@ -91,14 +91,14 @@ class TestLLMProviderChanges:
|
||||
the API key should be blocked.
|
||||
"""
|
||||
try:
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
_create_test_provider(db_session, provider_name)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base="https://attacker.example.com",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -125,16 +125,16 @@ class TestLLMProviderChanges:
|
||||
Changing api_base IS allowed when the API key is also being changed.
|
||||
"""
|
||||
try:
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
_create_test_provider(db_session, provider_name)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-new-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
api_base="https://custom-endpoint.example.com/v1",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -159,16 +159,14 @@ class TestLLMProviderChanges:
|
||||
original_api_base = "https://original.example.com/v1"
|
||||
|
||||
try:
|
||||
provider = _create_test_provider(
|
||||
db_session, provider_name, api_base=original_api_base
|
||||
)
|
||||
_create_test_provider(db_session, provider_name, api_base=original_api_base)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base=original_api_base,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -192,14 +190,14 @@ class TestLLMProviderChanges:
|
||||
changes. This allows model-only updates when provider has no custom base URL.
|
||||
"""
|
||||
try:
|
||||
view = _create_test_provider(db_session, provider_name, api_base=None)
|
||||
_create_test_provider(db_session, provider_name, api_base=None)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=view.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base="",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -225,16 +223,14 @@ class TestLLMProviderChanges:
|
||||
original_api_base = "https://original.example.com/v1"
|
||||
|
||||
try:
|
||||
provider = _create_test_provider(
|
||||
db_session, provider_name, api_base=original_api_base
|
||||
)
|
||||
_create_test_provider(db_session, provider_name, api_base=original_api_base)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base=None,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -263,14 +259,14 @@ class TestLLMProviderChanges:
|
||||
users have full control over their deployment.
|
||||
"""
|
||||
try:
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
_create_test_provider(db_session, provider_name)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base="https://custom.example.com/v1",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -301,6 +297,7 @@ class TestLLMProviderChanges:
|
||||
api_key="sk-new-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
api_base="https://custom.example.com/v1",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -325,7 +322,7 @@ class TestLLMProviderChanges:
|
||||
redirect LLM API requests).
|
||||
"""
|
||||
try:
|
||||
provider = _create_test_provider(
|
||||
_create_test_provider(
|
||||
db_session,
|
||||
provider_name,
|
||||
custom_config={"SOME_CONFIG": "original_value"},
|
||||
@@ -333,11 +330,11 @@ class TestLLMProviderChanges:
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
custom_config={"OPENAI_API_BASE": "https://attacker.example.com"},
|
||||
custom_config_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -365,15 +362,15 @@ class TestLLMProviderChanges:
|
||||
without changing the API key.
|
||||
"""
|
||||
try:
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
_create_test_provider(db_session, provider_name)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
custom_config={"OPENAI_API_BASE": "https://attacker.example.com"},
|
||||
custom_config_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -402,7 +399,7 @@ class TestLLMProviderChanges:
|
||||
new_config = {"AWS_REGION_NAME": "us-west-2"}
|
||||
|
||||
try:
|
||||
provider = _create_test_provider(
|
||||
_create_test_provider(
|
||||
db_session,
|
||||
provider_name,
|
||||
custom_config={"AWS_REGION_NAME": "us-east-1"},
|
||||
@@ -410,13 +407,13 @@ class TestLLMProviderChanges:
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-new-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=True,
|
||||
custom_config=new_config,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -441,17 +438,17 @@ class TestLLMProviderChanges:
|
||||
original_config = {"AWS_REGION_NAME": "us-east-1"}
|
||||
|
||||
try:
|
||||
provider = _create_test_provider(
|
||||
_create_test_provider(
|
||||
db_session, provider_name, custom_config=original_config
|
||||
)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
custom_config=original_config,
|
||||
custom_config_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -477,7 +474,7 @@ class TestLLMProviderChanges:
|
||||
new_config = {"AWS_REGION_NAME": "eu-west-1"}
|
||||
|
||||
try:
|
||||
provider = _create_test_provider(
|
||||
_create_test_provider(
|
||||
db_session,
|
||||
provider_name,
|
||||
custom_config={"AWS_REGION_NAME": "us-east-1"},
|
||||
@@ -485,10 +482,10 @@ class TestLLMProviderChanges:
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
custom_config=new_config,
|
||||
default_model_name="gpt-4o-mini",
|
||||
custom_config_changed=True,
|
||||
)
|
||||
|
||||
@@ -535,7 +532,12 @@ def test_upload_with_custom_config_then_change(
|
||||
LLMTestRequest(
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
model=default_model_name,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
api_key_changed=False,
|
||||
custom_config_changed=True,
|
||||
custom_config=custom_config,
|
||||
@@ -544,10 +546,11 @@ def test_upload_with_custom_config_then_change(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
provider = put_llm_provider(
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
default_model_name=default_model_name,
|
||||
custom_config=custom_config,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -566,10 +569,14 @@ def test_upload_with_custom_config_then_change(
|
||||
# Turn auto mode off
|
||||
run_llm_config_test(
|
||||
LLMTestRequest(
|
||||
id=provider.id,
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
model=default_model_name,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
api_key_changed=False,
|
||||
custom_config_changed=False,
|
||||
),
|
||||
@@ -579,9 +586,9 @@ def test_upload_with_custom_config_then_change(
|
||||
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
@@ -609,9 +616,7 @@ def test_upload_with_custom_config_then_change(
|
||||
)
|
||||
|
||||
# Check inside the database and check that custom_config is the same as the original
|
||||
provider = fetch_llm_provider_view(
|
||||
db_session=db_session, provider_name=name
|
||||
)
|
||||
provider = fetch_existing_llm_provider(name=name, db_session=db_session)
|
||||
if not provider:
|
||||
assert False, "Provider not found in the database"
|
||||
|
||||
@@ -637,10 +642,11 @@ def test_preserves_masked_sensitive_custom_config_on_provider_update(
|
||||
}
|
||||
|
||||
try:
|
||||
view = put_llm_provider(
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=provider,
|
||||
default_model_name=default_model_name,
|
||||
custom_config=original_custom_config,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -659,9 +665,9 @@ def test_preserves_masked_sensitive_custom_config_on_provider_update(
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
id=view.id,
|
||||
name=name,
|
||||
provider=provider,
|
||||
default_model_name=default_model_name,
|
||||
custom_config={
|
||||
"vertex_credentials": _mask_string(
|
||||
original_custom_config["vertex_credentials"]
|
||||
@@ -713,10 +719,11 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
|
||||
return ""
|
||||
|
||||
try:
|
||||
view = put_llm_provider(
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=provider,
|
||||
default_model_name=default_model_name,
|
||||
custom_config=original_custom_config,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -735,10 +742,14 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
|
||||
with patch("onyx.server.manage.llm.api.test_llm", side_effect=capture_test_llm):
|
||||
run_llm_config_test(
|
||||
LLMTestRequest(
|
||||
id=view.id,
|
||||
name=name,
|
||||
provider=provider,
|
||||
model=default_model_name,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
api_key_changed=False,
|
||||
custom_config_changed=True,
|
||||
custom_config={
|
||||
|
||||
@@ -18,7 +18,6 @@ from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_llm_provider_view
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import sync_auto_mode_models
|
||||
from onyx.db.llm import update_default_provider
|
||||
@@ -136,6 +135,7 @@ class TestAutoModeSyncFeature:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name=expected_default_model,
|
||||
model_configurations=[], # No model configs provided
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -163,8 +163,13 @@ class TestAutoModeSyncFeature:
|
||||
if mc.name in all_expected_models:
|
||||
assert mc.is_visible is True, f"Model '{mc.name}' should be visible"
|
||||
|
||||
# Verify the default model was set correctly
|
||||
assert (
|
||||
provider.default_model_name == expected_default_model
|
||||
), f"Default model should be '{expected_default_model}'"
|
||||
|
||||
# Step 4: Set the provider as default
|
||||
update_default_provider(provider.id, expected_default_model, db_session)
|
||||
update_default_provider(provider.id, db_session)
|
||||
|
||||
# Step 5: Fetch the default provider and verify
|
||||
default_model = fetch_default_llm_model(db_session)
|
||||
@@ -233,6 +238,7 @@ class TestAutoModeSyncFeature:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name="gpt-4o",
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -304,13 +310,14 @@ class TestAutoModeSyncFeature:
|
||||
|
||||
try:
|
||||
# Step 1: Upload provider WITHOUT auto mode, with initial models
|
||||
provider = put_llm_provider(
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=False, # Not in auto mode initially
|
||||
default_model_name="gpt-4",
|
||||
model_configurations=initial_models,
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -337,12 +344,12 @@ class TestAutoModeSyncFeature:
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None, # Not changing API key
|
||||
api_key_changed=False,
|
||||
is_auto_mode=True, # Now enabling auto mode
|
||||
default_model_name=auto_mode_default,
|
||||
model_configurations=[], # Auto mode will sync from config
|
||||
),
|
||||
is_creation=False, # This is an update
|
||||
@@ -353,8 +360,8 @@ class TestAutoModeSyncFeature:
|
||||
# Step 3: Verify model visibility after auto mode transition
|
||||
# Expire session cache to force fresh fetch after sync_auto_mode_models committed
|
||||
db_session.expire_all()
|
||||
provider = fetch_llm_provider_view(
|
||||
db_session=db_session, provider_name=provider_name
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
assert provider.is_auto_mode is True
|
||||
@@ -381,6 +388,9 @@ class TestAutoModeSyncFeature:
|
||||
model_visibility[model_name] is False
|
||||
), f"Model '{model_name}' not in auto config should NOT be visible"
|
||||
|
||||
# Verify the default model was updated
|
||||
assert provider.default_model_name == auto_mode_default
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
@@ -422,12 +432,8 @@ class TestAutoModeSyncFeature:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o",
|
||||
is_visible=True,
|
||||
)
|
||||
],
|
||||
default_model_name="gpt-4o",
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
@@ -529,6 +535,7 @@ class TestAutoModeSyncFeature:
|
||||
api_key=provider_1_api_key,
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name=provider_1_default_model,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -542,7 +549,7 @@ class TestAutoModeSyncFeature:
|
||||
name=provider_1_name, db_session=db_session
|
||||
)
|
||||
assert provider_1 is not None
|
||||
update_default_provider(provider_1.id, provider_1_default_model, db_session)
|
||||
update_default_provider(provider_1.id, db_session)
|
||||
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
@@ -556,6 +563,7 @@ class TestAutoModeSyncFeature:
|
||||
api_key=provider_2_api_key,
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name=provider_2_default_model,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -576,7 +584,7 @@ class TestAutoModeSyncFeature:
|
||||
name=provider_2_name, db_session=db_session
|
||||
)
|
||||
assert provider_2 is not None
|
||||
update_default_provider(provider_2.id, provider_2_default_model, db_session)
|
||||
update_default_provider(provider_2.id, db_session)
|
||||
|
||||
# Step 5: Verify provider 2 is now the default
|
||||
db_session.expire_all()
|
||||
|
||||
@@ -64,6 +64,7 @@ def _create_provider(
|
||||
name=name,
|
||||
provider=provider,
|
||||
api_key="sk-ant-api03-...",
|
||||
default_model_name="claude-3-5-sonnet-20240620",
|
||||
is_public=is_public,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -153,9 +154,7 @@ def test_user_sends_message_to_private_provider(
|
||||
)
|
||||
_create_provider(db_session, LlmProviderNames.GOOGLE, "private-provider", False)
|
||||
|
||||
update_default_provider(
|
||||
public_provider_id, "claude-3-5-sonnet-20240620", db_session
|
||||
)
|
||||
update_default_provider(public_provider_id, db_session)
|
||||
|
||||
try:
|
||||
# Create chat session
|
||||
|
||||
@@ -434,6 +434,7 @@ class TestSlackBotFederatedSearch:
|
||||
name=f"test-llm-provider-{uuid4().hex[:8]}",
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=api_key,
|
||||
default_model_name="gpt-4o",
|
||||
is_public=True,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -447,7 +448,7 @@ class TestSlackBotFederatedSearch:
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
update_default_provider(provider_view.id, "gpt-4o", db_session)
|
||||
update_default_provider(provider_view.id, db_session)
|
||||
|
||||
def _teardown_common_mocks(self, patches: list) -> None:
|
||||
"""Stop all patches"""
|
||||
|
||||
@@ -990,27 +990,6 @@ class _MockCIHandler(BaseHTTPRequestHandler):
|
||||
self._respond_json(
|
||||
200, {"file_id": f"mock-ci-file-{self.server._file_counter}"}
|
||||
)
|
||||
elif self.path == "/v1/execute/stream":
|
||||
if self.server.streaming_enabled:
|
||||
self._respond_sse(
|
||||
[
|
||||
(
|
||||
"output",
|
||||
{"stream": "stdout", "data": "mock output\n"},
|
||||
),
|
||||
(
|
||||
"result",
|
||||
{
|
||||
"exit_code": 0,
|
||||
"timed_out": False,
|
||||
"duration_ms": 50,
|
||||
"files": [],
|
||||
},
|
||||
),
|
||||
]
|
||||
)
|
||||
else:
|
||||
self._respond_json(404, {"error": "not found"})
|
||||
elif self.path == "/v1/execute":
|
||||
self._respond_json(
|
||||
200,
|
||||
@@ -1048,17 +1027,6 @@ class _MockCIHandler(BaseHTTPRequestHandler):
|
||||
self.end_headers()
|
||||
self.wfile.write(payload)
|
||||
|
||||
def _respond_sse(self, events: list[tuple[str, dict[str, Any]]]) -> None:
|
||||
frames = []
|
||||
for event_type, data in events:
|
||||
frames.append(f"event: {event_type}\ndata: {json.dumps(data)}\n\n")
|
||||
payload = "".join(frames).encode()
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/event-stream")
|
||||
self.send_header("Content-Length", str(len(payload)))
|
||||
self.end_headers()
|
||||
self.wfile.write(payload)
|
||||
|
||||
def log_message(self, format: str, *args: Any) -> None: # noqa: A002
|
||||
pass
|
||||
|
||||
@@ -1070,7 +1038,6 @@ class MockCodeInterpreterServer(HTTPServer):
|
||||
super().__init__(("localhost", 0), _MockCIHandler)
|
||||
self.captured_requests: list[CapturedRequest] = []
|
||||
self._file_counter = 0
|
||||
self.streaming_enabled: bool = True
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
@@ -1201,19 +1168,17 @@ def test_code_interpreter_receives_chat_files(
|
||||
finally:
|
||||
ci_mod.CodeInterpreterClient.__init__.__defaults__ = original_defaults
|
||||
|
||||
# Verify: file uploaded, code executed via streaming, staged file cleaned up
|
||||
# Verify: file uploaded, code executed, staged file cleaned up
|
||||
assert len(mock_ci_server.get_requests(method="POST", path="/v1/files")) == 1
|
||||
assert (
|
||||
len(mock_ci_server.get_requests(method="POST", path="/v1/execute/stream")) == 1
|
||||
)
|
||||
assert len(mock_ci_server.get_requests(method="POST", path="/v1/execute")) == 1
|
||||
|
||||
delete_requests = mock_ci_server.get_requests(method="DELETE")
|
||||
assert len(delete_requests) == 1
|
||||
assert delete_requests[0].path.startswith("/v1/files/")
|
||||
|
||||
execute_body = mock_ci_server.get_requests(
|
||||
method="POST", path="/v1/execute/stream"
|
||||
)[0].json_body()
|
||||
execute_body = mock_ci_server.get_requests(method="POST", path="/v1/execute")[
|
||||
0
|
||||
].json_body()
|
||||
assert execute_body["code"] == code
|
||||
assert len(execute_body["files"]) == 1
|
||||
assert execute_body["files"][0]["path"] == "data.csv"
|
||||
@@ -1319,9 +1284,7 @@ def test_code_interpreter_replay_packets_include_code_and_output(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
assert (
|
||||
len(mock_ci_server.get_requests(method="POST", path="/v1/execute/stream")) == 1
|
||||
)
|
||||
assert len(mock_ci_server.get_requests(method="POST", path="/v1/execute")) == 1
|
||||
|
||||
# The response contains `packets` — a list of packet-lists, one per
|
||||
# assistant message. We should have exactly one assistant message.
|
||||
@@ -1350,76 +1313,3 @@ def test_code_interpreter_replay_packets_include_code_and_output(
|
||||
delta_obj = delta_packets[0].obj
|
||||
assert isinstance(delta_obj, PythonToolDelta)
|
||||
assert "mock output" in delta_obj.stdout
|
||||
|
||||
|
||||
def test_code_interpreter_streaming_fallback_to_batch(
|
||||
db_session: Session,
|
||||
mock_ci_server: MockCodeInterpreterServer,
|
||||
_attach_python_tool_to_default_persona: None,
|
||||
initialize_file_store: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""When the streaming endpoint is not available (older code-interpreter),
|
||||
execute_streaming should fall back to the batch /v1/execute endpoint."""
|
||||
mock_ci_server.captured_requests.clear()
|
||||
mock_ci_server._file_counter = 0
|
||||
mock_ci_server.streaming_enabled = False
|
||||
mock_url = mock_ci_server.url
|
||||
|
||||
user = create_test_user(db_session, "ci_fallback_test")
|
||||
chat_session = create_chat_session(db_session=db_session, user=user)
|
||||
|
||||
code = 'print("fallback test")'
|
||||
msg_req = SendMessageRequest(
|
||||
message="Print fallback test",
|
||||
chat_session_id=chat_session.id,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
original_defaults = ci_mod.CodeInterpreterClient.__init__.__defaults__
|
||||
with (
|
||||
use_mock_llm() as mock_llm,
|
||||
patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
mock_url,
|
||||
),
|
||||
patch(
|
||||
"onyx.tools.tool_implementations.python.code_interpreter_client.CODE_INTERPRETER_BASE_URL",
|
||||
mock_url,
|
||||
),
|
||||
):
|
||||
mock_llm.add_response(
|
||||
LLMToolCallResponse(
|
||||
tool_name="python",
|
||||
tool_call_id="call_fallback",
|
||||
tool_call_argument_tokens=[json.dumps({"code": code})],
|
||||
)
|
||||
)
|
||||
mock_llm.forward_till_end()
|
||||
|
||||
ci_mod.CodeInterpreterClient.__init__.__defaults__ = (mock_url,)
|
||||
try:
|
||||
packets = list(
|
||||
handle_stream_message_objects(
|
||||
new_msg_req=msg_req, user=user, db_session=db_session
|
||||
)
|
||||
)
|
||||
finally:
|
||||
ci_mod.CodeInterpreterClient.__init__.__defaults__ = original_defaults
|
||||
mock_ci_server.streaming_enabled = True
|
||||
|
||||
# Streaming was attempted first (returned 404), then fell back to batch
|
||||
assert (
|
||||
len(mock_ci_server.get_requests(method="POST", path="/v1/execute/stream")) == 1
|
||||
)
|
||||
assert len(mock_ci_server.get_requests(method="POST", path="/v1/execute")) == 1
|
||||
|
||||
# Verify output still made it through
|
||||
delta_packets = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, PythonToolDelta)
|
||||
]
|
||||
assert len(delta_packets) >= 1
|
||||
first_delta = delta_packets[0].obj
|
||||
assert isinstance(first_delta, PythonToolDelta)
|
||||
assert "mock output" in first_delta.stdout
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
"""Tests that PythonTool.is_available() respects the server_enabled DB flag.
|
||||
|
||||
Uses a real DB session with CODE_INTERPRETER_BASE_URL mocked so the
|
||||
environment-variable check passes and the DB flag is the deciding factor.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.code_interpreter import fetch_code_interpreter_server
|
||||
from onyx.db.code_interpreter import update_code_interpreter_server_enabled
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
|
||||
|
||||
def test_python_tool_unavailable_when_server_disabled(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""With a valid base URL, the tool should be unavailable when
|
||||
server_enabled is False in the DB."""
|
||||
server = fetch_code_interpreter_server(db_session)
|
||||
initial_enabled = server.server_enabled
|
||||
|
||||
try:
|
||||
update_code_interpreter_server_enabled(db_session, enabled=False)
|
||||
|
||||
with patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
"http://fake:8888",
|
||||
):
|
||||
assert PythonTool.is_available(db_session) is False
|
||||
finally:
|
||||
update_code_interpreter_server_enabled(db_session, enabled=initial_enabled)
|
||||
|
||||
|
||||
def test_python_tool_available_when_server_enabled(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""With a valid base URL, the tool should be available when
|
||||
server_enabled is True in the DB."""
|
||||
server = fetch_code_interpreter_server(db_session)
|
||||
initial_enabled = server.server_enabled
|
||||
|
||||
try:
|
||||
update_code_interpreter_server_enabled(db_session, enabled=True)
|
||||
|
||||
with patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
"http://fake:8888",
|
||||
):
|
||||
assert PythonTool.is_available(db_session) is True
|
||||
finally:
|
||||
update_code_interpreter_server_enabled(db_session, enabled=initial_enabled)
|
||||
@@ -38,5 +38,5 @@ COPY --from=openapi-client /local/onyx_openapi_client /app/generated/onyx_openap
|
||||
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
ENTRYPOINT ["pytest", "-s", "-rs"]
|
||||
ENTRYPOINT ["pytest", "-s"]
|
||||
CMD ["/app/tests/integration", "--ignore=/app/tests/integration/multitenant_tests"]
|
||||
|
||||
@@ -4,12 +4,10 @@ from uuid import uuid4
|
||||
import requests
|
||||
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.server.manage.llm.models import DefaultModel
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -34,6 +32,7 @@ class LLMProviderManager:
|
||||
llm_provider = LLMProviderUpsertRequest(
|
||||
name=name or f"test-provider-{uuid4()}",
|
||||
provider=provider or LlmProviderNames.OPENAI,
|
||||
default_model_name=default_model_name or "gpt-4o-mini",
|
||||
api_key=api_key or os.environ["OPENAI_API_KEY"],
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
@@ -66,6 +65,7 @@ class LLMProviderManager:
|
||||
name=response_data["name"],
|
||||
provider=response_data["provider"],
|
||||
api_key=response_data["api_key"],
|
||||
default_model_name=response_data["default_model_name"],
|
||||
is_public=response_data["is_public"],
|
||||
is_auto_mode=response_data.get("is_auto_mode", False),
|
||||
groups=response_data["groups"],
|
||||
@@ -75,20 +75,9 @@ class LLMProviderManager:
|
||||
)
|
||||
|
||||
if set_as_default:
|
||||
if default_model_name is None:
|
||||
default_model_name = "gpt-4o-mini"
|
||||
|
||||
set_default_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/llm/default",
|
||||
json={
|
||||
"provider_id": response_data["id"],
|
||||
"model_name": default_model_name,
|
||||
},
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{llm_response.json()['id']}/default",
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
set_default_response.raise_for_status()
|
||||
|
||||
@@ -124,12 +113,7 @@ class LLMProviderManager:
|
||||
verify_deleted: bool = False,
|
||||
) -> None:
|
||||
all_llm_providers = LLMProviderManager.get_all(user_performing_action)
|
||||
default_model = LLMProviderManager.get_default_model(user_performing_action)
|
||||
for fetched_llm_provider in all_llm_providers:
|
||||
model_names = [
|
||||
model.name for model in fetched_llm_provider.model_configurations
|
||||
]
|
||||
|
||||
if llm_provider.id == fetched_llm_provider.id:
|
||||
if verify_deleted:
|
||||
raise ValueError(
|
||||
@@ -142,25 +126,11 @@ class LLMProviderManager:
|
||||
if (
|
||||
fetched_llm_groups == llm_provider_groups
|
||||
and llm_provider.provider == fetched_llm_provider.provider
|
||||
and default_model.model_name in model_names
|
||||
and llm_provider.default_model_name
|
||||
== fetched_llm_provider.default_model_name
|
||||
and llm_provider.is_public == fetched_llm_provider.is_public
|
||||
and set(fetched_llm_provider.personas) == set(llm_provider.personas)
|
||||
):
|
||||
return
|
||||
if not verify_deleted:
|
||||
raise ValueError(f"LLM Provider {llm_provider.id} not found")
|
||||
|
||||
@staticmethod
|
||||
def get_default_model(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DefaultModel:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/llm/default",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
return DefaultModel(**response.json())
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from urllib.parse import urlencode
|
||||
from uuid import UUID
|
||||
@@ -9,10 +8,8 @@ from requests.models import CaseInsensitiveDict
|
||||
from ee.onyx.server.query_history.models import ChatSessionMinimal
|
||||
from ee.onyx.server.query_history.models import ChatSessionSnapshot
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import MAX_DELAY
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@@ -72,42 +69,9 @@ class QueryHistoryManager:
|
||||
if end_time:
|
||||
query_params["end"] = end_time.isoformat()
|
||||
|
||||
start_response = requests.post(
|
||||
url=f"{API_SERVER_URL}/admin/query-history/start-export?{urlencode(query_params, doseq=True)}",
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/admin/query-history-csv?{urlencode(query_params, doseq=True)}",
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
start_response.raise_for_status()
|
||||
request_id = start_response.json()["request_id"]
|
||||
|
||||
deadline = time.time() + MAX_DELAY
|
||||
while time.time() < deadline:
|
||||
status_response = requests.get(
|
||||
url=f"{API_SERVER_URL}/admin/query-history/export-status",
|
||||
params={"request_id": request_id},
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
status_response.raise_for_status()
|
||||
status = status_response.json()["status"]
|
||||
if status == TaskStatus.SUCCESS:
|
||||
break
|
||||
if status == TaskStatus.FAILURE:
|
||||
raise RuntimeError("Query history export task failed")
|
||||
time.sleep(2)
|
||||
else:
|
||||
raise TimeoutError(
|
||||
f"Query history export not completed within {MAX_DELAY} seconds"
|
||||
)
|
||||
|
||||
download_response = requests.get(
|
||||
url=f"{API_SERVER_URL}/admin/query-history/download",
|
||||
params={"request_id": request_id},
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
download_response.raise_for_status()
|
||||
|
||||
if not download_response.content:
|
||||
raise RuntimeError(
|
||||
"Query history CSV download returned zero-length content"
|
||||
)
|
||||
|
||||
return download_response.headers, download_response.content.decode()
|
||||
response.raise_for_status()
|
||||
return response.headers, response.content.decode()
|
||||
|
||||
@@ -116,6 +116,7 @@ class DATestLLMProvider(BaseModel):
|
||||
name: str
|
||||
provider: str
|
||||
api_key: str
|
||||
default_model_name: str
|
||||
is_public: bool
|
||||
is_auto_mode: bool = False
|
||||
groups: list[int]
|
||||
|
||||
@@ -6,26 +6,16 @@ import pytest
|
||||
from onyx.connectors.slack.models import ChannelType
|
||||
from tests.integration.connector_job_tests.slack.slack_api_utils import SlackManager
|
||||
|
||||
SLACK_ADMIN_EMAIL = os.environ.get("SLACK_ADMIN_EMAIL", "evan@onyx.app")
|
||||
SLACK_TEST_USER_1_EMAIL = os.environ.get("SLACK_TEST_USER_1_EMAIL", "evan+1@onyx.app")
|
||||
SLACK_TEST_USER_2_EMAIL = os.environ.get("SLACK_TEST_USER_2_EMAIL", "justin@onyx.app")
|
||||
# from tests.load_env_vars import load_env_vars
|
||||
|
||||
# load_env_vars()
|
||||
|
||||
|
||||
def _provision_slack_channels(
|
||||
bot_token: str,
|
||||
) -> Generator[tuple[ChannelType, ChannelType], None, None]:
|
||||
slack_client = SlackManager.get_slack_client(bot_token)
|
||||
|
||||
auth_info = slack_client.auth_test()
|
||||
print(f"\nSlack workspace: {auth_info.get('team')} ({auth_info.get('url')})")
|
||||
|
||||
@pytest.fixture()
|
||||
def slack_test_setup() -> Generator[tuple[ChannelType, ChannelType], None, None]:
|
||||
slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"])
|
||||
user_map = SlackManager.build_slack_user_email_id_map(slack_client)
|
||||
if SLACK_ADMIN_EMAIL not in user_map:
|
||||
raise KeyError(
|
||||
f"'{SLACK_ADMIN_EMAIL}' not found in Slack workspace. "
|
||||
f"Available emails: {sorted(user_map.keys())}"
|
||||
)
|
||||
admin_user_id = user_map[SLACK_ADMIN_EMAIL]
|
||||
admin_user_id = user_map["admin@example.com"]
|
||||
|
||||
(
|
||||
public_channel,
|
||||
@@ -37,16 +27,5 @@ def _provision_slack_channels(
|
||||
|
||||
yield public_channel, private_channel
|
||||
|
||||
# This part will always run after the test, even if it fails
|
||||
SlackManager.cleanup_after_test(slack_client=slack_client, test_id=run_id)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def slack_test_setup() -> Generator[tuple[ChannelType, ChannelType], None, None]:
|
||||
yield from _provision_slack_channels(os.environ["SLACK_BOT_TOKEN"])
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def slack_perm_sync_test_setup() -> (
|
||||
Generator[tuple[ChannelType, ChannelType], None, None]
|
||||
):
|
||||
yield from _provision_slack_channels(os.environ["SLACK_BOT_TOKEN_TEST_SPACE"])
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user