mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-03 14:45:46 +00:00
Compare commits
109 Commits
experiment
...
tokenizer
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2d9ccd8bc9 | ||
|
|
24ac8b37d3 | ||
|
|
be8b108ae4 | ||
|
|
f380a75df3 | ||
|
|
21ec93663b | ||
|
|
d789c74024 | ||
|
|
fe014776f7 | ||
|
|
700ca0e0fc | ||
|
|
a84f8238ec | ||
|
|
4fc802e19d | ||
|
|
6cfd49439a | ||
|
|
71a1faa47e | ||
|
|
1a65217baf | ||
|
|
30fa43b5fc | ||
|
|
28332fa24b | ||
|
|
1f5050f9f6 | ||
|
|
3c1d29d3cf | ||
|
|
709e3f4ca7 | ||
|
|
dfa27c08ef | ||
|
|
13d60dcb0e | ||
|
|
30704f427f | ||
|
|
4f3c54f282 | ||
|
|
580d41dc23 | ||
|
|
897e181d67 | ||
|
|
fd322a8a10 | ||
|
|
11c54bafb5 | ||
|
|
c93617df5d | ||
|
|
0cdd438f46 | ||
|
|
31aef36f78 | ||
|
|
0c35dfc0e4 | ||
|
|
a9769757fe | ||
|
|
15d8946f40 | ||
|
|
ba79539d6d | ||
|
|
59d3725fc6 | ||
|
|
9c05bd215d | ||
|
|
4d2aa09654 | ||
|
|
16c07c8756 | ||
|
|
3fb4f5d6e6 | ||
|
|
14fab7fcdf | ||
|
|
22a335fffa | ||
|
|
b0f7466eba | ||
|
|
b1d42726b1 | ||
|
|
7d922bffc1 | ||
|
|
de7fc36fc5 | ||
|
|
7f9e37450d | ||
|
|
c7ef85b733 | ||
|
|
bd9319e592 | ||
|
|
db5955d6f2 | ||
|
|
5e447440ea | ||
|
|
78c6ca39b8 | ||
|
|
71a7cf09b3 | ||
|
|
91d30a0156 | ||
|
|
7b30752767 | ||
|
|
4450ecf07c | ||
|
|
0e6b766996 | ||
|
|
12c8cd338b | ||
|
|
ad5688bf65 | ||
|
|
d2deefd1f1 | ||
|
|
18b90d405d | ||
|
|
8394e8837b | ||
|
|
f06df891c4 | ||
|
|
d6d5e72c18 | ||
|
|
449f5d62f9 | ||
|
|
4d256c5666 | ||
|
|
2e53496f46 | ||
|
|
63a206706a | ||
|
|
28427b3e5f | ||
|
|
3cafcd8a5e | ||
|
|
f2c50b7bb5 | ||
|
|
6b28c6bbfc | ||
|
|
226e801665 | ||
|
|
be13aa1310 | ||
|
|
45d38c4906 | ||
|
|
8aab518532 | ||
|
|
da6ce10e86 | ||
|
|
aaf8253520 | ||
|
|
7c7f81b164 | ||
|
|
2d4a3c72e9 | ||
|
|
7c51712018 | ||
|
|
aa5614695d | ||
|
|
8d7255d3c4 | ||
|
|
d403498f48 | ||
|
|
9ef3095c17 | ||
|
|
a39e93a0cb | ||
|
|
46d73cdfee | ||
|
|
1e04ce78e0 | ||
|
|
f9b81c1725 | ||
|
|
3bc1b89fee | ||
|
|
01743d99d4 | ||
|
|
092c1db7e0 | ||
|
|
40ac0d859a | ||
|
|
929e58361f | ||
|
|
6d472df7c5 | ||
|
|
cfa7acd904 | ||
|
|
5c5a6f943b | ||
|
|
d04128b8b1 | ||
|
|
bbebdf8f78 | ||
|
|
161279a2d5 | ||
|
|
e5ebb45a20 | ||
|
|
320ba9cb1b | ||
|
|
f2e8cb3114 | ||
|
|
43054a28ec | ||
|
|
dc74aa7b1f | ||
|
|
bd773191c2 | ||
|
|
66dbff41e6 | ||
|
|
1dcffe38bc | ||
|
|
c35e883564 | ||
|
|
fefcd58481 | ||
|
|
bdc89d9e3f |
73
.github/actions/build-backend-image/action.yml
vendored
Normal file
73
.github/actions/build-backend-image/action.yml
vendored
Normal file
@@ -0,0 +1,73 @@
|
||||
name: "Build Backend Image"
|
||||
description: "Builds and pushes the backend Docker image with cache reuse"
|
||||
inputs:
|
||||
runs-on-ecr-cache:
|
||||
description: "ECR cache registry from runs-on/action"
|
||||
required: true
|
||||
ref-name:
|
||||
description: "Git ref name used for cache suffix fallback"
|
||||
required: true
|
||||
pr-number:
|
||||
description: "Optional PR number for cache suffix"
|
||||
required: false
|
||||
default: ""
|
||||
github-sha:
|
||||
description: "Commit SHA used for cache keys"
|
||||
required: true
|
||||
run-id:
|
||||
description: "GitHub run ID used in output image tag"
|
||||
required: true
|
||||
docker-username:
|
||||
description: "Docker Hub username"
|
||||
required: true
|
||||
docker-token:
|
||||
description: "Docker Hub token"
|
||||
required: true
|
||||
docker-no-cache:
|
||||
description: "Set to 'true' to disable docker build cache"
|
||||
required: false
|
||||
default: "false"
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
shell: bash
|
||||
env:
|
||||
PR_NUMBER: ${{ inputs.pr-number }}
|
||||
REF_NAME: ${{ inputs.ref-name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ inputs.docker-username }}
|
||||
password: ${{ inputs.docker-token }}
|
||||
|
||||
- name: Build and push Backend Docker image
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
push: true
|
||||
tags: ${{ inputs.runs-on-ecr-cache }}:nightly-llm-it-backend-${{ inputs.run-id }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ inputs.github-sha }}
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }}
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache
|
||||
type=registry,ref=onyxdotapp/onyx-backend:latest
|
||||
cache-to: |
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ inputs.github-sha }},mode=max
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache,mode=max
|
||||
no-cache: ${{ inputs.docker-no-cache == 'true' }}
|
||||
76
.github/actions/build-integration-image/action.yml
vendored
Normal file
76
.github/actions/build-integration-image/action.yml
vendored
Normal file
@@ -0,0 +1,76 @@
|
||||
name: "Build Integration Image"
|
||||
description: "Builds and pushes the integration test image with docker bake"
|
||||
inputs:
|
||||
runs-on-ecr-cache:
|
||||
description: "ECR cache registry from runs-on/action"
|
||||
required: true
|
||||
ref-name:
|
||||
description: "Git ref name used for cache suffix fallback"
|
||||
required: true
|
||||
pr-number:
|
||||
description: "Optional PR number for cache suffix"
|
||||
required: false
|
||||
default: ""
|
||||
github-sha:
|
||||
description: "Commit SHA used for cache keys"
|
||||
required: true
|
||||
run-id:
|
||||
description: "GitHub run ID used in output image tag"
|
||||
required: true
|
||||
docker-username:
|
||||
description: "Docker Hub username"
|
||||
required: true
|
||||
docker-token:
|
||||
description: "Docker Hub token"
|
||||
required: true
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ inputs.docker-username }}
|
||||
password: ${{ inputs.docker-token }}
|
||||
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
shell: bash
|
||||
env:
|
||||
PR_NUMBER: ${{ inputs.pr-number }}
|
||||
REF_NAME: ${{ inputs.ref-name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Build and push integration test image with Docker Bake
|
||||
shell: bash
|
||||
env:
|
||||
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
|
||||
INTEGRATION_REPOSITORY: ${{ inputs.runs-on-ecr-cache }}
|
||||
TAG: nightly-llm-it-${{ inputs.run-id }}
|
||||
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
|
||||
HEAD_SHA: ${{ inputs.github-sha }}
|
||||
run: |
|
||||
docker buildx bake --push \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA} \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX} \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
|
||||
--set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA},mode=max \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX},mode=max \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA} \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX} \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA},mode=max \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX},mode=max \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \
|
||||
integration
|
||||
68
.github/actions/build-model-server-image/action.yml
vendored
Normal file
68
.github/actions/build-model-server-image/action.yml
vendored
Normal file
@@ -0,0 +1,68 @@
|
||||
name: "Build Model Server Image"
|
||||
description: "Builds and pushes the model server Docker image with cache reuse"
|
||||
inputs:
|
||||
runs-on-ecr-cache:
|
||||
description: "ECR cache registry from runs-on/action"
|
||||
required: true
|
||||
ref-name:
|
||||
description: "Git ref name used for cache suffix fallback"
|
||||
required: true
|
||||
pr-number:
|
||||
description: "Optional PR number for cache suffix"
|
||||
required: false
|
||||
default: ""
|
||||
github-sha:
|
||||
description: "Commit SHA used for cache keys"
|
||||
required: true
|
||||
run-id:
|
||||
description: "GitHub run ID used in output image tag"
|
||||
required: true
|
||||
docker-username:
|
||||
description: "Docker Hub username"
|
||||
required: true
|
||||
docker-token:
|
||||
description: "Docker Hub token"
|
||||
required: true
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
shell: bash
|
||||
env:
|
||||
PR_NUMBER: ${{ inputs.pr-number }}
|
||||
REF_NAME: ${{ inputs.ref-name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ inputs.docker-username }}
|
||||
password: ${{ inputs.docker-token }}
|
||||
|
||||
- name: Build and push Model Server Docker image
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
push: true
|
||||
tags: ${{ inputs.runs-on-ecr-cache }}:nightly-llm-it-model-server-${{ inputs.run-id }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ inputs.github-sha }}
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }}
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache
|
||||
type=registry,ref=onyxdotapp/onyx-model-server:latest
|
||||
cache-to: |
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ inputs.github-sha }},mode=max
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache,mode=max
|
||||
130
.github/actions/run-nightly-provider-chat-test/action.yml
vendored
Normal file
130
.github/actions/run-nightly-provider-chat-test/action.yml
vendored
Normal file
@@ -0,0 +1,130 @@
|
||||
name: "Run Nightly Provider Chat Test"
|
||||
description: "Starts required compose services and runs nightly provider integration test"
|
||||
inputs:
|
||||
provider:
|
||||
description: "Provider slug for NIGHTLY_LLM_PROVIDER"
|
||||
required: true
|
||||
models:
|
||||
description: "Comma-separated model list for NIGHTLY_LLM_MODELS"
|
||||
required: true
|
||||
provider-api-key:
|
||||
description: "API key for NIGHTLY_LLM_API_KEY"
|
||||
required: false
|
||||
default: ""
|
||||
strict:
|
||||
description: "String true/false for NIGHTLY_LLM_STRICT"
|
||||
required: true
|
||||
api-base:
|
||||
description: "Optional NIGHTLY_LLM_API_BASE"
|
||||
required: false
|
||||
default: ""
|
||||
api-version:
|
||||
description: "Optional NIGHTLY_LLM_API_VERSION"
|
||||
required: false
|
||||
default: ""
|
||||
deployment-name:
|
||||
description: "Optional NIGHTLY_LLM_DEPLOYMENT_NAME"
|
||||
required: false
|
||||
default: ""
|
||||
custom-config-json:
|
||||
description: "Optional NIGHTLY_LLM_CUSTOM_CONFIG_JSON"
|
||||
required: false
|
||||
default: ""
|
||||
runs-on-ecr-cache:
|
||||
description: "ECR cache registry from runs-on/action"
|
||||
required: true
|
||||
run-id:
|
||||
description: "GitHub run ID used in image tags"
|
||||
required: true
|
||||
docker-username:
|
||||
description: "Docker Hub username"
|
||||
required: true
|
||||
docker-token:
|
||||
description: "Docker Hub token"
|
||||
required: true
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ inputs.docker-username }}
|
||||
password: ${{ inputs.docker-token }}
|
||||
|
||||
- name: Create .env file for Docker Compose
|
||||
shell: bash
|
||||
env:
|
||||
ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
|
||||
RUN_ID: ${{ inputs.run-id }}
|
||||
run: |
|
||||
cat <<EOF2 > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
|
||||
AWS_REGION_NAME=us-west-2
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:nightly-llm-it-backend-${RUN_ID}
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:nightly-llm-it-model-server-${RUN_ID}
|
||||
EOF2
|
||||
|
||||
- name: Start Docker containers
|
||||
shell: bash
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d --wait \
|
||||
relational_db \
|
||||
index \
|
||||
cache \
|
||||
minio \
|
||||
api_server \
|
||||
inference_model_server
|
||||
|
||||
- name: Run nightly provider integration test
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
env:
|
||||
MODELS: ${{ inputs.models }}
|
||||
NIGHTLY_LLM_PROVIDER: ${{ inputs.provider }}
|
||||
NIGHTLY_LLM_API_KEY: ${{ inputs.provider-api-key }}
|
||||
NIGHTLY_LLM_API_BASE: ${{ inputs.api-base }}
|
||||
NIGHTLY_LLM_API_VERSION: ${{ inputs.api-version }}
|
||||
NIGHTLY_LLM_DEPLOYMENT_NAME: ${{ inputs.deployment-name }}
|
||||
NIGHTLY_LLM_CUSTOM_CONFIG_JSON: ${{ inputs.custom-config-json }}
|
||||
NIGHTLY_LLM_STRICT: ${{ inputs.strict }}
|
||||
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
|
||||
RUN_ID: ${{ inputs.run-id }}
|
||||
with:
|
||||
timeout_minutes: 20
|
||||
max_attempts: 2
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e AWS_REGION_NAME=us-west-2 \
|
||||
-e NIGHTLY_LLM_PROVIDER="${NIGHTLY_LLM_PROVIDER}" \
|
||||
-e NIGHTLY_LLM_MODELS="${MODELS}" \
|
||||
-e NIGHTLY_LLM_API_KEY="${NIGHTLY_LLM_API_KEY}" \
|
||||
-e NIGHTLY_LLM_API_BASE="${NIGHTLY_LLM_API_BASE}" \
|
||||
-e NIGHTLY_LLM_API_VERSION="${NIGHTLY_LLM_API_VERSION}" \
|
||||
-e NIGHTLY_LLM_DEPLOYMENT_NAME="${NIGHTLY_LLM_DEPLOYMENT_NAME}" \
|
||||
-e NIGHTLY_LLM_CUSTOM_CONFIG_JSON="${NIGHTLY_LLM_CUSTOM_CONFIG_JSON}" \
|
||||
-e NIGHTLY_LLM_STRICT="${NIGHTLY_LLM_STRICT}" \
|
||||
${RUNS_ON_ECR_CACHE}:nightly-llm-it-${RUN_ID} \
|
||||
/app/tests/integration/tests/llm_workflows/test_nightly_provider_chat_workflow.py
|
||||
26
.github/workflows/deployment.yml
vendored
26
.github/workflows/deployment.yml
vendored
@@ -426,8 +426,9 @@ jobs:
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64,mode=max
|
||||
@@ -499,8 +500,9 @@ jobs:
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64,mode=max
|
||||
@@ -646,8 +648,8 @@ jobs:
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64,mode=max
|
||||
@@ -728,8 +730,8 @@ jobs:
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64,mode=max
|
||||
@@ -862,8 +864,9 @@ jobs:
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64,mode=max
|
||||
@@ -934,8 +937,9 @@ jobs:
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64,mode=max
|
||||
@@ -1072,8 +1076,8 @@ jobs:
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
ENABLE_CRAFT=true
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64,mode=max
|
||||
@@ -1145,8 +1149,8 @@ jobs:
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
ENABLE_CRAFT=true
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64,mode=max
|
||||
@@ -1287,8 +1291,9 @@ jobs:
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64,mode=max
|
||||
@@ -1366,8 +1371,9 @@ jobs:
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64,mode=max
|
||||
|
||||
49
.github/workflows/nightly-llm-provider-chat.yml
vendored
Normal file
49
.github/workflows/nightly-llm-provider-chat.yml
vendored
Normal file
@@ -0,0 +1,49 @@
|
||||
name: Nightly LLM Provider Chat Tests
|
||||
concurrency:
|
||||
group: Nightly-LLM-Provider-Chat-${{ github.workflow }}-${{ github.ref_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
schedule:
|
||||
# Runs daily at 10:30 UTC (2:30 AM PST / 3:30 AM PDT)
|
||||
- cron: "30 10 * * *"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
provider-chat-test:
|
||||
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
with:
|
||||
openai_models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
|
||||
anthropic_models: ${{ vars.NIGHTLY_LLM_ANTHROPIC_MODELS }}
|
||||
bedrock_models: ${{ vars.NIGHTLY_LLM_BEDROCK_MODELS }}
|
||||
vertex_ai_models: ${{ vars.NIGHTLY_LLM_VERTEX_AI_MODELS }}
|
||||
azure_models: ${{ vars.NIGHTLY_LLM_AZURE_MODELS }}
|
||||
azure_api_base: ${{ vars.NIGHTLY_LLM_AZURE_API_BASE }}
|
||||
ollama_models: ${{ vars.NIGHTLY_LLM_OLLAMA_MODELS }}
|
||||
openrouter_models: ${{ vars.NIGHTLY_LLM_OPENROUTER_MODELS }}
|
||||
strict: true
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs: [provider-chat-test]
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 5
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.SLACK_WEBHOOK }}
|
||||
failed-jobs: provider-chat-test
|
||||
title: "🚨 Scheduled LLM Provider Chat Tests failed!"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
@@ -114,8 +114,10 @@ jobs:
|
||||
|
||||
- name: Mark workflow as failed if cherry-pick failed
|
||||
if: steps.gate.outputs.should_cherrypick == 'true' && steps.run_cherry_pick.outputs.status == 'failure'
|
||||
env:
|
||||
CHERRY_PICK_REASON: ${{ steps.run_cherry_pick.outputs.reason }}
|
||||
run: |
|
||||
echo "::error::Automated cherry-pick failed (${{ steps.run_cherry_pick.outputs.reason }})."
|
||||
echo "::error::Automated cherry-pick failed (${CHERRY_PICK_REASON})."
|
||||
exit 1
|
||||
|
||||
notify-slack-on-cherry-pick-failure:
|
||||
|
||||
@@ -160,7 +160,7 @@ jobs:
|
||||
cd deployment/docker_compose
|
||||
|
||||
# Get list of running containers
|
||||
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml -f docker-compose.opensearch.yml ps -q)
|
||||
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml ps -q)
|
||||
|
||||
# Collect logs from each container
|
||||
for container in $containers; do
|
||||
|
||||
2
.github/workflows/pr-playwright-tests.yml
vendored
2
.github/workflows/pr-playwright-tests.yml
vendored
@@ -603,7 +603,7 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Download visual diff summaries
|
||||
uses: actions/download-artifact@95815c38cf2ff2164869cbab79da8d1f422bc89e # ratchet:actions/download-artifact@v4
|
||||
uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131
|
||||
with:
|
||||
pattern: screenshot-diff-summary-*
|
||||
path: summaries/
|
||||
|
||||
21
.github/workflows/pr-python-checks.yml
vendored
21
.github/workflows/pr-python-checks.yml
vendored
@@ -8,7 +8,7 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
- "release/**"
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
@@ -21,7 +21,13 @@ jobs:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
# Note: Mypy seems quite optimized for x64 compared to arm64.
|
||||
# Similarly, mypy is single-threaded and incremental, so 2cpu is sufficient.
|
||||
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}-mypy-check", "extras=s3-cache"]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=2cpu-linux-x64,
|
||||
"run-id=${{ github.run_id }}-mypy-check",
|
||||
"extras=s3-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
|
||||
steps:
|
||||
@@ -52,21 +58,14 @@ jobs:
|
||||
if: ${{ vars.DISABLE_MYPY_CACHE != 'true' }}
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: backend/.mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
path: .mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'pyproject.toml') }}
|
||||
restore-keys: |
|
||||
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
|
||||
mypy-${{ runner.os }}-
|
||||
|
||||
- name: Run MyPy
|
||||
working-directory: ./backend
|
||||
env:
|
||||
MYPY_FORCE_COLOR: 1
|
||||
TERM: xterm-256color
|
||||
run: mypy .
|
||||
|
||||
- name: Run MyPy (tools/)
|
||||
env:
|
||||
MYPY_FORCE_COLOR: 1
|
||||
TERM: xterm-256color
|
||||
run: mypy tools/
|
||||
|
||||
@@ -89,6 +89,10 @@ env:
|
||||
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
|
||||
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ vars.SHAREPOINT_CLIENT_DIRECTORY_ID }}
|
||||
SHAREPOINT_SITE: ${{ vars.SHAREPOINT_SITE }}
|
||||
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
|
||||
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
|
||||
|
||||
# Github
|
||||
ACCESS_TOKEN_GITHUB: ${{ secrets.ACCESS_TOKEN_GITHUB }}
|
||||
|
||||
329
.github/workflows/reusable-nightly-llm-provider-chat.yml
vendored
Normal file
329
.github/workflows/reusable-nightly-llm-provider-chat.yml
vendored
Normal file
@@ -0,0 +1,329 @@
|
||||
name: Reusable Nightly LLM Provider Chat Tests
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
openai_models:
|
||||
description: "Comma-separated models for openai"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
anthropic_models:
|
||||
description: "Comma-separated models for anthropic"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
bedrock_models:
|
||||
description: "Comma-separated models for bedrock"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
vertex_ai_models:
|
||||
description: "Comma-separated models for vertex_ai"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
azure_models:
|
||||
description: "Comma-separated models for azure"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
ollama_models:
|
||||
description: "Comma-separated models for ollama_chat"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
openrouter_models:
|
||||
description: "Comma-separated models for openrouter"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
azure_api_base:
|
||||
description: "API base for azure provider"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
strict:
|
||||
description: "Default NIGHTLY_LLM_STRICT passed to tests"
|
||||
required: false
|
||||
default: true
|
||||
type: boolean
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
|
||||
jobs:
|
||||
build-backend-image:
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-backend-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
environment: ci-protected
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, test/docker-username
|
||||
DOCKER_TOKEN, test/docker-token
|
||||
|
||||
- name: Build backend image
|
||||
uses: ./.github/actions/build-backend-image
|
||||
with:
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
ref-name: ${{ github.ref_name }}
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ env.DOCKER_USERNAME }}
|
||||
docker-token: ${{ env.DOCKER_TOKEN }}
|
||||
docker-no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' && 'true' || 'false' }}
|
||||
|
||||
build-model-server-image:
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-model-server-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
environment: ci-protected
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, test/docker-username
|
||||
DOCKER_TOKEN, test/docker-token
|
||||
|
||||
- name: Build model server image
|
||||
uses: ./.github/actions/build-model-server-image
|
||||
with:
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
ref-name: ${{ github.ref_name }}
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ env.DOCKER_USERNAME }}
|
||||
docker-token: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
build-integration-image:
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=2cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-integration-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
environment: ci-protected
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, test/docker-username
|
||||
DOCKER_TOKEN, test/docker-token
|
||||
|
||||
- name: Build integration image
|
||||
uses: ./.github/actions/build-integration-image
|
||||
with:
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
ref-name: ${{ github.ref_name }}
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ env.DOCKER_USERNAME }}
|
||||
docker-token: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
provider-chat-test:
|
||||
needs:
|
||||
[
|
||||
build-backend-image,
|
||||
build-model-server-image,
|
||||
build-integration-image,
|
||||
]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- provider: openai
|
||||
models: ${{ inputs.openai_models }}
|
||||
api_key_env: OPENAI_API_KEY
|
||||
custom_config_env: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: true
|
||||
- provider: anthropic
|
||||
models: ${{ inputs.anthropic_models }}
|
||||
api_key_env: ANTHROPIC_API_KEY
|
||||
custom_config_env: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: true
|
||||
- provider: bedrock
|
||||
models: ${{ inputs.bedrock_models }}
|
||||
api_key_env: BEDROCK_API_KEY
|
||||
custom_config_env: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: vertex_ai
|
||||
models: ${{ inputs.vertex_ai_models }}
|
||||
api_key_env: ""
|
||||
custom_config_env: NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: azure
|
||||
models: ${{ inputs.azure_models }}
|
||||
api_key_env: AZURE_API_KEY
|
||||
custom_config_env: ""
|
||||
api_base: ${{ inputs.azure_api_base }}
|
||||
api_version: "2025-04-01-preview"
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: ollama_chat
|
||||
models: ${{ inputs.ollama_models }}
|
||||
api_key_env: OLLAMA_API_KEY
|
||||
custom_config_env: ""
|
||||
api_base: "https://ollama.com"
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: openrouter
|
||||
models: ${{ inputs.openrouter_models }}
|
||||
api_key_env: OPENROUTER_API_KEY
|
||||
custom_config_env: ""
|
||||
api_base: "https://openrouter.ai/api/v1"
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- "run-id=${{ github.run_id }}-nightly-${{ matrix.provider }}-provider-chat-test"
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 45
|
||||
environment: ci-protected
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
# Keep JSON values unparsed so vertex custom config is passed as raw JSON.
|
||||
parse-json-secrets: false
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, test/docker-username
|
||||
DOCKER_TOKEN, test/docker-token
|
||||
OPENAI_API_KEY, test/openai-api-key
|
||||
ANTHROPIC_API_KEY, test/anthropic-api-key
|
||||
BEDROCK_API_KEY, test/bedrock-api-key
|
||||
NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON, test/nightly-llm-vertex-ai-custom-config-json
|
||||
AZURE_API_KEY, test/azure-api-key
|
||||
OLLAMA_API_KEY, test/ollama-api-key
|
||||
OPENROUTER_API_KEY, test/openrouter-api-key
|
||||
|
||||
- name: Run nightly provider chat test
|
||||
uses: ./.github/actions/run-nightly-provider-chat-test
|
||||
with:
|
||||
provider: ${{ matrix.provider }}
|
||||
models: ${{ matrix.models }}
|
||||
provider-api-key: ${{ matrix.api_key_env && env[matrix.api_key_env] || '' }}
|
||||
strict: ${{ inputs.strict && 'true' || 'false' }}
|
||||
api-base: ${{ matrix.api_base }}
|
||||
api-version: ${{ matrix.api_version }}
|
||||
deployment-name: ${{ matrix.deployment_name }}
|
||||
custom-config-json: ${{ matrix.custom_config_env && env[matrix.custom_config_env] || '' }}
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ env.DOCKER_USERNAME }}
|
||||
docker-token: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Dump API server logs
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
|
||||
- name: Dump all-container logs
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-nightly-${{ matrix.provider }}-llm-provider
|
||||
path: |
|
||||
${{ github.workspace }}/api_server.log
|
||||
${{ github.workspace }}/docker-compose.log
|
||||
|
||||
- name: Stop Docker containers
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose down -v
|
||||
@@ -0,0 +1,51 @@
|
||||
"""Add INDEXING to UserFileStatus
|
||||
|
||||
Revision ID: 4a1e4b1c89d2
|
||||
Revises: 6b3b4083c5aa
|
||||
Create Date: 2026-02-28 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = "4a1e4b1c89d2"
|
||||
down_revision = "6b3b4083c5aa"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
TABLE = "user_file"
|
||||
COLUMN = "status"
|
||||
CONSTRAINT_NAME = "ck_user_file_status"
|
||||
|
||||
OLD_VALUES = ("PROCESSING", "COMPLETED", "FAILED", "CANCELED", "DELETING")
|
||||
NEW_VALUES = ("PROCESSING", "INDEXING", "COMPLETED", "FAILED", "CANCELED", "DELETING")
|
||||
|
||||
|
||||
def _drop_status_check_constraint() -> None:
|
||||
"""Drop the existing CHECK constraint on user_file.status.
|
||||
|
||||
The constraint name is auto-generated by SQLAlchemy and unknown,
|
||||
so we look it up via the inspector.
|
||||
"""
|
||||
inspector = sa.inspect(op.get_bind())
|
||||
for constraint in inspector.get_check_constraints(TABLE):
|
||||
if COLUMN in constraint.get("sqltext", ""):
|
||||
constraint_name = constraint["name"]
|
||||
if constraint_name is not None:
|
||||
op.drop_constraint(constraint_name, TABLE, type_="check")
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
_drop_status_check_constraint()
|
||||
in_clause = ", ".join(f"'{v}'" for v in NEW_VALUES)
|
||||
op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
f"UPDATE {TABLE} SET {COLUMN} = 'PROCESSING' WHERE {COLUMN} = 'INDEXING'"
|
||||
)
|
||||
op.drop_constraint(CONSTRAINT_NAME, TABLE, type_="check")
|
||||
in_clause = ", ".join(f"'{v}'" for v in OLD_VALUES)
|
||||
op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})")
|
||||
@@ -0,0 +1,69 @@
|
||||
"""add python tool on default
|
||||
|
||||
Revision ID: 57122d037335
|
||||
Revises: c0c937d5c9e5
|
||||
Create Date: 2026-02-27 10:10:40.124925
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "57122d037335"
|
||||
down_revision = "c0c937d5c9e5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
PYTHON_TOOL_NAME = "python"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Look up the PythonTool id
|
||||
result = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE name = :name"),
|
||||
{"name": PYTHON_TOOL_NAME},
|
||||
).fetchone()
|
||||
|
||||
if not result:
|
||||
return
|
||||
|
||||
tool_id = result[0]
|
||||
|
||||
# Attach to the default persona (id=0) if not already attached
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": tool_id},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
result = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE name = :name"),
|
||||
{"name": PYTHON_TOOL_NAME},
|
||||
).fetchone()
|
||||
|
||||
if not result:
|
||||
return
|
||||
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM persona__tool
|
||||
WHERE persona_id = 0 AND tool_id = :tool_id
|
||||
"""
|
||||
),
|
||||
{"tool_id": result[0]},
|
||||
)
|
||||
@@ -0,0 +1,112 @@
|
||||
"""persona cleanup and featured
|
||||
|
||||
Revision ID: 6b3b4083c5aa
|
||||
Revises: 57122d037335
|
||||
Create Date: 2026-02-26 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "6b3b4083c5aa"
|
||||
down_revision = "57122d037335"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add featured column with nullable=True first
|
||||
op.add_column("persona", sa.Column("featured", sa.Boolean(), nullable=True))
|
||||
|
||||
# Migrate data from is_default_persona to featured
|
||||
op.execute("UPDATE persona SET featured = is_default_persona")
|
||||
|
||||
# Make featured non-nullable with default=False
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"featured",
|
||||
existing_type=sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
)
|
||||
|
||||
# Drop is_default_persona column
|
||||
op.drop_column("persona", "is_default_persona")
|
||||
|
||||
# Drop unused columns
|
||||
op.drop_column("persona", "num_chunks")
|
||||
op.drop_column("persona", "chunks_above")
|
||||
op.drop_column("persona", "chunks_below")
|
||||
op.drop_column("persona", "llm_relevance_filter")
|
||||
op.drop_column("persona", "llm_filter_extraction")
|
||||
op.drop_column("persona", "recency_bias")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Add back recency_bias column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"recency_bias",
|
||||
sa.VARCHAR(),
|
||||
nullable=False,
|
||||
server_default="base_decay",
|
||||
),
|
||||
)
|
||||
|
||||
# Add back llm_filter_extraction column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"llm_filter_extraction",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
# Add back llm_relevance_filter column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"llm_relevance_filter",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
# Add back chunks_below column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("chunks_below", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
|
||||
# Add back chunks_above column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("chunks_above", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
|
||||
# Add back num_chunks column
|
||||
op.add_column("persona", sa.Column("num_chunks", sa.Float(), nullable=True))
|
||||
|
||||
# Add back is_default_persona column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"is_default_persona",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
# Migrate data from featured to is_default_persona
|
||||
op.execute("UPDATE persona SET is_default_persona = featured")
|
||||
|
||||
# Drop featured column
|
||||
op.drop_column("persona", "featured")
|
||||
@@ -0,0 +1,33 @@
|
||||
"""add needs_persona_sync to user_file
|
||||
|
||||
Revision ID: 8ffcc2bcfc11
|
||||
Revises: 7616121f6e97
|
||||
Create Date: 2026-02-23 10:48:48.343826
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8ffcc2bcfc11"
|
||||
down_revision = "7616121f6e97"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user_file",
|
||||
sa.Column(
|
||||
"needs_persona_sync",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user_file", "needs_persona_sync")
|
||||
@@ -0,0 +1,70 @@
|
||||
"""llm provider deprecate fields
|
||||
|
||||
Revision ID: c0c937d5c9e5
|
||||
Revises: 8ffcc2bcfc11
|
||||
Create Date: 2026-02-25 17:35:46.125102
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c0c937d5c9e5"
|
||||
down_revision = "8ffcc2bcfc11"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Make default_model_name nullable (was NOT NULL)
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"default_model_name",
|
||||
existing_type=sa.String(),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Drop unique constraint on is_default_provider (defaults now tracked via LLMModelFlow)
|
||||
op.drop_constraint(
|
||||
"llm_provider_is_default_provider_key",
|
||||
"llm_provider",
|
||||
type_="unique",
|
||||
)
|
||||
|
||||
# Remove server_default from is_default_vision_provider (was server_default=false())
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"is_default_vision_provider",
|
||||
existing_type=sa.Boolean(),
|
||||
server_default=None,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Restore default_model_name to NOT NULL (set empty string for any NULLs first)
|
||||
op.execute(
|
||||
"UPDATE llm_provider SET default_model_name = '' WHERE default_model_name IS NULL"
|
||||
)
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"default_model_name",
|
||||
existing_type=sa.String(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Restore unique constraint on is_default_provider
|
||||
op.create_unique_constraint(
|
||||
"llm_provider_is_default_provider_key",
|
||||
"llm_provider",
|
||||
["is_default_provider"],
|
||||
)
|
||||
|
||||
# Restore server_default for is_default_vision_provider
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"is_default_vision_provider",
|
||||
existing_type=sa.Boolean(),
|
||||
server_default=sa.false(),
|
||||
)
|
||||
@@ -34,6 +34,7 @@ from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from ee.onyx.server.scim.filtering import ScimFilter
|
||||
from ee.onyx.server.scim.filtering import ScimFilterOperator
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from onyx.db.dal import DAL
|
||||
from onyx.db.models import ScimGroupMapping
|
||||
from onyx.db.models import ScimToken
|
||||
@@ -128,12 +129,19 @@ class ScimDAL(DAL):
|
||||
external_id: str,
|
||||
user_id: UUID,
|
||||
scim_username: str | None = None,
|
||||
fields: ScimMappingFields | None = None,
|
||||
) -> ScimUserMapping:
|
||||
"""Create a mapping between a SCIM externalId and an Onyx user."""
|
||||
f = fields or ScimMappingFields()
|
||||
mapping = ScimUserMapping(
|
||||
external_id=external_id,
|
||||
user_id=user_id,
|
||||
scim_username=scim_username,
|
||||
department=f.department,
|
||||
manager=f.manager,
|
||||
given_name=f.given_name,
|
||||
family_name=f.family_name,
|
||||
scim_emails_json=f.scim_emails_json,
|
||||
)
|
||||
self._session.add(mapping)
|
||||
self._session.flush()
|
||||
@@ -311,8 +319,14 @@ class ScimDAL(DAL):
|
||||
user_id: UUID,
|
||||
new_external_id: str | None,
|
||||
scim_username: str | None = None,
|
||||
fields: ScimMappingFields | None = None,
|
||||
) -> None:
|
||||
"""Create, update, or delete the external ID mapping for a user."""
|
||||
"""Create, update, or delete the external ID mapping for a user.
|
||||
|
||||
When *fields* is provided, all mapping fields are written
|
||||
unconditionally — including ``None`` values — so that a caller can
|
||||
clear a previously-set field (e.g. removing a department).
|
||||
"""
|
||||
mapping = self.get_user_mapping_by_user_id(user_id)
|
||||
if new_external_id:
|
||||
if mapping:
|
||||
@@ -320,11 +334,18 @@ class ScimDAL(DAL):
|
||||
mapping.external_id = new_external_id
|
||||
if scim_username is not None:
|
||||
mapping.scim_username = scim_username
|
||||
if fields is not None:
|
||||
mapping.department = fields.department
|
||||
mapping.manager = fields.manager
|
||||
mapping.given_name = fields.given_name
|
||||
mapping.family_name = fields.family_name
|
||||
mapping.scim_emails_json = fields.scim_emails_json
|
||||
else:
|
||||
self.create_user_mapping(
|
||||
external_id=new_external_id,
|
||||
user_id=user_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
elif mapping:
|
||||
self.delete_user_mapping(mapping.id)
|
||||
|
||||
@@ -4,7 +4,6 @@ from collections import deque
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from urllib.parse import unquote
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests as _requests
|
||||
@@ -598,8 +597,12 @@ def get_external_access_from_sharepoint(
|
||||
)
|
||||
elif site_page:
|
||||
site_url = site_page.get("webUrl")
|
||||
# Prefer server-relative URL to avoid OData filters that break on apostrophes
|
||||
server_relative_url = unquote(urlparse(site_url).path)
|
||||
# Keep percent-encoding intact so the path matches the encoding
|
||||
# used by the Office365 library's SPResPath.create_relative(),
|
||||
# which compares against urlparse(context.base_url).path.
|
||||
# Decoding (e.g. %27 → ') causes a mismatch that duplicates
|
||||
# the site prefix in the constructed URL.
|
||||
server_relative_url = urlparse(site_url).path
|
||||
file_obj = client_context.web.get_file_by_server_relative_url(
|
||||
server_relative_url
|
||||
)
|
||||
|
||||
@@ -31,6 +31,7 @@ from ee.onyx.server.query_and_chat.query_backend import (
|
||||
from ee.onyx.server.query_and_chat.search_backend import router as search_router
|
||||
from ee.onyx.server.query_history.api import router as query_history_router
|
||||
from ee.onyx.server.reporting.usage_export_api import router as usage_export_router
|
||||
from ee.onyx.server.scim.api import register_scim_exception_handlers
|
||||
from ee.onyx.server.scim.api import scim_router
|
||||
from ee.onyx.server.seeding import seed_db
|
||||
from ee.onyx.server.tenants.api import router as tenants_router
|
||||
@@ -167,6 +168,7 @@ def get_application() -> FastAPI:
|
||||
# they use their own SCIM bearer token auth).
|
||||
# Not behind APP_API_PREFIX because IdPs expect /scim/v2/... directly.
|
||||
application.include_router(scim_router)
|
||||
register_scim_exception_handlers(application)
|
||||
|
||||
# Ensure all routes have auth enabled or are explicitly marked as public
|
||||
check_ee_router_auth(application)
|
||||
|
||||
@@ -15,7 +15,9 @@ from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import FastAPI
|
||||
from fastapi import Query
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi_users.password import PasswordHelper
|
||||
@@ -24,16 +26,17 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.scim import ScimDAL
|
||||
from ee.onyx.server.scim.auth import ScimAuthError
|
||||
from ee.onyx.server.scim.auth import verify_scim_token
|
||||
from ee.onyx.server.scim.filtering import parse_scim_filter
|
||||
from ee.onyx.server.scim.models import SCIM_LIST_RESPONSE_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimError
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimListResponse
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimPatchRequest
|
||||
from ee.onyx.server.scim.models import ScimResourceType
|
||||
from ee.onyx.server.scim.models import ScimSchemaDefinition
|
||||
from ee.onyx.server.scim.models import ScimServiceProviderConfig
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.patch import apply_group_patch
|
||||
@@ -41,6 +44,8 @@ from ee.onyx.server.scim.patch import apply_user_patch
|
||||
from ee.onyx.server.scim.patch import ScimPatchError
|
||||
from ee.onyx.server.scim.providers.base import get_default_provider
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
from ee.onyx.server.scim.providers.base import serialize_emails
|
||||
from ee.onyx.server.scim.schema_definitions import ENTERPRISE_USER_SCHEMA_DEF
|
||||
from ee.onyx.server.scim.schema_definitions import GROUP_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import GROUP_SCHEMA_DEF
|
||||
from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
|
||||
@@ -48,20 +53,49 @@ from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import ScimUserMapping
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class ScimJSONResponse(JSONResponse):
|
||||
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
|
||||
|
||||
media_type = "application/scim+json"
|
||||
|
||||
|
||||
# NOTE: All URL paths in this router (/ServiceProviderConfig, /ResourceTypes,
|
||||
# /Schemas, /Users, /Groups) are mandated by the SCIM spec (RFC 7643/7644).
|
||||
# IdPs like Okta and Azure AD hardcode these exact paths, so they cannot be
|
||||
# changed to kebab-case.
|
||||
|
||||
|
||||
scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
|
||||
|
||||
_pw_helper = PasswordHelper()
|
||||
|
||||
|
||||
def register_scim_exception_handlers(app: FastAPI) -> None:
|
||||
"""Register SCIM-specific exception handlers on the FastAPI app.
|
||||
|
||||
Call this after ``app.include_router(scim_router)`` so that auth
|
||||
failures from ``verify_scim_token`` return RFC 7644 §3.12 error
|
||||
envelopes (with ``schemas`` and ``status`` fields) instead of
|
||||
FastAPI's default ``{"detail": "..."}`` format.
|
||||
"""
|
||||
|
||||
@app.exception_handler(ScimAuthError)
|
||||
async def _handle_scim_auth_error(
|
||||
_request: Request, exc: ScimAuthError
|
||||
) -> ScimJSONResponse:
|
||||
return _scim_error_response(exc.status_code, exc.detail)
|
||||
|
||||
|
||||
def _get_provider(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
) -> ScimProvider:
|
||||
@@ -86,15 +120,39 @@ def get_service_provider_config() -> ScimServiceProviderConfig:
|
||||
|
||||
|
||||
@scim_router.get("/ResourceTypes")
|
||||
def get_resource_types() -> list[ScimResourceType]:
|
||||
"""List available SCIM resource types (RFC 7643 §6)."""
|
||||
return [USER_RESOURCE_TYPE, GROUP_RESOURCE_TYPE]
|
||||
def get_resource_types() -> ScimJSONResponse:
|
||||
"""List available SCIM resource types (RFC 7643 §6).
|
||||
|
||||
Wrapped in a ListResponse envelope (RFC 7644 §3.4.2) because IdPs
|
||||
like Entra ID expect a JSON object, not a bare array.
|
||||
"""
|
||||
resources = [USER_RESOURCE_TYPE, GROUP_RESOURCE_TYPE]
|
||||
return ScimJSONResponse(
|
||||
content={
|
||||
"schemas": [SCIM_LIST_RESPONSE_SCHEMA],
|
||||
"totalResults": len(resources),
|
||||
"Resources": [
|
||||
r.model_dump(exclude_none=True, by_alias=True) for r in resources
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@scim_router.get("/Schemas")
|
||||
def get_schemas() -> list[ScimSchemaDefinition]:
|
||||
"""Return SCIM schema definitions (RFC 7643 §7)."""
|
||||
return [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF]
|
||||
def get_schemas() -> ScimJSONResponse:
|
||||
"""Return SCIM schema definitions (RFC 7643 §7).
|
||||
|
||||
Wrapped in a ListResponse envelope (RFC 7644 §3.4.2) because IdPs
|
||||
like Entra ID expect a JSON object, not a bare array.
|
||||
"""
|
||||
schemas = [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF, ENTERPRISE_USER_SCHEMA_DEF]
|
||||
return ScimJSONResponse(
|
||||
content={
|
||||
"schemas": [SCIM_LIST_RESPONSE_SCHEMA],
|
||||
"totalResults": len(schemas),
|
||||
"Resources": [s.model_dump(exclude_none=True) for s in schemas],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -102,15 +160,45 @@ def get_schemas() -> list[ScimSchemaDefinition]:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _scim_error_response(status: int, detail: str) -> JSONResponse:
|
||||
def _scim_error_response(status: int, detail: str) -> ScimJSONResponse:
|
||||
"""Build a SCIM-compliant error response (RFC 7644 §3.12)."""
|
||||
logger.warning("SCIM error response: status=%s detail=%s", status, detail)
|
||||
body = ScimError(status=str(status), detail=detail)
|
||||
return JSONResponse(
|
||||
return ScimJSONResponse(
|
||||
status_code=status,
|
||||
content=body.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
def _parse_excluded_attributes(raw: str | None) -> set[str]:
|
||||
"""Parse the ``excludedAttributes`` query parameter (RFC 7644 §3.4.2.5).
|
||||
|
||||
Returns a set of lowercased attribute names to omit from responses.
|
||||
"""
|
||||
if not raw:
|
||||
return set()
|
||||
return {attr.strip().lower() for attr in raw.split(",") if attr.strip()}
|
||||
|
||||
|
||||
def _apply_exclusions(
|
||||
resource: ScimUserResource | ScimGroupResource,
|
||||
excluded: set[str],
|
||||
) -> dict:
|
||||
"""Serialize a SCIM resource, omitting attributes the IdP excluded.
|
||||
|
||||
RFC 7644 §3.4.2.5 lets the IdP pass ``?excludedAttributes=groups,emails``
|
||||
to reduce response payload size. We strip those fields after serialization
|
||||
so the rest of the pipeline doesn't need to know about them.
|
||||
"""
|
||||
data = resource.model_dump(exclude_none=True, by_alias=True)
|
||||
for attr in excluded:
|
||||
# Match case-insensitively against the camelCase field names
|
||||
keys_to_remove = [k for k in data if k.lower() == attr]
|
||||
for k in keys_to_remove:
|
||||
del data[k]
|
||||
return data
|
||||
|
||||
|
||||
def _check_seat_availability(dal: ScimDAL) -> str | None:
|
||||
"""Return an error message if seat limit is reached, else None."""
|
||||
check_fn = fetch_ee_implementation_or_noop(
|
||||
@@ -124,7 +212,7 @@ def _check_seat_availability(dal: ScimDAL) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_user_or_404(user_id: str, dal: ScimDAL) -> User | JSONResponse:
|
||||
def _fetch_user_or_404(user_id: str, dal: ScimDAL) -> User | ScimJSONResponse:
|
||||
"""Parse *user_id* as UUID, look up the user, or return a 404 error."""
|
||||
try:
|
||||
uid = UUID(user_id)
|
||||
@@ -144,10 +232,95 @@ def _scim_name_to_str(name: ScimName | None) -> str | None:
|
||||
"""
|
||||
if not name:
|
||||
return None
|
||||
# Build from givenName/familyName first — IdPs like Okta may send a stale
|
||||
# ``formatted`` value while updating the individual name components.
|
||||
# If the client explicitly provides ``formatted``, prefer it — the client
|
||||
# knows what display string it wants. Otherwise build from components.
|
||||
if name.formatted:
|
||||
return name.formatted
|
||||
parts = " ".join(part for part in [name.givenName, name.familyName] if part)
|
||||
return parts or name.formatted
|
||||
return parts or None
|
||||
|
||||
|
||||
def _scim_resource_response(
|
||||
resource: ScimUserResource | ScimGroupResource | ScimListResponse,
|
||||
status_code: int = 200,
|
||||
) -> ScimJSONResponse:
|
||||
"""Serialize a SCIM resource as ``application/scim+json``."""
|
||||
content = resource.model_dump(exclude_none=True, by_alias=True)
|
||||
return ScimJSONResponse(
|
||||
status_code=status_code,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
def _build_list_response(
|
||||
resources: list[ScimUserResource | ScimGroupResource],
|
||||
total: int,
|
||||
start_index: int,
|
||||
count: int,
|
||||
excluded: set[str] | None = None,
|
||||
) -> ScimListResponse | ScimJSONResponse:
|
||||
"""Build a SCIM list response, optionally applying attribute exclusions.
|
||||
|
||||
RFC 7644 §3.4.2.5 — IdPs may request certain attributes be omitted via
|
||||
the ``excludedAttributes`` query parameter.
|
||||
"""
|
||||
if excluded:
|
||||
envelope = ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=start_index,
|
||||
itemsPerPage=count,
|
||||
)
|
||||
data = envelope.model_dump(exclude_none=True)
|
||||
data["Resources"] = [_apply_exclusions(r, excluded) for r in resources]
|
||||
return ScimJSONResponse(content=data)
|
||||
|
||||
return _scim_resource_response(
|
||||
ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=start_index,
|
||||
itemsPerPage=count,
|
||||
Resources=resources,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _extract_enterprise_fields(
|
||||
resource: ScimUserResource,
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Extract department and manager from enterprise extension."""
|
||||
ext = resource.enterprise_extension
|
||||
if not ext:
|
||||
return None, None
|
||||
department = ext.department
|
||||
manager = ext.manager.value if ext.manager else None
|
||||
return department, manager
|
||||
|
||||
|
||||
def _mapping_to_fields(
|
||||
mapping: ScimUserMapping | None,
|
||||
) -> ScimMappingFields | None:
|
||||
"""Extract round-trip fields from a SCIM user mapping."""
|
||||
if not mapping:
|
||||
return None
|
||||
return ScimMappingFields(
|
||||
department=mapping.department,
|
||||
manager=mapping.manager,
|
||||
given_name=mapping.given_name,
|
||||
family_name=mapping.family_name,
|
||||
scim_emails_json=mapping.scim_emails_json,
|
||||
)
|
||||
|
||||
|
||||
def _fields_from_resource(resource: ScimUserResource) -> ScimMappingFields:
|
||||
"""Build mapping fields from an incoming SCIM user resource."""
|
||||
department, manager = _extract_enterprise_fields(resource)
|
||||
return ScimMappingFields(
|
||||
department=department,
|
||||
manager=manager,
|
||||
given_name=resource.name.givenName if resource.name else None,
|
||||
family_name=resource.name.familyName if resource.name else None,
|
||||
scim_emails_json=serialize_emails(resource.emails),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -158,15 +331,17 @@ def _scim_name_to_str(name: ScimName | None) -> str | None:
|
||||
@scim_router.get("/Users", response_model=None)
|
||||
def list_users(
|
||||
filter: str | None = Query(None),
|
||||
excludedAttributes: str | None = None,
|
||||
startIndex: int = Query(1, ge=1),
|
||||
count: int = Query(100, ge=0, le=500),
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimListResponse | JSONResponse:
|
||||
) -> ScimListResponse | ScimJSONResponse:
|
||||
"""List users with optional SCIM filter and pagination."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
dal.commit()
|
||||
|
||||
try:
|
||||
scim_filter = parse_scim_filter(filter)
|
||||
@@ -185,42 +360,55 @@ def list_users(
|
||||
mapping.external_id if mapping else None,
|
||||
groups=user_groups_map.get(user.id, []),
|
||||
scim_username=mapping.scim_username if mapping else None,
|
||||
fields=_mapping_to_fields(mapping),
|
||||
)
|
||||
for user, mapping in users_with_mappings
|
||||
]
|
||||
|
||||
return ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=startIndex,
|
||||
itemsPerPage=count,
|
||||
Resources=resources,
|
||||
return _build_list_response(
|
||||
resources,
|
||||
total,
|
||||
startIndex,
|
||||
count,
|
||||
excluded=_parse_excluded_attributes(excludedAttributes),
|
||||
)
|
||||
|
||||
|
||||
@scim_router.get("/Users/{user_id}", response_model=None)
|
||||
def get_user(
|
||||
user_id: str,
|
||||
excludedAttributes: str | None = None,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
) -> ScimUserResource | ScimJSONResponse:
|
||||
"""Get a single user by ID."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
dal.commit()
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
return provider.build_user_resource(
|
||||
|
||||
resource = provider.build_user_resource(
|
||||
user,
|
||||
mapping.external_id if mapping else None,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=mapping.scim_username if mapping else None,
|
||||
fields=_mapping_to_fields(mapping),
|
||||
)
|
||||
|
||||
# RFC 7644 §3.4.2.5 — IdP may request certain attributes be omitted
|
||||
excluded = _parse_excluded_attributes(excludedAttributes)
|
||||
if excluded:
|
||||
return ScimJSONResponse(content=_apply_exclusions(resource, excluded))
|
||||
|
||||
return _scim_resource_response(resource)
|
||||
|
||||
|
||||
@scim_router.post("/Users", status_code=201, response_model=None)
|
||||
def create_user(
|
||||
@@ -228,19 +416,13 @@ def create_user(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
) -> ScimUserResource | ScimJSONResponse:
|
||||
"""Create a new user from a SCIM provisioning request."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
email = user_resource.userName.strip()
|
||||
|
||||
# externalId is how the IdP correlates this user on subsequent requests.
|
||||
# Without it, the IdP can't find the user and will try to re-create,
|
||||
# hitting a 409 conflict — so we require it up front.
|
||||
if not user_resource.externalId:
|
||||
return _scim_error_response(400, "externalId is required")
|
||||
|
||||
# Enforce seat limit
|
||||
seat_error = _check_seat_availability(dal)
|
||||
if seat_error:
|
||||
@@ -267,16 +449,31 @@ def create_user(
|
||||
dal.rollback()
|
||||
return _scim_error_response(409, f"User with email {email} already exists")
|
||||
|
||||
# Create SCIM mapping (externalId is validated above, always present)
|
||||
# Create SCIM mapping when externalId is provided — this is how the IdP
|
||||
# correlates this user on subsequent requests. Per RFC 7643, externalId
|
||||
# is optional and assigned by the provisioning client.
|
||||
external_id = user_resource.externalId
|
||||
scim_username = user_resource.userName.strip()
|
||||
dal.create_user_mapping(
|
||||
external_id=external_id, user_id=user.id, scim_username=scim_username
|
||||
)
|
||||
fields = _fields_from_resource(user_resource)
|
||||
if external_id:
|
||||
dal.create_user_mapping(
|
||||
external_id=external_id,
|
||||
user_id=user.id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return provider.build_user_resource(user, external_id, scim_username=scim_username)
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
external_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
),
|
||||
status_code=201,
|
||||
)
|
||||
|
||||
|
||||
@scim_router.put("/Users/{user_id}", response_model=None)
|
||||
@@ -286,13 +483,13 @@ def replace_user(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
) -> ScimUserResource | ScimJSONResponse:
|
||||
"""Replace a user entirely (RFC 7644 §3.5.1)."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
@@ -313,15 +510,24 @@ def replace_user(
|
||||
|
||||
new_external_id = user_resource.externalId
|
||||
scim_username = user_resource.userName.strip()
|
||||
dal.sync_user_external_id(user.id, new_external_id, scim_username=scim_username)
|
||||
fields = _fields_from_resource(user_resource)
|
||||
dal.sync_user_external_id(
|
||||
user.id,
|
||||
new_external_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return provider.build_user_resource(
|
||||
user,
|
||||
new_external_id,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=scim_username,
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
new_external_id,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -332,7 +538,7 @@ def patch_user(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
) -> ScimUserResource | ScimJSONResponse:
|
||||
"""Partially update a user (RFC 7644 §3.5.2).
|
||||
|
||||
This is the primary endpoint for user deprovisioning — Okta sends
|
||||
@@ -342,23 +548,25 @@ def patch_user(
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
external_id = mapping.external_id if mapping else None
|
||||
current_scim_username = mapping.scim_username if mapping else None
|
||||
current_fields = _mapping_to_fields(mapping)
|
||||
|
||||
current = provider.build_user_resource(
|
||||
user,
|
||||
external_id,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=current_scim_username,
|
||||
fields=current_fields,
|
||||
)
|
||||
|
||||
try:
|
||||
patched = apply_user_patch(
|
||||
patched, ent_data = apply_user_patch(
|
||||
patch_request.Operations, current, provider.ignored_patch_paths
|
||||
)
|
||||
except ScimPatchError as e:
|
||||
@@ -393,17 +601,37 @@ def patch_user(
|
||||
personal_name=personal_name,
|
||||
)
|
||||
|
||||
# Build updated fields by merging PATCH enterprise data with current values
|
||||
cf = current_fields or ScimMappingFields()
|
||||
fields = ScimMappingFields(
|
||||
department=ent_data.get("department", cf.department),
|
||||
manager=ent_data.get("manager", cf.manager),
|
||||
given_name=patched.name.givenName if patched.name else cf.given_name,
|
||||
family_name=patched.name.familyName if patched.name else cf.family_name,
|
||||
scim_emails_json=(
|
||||
serialize_emails(patched.emails)
|
||||
if patched.emails is not None
|
||||
else cf.scim_emails_json
|
||||
),
|
||||
)
|
||||
|
||||
dal.sync_user_external_id(
|
||||
user.id, patched.externalId, scim_username=new_scim_username
|
||||
user.id,
|
||||
patched.externalId,
|
||||
scim_username=new_scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return provider.build_user_resource(
|
||||
user,
|
||||
patched.externalId,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=new_scim_username,
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
patched.externalId,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=new_scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -412,25 +640,29 @@ def delete_user(
|
||||
user_id: str,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response | JSONResponse:
|
||||
) -> Response | ScimJSONResponse:
|
||||
"""Delete a user (RFC 7644 §3.6).
|
||||
|
||||
Deactivates the user and removes the SCIM mapping. Note that Okta
|
||||
typically uses PATCH active=false instead of DELETE.
|
||||
A second DELETE returns 404 per RFC 7644 §3.6.
|
||||
"""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
dal.deactivate_user(user)
|
||||
|
||||
# If no SCIM mapping exists, the user was already deleted from
|
||||
# SCIM's perspective — return 404 per RFC 7644 §3.6.
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
if mapping:
|
||||
dal.delete_user_mapping(mapping.id)
|
||||
if not mapping:
|
||||
return _scim_error_response(404, f"User {user_id} not found")
|
||||
|
||||
dal.deactivate_user(user)
|
||||
dal.delete_user_mapping(mapping.id)
|
||||
|
||||
dal.commit()
|
||||
|
||||
@@ -442,7 +674,7 @@ def delete_user(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | JSONResponse:
|
||||
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | ScimJSONResponse:
|
||||
"""Parse *group_id* as int, look up the group, or return a 404 error."""
|
||||
try:
|
||||
gid = int(group_id)
|
||||
@@ -497,15 +729,17 @@ def _validate_and_parse_members(
|
||||
@scim_router.get("/Groups", response_model=None)
|
||||
def list_groups(
|
||||
filter: str | None = Query(None),
|
||||
excludedAttributes: str | None = None,
|
||||
startIndex: int = Query(1, ge=1),
|
||||
count: int = Query(100, ge=0, le=500),
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimListResponse | JSONResponse:
|
||||
) -> ScimListResponse | ScimJSONResponse:
|
||||
"""List groups with optional SCIM filter and pagination."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
dal.commit()
|
||||
|
||||
try:
|
||||
scim_filter = parse_scim_filter(filter)
|
||||
@@ -522,37 +756,47 @@ def list_groups(
|
||||
for group, ext_id in groups_with_ext_ids
|
||||
]
|
||||
|
||||
return ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=startIndex,
|
||||
itemsPerPage=count,
|
||||
Resources=resources,
|
||||
return _build_list_response(
|
||||
resources,
|
||||
total,
|
||||
startIndex,
|
||||
count,
|
||||
excluded=_parse_excluded_attributes(excludedAttributes),
|
||||
)
|
||||
|
||||
|
||||
@scim_router.get("/Groups/{group_id}", response_model=None)
|
||||
def get_group(
|
||||
group_id: str,
|
||||
excludedAttributes: str | None = None,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
) -> ScimGroupResource | ScimJSONResponse:
|
||||
"""Get a single group by ID."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
dal.commit()
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
mapping = dal.get_group_mapping_by_group_id(group.id)
|
||||
members = dal.get_group_members(group.id)
|
||||
|
||||
return provider.build_group_resource(
|
||||
resource = provider.build_group_resource(
|
||||
group, members, mapping.external_id if mapping else None
|
||||
)
|
||||
|
||||
# RFC 7644 §3.4.2.5 — IdP may request certain attributes be omitted
|
||||
excluded = _parse_excluded_attributes(excludedAttributes)
|
||||
if excluded:
|
||||
return ScimJSONResponse(content=_apply_exclusions(resource, excluded))
|
||||
|
||||
return _scim_resource_response(resource)
|
||||
|
||||
|
||||
@scim_router.post("/Groups", status_code=201, response_model=None)
|
||||
def create_group(
|
||||
@@ -560,7 +804,7 @@ def create_group(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
) -> ScimGroupResource | ScimJSONResponse:
|
||||
"""Create a new group from a SCIM provisioning request."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
@@ -596,7 +840,10 @@ def create_group(
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(db_group.id)
|
||||
return provider.build_group_resource(db_group, members, external_id)
|
||||
return _scim_resource_response(
|
||||
provider.build_group_resource(db_group, members, external_id),
|
||||
status_code=201,
|
||||
)
|
||||
|
||||
|
||||
@scim_router.put("/Groups/{group_id}", response_model=None)
|
||||
@@ -606,13 +853,13 @@ def replace_group(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
) -> ScimGroupResource | ScimJSONResponse:
|
||||
"""Replace a group entirely (RFC 7644 §3.5.1)."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
@@ -627,7 +874,9 @@ def replace_group(
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(group.id)
|
||||
return provider.build_group_resource(group, members, group_resource.externalId)
|
||||
return _scim_resource_response(
|
||||
provider.build_group_resource(group, members, group_resource.externalId)
|
||||
)
|
||||
|
||||
|
||||
@scim_router.patch("/Groups/{group_id}", response_model=None)
|
||||
@@ -637,7 +886,7 @@ def patch_group(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
) -> ScimGroupResource | ScimJSONResponse:
|
||||
"""Partially update a group (RFC 7644 §3.5.2).
|
||||
|
||||
Handles member add/remove operations from Okta and Azure AD.
|
||||
@@ -646,7 +895,7 @@ def patch_group(
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
@@ -685,7 +934,9 @@ def patch_group(
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(group.id)
|
||||
return provider.build_group_resource(group, members, patched.externalId)
|
||||
return _scim_resource_response(
|
||||
provider.build_group_resource(group, members, patched.externalId)
|
||||
)
|
||||
|
||||
|
||||
@scim_router.delete("/Groups/{group_id}", status_code=204, response_model=None)
|
||||
@@ -693,13 +944,13 @@ def delete_group(
|
||||
group_id: str,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response | JSONResponse:
|
||||
) -> Response | ScimJSONResponse:
|
||||
"""Delete a group (RFC 7644 §3.6)."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ import hashlib
|
||||
import secrets
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -28,6 +27,21 @@ from onyx.auth.utils import get_hashed_bearer_token_from_request
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import ScimToken
|
||||
|
||||
|
||||
class ScimAuthError(Exception):
|
||||
"""Raised when SCIM bearer token authentication fails.
|
||||
|
||||
Unlike HTTPException, this carries the status and detail so the SCIM
|
||||
exception handler can wrap them in an RFC 7644 §3.12 error envelope
|
||||
with ``schemas`` and ``status`` fields.
|
||||
"""
|
||||
|
||||
def __init__(self, status_code: int, detail: str) -> None:
|
||||
self.status_code = status_code
|
||||
self.detail = detail
|
||||
super().__init__(detail)
|
||||
|
||||
|
||||
SCIM_TOKEN_PREFIX = "onyx_scim_"
|
||||
SCIM_TOKEN_LENGTH = 48
|
||||
|
||||
@@ -82,23 +96,14 @@ def verify_scim_token(
|
||||
"""
|
||||
hashed = _get_hashed_scim_token_from_request(request)
|
||||
if not hashed:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing or invalid SCIM bearer token",
|
||||
)
|
||||
raise ScimAuthError(401, "Missing or invalid SCIM bearer token")
|
||||
|
||||
token = dal.get_token_by_hash(hashed)
|
||||
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid SCIM bearer token",
|
||||
)
|
||||
raise ScimAuthError(401, "Invalid SCIM bearer token")
|
||||
|
||||
if not token.is_active:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="SCIM token has been revoked",
|
||||
)
|
||||
raise ScimAuthError(401, "SCIM token has been revoked")
|
||||
|
||||
return token
|
||||
|
||||
@@ -7,12 +7,14 @@ SCIM protocol schemas follow the wire format defined in:
|
||||
Admin API schemas are internal to Onyx and used for SCIM token management.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -31,6 +33,9 @@ SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA = (
|
||||
)
|
||||
SCIM_RESOURCE_TYPE_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:ResourceType"
|
||||
SCIM_SCHEMA_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Schema"
|
||||
SCIM_ENTERPRISE_USER_SCHEMA = (
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -70,6 +75,36 @@ class ScimUserGroupRef(BaseModel):
|
||||
display: str | None = None
|
||||
|
||||
|
||||
class ScimManagerRef(BaseModel):
|
||||
"""Manager sub-attribute for the enterprise extension (RFC 7643 §4.3)."""
|
||||
|
||||
value: str | None = None
|
||||
|
||||
|
||||
class ScimEnterpriseExtension(BaseModel):
|
||||
"""Enterprise User extension attributes (RFC 7643 §4.3)."""
|
||||
|
||||
department: str | None = None
|
||||
manager: ScimManagerRef | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScimMappingFields:
|
||||
"""Stored SCIM mapping fields that need to round-trip through the IdP.
|
||||
|
||||
Entra ID sends structured name components, email metadata, and enterprise
|
||||
extension attributes that must be returned verbatim in subsequent GET
|
||||
responses. These fields are persisted on ScimUserMapping and threaded
|
||||
through the DAL, provider, and endpoint layers.
|
||||
"""
|
||||
|
||||
department: str | None = None
|
||||
manager: str | None = None
|
||||
given_name: str | None = None
|
||||
family_name: str | None = None
|
||||
scim_emails_json: str | None = None
|
||||
|
||||
|
||||
class ScimUserResource(BaseModel):
|
||||
"""SCIM User resource representation (RFC 7643 §4.1).
|
||||
|
||||
@@ -78,6 +113,8 @@ class ScimUserResource(BaseModel):
|
||||
to match the SCIM wire format (not Python convention).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_USER_SCHEMA])
|
||||
id: str | None = None # Onyx's internal user ID, set on responses
|
||||
externalId: str | None = None # IdP's identifier for this user
|
||||
@@ -88,6 +125,10 @@ class ScimUserResource(BaseModel):
|
||||
active: bool = True
|
||||
groups: list[ScimUserGroupRef] = Field(default_factory=list)
|
||||
meta: ScimMeta | None = None
|
||||
enterprise_extension: ScimEnterpriseExtension | None = Field(
|
||||
default=None,
|
||||
alias="urn:ietf:params:scim:schemas:extension:enterprise:2.0:User",
|
||||
)
|
||||
|
||||
|
||||
class ScimGroupMember(BaseModel):
|
||||
@@ -165,6 +206,19 @@ class ScimPatchOperation(BaseModel):
|
||||
path: str | None = None
|
||||
value: ScimPatchValue = None
|
||||
|
||||
@field_validator("op", mode="before")
|
||||
@classmethod
|
||||
def normalize_operation(cls, v: object) -> object:
|
||||
"""Normalize op to lowercase for case-insensitive matching.
|
||||
|
||||
Some IdPs (e.g. Entra ID) send capitalized ops like ``"Replace"``
|
||||
instead of ``"replace"``. This is safe for all providers since the
|
||||
enum values are lowercase. If a future provider requires other
|
||||
pre-processing quirks, move patch deserialization into the provider
|
||||
subclass instead of adding more special cases here.
|
||||
"""
|
||||
return v.lower() if isinstance(v, str) else v
|
||||
|
||||
|
||||
class ScimPatchRequest(BaseModel):
|
||||
"""PATCH request body (RFC 7644 §3.5.2).
|
||||
|
||||
@@ -14,8 +14,13 @@ responsible for persisting changes.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field
|
||||
from typing import Any
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimPatchOperation
|
||||
@@ -24,6 +29,55 @@ from ee.onyx.server.scim.models import ScimPatchResourceValue
|
||||
from ee.onyx.server.scim.models import ScimPatchValue
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Lowercased enterprise extension URN for case-insensitive matching
|
||||
_ENTERPRISE_URN_LOWER = SCIM_ENTERPRISE_USER_SCHEMA.lower()
|
||||
|
||||
# Pattern for email filter paths, e.g.:
|
||||
# emails[primary eq true].value (Okta)
|
||||
# emails[type eq "work"].value (Azure AD / Entra ID)
|
||||
_EMAIL_FILTER_RE = re.compile(
|
||||
r"^emails\[.+\]\.value$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Pattern for member removal path: members[value eq "user-id"]
|
||||
_MEMBER_FILTER_RE = re.compile(
|
||||
r'^members\[value\s+eq\s+"([^"]+)"\]$',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dispatch tables for user PATCH paths
|
||||
#
|
||||
# Maps lowercased SCIM path → (camelCase key, target dict name).
|
||||
# "data" writes to the top-level resource dict, "name" writes to the
|
||||
# name sub-object dict. This replaces the elif chains for simple fields.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_USER_REPLACE_PATHS: dict[str, tuple[str, str]] = {
|
||||
"active": ("active", "data"),
|
||||
"username": ("userName", "data"),
|
||||
"externalid": ("externalId", "data"),
|
||||
"name.givenname": ("givenName", "name"),
|
||||
"name.familyname": ("familyName", "name"),
|
||||
"name.formatted": ("formatted", "name"),
|
||||
}
|
||||
|
||||
_USER_REMOVE_PATHS: dict[str, tuple[str, str]] = {
|
||||
"externalid": ("externalId", "data"),
|
||||
"name.givenname": ("givenName", "name"),
|
||||
"name.familyname": ("familyName", "name"),
|
||||
"name.formatted": ("formatted", "name"),
|
||||
"displayname": ("displayName", "data"),
|
||||
}
|
||||
|
||||
_GROUP_REPLACE_PATHS: dict[str, tuple[str, str]] = {
|
||||
"displayname": ("displayName", "data"),
|
||||
"externalid": ("externalId", "data"),
|
||||
}
|
||||
|
||||
|
||||
class ScimPatchError(Exception):
|
||||
"""Raised when a PATCH operation cannot be applied."""
|
||||
@@ -34,18 +88,25 @@ class ScimPatchError(Exception):
|
||||
super().__init__(detail)
|
||||
|
||||
|
||||
# Pattern for member removal path: members[value eq "user-id"]
|
||||
_MEMBER_FILTER_RE = re.compile(
|
||||
r'^members\[value\s+eq\s+"([^"]+)"\]$',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
@dataclass
|
||||
class _UserPatchCtx:
|
||||
"""Bundles the mutable state for user PATCH operations."""
|
||||
|
||||
data: dict[str, Any]
|
||||
name_data: dict[str, Any]
|
||||
ent_data: dict[str, str | None] = field(default_factory=dict)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User PATCH
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def apply_user_patch(
|
||||
operations: list[ScimPatchOperation],
|
||||
current: ScimUserResource,
|
||||
ignored_paths: frozenset[str] = frozenset(),
|
||||
) -> ScimUserResource:
|
||||
) -> tuple[ScimUserResource, dict[str, str | None]]:
|
||||
"""Apply SCIM PATCH operations to a user resource.
|
||||
|
||||
Args:
|
||||
@@ -53,79 +114,185 @@ def apply_user_patch(
|
||||
current: The current user resource state.
|
||||
ignored_paths: SCIM attribute paths to silently skip (from provider).
|
||||
|
||||
Returns a new ``ScimUserResource`` with the modifications applied.
|
||||
The original object is not mutated.
|
||||
Returns:
|
||||
A tuple of (modified user resource, enterprise extension data dict).
|
||||
The enterprise dict has keys ``"department"`` and ``"manager"``
|
||||
with values set only when a PATCH operation touched them.
|
||||
|
||||
Raises:
|
||||
ScimPatchError: If an operation targets an unsupported path.
|
||||
"""
|
||||
data = current.model_dump()
|
||||
name_data = data.get("name") or {}
|
||||
ctx = _UserPatchCtx(data=data, name_data=data.get("name") or {})
|
||||
|
||||
for op in operations:
|
||||
if op.op == ScimPatchOperationType.REPLACE:
|
||||
_apply_user_replace(op, data, name_data, ignored_paths)
|
||||
elif op.op == ScimPatchOperationType.ADD:
|
||||
_apply_user_replace(op, data, name_data, ignored_paths)
|
||||
if op.op in (ScimPatchOperationType.REPLACE, ScimPatchOperationType.ADD):
|
||||
_apply_user_replace(op, ctx, ignored_paths)
|
||||
elif op.op == ScimPatchOperationType.REMOVE:
|
||||
_apply_user_remove(op, ctx, ignored_paths)
|
||||
else:
|
||||
raise ScimPatchError(
|
||||
f"Unsupported operation '{op.op.value}' on User resource"
|
||||
)
|
||||
|
||||
data["name"] = name_data
|
||||
return ScimUserResource.model_validate(data)
|
||||
ctx.data["name"] = ctx.name_data
|
||||
return ScimUserResource.model_validate(ctx.data), ctx.ent_data
|
||||
|
||||
|
||||
def _apply_user_replace(
|
||||
op: ScimPatchOperation,
|
||||
data: dict,
|
||||
name_data: dict,
|
||||
ctx: _UserPatchCtx,
|
||||
ignored_paths: frozenset[str],
|
||||
) -> None:
|
||||
"""Apply a replace/add operation to user data."""
|
||||
path = (op.path or "").lower()
|
||||
|
||||
if not path:
|
||||
# No path — value is a resource dict of top-level attributes to set
|
||||
# No path — value is a resource dict of top-level attributes to set.
|
||||
if isinstance(op.value, ScimPatchResourceValue):
|
||||
for key, val in op.value.model_dump(exclude_unset=True).items():
|
||||
_set_user_field(key.lower(), val, data, name_data, ignored_paths)
|
||||
_set_user_field(key.lower(), val, ctx, ignored_paths, strict=False)
|
||||
else:
|
||||
raise ScimPatchError("Replace without path requires a dict value")
|
||||
return
|
||||
|
||||
_set_user_field(path, op.value, data, name_data, ignored_paths)
|
||||
_set_user_field(path, op.value, ctx, ignored_paths)
|
||||
|
||||
|
||||
def _apply_user_remove(
|
||||
op: ScimPatchOperation,
|
||||
ctx: _UserPatchCtx,
|
||||
ignored_paths: frozenset[str],
|
||||
) -> None:
|
||||
"""Apply a remove operation to user data — clears the target field."""
|
||||
path = (op.path or "").lower()
|
||||
if not path:
|
||||
raise ScimPatchError("Remove operation requires a path")
|
||||
|
||||
if path in ignored_paths:
|
||||
return
|
||||
|
||||
entry = _USER_REMOVE_PATHS.get(path)
|
||||
if entry:
|
||||
key, target = entry
|
||||
target_dict = ctx.data if target == "data" else ctx.name_data
|
||||
target_dict[key] = None
|
||||
return
|
||||
|
||||
raise ScimPatchError(f"Unsupported remove path '{path}' for User PATCH")
|
||||
|
||||
|
||||
def _set_user_field(
|
||||
path: str,
|
||||
value: ScimPatchValue,
|
||||
data: dict,
|
||||
name_data: dict,
|
||||
ctx: _UserPatchCtx,
|
||||
ignored_paths: frozenset[str],
|
||||
*,
|
||||
strict: bool = True,
|
||||
) -> None:
|
||||
"""Set a single field on user data by SCIM path."""
|
||||
"""Set a single field on user data by SCIM path.
|
||||
|
||||
Args:
|
||||
strict: When ``False`` (path-less replace), unknown attributes are
|
||||
silently skipped. When ``True`` (explicit path), they raise.
|
||||
"""
|
||||
if path in ignored_paths:
|
||||
return
|
||||
elif path == "active":
|
||||
data["active"] = value
|
||||
elif path == "username":
|
||||
data["userName"] = value
|
||||
elif path == "externalid":
|
||||
data["externalId"] = value
|
||||
elif path == "name.givenname":
|
||||
name_data["givenName"] = value
|
||||
elif path == "name.familyname":
|
||||
name_data["familyName"] = value
|
||||
elif path == "name.formatted":
|
||||
name_data["formatted"] = value
|
||||
elif path == "displayname":
|
||||
data["displayName"] = value
|
||||
name_data["formatted"] = value
|
||||
|
||||
# Simple field writes handled by the dispatch table
|
||||
entry = _USER_REPLACE_PATHS.get(path)
|
||||
if entry:
|
||||
key, target = entry
|
||||
target_dict = ctx.data if target == "data" else ctx.name_data
|
||||
target_dict[key] = value
|
||||
return
|
||||
|
||||
# displayName sets both the top-level field and the name.formatted sub-field
|
||||
if path == "displayname":
|
||||
ctx.data["displayName"] = value
|
||||
ctx.name_data["formatted"] = value
|
||||
elif path == "name":
|
||||
if isinstance(value, dict):
|
||||
for k, v in value.items():
|
||||
ctx.name_data[k] = v
|
||||
elif path == "emails":
|
||||
if isinstance(value, list):
|
||||
ctx.data["emails"] = value
|
||||
elif _EMAIL_FILTER_RE.match(path):
|
||||
_update_primary_email(ctx.data, value)
|
||||
elif path.startswith(_ENTERPRISE_URN_LOWER):
|
||||
_set_enterprise_field(path, value, ctx.ent_data)
|
||||
elif not strict:
|
||||
return
|
||||
else:
|
||||
raise ScimPatchError(f"Unsupported path '{path}' for User PATCH")
|
||||
|
||||
|
||||
def _update_primary_email(data: dict[str, Any], value: ScimPatchValue) -> None:
|
||||
"""Update the primary email entry via an email filter path."""
|
||||
emails: list[dict] = data.get("emails") or []
|
||||
for email_entry in emails:
|
||||
if email_entry.get("primary"):
|
||||
email_entry["value"] = value
|
||||
break
|
||||
else:
|
||||
emails.append({"value": value, "type": "work", "primary": True})
|
||||
data["emails"] = emails
|
||||
|
||||
|
||||
def _to_dict(value: ScimPatchValue) -> dict | None:
|
||||
"""Coerce a SCIM patch value to a plain dict if possible.
|
||||
|
||||
Pydantic may parse raw dicts as ``ScimPatchResourceValue`` (which uses
|
||||
``extra="allow"``), so we also dump those back to a dict.
|
||||
"""
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
if isinstance(value, ScimPatchResourceValue):
|
||||
return value.model_dump(exclude_unset=True)
|
||||
return None
|
||||
|
||||
|
||||
def _set_enterprise_field(
|
||||
path: str,
|
||||
value: ScimPatchValue,
|
||||
ent_data: dict[str, str | None],
|
||||
) -> None:
|
||||
"""Handle enterprise extension URN paths or value dicts."""
|
||||
# Full URN as key with dict value (path-less PATCH)
|
||||
# e.g. key="urn:...:user", value={"department": "Eng", "manager": {...}}
|
||||
if path == _ENTERPRISE_URN_LOWER:
|
||||
d = _to_dict(value)
|
||||
if d is not None:
|
||||
if "department" in d:
|
||||
ent_data["department"] = d["department"]
|
||||
if "manager" in d:
|
||||
mgr = d["manager"]
|
||||
if isinstance(mgr, dict):
|
||||
ent_data["manager"] = mgr.get("value")
|
||||
return
|
||||
|
||||
# Dotted URN path, e.g. "urn:...:user:department"
|
||||
suffix = path[len(_ENTERPRISE_URN_LOWER) :].lstrip(":").lower()
|
||||
if suffix == "department":
|
||||
ent_data["department"] = str(value) if value is not None else None
|
||||
elif suffix == "manager":
|
||||
d = _to_dict(value)
|
||||
if d is not None:
|
||||
ent_data["manager"] = d.get("value")
|
||||
elif isinstance(value, str):
|
||||
ent_data["manager"] = value
|
||||
else:
|
||||
# Unknown enterprise attributes are silently ignored rather than
|
||||
# rejected — IdPs may send attributes we don't model yet.
|
||||
logger.warning("Ignoring unknown enterprise extension attribute '%s'", suffix)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Group PATCH
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def apply_group_patch(
|
||||
operations: list[ScimPatchOperation],
|
||||
current: ScimGroupResource,
|
||||
@@ -235,12 +402,14 @@ def _set_group_field(
|
||||
"""Set a single field on group data by SCIM path."""
|
||||
if path in ignored_paths:
|
||||
return
|
||||
elif path == "displayname":
|
||||
data["displayName"] = value
|
||||
elif path == "externalid":
|
||||
data["externalId"] = value
|
||||
else:
|
||||
raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH")
|
||||
|
||||
entry = _GROUP_REPLACE_PATHS.get(path)
|
||||
if entry:
|
||||
key, _ = entry
|
||||
data[key] = value
|
||||
return
|
||||
|
||||
raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH")
|
||||
|
||||
|
||||
def _apply_group_add(
|
||||
|
||||
@@ -2,13 +2,22 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimEmail
|
||||
from ee.onyx.server.scim.models import ScimEnterpriseExtension
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimManagerRef
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from ee.onyx.server.scim.models import ScimMeta
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimUserGroupRef
|
||||
@@ -17,6 +26,17 @@ from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
COMMON_IGNORED_PATCH_PATHS: frozenset[str] = frozenset(
|
||||
{
|
||||
"id",
|
||||
"schemas",
|
||||
"meta",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ScimProvider(ABC):
|
||||
"""Base class for provider-specific SCIM behavior.
|
||||
|
||||
@@ -41,12 +61,22 @@ class ScimProvider(ABC):
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def user_schemas(self) -> list[str]:
|
||||
"""Schema URIs to include in User resource responses.
|
||||
|
||||
Override in subclasses to advertise additional schemas (e.g. the
|
||||
enterprise extension for Entra ID).
|
||||
"""
|
||||
return [SCIM_USER_SCHEMA]
|
||||
|
||||
def build_user_resource(
|
||||
self,
|
||||
user: User,
|
||||
external_id: str | None = None,
|
||||
groups: list[tuple[int, str]] | None = None,
|
||||
scim_username: str | None = None,
|
||||
fields: ScimMappingFields | None = None,
|
||||
) -> ScimUserResource:
|
||||
"""Build a SCIM User response from an Onyx User.
|
||||
|
||||
@@ -58,27 +88,48 @@ class ScimProvider(ABC):
|
||||
for newly-created users.
|
||||
scim_username: The original-case userName from the IdP. Falls
|
||||
back to ``user.email`` (lowercase) when not available.
|
||||
fields: Stored mapping fields that the IdP expects round-tripped.
|
||||
"""
|
||||
f = fields or ScimMappingFields()
|
||||
group_refs = [
|
||||
ScimUserGroupRef(value=str(gid), display=gname)
|
||||
for gid, gname in (groups or [])
|
||||
]
|
||||
|
||||
# Use original-case userName if stored, otherwise fall back to the
|
||||
# lowercased email from the User model.
|
||||
username = scim_username or user.email
|
||||
|
||||
return ScimUserResource(
|
||||
# Build enterprise extension when at least one value is present.
|
||||
# Dynamically add the enterprise URN to schemas per RFC 7643 §3.0.
|
||||
enterprise_ext: ScimEnterpriseExtension | None = None
|
||||
schemas = list(self.user_schemas)
|
||||
if f.department is not None or f.manager is not None:
|
||||
manager_ref = (
|
||||
ScimManagerRef(value=f.manager) if f.manager is not None else None
|
||||
)
|
||||
enterprise_ext = ScimEnterpriseExtension(
|
||||
department=f.department,
|
||||
manager=manager_ref,
|
||||
)
|
||||
if SCIM_ENTERPRISE_USER_SCHEMA not in schemas:
|
||||
schemas.append(SCIM_ENTERPRISE_USER_SCHEMA)
|
||||
|
||||
name = self.build_scim_name(user, f)
|
||||
emails = _deserialize_emails(f.scim_emails_json, username)
|
||||
|
||||
resource = ScimUserResource(
|
||||
schemas=schemas,
|
||||
id=str(user.id),
|
||||
externalId=external_id,
|
||||
userName=username,
|
||||
name=self._build_scim_name(user),
|
||||
name=name,
|
||||
displayName=user.personal_name,
|
||||
emails=[ScimEmail(value=username, type="work", primary=True)],
|
||||
emails=emails,
|
||||
active=user.is_active,
|
||||
groups=group_refs,
|
||||
meta=ScimMeta(resourceType="User"),
|
||||
)
|
||||
resource.enterprise_extension = enterprise_ext
|
||||
return resource
|
||||
|
||||
def build_group_resource(
|
||||
self,
|
||||
@@ -98,19 +149,57 @@ class ScimProvider(ABC):
|
||||
meta=ScimMeta(resourceType="Group"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_scim_name(user: User) -> ScimName | None:
|
||||
"""Extract SCIM name components from a user's personal name."""
|
||||
def build_scim_name(
|
||||
self,
|
||||
user: User,
|
||||
fields: ScimMappingFields,
|
||||
) -> ScimName:
|
||||
"""Build SCIM name components for the response.
|
||||
|
||||
Round-trips stored ``given_name``/``family_name`` when available (so
|
||||
the IdP gets back what it sent). Falls back to splitting
|
||||
``personal_name`` for users provisioned before we stored components.
|
||||
Always returns a ScimName — Okta's spec tests expect ``name``
|
||||
(with ``givenName``/``familyName``) on every user resource.
|
||||
Providers may override for custom behavior.
|
||||
"""
|
||||
if fields.given_name is not None or fields.family_name is not None:
|
||||
return ScimName(
|
||||
givenName=fields.given_name or "",
|
||||
familyName=fields.family_name or "",
|
||||
formatted=user.personal_name or "",
|
||||
)
|
||||
if not user.personal_name:
|
||||
return None
|
||||
return ScimName(givenName="", familyName="", formatted="")
|
||||
parts = user.personal_name.split(" ", 1)
|
||||
return ScimName(
|
||||
givenName=parts[0],
|
||||
familyName=parts[1] if len(parts) > 1 else None,
|
||||
familyName=parts[1] if len(parts) > 1 else "",
|
||||
formatted=user.personal_name,
|
||||
)
|
||||
|
||||
|
||||
def _deserialize_emails(stored_json: str | None, username: str) -> list[ScimEmail]:
|
||||
"""Deserialize stored email entries or build a default work email."""
|
||||
if stored_json:
|
||||
try:
|
||||
entries = json.loads(stored_json)
|
||||
if isinstance(entries, list) and entries:
|
||||
return [ScimEmail(**e) for e in entries]
|
||||
except (json.JSONDecodeError, TypeError, ValidationError):
|
||||
logger.warning(
|
||||
"Corrupt scim_emails_json, falling back to default: %s", stored_json
|
||||
)
|
||||
return [ScimEmail(value=username, type="work", primary=True)]
|
||||
|
||||
|
||||
def serialize_emails(emails: list[ScimEmail]) -> str | None:
|
||||
"""Serialize SCIM email entries to JSON for storage."""
|
||||
if not emails:
|
||||
return None
|
||||
return json.dumps([e.model_dump(exclude_none=True) for e in emails])
|
||||
|
||||
|
||||
def get_default_provider() -> ScimProvider:
|
||||
"""Return the default SCIM provider.
|
||||
|
||||
|
||||
36
backend/ee/onyx/server/scim/providers/entra.py
Normal file
36
backend/ee/onyx/server/scim/providers/entra.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Entra ID (Azure AD) SCIM provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
|
||||
_ENTRA_IGNORED_PATCH_PATHS = COMMON_IGNORED_PATCH_PATHS
|
||||
|
||||
|
||||
class EntraProvider(ScimProvider):
|
||||
"""Entra ID (Azure AD) SCIM provider.
|
||||
|
||||
Entra behavioral notes:
|
||||
- Sends capitalized PATCH ops (``"Add"``, ``"Replace"``, ``"Remove"``)
|
||||
— handled by ``ScimPatchOperation.normalize_op`` validator.
|
||||
- Sends the enterprise extension URN as a key in path-less PATCH value
|
||||
dicts — handled by ``_set_enterprise_field`` in ``patch.py`` to
|
||||
store department/manager values.
|
||||
- Expects the enterprise extension schema in ``schemas`` arrays and
|
||||
``/Schemas`` + ``/ResourceTypes`` discovery endpoints.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "entra"
|
||||
|
||||
@property
|
||||
def ignored_patch_paths(self) -> frozenset[str]:
|
||||
return _ENTRA_IGNORED_PATCH_PATHS
|
||||
|
||||
@property
|
||||
def user_schemas(self) -> list[str]:
|
||||
return [SCIM_USER_SCHEMA, SCIM_ENTERPRISE_USER_SCHEMA]
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
|
||||
|
||||
@@ -22,4 +23,4 @@ class OktaProvider(ScimProvider):
|
||||
|
||||
@property
|
||||
def ignored_patch_paths(self) -> frozenset[str]:
|
||||
return frozenset({"id", "schemas", "meta"})
|
||||
return COMMON_IGNORED_PATCH_PATHS
|
||||
|
||||
@@ -4,6 +4,7 @@ Pre-built at import time — these never change at runtime. Separated from
|
||||
api.py to keep the endpoint module focused on request handling.
|
||||
"""
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_GROUP_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimResourceType
|
||||
@@ -20,6 +21,9 @@ USER_RESOURCE_TYPE = ScimResourceType.model_validate(
|
||||
"endpoint": "/scim/v2/Users",
|
||||
"description": "SCIM User resource",
|
||||
"schema": SCIM_USER_SCHEMA,
|
||||
"schemaExtensions": [
|
||||
{"schema": SCIM_ENTERPRISE_USER_SCHEMA, "required": False}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
@@ -104,6 +108,31 @@ USER_SCHEMA_DEF = ScimSchemaDefinition(
|
||||
],
|
||||
)
|
||||
|
||||
ENTERPRISE_USER_SCHEMA_DEF = ScimSchemaDefinition(
|
||||
id=SCIM_ENTERPRISE_USER_SCHEMA,
|
||||
name="EnterpriseUser",
|
||||
description="Enterprise User extension (RFC 7643 §4.3)",
|
||||
attributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="department",
|
||||
type="string",
|
||||
description="Department.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="manager",
|
||||
type="complex",
|
||||
description="The user's manager.",
|
||||
subAttributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="value",
|
||||
type="string",
|
||||
description="Manager user ID.",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
GROUP_SCHEMA_DEF = ScimSchemaDefinition(
|
||||
id=SCIM_GROUP_SCHEMA,
|
||||
name="Group",
|
||||
|
||||
@@ -18,8 +18,8 @@ from ee.onyx.server.enterprise_settings.store import (
|
||||
store_settings as store_ee_settings,
|
||||
)
|
||||
from ee.onyx.server.enterprise_settings.store import upload_logo
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import Tool
|
||||
@@ -117,15 +117,38 @@ def _seed_custom_tools(db_session: Session, tools: List[CustomToolSeed]) -> None
|
||||
def _seed_llms(
|
||||
db_session: Session, llm_upsert_requests: list[LLMProviderUpsertRequest]
|
||||
) -> None:
|
||||
if llm_upsert_requests:
|
||||
logger.notice("Seeding LLMs")
|
||||
seeded_providers = [
|
||||
upsert_llm_provider(llm_upsert_request, db_session)
|
||||
for llm_upsert_request in llm_upsert_requests
|
||||
]
|
||||
update_default_provider(
|
||||
provider_id=seeded_providers[0].id, db_session=db_session
|
||||
)
|
||||
if not llm_upsert_requests:
|
||||
return
|
||||
|
||||
logger.notice("Seeding LLMs")
|
||||
for request in llm_upsert_requests:
|
||||
existing = fetch_existing_llm_provider(name=request.name, db_session=db_session)
|
||||
if existing:
|
||||
request.id = existing.id
|
||||
seeded_providers = [
|
||||
upsert_llm_provider(llm_upsert_request, db_session)
|
||||
for llm_upsert_request in llm_upsert_requests
|
||||
]
|
||||
|
||||
default_provider = next(
|
||||
(p for p in seeded_providers if p.model_configurations), None
|
||||
)
|
||||
if not default_provider:
|
||||
return
|
||||
|
||||
visible_configs = [
|
||||
mc for mc in default_provider.model_configurations if mc.is_visible
|
||||
]
|
||||
default_config = (
|
||||
visible_configs[0]
|
||||
if visible_configs
|
||||
else default_provider.model_configurations[0]
|
||||
)
|
||||
update_default_provider(
|
||||
provider_id=default_provider.id,
|
||||
model_name=default_config.name,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:
|
||||
@@ -137,12 +160,6 @@ def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) ->
|
||||
user=None, # Seeding is done as admin
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=(
|
||||
persona.num_chunks if persona.num_chunks is not None else 0.0
|
||||
),
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
document_set_ids=persona.document_set_ids,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
@@ -154,6 +171,7 @@ def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) ->
|
||||
system_prompt=persona.system_prompt,
|
||||
task_prompt=persona.task_prompt,
|
||||
datetime_aware=persona.datetime_aware,
|
||||
featured=persona.featured,
|
||||
commit=False,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
@@ -109,6 +109,12 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
if metadata.status == _BLOCKING_STATUS:
|
||||
settings.application_status = metadata.status
|
||||
settings.ee_features_enabled = False
|
||||
elif metadata.used_seats > metadata.seats:
|
||||
# License is valid but seat limit exceeded
|
||||
settings.application_status = ApplicationStatus.SEAT_LIMIT_EXCEEDED
|
||||
settings.seat_count = metadata.seats
|
||||
settings.used_seats = metadata.used_seats
|
||||
settings.ee_features_enabled = True
|
||||
else:
|
||||
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
|
||||
settings.ee_features_enabled = True
|
||||
|
||||
@@ -33,6 +33,7 @@ from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.engine.sql_engine import get_session_with_shared_schema
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.image_generation import create_default_image_gen_config_from_api_key
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_cloud_embedding_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
@@ -302,12 +303,17 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
|
||||
has_set_default_provider = False
|
||||
|
||||
def _upsert(request: LLMProviderUpsertRequest) -> None:
|
||||
def _upsert(request: LLMProviderUpsertRequest, default_model: str) -> None:
|
||||
nonlocal has_set_default_provider
|
||||
try:
|
||||
existing = fetch_existing_llm_provider(
|
||||
name=request.name, db_session=db_session
|
||||
)
|
||||
if existing:
|
||||
request.id = existing.id
|
||||
provider = upsert_llm_provider(request, db_session)
|
||||
if not has_set_default_provider:
|
||||
update_default_provider(provider.id, db_session)
|
||||
update_default_provider(provider.id, default_model, db_session)
|
||||
has_set_default_provider = True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure {request.provider} provider: {e}")
|
||||
@@ -325,14 +331,13 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
OPENAI_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(openai_provider)
|
||||
_upsert(openai_provider, default_model_name)
|
||||
|
||||
# Create default image generation config using the OpenAI API key
|
||||
try:
|
||||
@@ -361,14 +366,13 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="Anthropic",
|
||||
provider=ANTHROPIC_PROVIDER_NAME,
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
ANTHROPIC_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(anthropic_provider)
|
||||
_upsert(anthropic_provider, default_model_name)
|
||||
else:
|
||||
logger.info(
|
||||
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
|
||||
@@ -393,14 +397,13 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="Google Vertex AI",
|
||||
provider=VERTEXAI_PROVIDER_NAME,
|
||||
custom_config=custom_config,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
VERTEXAI_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(vertexai_provider)
|
||||
_upsert(vertexai_provider, default_model_name)
|
||||
else:
|
||||
logger.info(
|
||||
"VERTEXAI_DEFAULT_CREDENTIALS not set, skipping Vertex AI provider configuration"
|
||||
@@ -432,12 +435,11 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="OpenRouter",
|
||||
provider=OPENROUTER_PROVIDER_NAME,
|
||||
api_key=OPENROUTER_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=model_configurations,
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(openrouter_provider)
|
||||
_upsert(openrouter_provider, default_model_name)
|
||||
else:
|
||||
logger.info(
|
||||
"OPENROUTER_DEFAULT_API_KEY not set, skipping OpenRouter provider configuration"
|
||||
|
||||
@@ -58,16 +58,27 @@ class OAuthTokenManager:
|
||||
if not user_token.token_data:
|
||||
raise ValueError("No token data available for refresh")
|
||||
|
||||
if (
|
||||
self.oauth_config.client_id is None
|
||||
or self.oauth_config.client_secret is None
|
||||
):
|
||||
raise ValueError(
|
||||
"OAuth client_id and client_secret are required for token refresh"
|
||||
)
|
||||
|
||||
token_data = self._unwrap_token_data(user_token.token_data)
|
||||
|
||||
data: dict[str, str] = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": token_data["refresh_token"],
|
||||
"client_id": self._unwrap_sensitive_str(self.oauth_config.client_id),
|
||||
"client_secret": self._unwrap_sensitive_str(
|
||||
self.oauth_config.client_secret
|
||||
),
|
||||
}
|
||||
response = requests.post(
|
||||
self.oauth_config.token_url,
|
||||
data={
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": token_data["refresh_token"],
|
||||
"client_id": self.oauth_config.client_id,
|
||||
"client_secret": self.oauth_config.client_secret,
|
||||
},
|
||||
data=data,
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -115,15 +126,26 @@ class OAuthTokenManager:
|
||||
|
||||
def exchange_code_for_token(self, code: str, redirect_uri: str) -> dict[str, Any]:
|
||||
"""Exchange authorization code for access token"""
|
||||
if (
|
||||
self.oauth_config.client_id is None
|
||||
or self.oauth_config.client_secret is None
|
||||
):
|
||||
raise ValueError(
|
||||
"OAuth client_id and client_secret are required for code exchange"
|
||||
)
|
||||
|
||||
data: dict[str, str] = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"client_id": self._unwrap_sensitive_str(self.oauth_config.client_id),
|
||||
"client_secret": self._unwrap_sensitive_str(
|
||||
self.oauth_config.client_secret
|
||||
),
|
||||
"redirect_uri": redirect_uri,
|
||||
}
|
||||
response = requests.post(
|
||||
self.oauth_config.token_url,
|
||||
data={
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"client_id": self.oauth_config.client_id,
|
||||
"client_secret": self.oauth_config.client_secret,
|
||||
"redirect_uri": redirect_uri,
|
||||
},
|
||||
data=data,
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -141,8 +163,13 @@ class OAuthTokenManager:
|
||||
oauth_config: OAuthConfig, redirect_uri: str, state: str
|
||||
) -> str:
|
||||
"""Build OAuth authorization URL"""
|
||||
if oauth_config.client_id is None:
|
||||
raise ValueError("OAuth client_id is required to build authorization URL")
|
||||
|
||||
params: dict[str, Any] = {
|
||||
"client_id": oauth_config.client_id,
|
||||
"client_id": OAuthTokenManager._unwrap_sensitive_str(
|
||||
oauth_config.client_id
|
||||
),
|
||||
"redirect_uri": redirect_uri,
|
||||
"response_type": "code",
|
||||
"state": state,
|
||||
@@ -161,6 +188,12 @@ class OAuthTokenManager:
|
||||
|
||||
return f"{oauth_config.authorization_url}{separator}{urlencode(params)}"
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_sensitive_str(value: SensitiveValue[str] | str) -> str:
|
||||
if isinstance(value, SensitiveValue):
|
||||
return value.get_value(apply_mask=False)
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_token_data(
|
||||
token_data: SensitiveValue[dict[str, Any]] | dict[str, Any],
|
||||
|
||||
@@ -543,7 +543,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
result = await db_session.execute(
|
||||
select(Persona.id)
|
||||
.where(
|
||||
Persona.is_default_persona.is_(True),
|
||||
Persona.featured.is_(True),
|
||||
Persona.is_public.is_(True),
|
||||
Persona.is_visible.is_(True),
|
||||
Persona.deleted.is_(False),
|
||||
@@ -725,11 +725,19 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
if user_by_session:
|
||||
user = user_by_session
|
||||
|
||||
# If the user is inactive, check seat availability before
|
||||
# upgrading role — otherwise they'd become an inactive BASIC
|
||||
# user who still can't log in.
|
||||
if not user.is_active:
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
enforce_seat_limit(sync_db)
|
||||
|
||||
await self.user_db.update(
|
||||
user,
|
||||
{
|
||||
"is_verified": is_verified_by_default,
|
||||
"role": UserRole.BASIC,
|
||||
**({"is_active": True} if not user.is_active else {}),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -241,8 +241,7 @@ _VECTOR_DB_BEAT_TASK_NAMES: set[str] = {
|
||||
"check-for-index-attempt-cleanup",
|
||||
"check-for-doc-permissions-sync",
|
||||
"check-for-external-group-sync",
|
||||
"check-for-documents-for-opensearch-migration",
|
||||
"migrate-documents-from-vespa-to-opensearch",
|
||||
"migrate-chunks-from-vespa-to-opensearch",
|
||||
}
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
|
||||
@@ -48,6 +48,7 @@ from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchDocumentIndex,
|
||||
)
|
||||
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
@@ -149,8 +150,12 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT)
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
opensearch_document_index = OpenSearchDocumentIndex(
|
||||
index_name=search_settings.index_name, tenant_state=tenant_state
|
||||
tenant_state=tenant_state,
|
||||
index_name=search_settings.index_name,
|
||||
embedding_dim=indexing_setting.final_embedding_dim,
|
||||
embedding_precision=indexing_setting.embedding_precision,
|
||||
)
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
|
||||
@@ -12,6 +12,7 @@ from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from retry import retry
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
@@ -75,7 +76,7 @@ def _user_file_queued_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
|
||||
def user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
@@ -413,34 +414,31 @@ def _process_user_file_with_indexing(
|
||||
raise RuntimeError(f"Indexing pipeline failed for user file {user_file_id}")
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
def process_user_file_impl(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool
|
||||
) -> None:
|
||||
task_logger.info(f"process_single_user_file - Starting id={user_file_id}")
|
||||
"""Core implementation for processing a single user file.
|
||||
|
||||
When redis_locking=True, acquires a per-file Redis lock and clears the
|
||||
queued-key guard (Celery path). When redis_locking=False, skips all Redis
|
||||
operations (BackgroundTask path).
|
||||
"""
|
||||
task_logger.info(f"process_user_file_impl - Starting id={user_file_id}")
|
||||
start = time.monotonic()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Clear the "queued" guard set by the beat generator so that the next beat
|
||||
# cycle can re-enqueue this file if it is still in PROCESSING state after
|
||||
# this task completes or fails.
|
||||
redis_client.delete(_user_file_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
if not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"process_single_user_file - Lock held, skipping user_file_id={user_file_id}"
|
||||
file_lock: RedisLock | None = None
|
||||
if redis_locking:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client.delete(_user_file_queued_key(user_file_id))
|
||||
file_lock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
|
||||
)
|
||||
return None
|
||||
if file_lock is not None and not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"process_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
return
|
||||
|
||||
documents: list[Document] = []
|
||||
try:
|
||||
@@ -448,15 +446,18 @@ def process_single_user_file(
|
||||
uf = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if not uf:
|
||||
task_logger.warning(
|
||||
f"process_single_user_file - UserFile not found id={user_file_id}"
|
||||
f"process_user_file_impl - UserFile not found id={user_file_id}"
|
||||
)
|
||||
return None
|
||||
return
|
||||
|
||||
if uf.status != UserFileStatus.PROCESSING:
|
||||
if uf.status not in (
|
||||
UserFileStatus.PROCESSING,
|
||||
UserFileStatus.INDEXING,
|
||||
):
|
||||
task_logger.info(
|
||||
f"process_single_user_file - Skipping id={user_file_id} status={uf.status}"
|
||||
f"process_user_file_impl - Skipping id={user_file_id} status={uf.status}"
|
||||
)
|
||||
return None
|
||||
return
|
||||
|
||||
connector = LocalFileConnector(
|
||||
file_locations=[uf.file_id],
|
||||
@@ -470,7 +471,6 @@ def process_single_user_file(
|
||||
[doc for doc in batch if not isinstance(doc, HierarchyNode)]
|
||||
)
|
||||
|
||||
# update the document id to userfile id in the documents
|
||||
for document in documents:
|
||||
document.id = str(user_file_id)
|
||||
document.source = DocumentSource.USER_FILE
|
||||
@@ -492,9 +492,8 @@ def process_single_user_file(
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"process_single_user_file - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
f"process_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
# don't update the status if the user file is being deleted
|
||||
current_user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if (
|
||||
current_user_file
|
||||
@@ -503,33 +502,42 @@ def process_single_user_file(
|
||||
uf.status = UserFileStatus.FAILED
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
return None
|
||||
return
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"process_single_user_file - Finished id={user_file_id} docs={len(documents)} elapsed={elapsed:.2f}s"
|
||||
f"process_user_file_impl - Finished id={user_file_id} docs={len(documents)} elapsed={elapsed:.2f}s"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
# Attempt to mark the file as failed
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
uf = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if uf:
|
||||
# don't update the status if the user file is being deleted
|
||||
if uf.status != UserFileStatus.DELETING:
|
||||
uf.status = UserFileStatus.FAILED
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
|
||||
task_logger.exception(
|
||||
f"process_single_user_file - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
f"process_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
if file_lock.owned():
|
||||
if file_lock is not None and file_lock.owned():
|
||||
file_lock.release()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
process_user_file_impl(
|
||||
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
|
||||
)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_USER_FILE_DELETE,
|
||||
soft_time_limit=300,
|
||||
@@ -580,36 +588,38 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file_delete(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
def delete_user_file_impl(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool
|
||||
) -> None:
|
||||
"""Process a single user file delete."""
|
||||
task_logger.info(f"process_single_user_file_delete - Starting id={user_file_id}")
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_delete_lock_key(user_file_id),
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
if not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"process_single_user_file_delete - Lock held, skipping user_file_id={user_file_id}"
|
||||
"""Core implementation for deleting a single user file.
|
||||
|
||||
When redis_locking=True, acquires a per-file Redis lock (Celery path).
|
||||
When redis_locking=False, skips Redis operations (BackgroundTask path).
|
||||
"""
|
||||
task_logger.info(f"delete_user_file_impl - Starting id={user_file_id}")
|
||||
|
||||
file_lock: RedisLock | None = None
|
||||
if redis_locking:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
file_lock = redis_client.lock(
|
||||
_user_file_delete_lock_key(user_file_id),
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
return None
|
||||
if file_lock is not None and not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"delete_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if not user_file:
|
||||
task_logger.info(
|
||||
f"process_single_user_file_delete - User file not found id={user_file_id}"
|
||||
f"delete_user_file_impl - User file not found id={user_file_id}"
|
||||
)
|
||||
return None
|
||||
return
|
||||
|
||||
# 1) Delete vector DB chunks (skip when disabled)
|
||||
if not DISABLE_VECTOR_DB:
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
@@ -647,7 +657,6 @@ def process_single_user_file_delete(
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
|
||||
# 2) Delete the user-uploaded file content from filestore (blob + metadata)
|
||||
file_store = get_default_file_store()
|
||||
try:
|
||||
file_store.delete_file(user_file.file_id)
|
||||
@@ -655,26 +664,33 @@ def process_single_user_file_delete(
|
||||
user_file_id_to_plaintext_file_name(user_file.id)
|
||||
)
|
||||
except Exception as e:
|
||||
# This block executed only if the file is not found in the filestore
|
||||
task_logger.exception(
|
||||
f"process_single_user_file_delete - Error deleting file id={user_file.id} - {e.__class__.__name__}"
|
||||
f"delete_user_file_impl - Error deleting file id={user_file.id} - {e.__class__.__name__}"
|
||||
)
|
||||
|
||||
# 3) Finally, delete the UserFile row
|
||||
db_session.delete(user_file)
|
||||
db_session.commit()
|
||||
task_logger.info(
|
||||
f"process_single_user_file_delete - Completed id={user_file_id}"
|
||||
)
|
||||
task_logger.info(f"delete_user_file_impl - Completed id={user_file_id}")
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"process_single_user_file_delete - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
f"delete_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
if file_lock.owned():
|
||||
if file_lock is not None and file_lock.owned():
|
||||
file_lock.release()
|
||||
return None
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file_delete(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
delete_user_file_impl(
|
||||
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
|
||||
)
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -712,7 +728,10 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
|
||||
db_session.execute(
|
||||
select(UserFile.id).where(
|
||||
sa.and_(
|
||||
UserFile.needs_project_sync.is_(True),
|
||||
sa.or_(
|
||||
UserFile.needs_project_sync.is_(True),
|
||||
UserFile.needs_persona_sync.is_(True),
|
||||
),
|
||||
UserFile.status == UserFileStatus.COMPLETED,
|
||||
)
|
||||
)
|
||||
@@ -743,43 +762,44 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file_project_sync(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
def project_sync_user_file_impl(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool
|
||||
) -> None:
|
||||
"""Process a single user file project sync."""
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - Starting id={user_file_id}"
|
||||
)
|
||||
"""Core implementation for syncing a user file's project/persona metadata.
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
|
||||
When redis_locking=True, acquires a per-file Redis lock and clears the
|
||||
queued-key guard (Celery path). When redis_locking=False, skips Redis
|
||||
operations (BackgroundTask path).
|
||||
"""
|
||||
task_logger.info(f"project_sync_user_file_impl - Starting id={user_file_id}")
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_project_sync_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
if not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - Lock held, skipping user_file_id={user_file_id}"
|
||||
file_lock: RedisLock | None = None
|
||||
if redis_locking:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
|
||||
file_lock = redis_client.lock(
|
||||
user_file_project_sync_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
return None
|
||||
if file_lock is not None and not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"project_sync_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
user_file = db_session.execute(
|
||||
select(UserFile)
|
||||
.where(UserFile.id == _as_uuid(user_file_id))
|
||||
.options(selectinload(UserFile.assistants))
|
||||
).scalar_one_or_none()
|
||||
if not user_file:
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - User file not found id={user_file_id}"
|
||||
f"project_sync_user_file_impl - User file not found id={user_file_id}"
|
||||
)
|
||||
return None
|
||||
return
|
||||
|
||||
# Sync project metadata to vector DB (skip when disabled)
|
||||
if not DISABLE_VECTOR_DB:
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
@@ -800,20 +820,25 @@ def process_single_user_file_project_sync(
|
||||
]
|
||||
|
||||
project_ids = [project.id for project in user_file.projects]
|
||||
persona_ids = [p.id for p in user_file.assistants if not p.deleted]
|
||||
for retry_document_index in retry_document_indices:
|
||||
retry_document_index.update_single(
|
||||
doc_id=str(user_file.id),
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=user_file.chunk_count,
|
||||
fields=None,
|
||||
user_fields=VespaDocumentUserFields(user_projects=project_ids),
|
||||
user_fields=VespaDocumentUserFields(
|
||||
user_projects=project_ids,
|
||||
personas=persona_ids,
|
||||
),
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - User file id={user_file_id}"
|
||||
f"project_sync_user_file_impl - User file id={user_file_id}"
|
||||
)
|
||||
|
||||
user_file.needs_project_sync = False
|
||||
user_file.needs_persona_sync = False
|
||||
user_file.last_project_sync_at = datetime.datetime.now(
|
||||
datetime.timezone.utc
|
||||
)
|
||||
@@ -822,11 +847,21 @@ def process_single_user_file_project_sync(
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"process_single_user_file_project_sync - Error syncing project for file id={user_file_id} - {e.__class__.__name__}"
|
||||
f"project_sync_user_file_impl - Error syncing project for file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
if file_lock.owned():
|
||||
if file_lock is not None and file_lock.owned():
|
||||
file_lock.release()
|
||||
|
||||
return None
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file_project_sync(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
|
||||
)
|
||||
|
||||
287
backend/onyx/background/periodic_poller.py
Normal file
287
backend/onyx/background/periodic_poller.py
Normal file
@@ -0,0 +1,287 @@
|
||||
"""Periodic poller for NO_VECTOR_DB deployments.
|
||||
|
||||
Replaces Celery Beat and background workers with a lightweight daemon thread
|
||||
that runs from the API server process. Two responsibilities:
|
||||
|
||||
1. Recovery polling (every 30 s): re-processes user files stuck in
|
||||
PROCESSING / DELETING / needs_sync states via the drain loops defined
|
||||
in ``task_utils.py``.
|
||||
|
||||
2. Periodic task execution (configurable intervals): runs LLM model updates
|
||||
and scheduled evals at their configured cadences, with Postgres advisory
|
||||
lock deduplication across multiple API server instances.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
RECOVERY_INTERVAL_SECONDS = 30
|
||||
PERIODIC_TASK_LOCK_BASE = 20_000
|
||||
PERIODIC_TASK_KV_PREFIX = "periodic_poller:last_claimed:"
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Periodic task definitions
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
_NEVER_RAN: float = -1e18
|
||||
|
||||
|
||||
@dataclass
|
||||
class _PeriodicTaskDef:
|
||||
name: str
|
||||
interval_seconds: float
|
||||
lock_id: int
|
||||
run_fn: Callable[[], None]
|
||||
last_run_at: float = field(default=_NEVER_RAN)
|
||||
|
||||
|
||||
def _run_auto_llm_update() -> None:
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
|
||||
if not AUTO_LLM_CONFIG_URL:
|
||||
return
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.llm.well_known_providers.auto_update_service import (
|
||||
sync_llm_models_from_github,
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
sync_llm_models_from_github(db_session)
|
||||
|
||||
|
||||
def _run_scheduled_eval() -> None:
|
||||
from onyx.configs.app_configs import BRAINTRUST_API_KEY
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_PERMISSIONS_EMAIL
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_PROJECT
|
||||
|
||||
if not all(
|
||||
[
|
||||
BRAINTRUST_API_KEY,
|
||||
SCHEDULED_EVAL_PROJECT,
|
||||
SCHEDULED_EVAL_DATASET_NAMES,
|
||||
SCHEDULED_EVAL_PERMISSIONS_EMAIL,
|
||||
]
|
||||
):
|
||||
return
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from onyx.evals.eval import run_eval
|
||||
from onyx.evals.models import EvalConfigurationOptions
|
||||
|
||||
run_timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
for dataset_name in SCHEDULED_EVAL_DATASET_NAMES:
|
||||
try:
|
||||
run_eval(
|
||||
configuration=EvalConfigurationOptions(
|
||||
search_permissions_email=SCHEDULED_EVAL_PERMISSIONS_EMAIL,
|
||||
dataset_name=dataset_name,
|
||||
no_send_logs=False,
|
||||
braintrust_project=SCHEDULED_EVAL_PROJECT,
|
||||
experiment_name=f"{dataset_name} - {run_timestamp}",
|
||||
),
|
||||
remote_dataset_name=dataset_name,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Periodic poller - Failed scheduled eval for dataset {dataset_name}"
|
||||
)
|
||||
|
||||
|
||||
def _build_periodic_tasks() -> list[_PeriodicTaskDef]:
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
|
||||
tasks: list[_PeriodicTaskDef] = []
|
||||
if AUTO_LLM_CONFIG_URL:
|
||||
tasks.append(
|
||||
_PeriodicTaskDef(
|
||||
name="auto-llm-update",
|
||||
interval_seconds=AUTO_LLM_UPDATE_INTERVAL_SECONDS,
|
||||
lock_id=PERIODIC_TASK_LOCK_BASE,
|
||||
run_fn=_run_auto_llm_update,
|
||||
)
|
||||
)
|
||||
if SCHEDULED_EVAL_DATASET_NAMES:
|
||||
tasks.append(
|
||||
_PeriodicTaskDef(
|
||||
name="scheduled-eval",
|
||||
interval_seconds=7 * 24 * 3600,
|
||||
lock_id=PERIODIC_TASK_LOCK_BASE + 1,
|
||||
run_fn=_run_scheduled_eval,
|
||||
)
|
||||
)
|
||||
return tasks
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Periodic task runner with advisory-lock-guarded claim
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def _try_claim_task(task_def: _PeriodicTaskDef) -> bool:
|
||||
"""Atomically check whether *task_def* should run and record a claim.
|
||||
|
||||
Uses a transaction-scoped advisory lock for atomicity combined with a
|
||||
``KVStore`` timestamp for cross-instance dedup. The DB session is held
|
||||
only for this brief claim transaction, not during task execution.
|
||||
"""
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import KVStore
|
||||
|
||||
kv_key = PERIODIC_TASK_KV_PREFIX + task_def.name
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
acquired = db_session.execute(
|
||||
text("SELECT pg_try_advisory_xact_lock(:id)"),
|
||||
{"id": task_def.lock_id},
|
||||
).scalar()
|
||||
if not acquired:
|
||||
return False
|
||||
|
||||
row = db_session.query(KVStore).filter_by(key=kv_key).first()
|
||||
if row and row.value is not None:
|
||||
last_claimed = datetime.fromisoformat(str(row.value))
|
||||
elapsed = (datetime.now(timezone.utc) - last_claimed).total_seconds()
|
||||
if elapsed < task_def.interval_seconds:
|
||||
return False
|
||||
|
||||
now_ts = datetime.now(timezone.utc).isoformat()
|
||||
if row:
|
||||
row.value = now_ts
|
||||
else:
|
||||
db_session.add(KVStore(key=kv_key, value=now_ts))
|
||||
db_session.commit()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _try_run_periodic_task(task_def: _PeriodicTaskDef) -> None:
|
||||
"""Run *task_def* if its interval has elapsed and no peer holds the lock."""
|
||||
now = time.monotonic()
|
||||
if now - task_def.last_run_at < task_def.interval_seconds:
|
||||
return
|
||||
|
||||
if not _try_claim_task(task_def):
|
||||
return
|
||||
|
||||
try:
|
||||
task_def.run_fn()
|
||||
task_def.last_run_at = now
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Periodic poller - Error running periodic task {task_def.name}"
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Recovery / drain loop runner
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def _run_drain_loops(tenant_id: str) -> None:
|
||||
from onyx.background.task_utils import drain_delete_loop
|
||||
from onyx.background.task_utils import drain_processing_loop
|
||||
from onyx.background.task_utils import drain_project_sync_loop
|
||||
|
||||
drain_processing_loop(tenant_id)
|
||||
drain_delete_loop(tenant_id)
|
||||
drain_project_sync_loop(tenant_id)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Startup recovery (10g)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def recover_stuck_user_files(tenant_id: str) -> None:
|
||||
"""Run all drain loops once to re-process files left in intermediate states.
|
||||
|
||||
Called from ``lifespan()`` on startup when ``DISABLE_VECTOR_DB`` is set.
|
||||
"""
|
||||
logger.info("recover_stuck_user_files - Checking for stuck user files")
|
||||
try:
|
||||
_run_drain_loops(tenant_id)
|
||||
except Exception:
|
||||
logger.exception("recover_stuck_user_files - Error during recovery")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Daemon thread (10f)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
_shutdown_event = threading.Event()
|
||||
_poller_thread: threading.Thread | None = None
|
||||
|
||||
|
||||
def _poller_loop(tenant_id: str) -> None:
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
periodic_tasks = _build_periodic_tasks()
|
||||
logger.info(
|
||||
f"Periodic poller started with {len(periodic_tasks)} periodic task(s): "
|
||||
f"{[t.name for t in periodic_tasks]}"
|
||||
)
|
||||
|
||||
while not _shutdown_event.is_set():
|
||||
try:
|
||||
_run_drain_loops(tenant_id)
|
||||
except Exception:
|
||||
logger.exception("Periodic poller - Error in recovery polling")
|
||||
|
||||
for task_def in periodic_tasks:
|
||||
try:
|
||||
_try_run_periodic_task(task_def)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Periodic poller - Unhandled error checking task {task_def.name}"
|
||||
)
|
||||
|
||||
_shutdown_event.wait(RECOVERY_INTERVAL_SECONDS)
|
||||
|
||||
|
||||
def start_periodic_poller(tenant_id: str) -> None:
|
||||
"""Start the periodic poller daemon thread."""
|
||||
global _poller_thread # noqa: PLW0603
|
||||
_shutdown_event.clear()
|
||||
_poller_thread = threading.Thread(
|
||||
target=_poller_loop,
|
||||
args=(tenant_id,),
|
||||
daemon=True,
|
||||
name="no-vectordb-periodic-poller",
|
||||
)
|
||||
_poller_thread.start()
|
||||
logger.info("Periodic poller thread started")
|
||||
|
||||
|
||||
def stop_periodic_poller() -> None:
|
||||
"""Signal the periodic poller to stop and wait for it to exit."""
|
||||
global _poller_thread # noqa: PLW0603
|
||||
if _poller_thread is None:
|
||||
return
|
||||
_shutdown_event.set()
|
||||
_poller_thread.join(timeout=10)
|
||||
if _poller_thread.is_alive():
|
||||
logger.warning("Periodic poller thread did not stop within timeout")
|
||||
_poller_thread = None
|
||||
logger.info("Periodic poller thread stopped")
|
||||
@@ -1,3 +1,33 @@
|
||||
"""Background task utilities.
|
||||
|
||||
Contains query-history report helpers (used by all deployment modes) and
|
||||
in-process background task execution helpers for NO_VECTOR_DB mode:
|
||||
|
||||
- Atomic claim-and-mark helpers that prevent duplicate processing
|
||||
- Drain loops that process all pending user file work
|
||||
|
||||
Each claim function runs a short-lived transaction: SELECT ... FOR UPDATE
|
||||
SKIP LOCKED, UPDATE the row to remove it from future queries, COMMIT.
|
||||
After the commit the row lock is released, but the row is no longer
|
||||
eligible for re-claiming. No long-lived sessions or advisory locks.
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Query-history report helpers (pre-existing, used by all modes)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
QUERY_REPORT_NAME_PREFIX = "query-history"
|
||||
|
||||
|
||||
@@ -9,3 +39,142 @@ def construct_query_history_report_name(
|
||||
|
||||
def extract_task_id_from_query_history_report_name(name: str) -> str:
|
||||
return name.removeprefix(f"{QUERY_REPORT_NAME_PREFIX}-").removesuffix(".csv")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Atomic claim-and-mark helpers
|
||||
# ------------------------------------------------------------------
|
||||
# Each function runs inside a single short-lived session/transaction:
|
||||
# 1. SELECT ... FOR UPDATE SKIP LOCKED (locks one eligible row)
|
||||
# 2. UPDATE the row so it is no longer eligible
|
||||
# 3. COMMIT (releases the row lock)
|
||||
# After the commit, no other drain loop can claim the same row.
|
||||
|
||||
|
||||
def _claim_next_processing_file(db_session: Session) -> UUID | None:
|
||||
"""Claim the next PROCESSING file by transitioning it to INDEXING.
|
||||
|
||||
Returns the file id, or None when no eligible files remain.
|
||||
"""
|
||||
file_id = db_session.execute(
|
||||
select(UserFile.id)
|
||||
.where(UserFile.status == UserFileStatus.PROCESSING)
|
||||
.order_by(UserFile.created_at)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
).scalar_one_or_none()
|
||||
if file_id is None:
|
||||
return None
|
||||
|
||||
db_session.execute(
|
||||
sa.update(UserFile)
|
||||
.where(UserFile.id == file_id)
|
||||
.values(status=UserFileStatus.INDEXING)
|
||||
)
|
||||
db_session.commit()
|
||||
return file_id
|
||||
|
||||
|
||||
def _claim_next_deleting_file(db_session: Session) -> UUID | None:
|
||||
"""Claim the next DELETING file.
|
||||
|
||||
No status transition needed — the impl deletes the row on success.
|
||||
The short-lived FOR UPDATE lock prevents concurrent claims.
|
||||
"""
|
||||
file_id = db_session.execute(
|
||||
select(UserFile.id)
|
||||
.where(UserFile.status == UserFileStatus.DELETING)
|
||||
.order_by(UserFile.created_at)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
).scalar_one_or_none()
|
||||
# Commit to release the row lock promptly.
|
||||
db_session.commit()
|
||||
return file_id
|
||||
|
||||
|
||||
def _claim_next_sync_file(db_session: Session) -> UUID | None:
|
||||
"""Claim the next file needing project/persona sync.
|
||||
|
||||
No status transition needed — the impl clears the sync flags on
|
||||
success. The short-lived FOR UPDATE lock prevents concurrent claims.
|
||||
"""
|
||||
file_id = db_session.execute(
|
||||
select(UserFile.id)
|
||||
.where(
|
||||
sa.and_(
|
||||
sa.or_(
|
||||
UserFile.needs_project_sync.is_(True),
|
||||
UserFile.needs_persona_sync.is_(True),
|
||||
),
|
||||
UserFile.status == UserFileStatus.COMPLETED,
|
||||
)
|
||||
)
|
||||
.order_by(UserFile.created_at)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
).scalar_one_or_none()
|
||||
db_session.commit()
|
||||
return file_id
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Drain loops — process *all* pending work of each type
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def drain_processing_loop(tenant_id: str) -> None:
|
||||
"""Process all pending PROCESSING user files."""
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
process_user_file_impl,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
while True:
|
||||
with get_session_with_current_tenant() as session:
|
||||
file_id = _claim_next_processing_file(session)
|
||||
if file_id is None:
|
||||
break
|
||||
process_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
|
||||
def drain_delete_loop(tenant_id: str) -> None:
|
||||
"""Delete all pending DELETING user files."""
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
delete_user_file_impl,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
while True:
|
||||
with get_session_with_current_tenant() as session:
|
||||
file_id = _claim_next_deleting_file(session)
|
||||
if file_id is None:
|
||||
break
|
||||
delete_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
|
||||
def drain_project_sync_loop(tenant_id: str) -> None:
|
||||
"""Sync all pending project/persona metadata for user files."""
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
project_sync_user_file_impl,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
while True:
|
||||
with get_session_with_current_tenant() as session:
|
||||
file_id = _claim_next_sync_file(session)
|
||||
if file_id is None:
|
||||
break
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
45
backend/onyx/cache/factory.py
vendored
Normal file
45
backend/onyx/cache/factory.py
vendored
Normal file
@@ -0,0 +1,45 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.interface import CacheBackendType
|
||||
from onyx.configs.app_configs import CACHE_BACKEND
|
||||
|
||||
|
||||
def _build_redis_backend(tenant_id: str) -> CacheBackend:
|
||||
from onyx.cache.redis_backend import RedisCacheBackend
|
||||
from onyx.redis.redis_pool import redis_pool
|
||||
|
||||
return RedisCacheBackend(redis_pool.get_client(tenant_id))
|
||||
|
||||
|
||||
_BACKEND_BUILDERS: dict[CacheBackendType, Callable[[str], CacheBackend]] = {
|
||||
CacheBackendType.REDIS: _build_redis_backend,
|
||||
# CacheBackendType.POSTGRES will be added in a follow-up PR.
|
||||
}
|
||||
|
||||
|
||||
def get_cache_backend(*, tenant_id: str | None = None) -> CacheBackend:
|
||||
"""Return a tenant-aware ``CacheBackend``.
|
||||
|
||||
If *tenant_id* is ``None``, the current tenant is read from the
|
||||
thread-local context variable (same behaviour as ``get_redis_client``).
|
||||
"""
|
||||
if tenant_id is None:
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
builder = _BACKEND_BUILDERS.get(CACHE_BACKEND)
|
||||
if builder is None:
|
||||
raise ValueError(
|
||||
f"Unsupported CACHE_BACKEND={CACHE_BACKEND!r}. "
|
||||
f"Supported values: {[t.value for t in CacheBackendType]}"
|
||||
)
|
||||
return builder(tenant_id)
|
||||
|
||||
|
||||
def get_shared_cache_backend() -> CacheBackend:
|
||||
"""Return a ``CacheBackend`` in the shared (cross-tenant) namespace."""
|
||||
from shared_configs.configs import DEFAULT_REDIS_PREFIX
|
||||
|
||||
return get_cache_backend(tenant_id=DEFAULT_REDIS_PREFIX)
|
||||
89
backend/onyx/cache/interface.py
vendored
Normal file
89
backend/onyx/cache/interface.py
vendored
Normal file
@@ -0,0 +1,89 @@
|
||||
import abc
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class CacheBackendType(str, Enum):
|
||||
REDIS = "redis"
|
||||
POSTGRES = "postgres"
|
||||
|
||||
|
||||
class CacheLock(abc.ABC):
|
||||
"""Abstract distributed lock returned by CacheBackend.lock()."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def acquire(
|
||||
self,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: float | None = None,
|
||||
) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def release(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def owned(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CacheBackend(abc.ABC):
|
||||
"""Thin abstraction over a key-value cache with TTL, locks, and blocking lists.
|
||||
|
||||
Covers the subset of Redis operations used outside of Celery. When
|
||||
CACHE_BACKEND=postgres, a PostgreSQL-backed implementation is used instead.
|
||||
"""
|
||||
|
||||
# -- basic key/value ---------------------------------------------------
|
||||
|
||||
@abc.abstractmethod
|
||||
def get(self, key: str) -> bytes | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str | bytes | int | float,
|
||||
ex: int | None = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete(self, key: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def exists(self, key: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
# -- TTL ---------------------------------------------------------------
|
||||
|
||||
@abc.abstractmethod
|
||||
def expire(self, key: str, seconds: int) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def ttl(self, key: str) -> int:
|
||||
"""Return remaining TTL in seconds. -1 if no expiry, -2 if key missing."""
|
||||
raise NotImplementedError
|
||||
|
||||
# -- distributed lock --------------------------------------------------
|
||||
|
||||
@abc.abstractmethod
|
||||
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
|
||||
raise NotImplementedError
|
||||
|
||||
# -- blocking list (used by MCP OAuth BLPOP pattern) -------------------
|
||||
|
||||
@abc.abstractmethod
|
||||
def rpush(self, key: str, value: str | bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
|
||||
"""Block until a value is available on one of *keys*, or *timeout* expires.
|
||||
|
||||
Returns ``(key, value)`` or ``None`` on timeout.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
92
backend/onyx/cache/redis_backend.py
vendored
Normal file
92
backend/onyx/cache/redis_backend.py
vendored
Normal file
@@ -0,0 +1,92 @@
|
||||
from typing import cast
|
||||
|
||||
from redis.client import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.interface import CacheLock
|
||||
|
||||
|
||||
class RedisCacheLock(CacheLock):
|
||||
"""Wraps ``redis.lock.Lock`` behind the ``CacheLock`` interface."""
|
||||
|
||||
def __init__(self, lock: RedisLock) -> None:
|
||||
self._lock = lock
|
||||
|
||||
def acquire(
|
||||
self,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: float | None = None,
|
||||
) -> bool:
|
||||
return bool(
|
||||
self._lock.acquire(
|
||||
blocking=blocking,
|
||||
blocking_timeout=blocking_timeout,
|
||||
)
|
||||
)
|
||||
|
||||
def release(self) -> None:
|
||||
self._lock.release()
|
||||
|
||||
def owned(self) -> bool:
|
||||
return bool(self._lock.owned())
|
||||
|
||||
|
||||
class RedisCacheBackend(CacheBackend):
|
||||
"""``CacheBackend`` implementation that delegates to a ``redis.Redis`` client.
|
||||
|
||||
This is a thin pass-through — every method maps 1-to-1 to the underlying
|
||||
Redis command. ``TenantRedis`` key-prefixing is handled by the client
|
||||
itself (provided by ``get_redis_client``).
|
||||
"""
|
||||
|
||||
def __init__(self, redis_client: Redis) -> None:
|
||||
self._r = redis_client
|
||||
|
||||
# -- basic key/value ---------------------------------------------------
|
||||
|
||||
def get(self, key: str) -> bytes | None:
|
||||
val = self._r.get(key)
|
||||
if val is None:
|
||||
return None
|
||||
if isinstance(val, bytes):
|
||||
return val
|
||||
return str(val).encode()
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str | bytes | int | float,
|
||||
ex: int | None = None,
|
||||
) -> None:
|
||||
self._r.set(key, value, ex=ex)
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
self._r.delete(key)
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
return bool(self._r.exists(key))
|
||||
|
||||
# -- TTL ---------------------------------------------------------------
|
||||
|
||||
def expire(self, key: str, seconds: int) -> None:
|
||||
self._r.expire(key, seconds)
|
||||
|
||||
def ttl(self, key: str) -> int:
|
||||
return cast(int, self._r.ttl(key))
|
||||
|
||||
# -- distributed lock --------------------------------------------------
|
||||
|
||||
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
|
||||
return RedisCacheLock(self._r.lock(name, timeout=timeout))
|
||||
|
||||
# -- blocking list (MCP OAuth BLPOP pattern) ---------------------------
|
||||
|
||||
def rpush(self, key: str, value: str | bytes) -> None:
|
||||
self._r.rpush(key, value)
|
||||
|
||||
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
|
||||
result = cast(list[bytes] | None, self._r.blpop(keys, timeout=timeout))
|
||||
if result is None:
|
||||
return None
|
||||
return (result[0], result[1])
|
||||
@@ -3,7 +3,6 @@ import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from queue import Empty
|
||||
from typing import Any
|
||||
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.emitter import Emitter
|
||||
@@ -163,13 +162,11 @@ class ChatStateContainer:
|
||||
|
||||
|
||||
def run_chat_loop_with_state_containers(
|
||||
func: Callable[..., None],
|
||||
chat_loop_func: Callable[[Emitter, ChatStateContainer], None],
|
||||
completion_callback: Callable[[ChatStateContainer], None],
|
||||
is_connected: Callable[[], bool],
|
||||
emitter: Emitter,
|
||||
state_container: ChatStateContainer,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Generator[Packet, None]:
|
||||
"""
|
||||
Explicit wrapper function that runs a function in a background thread
|
||||
@@ -180,19 +177,18 @@ def run_chat_loop_with_state_containers(
|
||||
|
||||
Args:
|
||||
func: The function to wrap (should accept emitter and state_container as first and second args)
|
||||
completion_callback: Callback function to call when the function completes
|
||||
emitter: Emitter instance for sending packets
|
||||
state_container: ChatStateContainer instance for accumulating state
|
||||
is_connected: Callable that returns False when stop signal is set
|
||||
*args: Additional positional arguments for func
|
||||
**kwargs: Additional keyword arguments for func
|
||||
|
||||
Usage:
|
||||
packets = run_chat_loop_with_state_containers(
|
||||
my_func,
|
||||
completion_callback=completion_callback,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
is_connected=check_func,
|
||||
arg1, arg2, kwarg1=value1
|
||||
)
|
||||
for packet in packets:
|
||||
# Process packets
|
||||
@@ -201,9 +197,7 @@ def run_chat_loop_with_state_containers(
|
||||
|
||||
def run_with_exception_capture() -> None:
|
||||
try:
|
||||
# Ensure state_container is passed explicitly, removing it from kwargs if present
|
||||
kwargs_with_state = {**kwargs, "state_container": state_container}
|
||||
func(emitter, *args, **kwargs_with_state)
|
||||
chat_loop_func(emitter, state_container)
|
||||
except Exception as e:
|
||||
# If execution fails, emit an exception packet
|
||||
emitter.emit(
|
||||
|
||||
@@ -461,7 +461,7 @@ def _build_tool_call_response_history_message(
|
||||
def convert_chat_history(
|
||||
chat_history: list[ChatMessage],
|
||||
files: list[ChatLoadedFile],
|
||||
project_image_files: list[ChatLoadedFile],
|
||||
context_image_files: list[ChatLoadedFile],
|
||||
additional_context: str | None,
|
||||
token_counter: Callable[[str], int],
|
||||
tool_id_to_name_map: dict[int, str],
|
||||
@@ -541,11 +541,11 @@ def convert_chat_history(
|
||||
)
|
||||
|
||||
# Add the user message with image files attached
|
||||
# If this is the last USER message, also include project_image_files
|
||||
# Note: project image file tokens are NOT counted in the token count
|
||||
# If this is the last USER message, also include context_image_files
|
||||
# Note: context image file tokens are NOT counted in the token count
|
||||
if idx == last_user_message_idx:
|
||||
if project_image_files:
|
||||
image_files.extend(project_image_files)
|
||||
if context_image_files:
|
||||
image_files.extend(context_image_files)
|
||||
|
||||
if additional_context:
|
||||
simple_messages.append(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -15,10 +16,10 @@ from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.llm_step import extract_tool_calls_from_response_text
|
||||
from onyx.chat.llm_step import run_llm_step
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
from onyx.chat.models import ContextFileMetadata
|
||||
from onyx.chat.models import ExtractedContextFiles
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.models import ProjectFileMetadata
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.chat.prompt_utils import build_reminder_message
|
||||
from onyx.chat.prompt_utils import build_system_prompt
|
||||
@@ -203,17 +204,17 @@ def _try_fallback_tool_extraction(
|
||||
MAX_LLM_CYCLES = 6
|
||||
|
||||
|
||||
def _build_project_file_citation_mapping(
|
||||
project_file_metadata: list[ProjectFileMetadata],
|
||||
def _build_context_file_citation_mapping(
|
||||
file_metadata: list[ContextFileMetadata],
|
||||
starting_citation_num: int = 1,
|
||||
) -> CitationMapping:
|
||||
"""Build citation mapping for project files.
|
||||
"""Build citation mapping for context files.
|
||||
|
||||
Converts project file metadata into SearchDoc objects that can be cited.
|
||||
Converts context file metadata into SearchDoc objects that can be cited.
|
||||
Citation numbers start from the provided starting number.
|
||||
|
||||
Args:
|
||||
project_file_metadata: List of project file metadata
|
||||
file_metadata: List of context file metadata
|
||||
starting_citation_num: Starting citation number (default: 1)
|
||||
|
||||
Returns:
|
||||
@@ -221,8 +222,7 @@ def _build_project_file_citation_mapping(
|
||||
"""
|
||||
citation_mapping: CitationMapping = {}
|
||||
|
||||
for idx, file_meta in enumerate(project_file_metadata, start=starting_citation_num):
|
||||
# Create a SearchDoc for each project file
|
||||
for idx, file_meta in enumerate(file_metadata, start=starting_citation_num):
|
||||
search_doc = SearchDoc(
|
||||
document_id=file_meta.file_id,
|
||||
chunk_ind=0,
|
||||
@@ -242,29 +242,28 @@ def _build_project_file_citation_mapping(
|
||||
|
||||
|
||||
def _build_project_message(
|
||||
project_files: ExtractedProjectFiles | None,
|
||||
context_files: ExtractedContextFiles | None,
|
||||
token_counter: Callable[[str], int] | None,
|
||||
) -> list[ChatMessageSimple]:
|
||||
"""Build messages for project / tool-backed files.
|
||||
"""Build messages for context-injected / tool-backed files.
|
||||
|
||||
Returns up to two messages:
|
||||
1. The full-text project files message (if project_file_texts is populated).
|
||||
1. The full-text files message (if file_texts is populated).
|
||||
2. A lightweight metadata message for files the LLM should access via the
|
||||
FileReaderTool (e.g. oversized chat-attached files or project files that
|
||||
don't fit in context).
|
||||
FileReaderTool (e.g. oversized files that don't fit in context).
|
||||
"""
|
||||
if not project_files:
|
||||
if not context_files:
|
||||
return []
|
||||
|
||||
messages: list[ChatMessageSimple] = []
|
||||
if project_files.project_file_texts:
|
||||
if context_files.file_texts:
|
||||
messages.append(
|
||||
_create_project_files_message(project_files, token_counter=None)
|
||||
_create_context_files_message(context_files, token_counter=None)
|
||||
)
|
||||
if project_files.file_metadata_for_tool and token_counter:
|
||||
if context_files.file_metadata_for_tool and token_counter:
|
||||
messages.append(
|
||||
_create_file_tool_metadata_message(
|
||||
project_files.file_metadata_for_tool, token_counter
|
||||
context_files.file_metadata_for_tool, token_counter
|
||||
)
|
||||
)
|
||||
return messages
|
||||
@@ -275,7 +274,7 @@ def construct_message_history(
|
||||
custom_agent_prompt: ChatMessageSimple | None,
|
||||
simple_chat_history: list[ChatMessageSimple],
|
||||
reminder_message: ChatMessageSimple | None,
|
||||
project_files: ExtractedProjectFiles | None,
|
||||
context_files: ExtractedContextFiles | None,
|
||||
available_tokens: int,
|
||||
last_n_user_messages: int | None = None,
|
||||
token_counter: Callable[[str], int] | None = None,
|
||||
@@ -289,7 +288,7 @@ def construct_message_history(
|
||||
|
||||
# Build the project / file-metadata messages up front so we can use their
|
||||
# actual token counts for the budget.
|
||||
project_messages = _build_project_message(project_files, token_counter)
|
||||
project_messages = _build_project_message(context_files, token_counter)
|
||||
project_messages_tokens = sum(m.token_count for m in project_messages)
|
||||
|
||||
history_token_budget = available_tokens
|
||||
@@ -445,17 +444,17 @@ def construct_message_history(
|
||||
)
|
||||
|
||||
# Attach project images to the last user message
|
||||
if project_files and project_files.project_image_files:
|
||||
if context_files and context_files.image_files:
|
||||
existing_images = last_user_message.image_files or []
|
||||
last_user_message = ChatMessageSimple(
|
||||
message=last_user_message.message,
|
||||
token_count=last_user_message.token_count,
|
||||
message_type=last_user_message.message_type,
|
||||
image_files=existing_images + project_files.project_image_files,
|
||||
image_files=existing_images + context_files.image_files,
|
||||
)
|
||||
|
||||
# Build the final message list according to README ordering:
|
||||
# [system], [history_before_last_user], [custom_agent], [project_files],
|
||||
# [system], [history_before_last_user], [custom_agent], [context_files],
|
||||
# [forgotten_files], [last_user_message], [messages_after_last_user], [reminder]
|
||||
result = [system_prompt] if system_prompt else []
|
||||
|
||||
@@ -466,14 +465,14 @@ def construct_message_history(
|
||||
if custom_agent_prompt:
|
||||
result.append(custom_agent_prompt)
|
||||
|
||||
# 3. Add project files / file-metadata messages (inserted before last user message)
|
||||
# 3. Add context files / file-metadata messages (inserted before last user message)
|
||||
result.extend(project_messages)
|
||||
|
||||
# 4. Add forgotten-files metadata (right before the user's question)
|
||||
if forgotten_files_message:
|
||||
result.append(forgotten_files_message)
|
||||
|
||||
# 5. Add last user message (with project images attached)
|
||||
# 5. Add last user message (with context images attached)
|
||||
result.append(last_user_message)
|
||||
|
||||
# 6. Add messages after last user message (tool calls, responses, etc.)
|
||||
@@ -532,11 +531,13 @@ def _create_file_tool_metadata_message(
|
||||
"""
|
||||
lines = [
|
||||
"You have access to the following files. Use the read_file tool to "
|
||||
"read sections of any file:"
|
||||
"read sections of any file. You MUST pass the file_id UUID (not the "
|
||||
"filename) to read_file:"
|
||||
]
|
||||
for meta in file_metadata:
|
||||
lines.append(
|
||||
f'- {meta.file_id}: "{meta.filename}" (~{meta.approx_char_count:,} chars)'
|
||||
f'- file_id="{meta.file_id}" filename="{meta.filename}" '
|
||||
f"(~{meta.approx_char_count:,} chars)"
|
||||
)
|
||||
|
||||
message_content = "\n".join(lines)
|
||||
@@ -547,11 +548,11 @@ def _create_file_tool_metadata_message(
|
||||
)
|
||||
|
||||
|
||||
def _create_project_files_message(
|
||||
project_files: ExtractedProjectFiles,
|
||||
def _create_context_files_message(
|
||||
context_files: ExtractedContextFiles,
|
||||
token_counter: Callable[[str], int] | None, # noqa: ARG001
|
||||
) -> ChatMessageSimple:
|
||||
"""Convert project files to a ChatMessageSimple message.
|
||||
"""Convert context files to a ChatMessageSimple message.
|
||||
|
||||
Format follows the README specification for document representation.
|
||||
"""
|
||||
@@ -559,21 +560,25 @@ def _create_project_files_message(
|
||||
|
||||
# Format as documents JSON as described in README
|
||||
documents_list = []
|
||||
for idx, file_text in enumerate(project_files.project_file_texts, start=1):
|
||||
documents_list.append(
|
||||
{
|
||||
"document": idx,
|
||||
"contents": file_text,
|
||||
}
|
||||
for idx, file_text in enumerate(context_files.file_texts, start=1):
|
||||
title = (
|
||||
context_files.file_metadata[idx - 1].filename
|
||||
if idx - 1 < len(context_files.file_metadata)
|
||||
else None
|
||||
)
|
||||
entry: dict[str, Any] = {"document": idx}
|
||||
if title:
|
||||
entry["title"] = title
|
||||
entry["contents"] = file_text
|
||||
documents_list.append(entry)
|
||||
|
||||
documents_json = json.dumps({"documents": documents_list}, indent=2)
|
||||
message_content = f"Here are some documents provided for context, they may not all be relevant:\n{documents_json}"
|
||||
|
||||
# Use pre-calculated token count from project_files
|
||||
# Use pre-calculated token count from context_files
|
||||
return ChatMessageSimple(
|
||||
message=message_content,
|
||||
token_count=project_files.total_token_count,
|
||||
token_count=context_files.total_token_count,
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
|
||||
@@ -584,7 +589,7 @@ def run_llm_loop(
|
||||
simple_chat_history: list[ChatMessageSimple],
|
||||
tools: list[Tool],
|
||||
custom_agent_prompt: str | None,
|
||||
project_files: ExtractedProjectFiles,
|
||||
context_files: ExtractedContextFiles,
|
||||
persona: Persona | None,
|
||||
user_memory_context: UserMemoryContext | None,
|
||||
llm: LLM,
|
||||
@@ -627,9 +632,9 @@ def run_llm_loop(
|
||||
|
||||
# Add project file citation mappings if project files are present
|
||||
project_citation_mapping: CitationMapping = {}
|
||||
if project_files.project_file_metadata:
|
||||
project_citation_mapping = _build_project_file_citation_mapping(
|
||||
project_files.project_file_metadata
|
||||
if context_files.file_metadata:
|
||||
project_citation_mapping = _build_context_file_citation_mapping(
|
||||
context_files.file_metadata
|
||||
)
|
||||
citation_processor.update_citation_mapping(project_citation_mapping)
|
||||
|
||||
@@ -647,7 +652,7 @@ def run_llm_loop(
|
||||
# TODO allow citing of images in Projects. Since attached to the last user message, it has no text associated with it.
|
||||
# One future workaround is to include the images as separate user messages with citation information and process those.
|
||||
always_cite_documents: bool = bool(
|
||||
project_files.project_as_filter or project_files.project_file_texts
|
||||
context_files.use_as_search_filter or context_files.file_texts
|
||||
)
|
||||
should_cite_documents: bool = False
|
||||
ran_image_gen: bool = False
|
||||
@@ -788,7 +793,7 @@ def run_llm_loop(
|
||||
custom_agent_prompt=custom_agent_prompt_msg,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=reminder_msg,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=available_tokens,
|
||||
token_counter=token_counter,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
|
||||
@@ -31,13 +31,6 @@ class CustomToolResponse(BaseModel):
|
||||
tool_name: str
|
||||
|
||||
|
||||
class ProjectSearchConfig(BaseModel):
|
||||
"""Configuration for search tool availability in project context."""
|
||||
|
||||
search_usage: SearchToolUsage
|
||||
disable_forced_tool: bool
|
||||
|
||||
|
||||
class CreateChatSessionID(BaseModel):
|
||||
chat_session_id: UUID
|
||||
|
||||
@@ -132,8 +125,8 @@ class ChatMessageSimple(BaseModel):
|
||||
file_id: str | None = None
|
||||
|
||||
|
||||
class ProjectFileMetadata(BaseModel):
|
||||
"""Metadata for a project file to enable citation support."""
|
||||
class ContextFileMetadata(BaseModel):
|
||||
"""Metadata for a context-injected file to enable citation support."""
|
||||
|
||||
file_id: str
|
||||
filename: str
|
||||
@@ -167,20 +160,28 @@ class ChatHistoryResult(BaseModel):
|
||||
all_injected_file_metadata: dict[str, FileToolMetadata]
|
||||
|
||||
|
||||
class ExtractedProjectFiles(BaseModel):
|
||||
project_file_texts: list[str]
|
||||
project_image_files: list[ChatLoadedFile]
|
||||
project_as_filter: bool
|
||||
class ExtractedContextFiles(BaseModel):
|
||||
"""Result of attempting to load user files (from a project or persona) into context."""
|
||||
|
||||
file_texts: list[str]
|
||||
image_files: list[ChatLoadedFile]
|
||||
use_as_search_filter: bool
|
||||
total_token_count: int
|
||||
# Metadata for project files to enable citations
|
||||
project_file_metadata: list[ProjectFileMetadata]
|
||||
# None if not a project
|
||||
project_uncapped_token_count: int | None
|
||||
# Lightweight metadata for files exposed via FileReaderTool
|
||||
# (populated when files don't fit in context and vector DB is disabled)
|
||||
# (populated when files don't fit in context and vector DB is disabled).
|
||||
file_metadata: list[ContextFileMetadata]
|
||||
uncapped_token_count: int | None
|
||||
file_metadata_for_tool: list[FileToolMetadata] = []
|
||||
|
||||
|
||||
class SearchParams(BaseModel):
|
||||
"""Resolved search filter IDs and search-tool usage for a chat turn."""
|
||||
|
||||
search_project_id: int | None
|
||||
search_persona_id: int | None
|
||||
search_usage: SearchToolUsage
|
||||
|
||||
|
||||
class LlmStepResult(BaseModel):
|
||||
reasoning: str | None
|
||||
answer: str | None
|
||||
|
||||
@@ -3,6 +3,7 @@ IMPORTANT: familiarize yourself with the design concepts prior to contributing t
|
||||
An overview can be found in the README.md file in this directory.
|
||||
"""
|
||||
|
||||
import io
|
||||
import re
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
@@ -33,11 +34,11 @@ from onyx.chat.models import ChatBasicResponse
|
||||
from onyx.chat.models import ChatFullResponse
|
||||
from onyx.chat.models import ChatLoadedFile
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ContextFileMetadata
|
||||
from onyx.chat.models import CreateChatSessionID
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
from onyx.chat.models import ExtractedContextFiles
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import ProjectFileMetadata
|
||||
from onyx.chat.models import ProjectSearchConfig
|
||||
from onyx.chat.models import SearchParams
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import ToolCallResponse
|
||||
from onyx.chat.prompt_utils import calculate_reserved_tokens
|
||||
@@ -62,11 +63,12 @@ from onyx.db.models import ChatSession
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.projects import get_project_token_count
|
||||
from onyx.db.projects import get_user_files_from_project
|
||||
from onyx.db.tools import get_tools
|
||||
from onyx.deep_research.dr_loop import run_deep_research_llm_loop
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.file_store.utils import load_in_memory_chat_files
|
||||
from onyx.file_store.utils import verify_user_files
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
@@ -139,12 +141,12 @@ def _collect_available_file_ids(
|
||||
pass
|
||||
|
||||
if project_id:
|
||||
project_files = get_user_files_from_project(
|
||||
user_files = get_user_files_from_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
for uf in project_files:
|
||||
for uf in user_files:
|
||||
user_file_ids.add(uf.id)
|
||||
|
||||
return _AvailableFiles(
|
||||
@@ -192,9 +194,67 @@ def _convert_loaded_files_to_chat_files(
|
||||
return chat_files
|
||||
|
||||
|
||||
def _extract_project_file_texts_and_images(
|
||||
def resolve_context_user_files(
|
||||
persona: Persona,
|
||||
project_id: int | None,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> list[UserFile]:
|
||||
"""Apply the precedence rule to decide which user files to load.
|
||||
|
||||
A custom persona fully supersedes the project. When a chat uses a
|
||||
custom persona, the project is purely organisational — its files are
|
||||
never loaded and never made searchable.
|
||||
|
||||
Custom persona → persona's own user_files (may be empty).
|
||||
Default persona inside a project → project files.
|
||||
Otherwise → empty list.
|
||||
"""
|
||||
if persona.id != DEFAULT_PERSONA_ID:
|
||||
return list(persona.user_files) if persona.user_files else []
|
||||
if project_id:
|
||||
return get_user_files_from_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
def _empty_extracted_context_files() -> ExtractedContextFiles:
|
||||
return ExtractedContextFiles(
|
||||
file_texts=[],
|
||||
image_files=[],
|
||||
use_as_search_filter=False,
|
||||
total_token_count=0,
|
||||
file_metadata=[],
|
||||
uncapped_token_count=None,
|
||||
)
|
||||
|
||||
|
||||
def _extract_text_from_in_memory_file(f: InMemoryChatFile) -> str | None:
|
||||
"""Extract text content from an InMemoryChatFile.
|
||||
|
||||
PLAIN_TEXT: the content is pre-extracted UTF-8 plaintext stored during
|
||||
ingestion — decode directly.
|
||||
DOC / CSV / other text types: the content is the original file bytes —
|
||||
use extract_file_text which handles encoding detection and format parsing.
|
||||
"""
|
||||
try:
|
||||
if f.file_type == ChatFileType.PLAIN_TEXT:
|
||||
return f.content.decode("utf-8", errors="ignore").replace("\x00", "")
|
||||
return extract_file_text(
|
||||
file=io.BytesIO(f.content),
|
||||
file_name=f.filename or "",
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(f"Failed to extract text from file {f.file_id}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def extract_context_files(
|
||||
user_files: list[UserFile],
|
||||
llm_max_context_window: int,
|
||||
reserved_token_count: int,
|
||||
db_session: Session,
|
||||
@@ -203,8 +263,12 @@ def _extract_project_file_texts_and_images(
|
||||
# 60% of the LLM's max context window. The other benefit is that for projects with
|
||||
# more files, this makes it so that we don't throw away the history too quickly every time.
|
||||
max_llm_context_percentage: float = 0.6,
|
||||
) -> ExtractedProjectFiles:
|
||||
"""Extract text content from project files if they fit within the context window.
|
||||
) -> ExtractedContextFiles:
|
||||
"""Load user files into context if they fit; otherwise flag for search.
|
||||
|
||||
The caller is responsible for deciding *which* user files to pass in
|
||||
(project files, persona files, etc.). This function only cares about
|
||||
the all-or-nothing fit check and the actual content loading.
|
||||
|
||||
Args:
|
||||
project_id: The project ID to load files from
|
||||
@@ -213,160 +277,95 @@ def _extract_project_file_texts_and_images(
|
||||
reserved_token_count: Number of tokens to reserve for other content
|
||||
db_session: Database session
|
||||
max_llm_context_percentage: Maximum percentage of the LLM context window to use.
|
||||
|
||||
Returns:
|
||||
ExtractedProjectFiles containing:
|
||||
- List of text content strings from project files (text files only)
|
||||
- List of image files from project (ChatLoadedFile objects)
|
||||
- Project id if the the project should be provided as a filter in search or None if not.
|
||||
ExtractedContextFiles containing:
|
||||
- List of text content strings from context files (text files only)
|
||||
- List of image files from context (ChatLoadedFile objects)
|
||||
- Total token count of all extracted files
|
||||
- File metadata for context files
|
||||
- Uncapped token count of all extracted files
|
||||
- File metadata for files that don't fit in context and vector DB is disabled
|
||||
"""
|
||||
# TODO I believe this is not handling all file types correctly.
|
||||
project_as_filter = False
|
||||
if not project_id:
|
||||
return ExtractedProjectFiles(
|
||||
project_file_texts=[],
|
||||
project_image_files=[],
|
||||
project_as_filter=False,
|
||||
total_token_count=0,
|
||||
project_file_metadata=[],
|
||||
project_uncapped_token_count=None,
|
||||
)
|
||||
# TODO(yuhong): I believe this is not handling all file types correctly.
|
||||
|
||||
if not user_files:
|
||||
return _empty_extracted_context_files()
|
||||
|
||||
aggregate_tokens = sum(uf.token_count or 0 for uf in user_files)
|
||||
max_actual_tokens = (
|
||||
llm_max_context_window - reserved_token_count
|
||||
) * max_llm_context_percentage
|
||||
|
||||
# Calculate total token count for all user files in the project
|
||||
project_tokens = get_project_token_count(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
if aggregate_tokens >= max_actual_tokens:
|
||||
tool_metadata = []
|
||||
use_as_search_filter = not DISABLE_VECTOR_DB
|
||||
if DISABLE_VECTOR_DB:
|
||||
tool_metadata = _build_file_tool_metadata_for_user_files(user_files)
|
||||
return ExtractedContextFiles(
|
||||
file_texts=[],
|
||||
image_files=[],
|
||||
use_as_search_filter=use_as_search_filter,
|
||||
total_token_count=0,
|
||||
file_metadata=[],
|
||||
uncapped_token_count=aggregate_tokens,
|
||||
file_metadata_for_tool=tool_metadata,
|
||||
)
|
||||
|
||||
# Files fit — load them into context
|
||||
user_file_map = {str(uf.id): uf for uf in user_files}
|
||||
in_memory_files = load_in_memory_chat_files(
|
||||
user_file_ids=[uf.id for uf in user_files],
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
project_file_texts: list[str] = []
|
||||
project_image_files: list[ChatLoadedFile] = []
|
||||
project_file_metadata: list[ProjectFileMetadata] = []
|
||||
file_texts: list[str] = []
|
||||
image_files: list[ChatLoadedFile] = []
|
||||
file_metadata: list[ContextFileMetadata] = []
|
||||
total_token_count = 0
|
||||
if project_tokens < max_actual_tokens:
|
||||
# Load project files into memory using cached plaintext when available
|
||||
project_user_files = get_user_files_from_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
if project_user_files:
|
||||
# Create a mapping from file_id to UserFile for token count lookup
|
||||
user_file_map = {str(file.id): file for file in project_user_files}
|
||||
|
||||
project_file_ids = [file.id for file in project_user_files]
|
||||
in_memory_project_files = load_in_memory_chat_files(
|
||||
user_file_ids=project_file_ids,
|
||||
db_session=db_session,
|
||||
for f in in_memory_files:
|
||||
uf = user_file_map.get(str(f.file_id))
|
||||
if f.file_type.is_text_file():
|
||||
text_content = _extract_text_from_in_memory_file(f)
|
||||
if not text_content:
|
||||
continue
|
||||
file_texts.append(text_content)
|
||||
file_metadata.append(
|
||||
ContextFileMetadata(
|
||||
file_id=str(f.file_id),
|
||||
filename=f.filename or f"file_{f.file_id}",
|
||||
file_content=text_content,
|
||||
)
|
||||
)
|
||||
if uf and uf.token_count:
|
||||
total_token_count += uf.token_count
|
||||
elif f.file_type == ChatFileType.IMAGE:
|
||||
token_count = uf.token_count if uf and uf.token_count else 0
|
||||
total_token_count += token_count
|
||||
image_files.append(
|
||||
ChatLoadedFile(
|
||||
file_id=f.file_id,
|
||||
content=f.content,
|
||||
file_type=f.file_type,
|
||||
filename=f.filename,
|
||||
content_text=None,
|
||||
token_count=token_count,
|
||||
)
|
||||
)
|
||||
|
||||
# Extract text content from loaded files
|
||||
for file in in_memory_project_files:
|
||||
if file.file_type.is_text_file():
|
||||
try:
|
||||
text_content = file.content.decode("utf-8", errors="ignore")
|
||||
# Strip null bytes
|
||||
text_content = text_content.replace("\x00", "")
|
||||
if text_content:
|
||||
project_file_texts.append(text_content)
|
||||
# Add metadata for citation support
|
||||
project_file_metadata.append(
|
||||
ProjectFileMetadata(
|
||||
file_id=str(file.file_id),
|
||||
filename=file.filename or f"file_{file.file_id}",
|
||||
file_content=text_content,
|
||||
)
|
||||
)
|
||||
# Add token count for text file
|
||||
user_file = user_file_map.get(str(file.file_id))
|
||||
if user_file and user_file.token_count:
|
||||
total_token_count += user_file.token_count
|
||||
except Exception:
|
||||
# Skip files that can't be decoded
|
||||
pass
|
||||
elif file.file_type == ChatFileType.IMAGE:
|
||||
# Convert InMemoryChatFile to ChatLoadedFile
|
||||
user_file = user_file_map.get(str(file.file_id))
|
||||
token_count = (
|
||||
user_file.token_count
|
||||
if user_file and user_file.token_count
|
||||
else 0
|
||||
)
|
||||
total_token_count += token_count
|
||||
chat_loaded_file = ChatLoadedFile(
|
||||
file_id=file.file_id,
|
||||
content=file.content,
|
||||
file_type=file.file_type,
|
||||
filename=file.filename,
|
||||
content_text=None, # Images don't have text content
|
||||
token_count=token_count,
|
||||
)
|
||||
project_image_files.append(chat_loaded_file)
|
||||
else:
|
||||
if DISABLE_VECTOR_DB:
|
||||
# Without a vector DB we can't use project-as-filter search.
|
||||
# Instead, build lightweight metadata so the LLM can call the
|
||||
# FileReaderTool to inspect individual files on demand.
|
||||
file_metadata_for_tool = _build_file_tool_metadata_for_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
return ExtractedProjectFiles(
|
||||
project_file_texts=[],
|
||||
project_image_files=[],
|
||||
project_as_filter=False,
|
||||
total_token_count=0,
|
||||
project_file_metadata=[],
|
||||
project_uncapped_token_count=project_tokens,
|
||||
file_metadata_for_tool=file_metadata_for_tool,
|
||||
)
|
||||
project_as_filter = True
|
||||
|
||||
return ExtractedProjectFiles(
|
||||
project_file_texts=project_file_texts,
|
||||
project_image_files=project_image_files,
|
||||
project_as_filter=project_as_filter,
|
||||
return ExtractedContextFiles(
|
||||
file_texts=file_texts,
|
||||
image_files=image_files,
|
||||
use_as_search_filter=False,
|
||||
total_token_count=total_token_count,
|
||||
project_file_metadata=project_file_metadata,
|
||||
project_uncapped_token_count=project_tokens,
|
||||
file_metadata=file_metadata,
|
||||
uncapped_token_count=aggregate_tokens,
|
||||
)
|
||||
|
||||
|
||||
APPROX_CHARS_PER_TOKEN = 4
|
||||
|
||||
|
||||
def _build_file_tool_metadata_for_project(
|
||||
project_id: int,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> list[FileToolMetadata]:
|
||||
"""Build lightweight FileToolMetadata for every file in a project.
|
||||
|
||||
Used when files are too large to fit in context and the vector DB is
|
||||
disabled, so the LLM needs to know which files it can read via the
|
||||
FileReaderTool.
|
||||
"""
|
||||
project_user_files = get_user_files_from_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
return [
|
||||
FileToolMetadata(
|
||||
file_id=str(uf.id),
|
||||
filename=uf.name,
|
||||
approx_char_count=(uf.token_count or 0) * APPROX_CHARS_PER_TOKEN,
|
||||
)
|
||||
for uf in project_user_files
|
||||
]
|
||||
|
||||
|
||||
def _build_file_tool_metadata_for_user_files(
|
||||
user_files: list[UserFile],
|
||||
) -> list[FileToolMetadata]:
|
||||
@@ -381,55 +380,46 @@ def _build_file_tool_metadata_for_user_files(
|
||||
]
|
||||
|
||||
|
||||
def _get_project_search_availability(
|
||||
def determine_search_params(
|
||||
persona_id: int,
|
||||
project_id: int | None,
|
||||
persona_id: int | None,
|
||||
loaded_project_files: bool,
|
||||
project_has_files: bool,
|
||||
forced_tool_id: int | None,
|
||||
search_tool_id: int | None,
|
||||
) -> ProjectSearchConfig:
|
||||
"""Determine search tool availability based on project context.
|
||||
extracted_context_files: ExtractedContextFiles,
|
||||
) -> SearchParams:
|
||||
"""Decide which search filter IDs and search-tool usage apply for a chat turn.
|
||||
|
||||
Search is disabled when ALL of the following are true:
|
||||
- User is in a project
|
||||
- Using the default persona (not a custom agent)
|
||||
- Project files are already loaded in context
|
||||
A custom persona fully supersedes the project — project files are never
|
||||
searchable and the search tool config is entirely controlled by the
|
||||
persona. The project_id filter is only set for the default persona.
|
||||
|
||||
When search is disabled and the user tried to force the search tool,
|
||||
that forcing is also disabled.
|
||||
|
||||
Returns AUTO (follow persona config) in all other cases.
|
||||
For the default persona inside a project:
|
||||
- Files overflow → ENABLED (vector DB scopes to these files)
|
||||
- Files fit → DISABLED (content already in prompt)
|
||||
- No files at all → DISABLED (nothing to search)
|
||||
"""
|
||||
# Not in a project, this should have no impact on search tool availability
|
||||
if not project_id:
|
||||
return ProjectSearchConfig(
|
||||
search_usage=SearchToolUsage.AUTO, disable_forced_tool=False
|
||||
)
|
||||
is_custom_persona = persona_id != DEFAULT_PERSONA_ID
|
||||
|
||||
# Custom persona in project - let persona config decide
|
||||
# Even if there are no files in the project, it's still guided by the persona config.
|
||||
if persona_id != DEFAULT_PERSONA_ID:
|
||||
return ProjectSearchConfig(
|
||||
search_usage=SearchToolUsage.AUTO, disable_forced_tool=False
|
||||
)
|
||||
search_project_id: int | None = None
|
||||
search_persona_id: int | None = None
|
||||
if extracted_context_files.use_as_search_filter:
|
||||
if is_custom_persona:
|
||||
search_persona_id = persona_id
|
||||
else:
|
||||
search_project_id = project_id
|
||||
|
||||
# If in a project with the default persona and the files have been already loaded into the context or
|
||||
# there are no files in the project, disable search as there is nothing to search for.
|
||||
if loaded_project_files or not project_has_files:
|
||||
user_forced_search = (
|
||||
forced_tool_id is not None
|
||||
and search_tool_id is not None
|
||||
and forced_tool_id == search_tool_id
|
||||
)
|
||||
return ProjectSearchConfig(
|
||||
search_usage=SearchToolUsage.DISABLED,
|
||||
disable_forced_tool=user_forced_search,
|
||||
)
|
||||
search_usage = SearchToolUsage.AUTO
|
||||
if not is_custom_persona and project_id:
|
||||
has_context_files = bool(extracted_context_files.uncapped_token_count)
|
||||
files_loaded_in_context = bool(extracted_context_files.file_texts)
|
||||
|
||||
# Default persona in a project with files, but also the files have not been loaded into the context already.
|
||||
return ProjectSearchConfig(
|
||||
search_usage=SearchToolUsage.ENABLED, disable_forced_tool=False
|
||||
if extracted_context_files.use_as_search_filter:
|
||||
search_usage = SearchToolUsage.ENABLED
|
||||
elif files_loaded_in_context or not has_context_files:
|
||||
search_usage = SearchToolUsage.DISABLED
|
||||
|
||||
return SearchParams(
|
||||
search_project_id=search_project_id,
|
||||
search_persona_id=search_persona_id,
|
||||
search_usage=search_usage,
|
||||
)
|
||||
|
||||
|
||||
@@ -661,26 +651,37 @@ def handle_stream_message_objects(
|
||||
user_memory_context=prompt_memory_context,
|
||||
)
|
||||
|
||||
# Process projects, if all of the files fit in the context, it doesn't need to use RAG
|
||||
extracted_project_files = _extract_project_file_texts_and_images(
|
||||
# Determine which user files to use. A custom persona fully
|
||||
# supersedes the project — project files are never loaded or
|
||||
# searchable when a custom persona is in play. Only the default
|
||||
# persona inside a project uses the project's files.
|
||||
context_user_files = resolve_context_user_files(
|
||||
persona=persona,
|
||||
project_id=chat_session.project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
extracted_context_files = extract_context_files(
|
||||
user_files=context_user_files,
|
||||
llm_max_context_window=llm.config.max_input_tokens,
|
||||
reserved_token_count=reserved_token_count,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# When the vector DB is disabled, persona-attached user_files have no
|
||||
# search pipeline path. Inject them as file_metadata_for_tool so the
|
||||
# LLM can read them via the FileReaderTool.
|
||||
if DISABLE_VECTOR_DB and persona.user_files:
|
||||
persona_file_metadata = _build_file_tool_metadata_for_user_files(
|
||||
persona.user_files
|
||||
)
|
||||
# Merge persona file metadata into the extracted project files
|
||||
extracted_project_files.file_metadata_for_tool.extend(persona_file_metadata)
|
||||
search_params = determine_search_params(
|
||||
persona_id=persona.id,
|
||||
project_id=chat_session.project_id,
|
||||
extracted_context_files=extracted_context_files,
|
||||
)
|
||||
|
||||
# Also grant access to persona-attached user files for FileReaderTool
|
||||
if persona.user_files:
|
||||
existing = set(available_files.user_file_ids)
|
||||
for uf in persona.user_files:
|
||||
if uf.id not in existing:
|
||||
available_files.user_file_ids.append(uf.id)
|
||||
|
||||
# Build a mapping of tool_id to tool_name for history reconstruction
|
||||
all_tools = get_tools(db_session)
|
||||
tool_id_to_name_map = {tool.id: tool.name for tool in all_tools}
|
||||
|
||||
@@ -689,30 +690,17 @@ def handle_stream_message_objects(
|
||||
None,
|
||||
)
|
||||
|
||||
# Determine if search should be disabled for this project context
|
||||
forced_tool_id = new_msg_req.forced_tool_id
|
||||
project_search_config = _get_project_search_availability(
|
||||
project_id=chat_session.project_id,
|
||||
persona_id=persona.id,
|
||||
loaded_project_files=bool(extracted_project_files.project_file_texts),
|
||||
project_has_files=bool(
|
||||
extracted_project_files.project_uncapped_token_count
|
||||
),
|
||||
forced_tool_id=new_msg_req.forced_tool_id,
|
||||
search_tool_id=search_tool_id,
|
||||
)
|
||||
if project_search_config.disable_forced_tool:
|
||||
if (
|
||||
search_params.search_usage == SearchToolUsage.DISABLED
|
||||
and forced_tool_id is not None
|
||||
and search_tool_id is not None
|
||||
and forced_tool_id == search_tool_id
|
||||
):
|
||||
forced_tool_id = None
|
||||
|
||||
emitter = get_default_emitter()
|
||||
|
||||
# Also grant access to persona-attached user files
|
||||
if persona.user_files:
|
||||
existing = set(available_files.user_file_ids)
|
||||
for uf in persona.user_files:
|
||||
if uf.id not in existing:
|
||||
available_files.user_file_ids.append(uf.id)
|
||||
|
||||
# Construct tools based on the persona configurations
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
@@ -722,11 +710,8 @@ def handle_stream_message_objects(
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=new_msg_req.internal_search_filters,
|
||||
project_id=(
|
||||
chat_session.project_id
|
||||
if extracted_project_files.project_as_filter
|
||||
else None
|
||||
),
|
||||
project_id=search_params.search_project_id,
|
||||
persona_id=search_params.search_persona_id,
|
||||
bypass_acl=bypass_acl,
|
||||
slack_context=slack_context,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
@@ -744,7 +729,7 @@ def handle_stream_message_objects(
|
||||
chat_file_ids=available_files.chat_file_ids,
|
||||
),
|
||||
allowed_tool_ids=new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=project_search_config.search_usage,
|
||||
search_usage_forcing_setting=search_params.search_usage,
|
||||
)
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
@@ -783,7 +768,7 @@ def handle_stream_message_objects(
|
||||
chat_history_result = convert_chat_history(
|
||||
chat_history=chat_history,
|
||||
files=files,
|
||||
project_image_files=extracted_project_files.project_image_files,
|
||||
context_image_files=extracted_context_files.image_files,
|
||||
additional_context=additional_context,
|
||||
token_counter=token_counter,
|
||||
tool_id_to_name_map=tool_id_to_name_map,
|
||||
@@ -879,46 +864,54 @@ def handle_stream_message_objects(
|
||||
# (user has already responded to a clarification question)
|
||||
skip_clarification = is_last_assistant_message_clarification(chat_history)
|
||||
|
||||
# NOTE: we _could_ pass in a zero argument function since emitter and state_container
|
||||
# are just passed in immediately anyways, but the abstraction is cleaner this way.
|
||||
yield from run_chat_loop_with_state_containers(
|
||||
run_deep_research_llm_loop,
|
||||
lambda emitter, state_container: run_deep_research_llm_loop(
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
skip_clarification=skip_clarification,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
),
|
||||
llm_loop_completion_callback,
|
||||
is_connected=check_is_connected,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
skip_clarification=skip_clarification,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
else:
|
||||
yield from run_chat_loop_with_state_containers(
|
||||
run_llm_loop,
|
||||
lambda emitter, state_container: run_llm_loop(
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
context_files=extracted_context_files,
|
||||
persona=persona,
|
||||
user_memory_context=user_memory_context,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
forced_tool_id=forced_tool_id,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
chat_files=chat_files_for_tools,
|
||||
include_citations=new_msg_req.include_citations,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
),
|
||||
llm_loop_completion_callback,
|
||||
is_connected=check_is_connected, # Not passed through to run_llm_loop
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
project_files=extracted_project_files,
|
||||
persona=persona,
|
||||
user_memory_context=user_memory_context,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
forced_tool_id=forced_tool_id,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
chat_files=chat_files_for_tools,
|
||||
include_citations=new_msg_req.include_citations,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
|
||||
@@ -6,6 +6,7 @@ from datetime import timezone
|
||||
from typing import cast
|
||||
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
from onyx.cache.interface import CacheBackendType
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
|
||||
@@ -54,6 +55,12 @@ DISABLE_USER_KNOWLEDGE = os.environ.get("DISABLE_USER_KNOWLEDGE", "").lower() ==
|
||||
# are disabled but core chat, tools, user file uploads, and Projects still work.
|
||||
DISABLE_VECTOR_DB = os.environ.get("DISABLE_VECTOR_DB", "").lower() == "true"
|
||||
|
||||
# Which backend to use for caching, locks, and ephemeral state.
|
||||
# "redis" (default) or "postgres" (only valid when DISABLE_VECTOR_DB=true).
|
||||
CACHE_BACKEND = CacheBackendType(
|
||||
os.environ.get("CACHE_BACKEND", CacheBackendType.REDIS)
|
||||
)
|
||||
|
||||
# Maximum token count for a single uploaded file. Files exceeding this are rejected.
|
||||
# Defaults to 100k tokens (or 10M when vector DB is disabled).
|
||||
_DEFAULT_FILE_TOKEN_LIMIT = 10_000_000 if DISABLE_VECTOR_DB else 100_000
|
||||
@@ -294,6 +301,12 @@ ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX = (
|
||||
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
and os.environ.get("ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX", "").lower() == "true"
|
||||
)
|
||||
# Whether we should check for and create an index if necessary every time we
|
||||
# instantiate an OpenSearchDocumentIndex on multitenant cloud. Defaults to True.
|
||||
VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT = (
|
||||
os.environ.get("VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT", "true").lower()
|
||||
== "true"
|
||||
)
|
||||
|
||||
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
|
||||
# NOTE: this is used if and only if the vespa config server is accessible via a
|
||||
|
||||
@@ -32,6 +32,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
BASE_URL = "https://api.gong.io"
|
||||
MAX_CALL_DETAILS_ATTEMPTS = 6
|
||||
CALL_DETAILS_DELAY = 30 # in seconds
|
||||
# Gong API limit is 3 calls/sec — stay safely under it
|
||||
MIN_REQUEST_INTERVAL = 0.5 # seconds between requests
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -45,9 +47,13 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
self.continue_on_fail = continue_on_fail
|
||||
self.auth_token_basic: str | None = None
|
||||
self.hide_user_info = hide_user_info
|
||||
self._last_request_time: float = 0.0
|
||||
|
||||
# urllib3 Retry already respects the Retry-After header by default
|
||||
# (respect_retry_after_header=True), so on 429 it will sleep for the
|
||||
# duration Gong specifies before retrying.
|
||||
retry_strategy = Retry(
|
||||
total=5,
|
||||
total=10,
|
||||
backoff_factor=2,
|
||||
status_forcelist=[429, 500, 502, 503, 504],
|
||||
)
|
||||
@@ -61,8 +67,24 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
url = f"{GongConnector.BASE_URL}{endpoint}"
|
||||
return url
|
||||
|
||||
def _throttled_request(
|
||||
self, method: str, url: str, **kwargs: Any
|
||||
) -> requests.Response:
|
||||
"""Rate-limited request wrapper. Enforces MIN_REQUEST_INTERVAL between
|
||||
calls to stay under Gong's 3 calls/sec limit and avoid triggering 429s."""
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._last_request_time
|
||||
if elapsed < self.MIN_REQUEST_INTERVAL:
|
||||
time.sleep(self.MIN_REQUEST_INTERVAL - elapsed)
|
||||
|
||||
response = self._session.request(method, url, **kwargs)
|
||||
self._last_request_time = time.monotonic()
|
||||
return response
|
||||
|
||||
def _get_workspace_id_map(self) -> dict[str, str]:
|
||||
response = self._session.get(GongConnector.make_url("/v2/workspaces"))
|
||||
response = self._throttled_request(
|
||||
"GET", GongConnector.make_url("/v2/workspaces")
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
workspaces_details = response.json().get("workspaces")
|
||||
@@ -106,8 +128,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
del body["filter"]["workspaceId"]
|
||||
|
||||
while True:
|
||||
response = self._session.post(
|
||||
GongConnector.make_url("/v2/calls/transcript"), json=body
|
||||
response = self._throttled_request(
|
||||
"POST", GongConnector.make_url("/v2/calls/transcript"), json=body
|
||||
)
|
||||
# If no calls in the range, just break out
|
||||
if response.status_code == 404:
|
||||
@@ -142,8 +164,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
"contentSelector": {"exposedFields": {"parties": True}},
|
||||
}
|
||||
|
||||
response = self._session.post(
|
||||
GongConnector.make_url("/v2/calls/extensive"), json=body
|
||||
response = self._throttled_request(
|
||||
"POST", GongConnector.make_url("/v2/calls/extensive"), json=body
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -194,7 +216,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
# There's a likely race condition in the API where a transcript will have a
|
||||
# call id but the call to v2/calls/extensive will not return all of the id's
|
||||
# retry with exponential backoff has been observed to mitigate this
|
||||
# in ~2 minutes
|
||||
# in ~2 minutes. After max attempts, proceed with whatever we have —
|
||||
# the per-call loop below will skip missing IDs gracefully.
|
||||
current_attempt = 0
|
||||
while True:
|
||||
current_attempt += 1
|
||||
@@ -213,11 +236,14 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
f"missing_call_ids={missing_call_ids}"
|
||||
)
|
||||
if current_attempt >= self.MAX_CALL_DETAILS_ATTEMPTS:
|
||||
raise RuntimeError(
|
||||
f"Attempt count exceeded for _get_call_details_by_ids: "
|
||||
f"missing_call_ids={missing_call_ids} "
|
||||
f"max_attempts={self.MAX_CALL_DETAILS_ATTEMPTS}"
|
||||
logger.error(
|
||||
f"Giving up on missing call id's after "
|
||||
f"{self.MAX_CALL_DETAILS_ATTEMPTS} attempts: "
|
||||
f"missing_call_ids={missing_call_ids} — "
|
||||
f"proceeding with {len(call_details_map)} of "
|
||||
f"{len(transcript_call_ids)} calls"
|
||||
)
|
||||
break
|
||||
|
||||
wait_seconds = self.CALL_DETAILS_DELAY * pow(2, current_attempt - 1)
|
||||
logger.warning(
|
||||
|
||||
@@ -23,7 +23,6 @@ from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.serialization import pkcs12
|
||||
from office365.graph_client import GraphClient # type: ignore[import-untyped]
|
||||
from office365.intune.organizations.organization import Organization # type: ignore[import-untyped]
|
||||
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
|
||||
from office365.onedrive.sites.site import Site # type: ignore[import-untyped]
|
||||
from office365.onedrive.sites.sites_with_root import SitesWithRoot # type: ignore[import-untyped]
|
||||
@@ -872,6 +871,56 @@ class SharepointConnector(
|
||||
"Site URLs must be full Sharepoint URLs (e.g. https://your-tenant.sharepoint.com/sites/your-site or https://your-tenant.sharepoint.com/teams/your-team)"
|
||||
)
|
||||
|
||||
def _extract_tenant_domain_from_sites(self) -> str | None:
|
||||
"""Extract the tenant domain from configured site URLs.
|
||||
|
||||
Site URLs look like https://{tenant}.sharepoint.com/sites/... so the
|
||||
tenant domain is the first label of the hostname.
|
||||
"""
|
||||
for site_url in self.sites:
|
||||
try:
|
||||
hostname = urlsplit(site_url.strip()).hostname
|
||||
except ValueError:
|
||||
continue
|
||||
if not hostname:
|
||||
continue
|
||||
tenant = hostname.split(".")[0]
|
||||
if tenant:
|
||||
return tenant
|
||||
logger.warning(f"No tenant domain found from {len(self.sites)} sites")
|
||||
return None
|
||||
|
||||
def _resolve_tenant_domain_from_root_site(self) -> str:
|
||||
"""Resolve tenant domain via GET /v1.0/sites/root which only requires
|
||||
Sites.Read.All (a permission the connector already needs)."""
|
||||
root_site = self.graph_client.sites.root.get().execute_query()
|
||||
hostname = root_site.site_collection.hostname
|
||||
if not hostname:
|
||||
raise ConnectorValidationError(
|
||||
"Could not determine tenant domain from root site"
|
||||
)
|
||||
tenant_domain = hostname.split(".")[0]
|
||||
logger.info(
|
||||
"Resolved tenant domain '%s' from root site hostname '%s'",
|
||||
tenant_domain,
|
||||
hostname,
|
||||
)
|
||||
return tenant_domain
|
||||
|
||||
def _resolve_tenant_domain(self) -> str:
|
||||
"""Determine the tenant domain, preferring site URLs over a Graph API
|
||||
call to avoid needing extra permissions."""
|
||||
from_sites = self._extract_tenant_domain_from_sites()
|
||||
if from_sites:
|
||||
logger.info(
|
||||
"Resolved tenant domain '%s' from site URLs",
|
||||
from_sites,
|
||||
)
|
||||
return from_sites
|
||||
|
||||
logger.info("No site URLs available; resolving tenant domain from root site")
|
||||
return self._resolve_tenant_domain_from_root_site()
|
||||
|
||||
@property
|
||||
def graph_client(self) -> GraphClient:
|
||||
if self._graph_client is None:
|
||||
@@ -1589,6 +1638,11 @@ class SharepointConnector(
|
||||
sp_private_key = credentials.get("sp_private_key")
|
||||
sp_certificate_password = credentials.get("sp_certificate_password")
|
||||
|
||||
if not sp_client_id:
|
||||
raise ConnectorValidationError("Client ID is required")
|
||||
if not sp_directory_id:
|
||||
raise ConnectorValidationError("Directory (tenant) ID is required")
|
||||
|
||||
authority_url = f"{self.authority_host}/{sp_directory_id}"
|
||||
|
||||
if auth_method == SharepointAuthMethod.CERTIFICATE.value:
|
||||
@@ -1641,21 +1695,7 @@ class SharepointConnector(
|
||||
_acquire_token_for_graph, environment=self._azure_environment
|
||||
)
|
||||
if auth_method == SharepointAuthMethod.CERTIFICATE.value:
|
||||
org = self.graph_client.organization.get().execute_query()
|
||||
if not org or len(org) == 0:
|
||||
raise ConnectorValidationError("No organization found")
|
||||
|
||||
tenant_info: Organization = org[
|
||||
0
|
||||
] # Access first item directly from collection
|
||||
if not tenant_info.verified_domains:
|
||||
raise ConnectorValidationError("No verified domains found for tenant")
|
||||
|
||||
sp_tenant_domain = tenant_info.verified_domains[0].name
|
||||
if not sp_tenant_domain:
|
||||
raise ConnectorValidationError("No verified domains found for tenant")
|
||||
# remove the .onmicrosoft.com part
|
||||
self.sp_tenant_domain = sp_tenant_domain.split(".")[0]
|
||||
self.sp_tenant_domain = self._resolve_tenant_domain()
|
||||
return None
|
||||
|
||||
def _get_drive_names_for_site(self, site_url: str) -> list[str]:
|
||||
|
||||
@@ -72,6 +72,7 @@ class BaseFilters(BaseModel):
|
||||
class UserFileFilters(BaseModel):
|
||||
user_file_ids: list[UUID] | None = None
|
||||
project_id: int | None = None
|
||||
persona_id: int | None = None
|
||||
|
||||
|
||||
class AssistantKnowledgeFilters(BaseModel):
|
||||
|
||||
@@ -40,6 +40,7 @@ def _build_index_filters(
|
||||
user_provided_filters: BaseFilters | None,
|
||||
user: User, # Used for ACLs, anonymous users only see public docs
|
||||
project_id: int | None,
|
||||
persona_id: int | None,
|
||||
user_file_ids: list[UUID] | None,
|
||||
persona_document_sets: list[str] | None,
|
||||
persona_time_cutoff: datetime | None,
|
||||
@@ -118,6 +119,7 @@ def _build_index_filters(
|
||||
final_filters = IndexFilters(
|
||||
user_file_ids=user_file_ids,
|
||||
project_id=project_id,
|
||||
persona_id=persona_id,
|
||||
source_type=source_filter,
|
||||
document_set=document_set_filter,
|
||||
time_cutoff=time_filter,
|
||||
@@ -265,6 +267,8 @@ def search_pipeline(
|
||||
llm: LLM | None = None,
|
||||
# If a project ID is provided, it will be exclusively scoped to that project
|
||||
project_id: int | None = None,
|
||||
# If a persona_id is provided, search scopes to files attached to this persona
|
||||
persona_id: int | None = None,
|
||||
# Pre-fetched data — when provided, avoids DB queries (no session needed)
|
||||
acl_filters: list[str] | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
@@ -299,6 +303,7 @@ def search_pipeline(
|
||||
user_provided_filters=chunk_search_request.user_selected_filters,
|
||||
user=user,
|
||||
project_id=project_id,
|
||||
persona_id=persona_id,
|
||||
user_file_ids=user_uploaded_persona_files,
|
||||
persona_document_sets=persona_document_sets,
|
||||
persona_time_cutoff=persona_time_cutoff,
|
||||
|
||||
@@ -98,6 +98,7 @@ def get_chat_sessions_by_user(
|
||||
db_session: Session,
|
||||
include_onyxbot_flows: bool = False,
|
||||
limit: int = 50,
|
||||
before: datetime | None = None,
|
||||
project_id: int | None = None,
|
||||
only_non_project_chats: bool = False,
|
||||
include_failed_chats: bool = False,
|
||||
@@ -112,6 +113,9 @@ def get_chat_sessions_by_user(
|
||||
if deleted is not None:
|
||||
stmt = stmt.where(ChatSession.deleted == deleted)
|
||||
|
||||
if before is not None:
|
||||
stmt = stmt.where(ChatSession.time_updated < before)
|
||||
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
|
||||
@@ -21,8 +21,8 @@ from onyx.configs.app_configs import POSTGRES_POOL_RECYCLE
|
||||
from onyx.configs.app_configs import POSTGRES_PORT
|
||||
from onyx.configs.app_configs import POSTGRES_USE_NULL_POOL
|
||||
from onyx.configs.app_configs import POSTGRES_USER
|
||||
from onyx.db.engine.iam_auth import create_ssl_context_if_iam
|
||||
from onyx.db.engine.iam_auth import get_iam_auth_token
|
||||
from onyx.db.engine.iam_auth import ssl_context
|
||||
from onyx.db.engine.sql_engine import ASYNC_DB_API
|
||||
from onyx.db.engine.sql_engine import build_connection_string
|
||||
from onyx.db.engine.sql_engine import is_valid_schema_name
|
||||
@@ -66,7 +66,7 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
if app_name:
|
||||
connect_args["server_settings"] = {"application_name": app_name}
|
||||
|
||||
connect_args["ssl"] = ssl_context
|
||||
connect_args["ssl"] = create_ssl_context_if_iam()
|
||||
|
||||
engine_kwargs = {
|
||||
"connect_args": connect_args,
|
||||
@@ -97,7 +97,7 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
user = POSTGRES_USER
|
||||
token = get_iam_auth_token(host, port, user, AWS_REGION_NAME)
|
||||
cparams["password"] = token
|
||||
cparams["ssl"] = ssl_context
|
||||
cparams["ssl"] = create_ssl_context_if_iam()
|
||||
|
||||
return _ASYNC_ENGINE
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import functools
|
||||
import os
|
||||
import ssl
|
||||
from typing import Any
|
||||
@@ -48,11 +49,9 @@ def provide_iam_token(
|
||||
configure_psycopg2_iam_auth(cparams, host, port, user, region)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def create_ssl_context_if_iam() -> ssl.SSLContext | None:
|
||||
"""Create an SSL context if IAM authentication is enabled, else return None."""
|
||||
if USE_IAM_AUTH:
|
||||
return ssl.create_default_context(cafile=SSL_CERT_FILE)
|
||||
return None
|
||||
|
||||
|
||||
ssl_context = create_ssl_context_if_iam()
|
||||
|
||||
@@ -186,6 +186,7 @@ class EmbeddingPrecision(str, PyEnum):
|
||||
|
||||
class UserFileStatus(str, PyEnum):
|
||||
PROCESSING = "PROCESSING"
|
||||
INDEXING = "INDEXING"
|
||||
COMPLETED = "COMPLETED"
|
||||
FAILED = "FAILED"
|
||||
CANCELED = "CANCELED"
|
||||
|
||||
@@ -202,7 +202,6 @@ def create_default_image_gen_config_from_api_key(
|
||||
api_key=api_key,
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
default_model_name=model_name,
|
||||
deployment_name=None,
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
@@ -109,45 +109,38 @@ def can_user_access_llm_provider(
|
||||
is_admin: If True, bypass user group restrictions but still respect persona restrictions
|
||||
|
||||
Access logic:
|
||||
1. If is_public=True → everyone has access (public override)
|
||||
2. If is_public=False:
|
||||
- Both groups AND personas set → must satisfy BOTH (AND logic, admins bypass group check)
|
||||
- Only groups set → must be in one of the groups (OR across groups, admins bypass)
|
||||
- Only personas set → must use one of the personas (OR across personas, applies to admins)
|
||||
- Neither set → NOBODY has access unless admin (locked, admin-only)
|
||||
- is_public controls USER access (group bypass): when True, all users can access
|
||||
regardless of group membership. When False, user must be in a whitelisted group
|
||||
(or be admin).
|
||||
- Persona restrictions are ALWAYS enforced when set, regardless of is_public.
|
||||
This allows admins to make a provider available to all users while still
|
||||
restricting which personas (assistants) can use it.
|
||||
|
||||
Decision matrix:
|
||||
1. is_public=True, no personas set → everyone has access
|
||||
2. is_public=True, personas set → all users, but only whitelisted personas
|
||||
3. is_public=False, groups+personas set → must satisfy BOTH (admins bypass groups)
|
||||
4. is_public=False, only groups set → must be in group (admins bypass)
|
||||
5. is_public=False, only personas set → must use whitelisted persona
|
||||
6. is_public=False, neither set → admin-only (locked)
|
||||
"""
|
||||
# Public override - everyone has access
|
||||
if provider.is_public:
|
||||
return True
|
||||
|
||||
# Extract IDs once to avoid multiple iterations
|
||||
provider_group_ids = (
|
||||
{group.id for group in provider.groups} if provider.groups else set()
|
||||
)
|
||||
provider_persona_ids = (
|
||||
{p.id for p in provider.personas} if provider.personas else set()
|
||||
)
|
||||
|
||||
provider_group_ids = {g.id for g in (provider.groups or [])}
|
||||
provider_persona_ids = {p.id for p in (provider.personas or [])}
|
||||
has_groups = bool(provider_group_ids)
|
||||
has_personas = bool(provider_persona_ids)
|
||||
|
||||
# Both groups AND personas set → AND logic (must satisfy both)
|
||||
if has_groups and has_personas:
|
||||
# Admins bypass group check but still must satisfy persona restrictions
|
||||
user_in_group = is_admin or bool(user_group_ids & provider_group_ids)
|
||||
persona_allowed = persona.id in provider_persona_ids if persona else False
|
||||
return user_in_group and persona_allowed
|
||||
# Persona restrictions are always enforced when set, regardless of is_public
|
||||
if has_personas and not (persona and persona.id in provider_persona_ids):
|
||||
return False
|
||||
|
||||
if provider.is_public:
|
||||
return True
|
||||
|
||||
# Only groups set → user must be in one of the groups (admins bypass)
|
||||
if has_groups:
|
||||
return is_admin or bool(user_group_ids & provider_group_ids)
|
||||
|
||||
# Only personas set → persona must be in allowed list (applies to admins too)
|
||||
if has_personas:
|
||||
return persona.id in provider_persona_ids if persona else False
|
||||
|
||||
# Neither groups nor personas set, and not public → admins can access
|
||||
return is_admin
|
||||
# No groups: either persona-whitelisted (already passed) or admin-only if locked
|
||||
return has_personas or is_admin
|
||||
|
||||
|
||||
def validate_persona_ids_exist(
|
||||
@@ -213,11 +206,29 @@ def upsert_llm_provider(
|
||||
llm_provider_upsert_request: LLMProviderUpsertRequest,
|
||||
db_session: Session,
|
||||
) -> LLMProviderView:
|
||||
existing_llm_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
)
|
||||
existing_llm_provider: LLMProviderModel | None = None
|
||||
if llm_provider_upsert_request.id:
|
||||
existing_llm_provider = fetch_existing_llm_provider_by_id(
|
||||
id=llm_provider_upsert_request.id, db_session=db_session
|
||||
)
|
||||
if not existing_llm_provider:
|
||||
raise ValueError(
|
||||
f"LLM provider with id {llm_provider_upsert_request.id} not found"
|
||||
)
|
||||
|
||||
if not existing_llm_provider:
|
||||
if existing_llm_provider.name != llm_provider_upsert_request.name:
|
||||
raise ValueError(
|
||||
f"LLM provider with id {llm_provider_upsert_request.id} name change not allowed"
|
||||
)
|
||||
else:
|
||||
existing_llm_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
)
|
||||
if existing_llm_provider:
|
||||
raise ValueError(
|
||||
f"LLM provider with name '{llm_provider_upsert_request.name}'"
|
||||
" already exists"
|
||||
)
|
||||
existing_llm_provider = LLMProviderModel(name=llm_provider_upsert_request.name)
|
||||
db_session.add(existing_llm_provider)
|
||||
|
||||
@@ -238,11 +249,7 @@ def upsert_llm_provider(
|
||||
existing_llm_provider.api_base = api_base
|
||||
existing_llm_provider.api_version = llm_provider_upsert_request.api_version
|
||||
existing_llm_provider.custom_config = custom_config
|
||||
# TODO: Remove default model name on api change
|
||||
# Needed due to /provider/{id}/default endpoint not disclosing the default model name
|
||||
existing_llm_provider.default_model_name = (
|
||||
llm_provider_upsert_request.default_model_name
|
||||
)
|
||||
|
||||
existing_llm_provider.is_public = llm_provider_upsert_request.is_public
|
||||
existing_llm_provider.is_auto_mode = llm_provider_upsert_request.is_auto_mode
|
||||
existing_llm_provider.deployment_name = llm_provider_upsert_request.deployment_name
|
||||
@@ -306,15 +313,6 @@ def upsert_llm_provider(
|
||||
display_name=model_config.display_name,
|
||||
)
|
||||
|
||||
default_model = fetch_default_model(db_session, LLMModelFlowType.CHAT)
|
||||
if default_model and default_model.llm_provider_id == existing_llm_provider.id:
|
||||
_update_default_model(
|
||||
db_session=db_session,
|
||||
provider_id=existing_llm_provider.id,
|
||||
model=existing_llm_provider.default_model_name,
|
||||
flow_type=LLMModelFlowType.CHAT,
|
||||
)
|
||||
|
||||
# Make sure the relationship table stays up to date
|
||||
update_group_llm_provider_relationships__no_commit(
|
||||
llm_provider_id=existing_llm_provider.id,
|
||||
@@ -488,6 +486,22 @@ def fetch_existing_llm_provider(
|
||||
return provider_model
|
||||
|
||||
|
||||
def fetch_existing_llm_provider_by_id(
|
||||
id: int, db_session: Session
|
||||
) -> LLMProviderModel | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel)
|
||||
.where(LLMProviderModel.id == id)
|
||||
.options(
|
||||
selectinload(LLMProviderModel.model_configurations),
|
||||
selectinload(LLMProviderModel.groups),
|
||||
selectinload(LLMProviderModel.personas),
|
||||
)
|
||||
)
|
||||
|
||||
return provider_model
|
||||
|
||||
|
||||
def fetch_embedding_provider(
|
||||
db_session: Session, provider_type: EmbeddingProvider
|
||||
) -> CloudEmbeddingProviderModel | None:
|
||||
@@ -604,22 +618,13 @@ def remove_llm_provider__no_commit(db_session: Session, provider_id: int) -> Non
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def update_default_provider(provider_id: int, db_session: Session) -> None:
|
||||
# Attempt to get the default_model_name from the provider first
|
||||
# TODO: Remove default_model_name check
|
||||
provider = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.id == provider_id,
|
||||
)
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f"LLM Provider with id={provider_id} does not exist")
|
||||
|
||||
def update_default_provider(
|
||||
provider_id: int, model_name: str, db_session: Session
|
||||
) -> None:
|
||||
_update_default_model(
|
||||
db_session,
|
||||
provider_id,
|
||||
provider.default_model_name,
|
||||
model_name,
|
||||
LLMModelFlowType.CHAT,
|
||||
)
|
||||
|
||||
@@ -805,12 +810,6 @@ def sync_auto_mode_models(
|
||||
)
|
||||
changes += 1
|
||||
|
||||
# In Auto mode, default model is always set from GitHub config
|
||||
default_model = llm_recommendations.get_default_model(provider.provider)
|
||||
if default_model and provider.default_model_name != default_model.name:
|
||||
provider.default_model_name = default_model.name
|
||||
changes += 1
|
||||
|
||||
db_session.commit()
|
||||
return changes
|
||||
|
||||
|
||||
@@ -103,7 +103,6 @@ from onyx.utils.encryption import encrypt_string_to_bytes
|
||||
from onyx.utils.sensitive import SensitiveValue
|
||||
from onyx.utils.headers import HeaderItemDict
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
|
||||
# TODO: After anonymous user migration has been deployed, make user_id columns NOT NULL
|
||||
# and update Mapped[User | None] relationships to Mapped[User] where needed.
|
||||
@@ -2822,13 +2821,17 @@ class LLMProvider(Base):
|
||||
custom_config: Mapped[dict[str, str] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
default_model_name: Mapped[str] = mapped_column(String)
|
||||
|
||||
# Deprecated: use LLMModelFlow with CHAT flow type instead
|
||||
default_model_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
# should only be set for a single provider
|
||||
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
|
||||
# Deprecated: use LLMModelFlow.is_default with CHAT flow type instead
|
||||
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
# Deprecated: use LLMModelFlow.is_default with VISION flow type instead
|
||||
is_default_vision_provider: Mapped[bool | None] = mapped_column(Boolean)
|
||||
# Deprecated: use LLMModelFlow with VISION flow type instead
|
||||
default_vision_model: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
# EE only
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
@@ -2879,6 +2882,7 @@ class ModelConfiguration(Base):
|
||||
# - The end-user is configuring a model and chooses not to set a max-input-tokens limit.
|
||||
max_input_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Deprecated: use LLMModelFlow with VISION flow type instead
|
||||
supports_image_input: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
|
||||
# Human-readable display name for the model.
|
||||
@@ -3260,19 +3264,6 @@ class Persona(Base):
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String)
|
||||
description: Mapped[str] = mapped_column(String)
|
||||
# Number of chunks to pass to the LLM for generation.
|
||||
num_chunks: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
chunks_above: Mapped[int] = mapped_column(Integer)
|
||||
chunks_below: Mapped[int] = mapped_column(Integer)
|
||||
# Pass every chunk through LLM for evaluation, fairly expensive
|
||||
# Can be turned off globally by admin, in which case, this setting is ignored
|
||||
llm_relevance_filter: Mapped[bool] = mapped_column(Boolean)
|
||||
# Enables using LLM to extract time and source type filters
|
||||
# Can also be admin disabled globally
|
||||
llm_filter_extraction: Mapped[bool] = mapped_column(Boolean)
|
||||
recency_bias: Mapped[RecencyBiasSetting] = mapped_column(
|
||||
Enum(RecencyBiasSetting, native_enum=False)
|
||||
)
|
||||
|
||||
# Allows the persona to specify a specific default LLM model
|
||||
# NOTE: only is applied on the actual response generation - is not used for things like
|
||||
@@ -3299,11 +3290,8 @@ class Persona(Base):
|
||||
# Treated specially (cannot be user edited etc.)
|
||||
builtin_persona: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
# Default personas are personas created by admins and are automatically added
|
||||
# to all users' assistants list.
|
||||
is_default_persona: Mapped[bool] = mapped_column(
|
||||
Boolean, default=False, nullable=False
|
||||
)
|
||||
# Featured personas are highlighted in the UI
|
||||
featured: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
# controls whether the persona is available to be selected by users
|
||||
is_visible: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
# controls the ordering of personas in the UI
|
||||
@@ -4270,6 +4258,9 @@ class UserFile(Base):
|
||||
needs_project_sync: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
needs_persona_sync: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
last_project_sync_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
@@ -18,11 +18,8 @@ from sqlalchemy.orm import Session
|
||||
from onyx.access.hierarchy_access import get_user_external_group_ids
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS
|
||||
from onyx.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
|
||||
from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX
|
||||
from onyx.db.document_access import get_accessible_documents_by_ids
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
@@ -254,16 +251,15 @@ def create_update_persona(
|
||||
# Permission to actually use these is checked later
|
||||
|
||||
try:
|
||||
# Default persona validation
|
||||
if create_persona_request.is_default_persona:
|
||||
if not create_persona_request.is_public:
|
||||
raise ValueError("Cannot make a default persona non public")
|
||||
# Featured persona validation
|
||||
if create_persona_request.featured:
|
||||
|
||||
# Curators can edit default personas, but not make them
|
||||
# Curators can edit featured personas, but not make them
|
||||
# TODO this will be reworked soon with RBAC permissions feature
|
||||
if user.role == UserRole.CURATOR or user.role == UserRole.GLOBAL_CURATOR:
|
||||
pass
|
||||
elif user.role != UserRole.ADMIN:
|
||||
raise ValueError("Only admins can make a default persona")
|
||||
raise ValueError("Only admins can make a featured persona")
|
||||
|
||||
# Convert incoming string UUIDs to UUID objects for DB operations
|
||||
converted_user_file_ids = None
|
||||
@@ -284,7 +280,6 @@ def create_update_persona(
|
||||
document_set_ids=create_persona_request.document_set_ids,
|
||||
tool_ids=create_persona_request.tool_ids,
|
||||
is_public=create_persona_request.is_public,
|
||||
recency_bias=create_persona_request.recency_bias,
|
||||
llm_model_provider_override=create_persona_request.llm_model_provider_override,
|
||||
llm_model_version_override=create_persona_request.llm_model_version_override,
|
||||
starter_messages=create_persona_request.starter_messages,
|
||||
@@ -298,10 +293,7 @@ def create_update_persona(
|
||||
remove_image=create_persona_request.remove_image,
|
||||
search_start_date=create_persona_request.search_start_date,
|
||||
label_ids=create_persona_request.label_ids,
|
||||
num_chunks=create_persona_request.num_chunks,
|
||||
llm_relevance_filter=create_persona_request.llm_relevance_filter,
|
||||
llm_filter_extraction=create_persona_request.llm_filter_extraction,
|
||||
is_default_persona=create_persona_request.is_default_persona,
|
||||
featured=create_persona_request.featured,
|
||||
user_file_ids=converted_user_file_ids,
|
||||
commit=False,
|
||||
hierarchy_node_ids=create_persona_request.hierarchy_node_ids,
|
||||
@@ -335,6 +327,7 @@ def update_persona_shared(
|
||||
db_session: Session,
|
||||
group_ids: list[int] | None = None,
|
||||
is_public: bool | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
) -> None:
|
||||
"""Simplified version of `create_update_persona` which only touches the
|
||||
accessibility rather than any of the logic (e.g. prompt, connected data sources,
|
||||
@@ -344,9 +337,7 @@ def update_persona_shared(
|
||||
)
|
||||
|
||||
if user and user.role != UserRole.ADMIN and persona.user_id != user.id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have permission to modify this persona"
|
||||
)
|
||||
raise PermissionError("You don't have permission to modify this persona")
|
||||
|
||||
versioned_update_persona_access = fetch_versioned_implementation(
|
||||
"onyx.db.persona", "update_persona_access"
|
||||
@@ -360,6 +351,15 @@ def update_persona_shared(
|
||||
group_ids=group_ids,
|
||||
)
|
||||
|
||||
if label_ids is not None:
|
||||
labels = (
|
||||
db_session.query(PersonaLabel).filter(PersonaLabel.id.in_(label_ids)).all()
|
||||
)
|
||||
if len(labels) != len(label_ids):
|
||||
raise ValueError("Some label IDs were not found in the database")
|
||||
persona.labels.clear()
|
||||
persona.labels = labels
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -765,6 +765,9 @@ def mark_persona_as_deleted(
|
||||
) -> None:
|
||||
persona = get_persona_by_id(persona_id=persona_id, user=user, db_session=db_session)
|
||||
persona.deleted = True
|
||||
affected_file_ids = [uf.id for uf in persona.user_files]
|
||||
if affected_file_ids:
|
||||
_mark_files_need_persona_sync(db_session, affected_file_ids)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -776,11 +779,13 @@ def mark_persona_as_not_deleted(
|
||||
persona = get_persona_by_id(
|
||||
persona_id=persona_id, user=user, db_session=db_session, include_deleted=True
|
||||
)
|
||||
if persona.deleted:
|
||||
persona.deleted = False
|
||||
db_session.commit()
|
||||
else:
|
||||
if not persona.deleted:
|
||||
raise ValueError(f"Persona with ID {persona_id} is not deleted.")
|
||||
persona.deleted = False
|
||||
affected_file_ids = [uf.id for uf in persona.user_files]
|
||||
if affected_file_ids:
|
||||
_mark_files_need_persona_sync(db_session, affected_file_ids)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def mark_delete_persona_by_name(
|
||||
@@ -846,14 +851,24 @@ def update_personas_display_priority(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _mark_files_need_persona_sync(
|
||||
db_session: Session,
|
||||
user_file_ids: list[UUID],
|
||||
) -> None:
|
||||
"""Flag the given UserFile rows so the background sync task picks them up
|
||||
and updates their persona metadata in the vector DB."""
|
||||
if not user_file_ids:
|
||||
return
|
||||
db_session.query(UserFile).filter(UserFile.id.in_(user_file_ids)).update(
|
||||
{UserFile.needs_persona_sync: True},
|
||||
synchronize_session=False,
|
||||
)
|
||||
|
||||
|
||||
def upsert_persona(
|
||||
user: User | None,
|
||||
name: str,
|
||||
description: str,
|
||||
num_chunks: float,
|
||||
llm_relevance_filter: bool,
|
||||
llm_filter_extraction: bool,
|
||||
recency_bias: RecencyBiasSetting,
|
||||
llm_model_provider_override: str | None,
|
||||
llm_model_version_override: str | None,
|
||||
starter_messages: list[StarterMessage] | None,
|
||||
@@ -874,13 +889,11 @@ def upsert_persona(
|
||||
remove_image: bool | None = None,
|
||||
search_start_date: datetime | None = None,
|
||||
builtin_persona: bool = False,
|
||||
is_default_persona: bool | None = None,
|
||||
featured: bool | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
user_file_ids: list[UUID] | None = None,
|
||||
hierarchy_node_ids: list[int] | None = None,
|
||||
document_ids: list[str] | None = None,
|
||||
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
|
||||
chunks_below: int = CONTEXT_CHUNKS_BELOW,
|
||||
replace_base_system_prompt: bool = False,
|
||||
) -> Persona:
|
||||
"""
|
||||
@@ -946,6 +959,8 @@ def upsert_persona(
|
||||
labels = (
|
||||
db_session.query(PersonaLabel).filter(PersonaLabel.id.in_(label_ids)).all()
|
||||
)
|
||||
if len(labels) != len(label_ids):
|
||||
raise ValueError("Some label IDs were not found in the database")
|
||||
|
||||
# Fetch and attach hierarchy_nodes by IDs
|
||||
hierarchy_nodes = None
|
||||
@@ -989,12 +1004,6 @@ def upsert_persona(
|
||||
# `default` and `built-in` properties can only be set when creating a persona.
|
||||
existing_persona.name = name
|
||||
existing_persona.description = description
|
||||
existing_persona.num_chunks = num_chunks
|
||||
existing_persona.chunks_above = chunks_above
|
||||
existing_persona.chunks_below = chunks_below
|
||||
existing_persona.llm_relevance_filter = llm_relevance_filter
|
||||
existing_persona.llm_filter_extraction = llm_filter_extraction
|
||||
existing_persona.recency_bias = recency_bias
|
||||
existing_persona.llm_model_provider_override = llm_model_provider_override
|
||||
existing_persona.llm_model_version_override = llm_model_version_override
|
||||
existing_persona.starter_messages = starter_messages
|
||||
@@ -1008,10 +1017,8 @@ def upsert_persona(
|
||||
if label_ids is not None:
|
||||
existing_persona.labels.clear()
|
||||
existing_persona.labels = labels or []
|
||||
existing_persona.is_default_persona = (
|
||||
is_default_persona
|
||||
if is_default_persona is not None
|
||||
else existing_persona.is_default_persona
|
||||
existing_persona.featured = (
|
||||
featured if featured is not None else existing_persona.featured
|
||||
)
|
||||
# Update embedded prompt fields if provided
|
||||
if system_prompt is not None:
|
||||
@@ -1034,8 +1041,13 @@ def upsert_persona(
|
||||
existing_persona.tools = tools or []
|
||||
|
||||
if user_file_ids is not None:
|
||||
old_file_ids = {uf.id for uf in existing_persona.user_files}
|
||||
new_file_ids = {uf.id for uf in (user_files or [])}
|
||||
affected_file_ids = old_file_ids | new_file_ids
|
||||
existing_persona.user_files.clear()
|
||||
existing_persona.user_files = user_files or []
|
||||
if affected_file_ids:
|
||||
_mark_files_need_persona_sync(db_session, list(affected_file_ids))
|
||||
|
||||
if hierarchy_node_ids is not None:
|
||||
existing_persona.hierarchy_nodes.clear()
|
||||
@@ -1059,12 +1071,6 @@ def upsert_persona(
|
||||
is_public=is_public,
|
||||
name=name,
|
||||
description=description,
|
||||
num_chunks=num_chunks,
|
||||
chunks_above=chunks_above,
|
||||
chunks_below=chunks_below,
|
||||
llm_relevance_filter=llm_relevance_filter,
|
||||
llm_filter_extraction=llm_filter_extraction,
|
||||
recency_bias=recency_bias,
|
||||
builtin_persona=builtin_persona,
|
||||
system_prompt=system_prompt or "",
|
||||
task_prompt=task_prompt or "",
|
||||
@@ -1080,15 +1086,15 @@ def upsert_persona(
|
||||
display_priority=display_priority,
|
||||
is_visible=is_visible,
|
||||
search_start_date=search_start_date,
|
||||
is_default_persona=(
|
||||
is_default_persona if is_default_persona is not None else False
|
||||
),
|
||||
featured=(featured if featured is not None else False),
|
||||
user_files=user_files or [],
|
||||
labels=labels or [],
|
||||
hierarchy_nodes=hierarchy_nodes or [],
|
||||
attached_documents=attached_documents or [],
|
||||
)
|
||||
db_session.add(new_persona)
|
||||
if user_files:
|
||||
_mark_files_need_persona_sync(db_session, [uf.id for uf in user_files])
|
||||
persona = new_persona
|
||||
if commit:
|
||||
db_session.commit()
|
||||
@@ -1125,9 +1131,9 @@ def delete_old_default_personas(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_persona_is_default(
|
||||
def update_persona_featured(
|
||||
persona_id: int,
|
||||
is_default: bool,
|
||||
featured: bool,
|
||||
db_session: Session,
|
||||
user: User,
|
||||
) -> None:
|
||||
@@ -1135,10 +1141,7 @@ def update_persona_is_default(
|
||||
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
|
||||
)
|
||||
|
||||
if not persona.is_public:
|
||||
persona.is_public = True
|
||||
|
||||
persona.is_default_persona = is_default
|
||||
persona.featured = featured
|
||||
db_session.commit()
|
||||
|
||||
|
||||
|
||||
@@ -9,8 +9,9 @@ from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.background import BackgroundTasks
|
||||
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
@@ -51,7 +52,7 @@ def create_user_files(
|
||||
) -> CategorizedFilesResult:
|
||||
|
||||
# Categorize the files
|
||||
categorized_files = categorize_uploaded_files(files)
|
||||
categorized_files = categorize_uploaded_files(files, db_session)
|
||||
# NOTE: At the moment, zip metadata is not used for user files.
|
||||
# Should revisit to decide whether this should be a feature.
|
||||
upload_response = upload_files(categorized_files.acceptable, FileOrigin.USER_FILE)
|
||||
@@ -105,8 +106,8 @@ def upload_files_to_user_files_with_indexing(
|
||||
user: User,
|
||||
temp_id_map: dict[str, str] | None,
|
||||
db_session: Session,
|
||||
background_tasks: BackgroundTasks | None = None,
|
||||
) -> CategorizedFilesResult:
|
||||
# Validate project ownership if a project_id is provided
|
||||
if project_id is not None and user is not None:
|
||||
if not check_project_ownership(project_id, user.id, db_session):
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
@@ -127,16 +128,27 @@ def upload_files_to_user_files_with_indexing(
|
||||
logger.warning(
|
||||
f"File {rejected_file.filename} rejected for {rejected_file.reason}"
|
||||
)
|
||||
for user_file in user_files:
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered indexing for user_file_id={user_file.id} with task_id={task.id}"
|
||||
)
|
||||
|
||||
if DISABLE_VECTOR_DB and background_tasks is not None:
|
||||
from onyx.background.task_utils import drain_processing_loop
|
||||
|
||||
background_tasks.add_task(drain_processing_loop, tenant_id)
|
||||
for user_file in user_files:
|
||||
logger.info(f"Queued in-process processing for user_file_id={user_file.id}")
|
||||
else:
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
|
||||
for user_file in user_files:
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered indexing for user_file_id={user_file.id} "
|
||||
f"with task_id={task.id}"
|
||||
)
|
||||
|
||||
return CategorizedFilesResult(
|
||||
user_files=user_files,
|
||||
|
||||
@@ -5,8 +5,6 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.constants import DEFAULT_PERSONA_SLACK_CHANNEL_NAME
|
||||
from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX
|
||||
from onyx.db.models import ChannelConfig
|
||||
@@ -45,8 +43,6 @@ def create_slack_channel_persona(
|
||||
channel_name: str | None,
|
||||
document_set_ids: list[int],
|
||||
existing_persona_id: int | None = None,
|
||||
num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
||||
enable_auto_filters: bool = False,
|
||||
) -> Persona:
|
||||
"""NOTE: does not commit changes"""
|
||||
|
||||
@@ -73,17 +69,13 @@ def create_slack_channel_persona(
|
||||
system_prompt="",
|
||||
task_prompt="",
|
||||
datetime_aware=True,
|
||||
num_chunks=num_chunks,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=enable_auto_filters,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
tool_ids=[search_tool.id],
|
||||
document_set_ids=document_set_ids,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
is_public=True,
|
||||
is_default_persona=False,
|
||||
featured=False,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
|
||||
@@ -3,8 +3,10 @@ from uuid import UUID
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import Project__UserFile
|
||||
from onyx.db.models import UserFile
|
||||
|
||||
|
||||
@@ -56,10 +58,34 @@ def fetch_user_project_ids_for_user_files(
|
||||
db_session: Session,
|
||||
) -> dict[str, list[int]]:
|
||||
"""Fetch user project ids for specified user files"""
|
||||
stmt = select(UserFile).where(UserFile.id.in_(user_file_ids))
|
||||
user_file_uuid_ids = [UUID(user_file_id) for user_file_id in user_file_ids]
|
||||
stmt = select(Project__UserFile.user_file_id, Project__UserFile.project_id).where(
|
||||
Project__UserFile.user_file_id.in_(user_file_uuid_ids)
|
||||
)
|
||||
rows = db_session.execute(stmt).all()
|
||||
|
||||
user_file_id_to_project_ids: dict[str, list[int]] = {
|
||||
user_file_id: [] for user_file_id in user_file_ids
|
||||
}
|
||||
for user_file_id, project_id in rows:
|
||||
user_file_id_to_project_ids[str(user_file_id)].append(project_id)
|
||||
|
||||
return user_file_id_to_project_ids
|
||||
|
||||
|
||||
def fetch_persona_ids_for_user_files(
|
||||
user_file_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, list[int]]:
|
||||
"""Fetch persona (assistant) ids for specified user files."""
|
||||
stmt = (
|
||||
select(UserFile)
|
||||
.where(UserFile.id.in_(user_file_ids))
|
||||
.options(selectinload(UserFile.assistants))
|
||||
)
|
||||
results = db_session.execute(stmt).scalars().all()
|
||||
return {
|
||||
str(user_file.id): [project.id for project in user_file.projects]
|
||||
str(user_file.id): [persona.id for persona in user_file.assistants]
|
||||
for user_file in results
|
||||
}
|
||||
|
||||
|
||||
@@ -139,7 +139,7 @@ def generate_final_report(
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=history,
|
||||
reminder_message=reminder_message,
|
||||
project_files=None,
|
||||
context_files=None,
|
||||
available_tokens=llm.config.max_input_tokens,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
@@ -257,7 +257,7 @@ def run_deep_research_llm_loop(
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=None,
|
||||
context_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
@@ -321,7 +321,7 @@ def run_deep_research_llm_loop(
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history + [reminder_message],
|
||||
reminder_message=None,
|
||||
project_files=None,
|
||||
context_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT + 1,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
@@ -485,7 +485,7 @@ def run_deep_research_llm_loop(
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=first_cycle_reminder_message,
|
||||
project_files=None,
|
||||
context_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
|
||||
@@ -11,6 +11,7 @@ from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchOldDocumentIndex,
|
||||
)
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
@@ -49,8 +50,11 @@ def get_default_document_index(
|
||||
|
||||
opensearch_retrieval_enabled = get_opensearch_retrieval_state(db_session)
|
||||
if opensearch_retrieval_enabled:
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
return OpenSearchOldDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
embedding_dim=indexing_setting.final_embedding_dim,
|
||||
embedding_precision=indexing_setting.embedding_precision,
|
||||
secondary_index_name=secondary_index_name,
|
||||
large_chunks_enabled=search_settings.large_chunks_enabled,
|
||||
secondary_large_chunks_enabled=secondary_large_chunks_enabled,
|
||||
@@ -118,8 +122,11 @@ def get_all_document_indices(
|
||||
)
|
||||
opensearch_document_index: OpenSearchOldDocumentIndex | None = None
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
opensearch_document_index = OpenSearchOldDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
embedding_dim=indexing_setting.final_embedding_dim,
|
||||
embedding_precision=indexing_setting.embedding_precision,
|
||||
secondary_index_name=None,
|
||||
large_chunks_enabled=False,
|
||||
secondary_large_chunks_enabled=None,
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import logging
|
||||
import time
|
||||
from contextlib import AbstractContextManager
|
||||
from contextlib import nullcontext
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TypeVar
|
||||
@@ -83,22 +85,26 @@ def get_new_body_without_vectors(body: dict[str, Any]) -> dict[str, Any]:
|
||||
return new_body
|
||||
|
||||
|
||||
class OpenSearchClient:
|
||||
"""Client for interacting with OpenSearch.
|
||||
class OpenSearchClient(AbstractContextManager):
|
||||
"""Client for interacting with OpenSearch for cluster-level operations.
|
||||
|
||||
OpenSearch's Python module has pretty bad typing support so this client
|
||||
attempts to protect the rest of the codebase from this. As a consequence,
|
||||
most methods here return the minimum data needed for the rest of Onyx, and
|
||||
tend to rely on Exceptions to handle errors.
|
||||
|
||||
TODO(andrei): This class currently assumes the structure of the database
|
||||
schema when it returns a DocumentChunk. Make the class, or at least the
|
||||
search method, templated on the structure the caller can expect.
|
||||
Args:
|
||||
host: The host of the OpenSearch cluster.
|
||||
port: The port of the OpenSearch cluster.
|
||||
auth: The authentication credentials for the OpenSearch cluster. A tuple
|
||||
of (username, password).
|
||||
use_ssl: Whether to use SSL for the OpenSearch cluster. Defaults to
|
||||
True.
|
||||
verify_certs: Whether to verify the SSL certificates for the OpenSearch
|
||||
cluster. Defaults to False.
|
||||
ssl_show_warn: Whether to show warnings for SSL certificates. Defaults
|
||||
to False.
|
||||
timeout: The timeout for the OpenSearch cluster. Defaults to
|
||||
DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
host: str = OPENSEARCH_HOST,
|
||||
port: int = OPENSEARCH_REST_API_PORT,
|
||||
auth: tuple[str, str] = (OPENSEARCH_ADMIN_USERNAME, OPENSEARCH_ADMIN_PASSWORD),
|
||||
@@ -107,9 +113,8 @@ class OpenSearchClient:
|
||||
ssl_show_warn: bool = False,
|
||||
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
|
||||
):
|
||||
self._index_name = index_name
|
||||
logger.debug(
|
||||
f"Creating OpenSearch client for index {index_name} with host {host} and port {port} and timeout {timeout} seconds."
|
||||
f"Creating OpenSearch client with host {host}, port {port} and timeout {timeout} seconds."
|
||||
)
|
||||
self._client = OpenSearch(
|
||||
hosts=[{"host": host, "port": port}],
|
||||
@@ -125,6 +130,142 @@ class OpenSearchClient:
|
||||
# your request body that is less than this value.
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
def __exit__(self, *_: Any) -> None:
|
||||
self.close()
|
||||
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
self.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def create_search_pipeline(
|
||||
self,
|
||||
pipeline_id: str,
|
||||
pipeline_body: dict[str, Any],
|
||||
) -> None:
|
||||
"""Creates a search pipeline.
|
||||
|
||||
See the OpenSearch documentation for more information on the search
|
||||
pipeline body.
|
||||
https://docs.opensearch.org/latest/search-plugins/search-pipelines/index/
|
||||
|
||||
Args:
|
||||
pipeline_id: The ID of the search pipeline to create.
|
||||
pipeline_body: The body of the search pipeline to create.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error creating the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
|
||||
if not result.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to create search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def delete_search_pipeline(self, pipeline_id: str) -> None:
|
||||
"""Deletes a search pipeline.
|
||||
|
||||
Args:
|
||||
pipeline_id: The ID of the search pipeline to delete.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error deleting the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.delete(id=pipeline_id)
|
||||
if not result.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to delete search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def put_cluster_settings(self, settings: dict[str, Any]) -> bool:
|
||||
"""Puts cluster settings.
|
||||
|
||||
Args:
|
||||
settings: The settings to put.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error putting the cluster settings.
|
||||
|
||||
Returns:
|
||||
True if the settings were put successfully, False otherwise.
|
||||
"""
|
||||
response = self._client.cluster.put_settings(body=settings)
|
||||
if response.get("acknowledged", False):
|
||||
logger.info("Successfully put cluster settings.")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to put cluster settings: {response}.")
|
||||
return False
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def ping(self) -> bool:
|
||||
"""Pings the OpenSearch cluster.
|
||||
|
||||
Returns:
|
||||
True if OpenSearch could be reached, False if it could not.
|
||||
"""
|
||||
return self._client.ping()
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def close(self) -> None:
|
||||
"""Closes the client.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error closing the client.
|
||||
"""
|
||||
self._client.close()
|
||||
|
||||
|
||||
class OpenSearchIndexClient(OpenSearchClient):
|
||||
"""Client for interacting with OpenSearch for index-level operations.
|
||||
|
||||
OpenSearch's Python module has pretty bad typing support so this client
|
||||
attempts to protect the rest of the codebase from this. As a consequence,
|
||||
most methods here return the minimum data needed for the rest of Onyx, and
|
||||
tend to rely on Exceptions to handle errors.
|
||||
|
||||
TODO(andrei): This class currently assumes the structure of the database
|
||||
schema when it returns a DocumentChunk. Make the class, or at least the
|
||||
search method, templated on the structure the caller can expect.
|
||||
|
||||
Args:
|
||||
index_name: The name of the index to interact with.
|
||||
host: The host of the OpenSearch cluster.
|
||||
port: The port of the OpenSearch cluster.
|
||||
auth: The authentication credentials for the OpenSearch cluster. A tuple
|
||||
of (username, password).
|
||||
use_ssl: Whether to use SSL for the OpenSearch cluster. Defaults to
|
||||
True.
|
||||
verify_certs: Whether to verify the SSL certificates for the OpenSearch
|
||||
cluster. Defaults to False.
|
||||
ssl_show_warn: Whether to show warnings for SSL certificates. Defaults
|
||||
to False.
|
||||
timeout: The timeout for the OpenSearch cluster. Defaults to
|
||||
DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
host: str = OPENSEARCH_HOST,
|
||||
port: int = OPENSEARCH_REST_API_PORT,
|
||||
auth: tuple[str, str] = (OPENSEARCH_ADMIN_USERNAME, OPENSEARCH_ADMIN_PASSWORD),
|
||||
use_ssl: bool = True,
|
||||
verify_certs: bool = False,
|
||||
ssl_show_warn: bool = False,
|
||||
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
|
||||
):
|
||||
super().__init__(
|
||||
host=host,
|
||||
port=port,
|
||||
auth=auth,
|
||||
use_ssl=use_ssl,
|
||||
verify_certs=verify_certs,
|
||||
ssl_show_warn=ssl_show_warn,
|
||||
timeout=timeout,
|
||||
)
|
||||
self._index_name = index_name
|
||||
logger.debug(
|
||||
f"OpenSearch client created successfully for index {self._index_name}."
|
||||
)
|
||||
@@ -192,6 +333,38 @@ class OpenSearchClient:
|
||||
"""
|
||||
return self._client.indices.exists(index=self._index_name)
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def put_mapping(self, mappings: dict[str, Any]) -> None:
|
||||
"""Updates the index mapping in an idempotent manner.
|
||||
|
||||
- Existing fields with the same definition: No-op (succeeds silently).
|
||||
- New fields: Added to the index.
|
||||
- Existing fields with different types: Raises exception (requires
|
||||
reindex).
|
||||
|
||||
See the OpenSearch documentation for more information:
|
||||
https://docs.opensearch.org/latest/api-reference/index-apis/put-mapping/
|
||||
|
||||
Args:
|
||||
mappings: The complete mapping definition to apply. This will be
|
||||
merged with existing mappings in the index.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error updating the mappings, such as
|
||||
attempting to change the type of an existing field.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Putting mappings for index {self._index_name} with mappings {mappings}."
|
||||
)
|
||||
response = self._client.indices.put_mapping(
|
||||
index=self._index_name, body=mappings
|
||||
)
|
||||
if not response.get("acknowledged", False):
|
||||
raise RuntimeError(
|
||||
f"Failed to put the mapping update for index {self._index_name}."
|
||||
)
|
||||
logger.debug(f"Successfully put mappings for index {self._index_name}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def validate_index(self, expected_mappings: dict[str, Any]) -> bool:
|
||||
"""Validates the index.
|
||||
@@ -610,43 +783,6 @@ class OpenSearchClient:
|
||||
)
|
||||
return DocumentChunk.model_validate(document_chunk_source)
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def create_search_pipeline(
|
||||
self,
|
||||
pipeline_id: str,
|
||||
pipeline_body: dict[str, Any],
|
||||
) -> None:
|
||||
"""Creates a search pipeline.
|
||||
|
||||
See the OpenSearch documentation for more information on the search
|
||||
pipeline body.
|
||||
https://docs.opensearch.org/latest/search-plugins/search-pipelines/index/
|
||||
|
||||
Args:
|
||||
pipeline_id: The ID of the search pipeline to create.
|
||||
pipeline_body: The body of the search pipeline to create.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error creating the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
|
||||
if not result.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to create search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def delete_search_pipeline(self, pipeline_id: str) -> None:
|
||||
"""Deletes a search pipeline.
|
||||
|
||||
Args:
|
||||
pipeline_id: The ID of the search pipeline to delete.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error deleting the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.delete(id=pipeline_id)
|
||||
if not result.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to delete search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def search(
|
||||
self, body: dict[str, Any], search_pipeline_id: str | None
|
||||
@@ -807,48 +943,6 @@ class OpenSearchClient:
|
||||
"""
|
||||
self._client.indices.refresh(index=self._index_name)
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def put_cluster_settings(self, settings: dict[str, Any]) -> bool:
|
||||
"""Puts cluster settings.
|
||||
|
||||
Args:
|
||||
settings: The settings to put.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error putting the cluster settings.
|
||||
|
||||
Returns:
|
||||
True if the settings were put successfully, False otherwise.
|
||||
"""
|
||||
response = self._client.cluster.put_settings(body=settings)
|
||||
if response.get("acknowledged", False):
|
||||
logger.info("Successfully put cluster settings.")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to put cluster settings: {response}.")
|
||||
return False
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def ping(self) -> bool:
|
||||
"""Pings the OpenSearch cluster.
|
||||
|
||||
Returns:
|
||||
True if OpenSearch could be reached, False if it could not.
|
||||
"""
|
||||
return self._client.ping()
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def close(self) -> None:
|
||||
"""Closes the client.
|
||||
|
||||
TODO(andrei): Can we have some way to auto close when the client no
|
||||
longer has any references?
|
||||
|
||||
Raises:
|
||||
Exception: There was an error closing the client.
|
||||
"""
|
||||
self._client.close()
|
||||
|
||||
def _get_hits_and_profile_from_search_result(
|
||||
self, result: dict[str, Any]
|
||||
) -> tuple[list[Any], int | None, bool | None, dict[str, Any], dict[str, Any]]:
|
||||
@@ -945,14 +1039,7 @@ def wait_for_opensearch_with_timeout(
|
||||
Returns:
|
||||
True if OpenSearch is ready, False otherwise.
|
||||
"""
|
||||
made_client = False
|
||||
try:
|
||||
if client is None:
|
||||
# NOTE: index_name does not matter because we are only using this object
|
||||
# to ping.
|
||||
# TODO(andrei): Make this better.
|
||||
client = OpenSearchClient(index_name="")
|
||||
made_client = True
|
||||
with nullcontext(client) if client else OpenSearchClient() as client:
|
||||
time_start = time.monotonic()
|
||||
while True:
|
||||
if client.ping():
|
||||
@@ -969,7 +1056,3 @@ def wait_for_opensearch_with_timeout(
|
||||
f"[OpenSearch] Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={wait_limit_s:.1f}"
|
||||
)
|
||||
time.sleep(wait_interval_s)
|
||||
finally:
|
||||
if made_client:
|
||||
assert client is not None
|
||||
client.close()
|
||||
|
||||
@@ -6,7 +6,7 @@ import httpx
|
||||
from opensearchpy import NotFoundError
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH
|
||||
from onyx.configs.app_configs import VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT
|
||||
from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
from onyx.configs.constants import PUBLIC_DOC_PAT
|
||||
@@ -40,6 +40,7 @@ from onyx.document_index.interfaces_new import IndexingMetadata
|
||||
from onyx.document_index.interfaces_new import MetadataUpdateRequest
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.client import OpenSearchIndexClient
|
||||
from onyx.document_index.opensearch.client import SearchHit
|
||||
from onyx.document_index.opensearch.cluster_settings import OPENSEARCH_CLUSTER_SETTINGS
|
||||
from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
|
||||
@@ -93,6 +94,25 @@ def generate_opensearch_filtered_access_control_list(
|
||||
return list(access_control_list)
|
||||
|
||||
|
||||
def set_cluster_state(client: OpenSearchClient) -> None:
|
||||
if not client.put_cluster_settings(settings=OPENSEARCH_CLUSTER_SETTINGS):
|
||||
logger.error(
|
||||
"Failed to put cluster settings. If the settings have never been set before, "
|
||||
"this may cause unexpected index creation when indexing documents into an "
|
||||
"index that does not exist, or may cause expected logs to not appear. If this "
|
||||
"is not the first time running Onyx against this instance of OpenSearch, these "
|
||||
"settings have likely already been set. Not taking any further action..."
|
||||
)
|
||||
client.create_search_pipeline(
|
||||
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
|
||||
pipeline_body=MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
|
||||
)
|
||||
client.create_search_pipeline(
|
||||
pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
|
||||
pipeline_body=ZSCORE_NORMALIZATION_PIPELINE_CONFIG,
|
||||
)
|
||||
|
||||
|
||||
def _convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
chunk: DocumentChunk,
|
||||
score: float | None,
|
||||
@@ -248,6 +268,8 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
embedding_dim: int,
|
||||
embedding_precision: EmbeddingPrecision,
|
||||
secondary_index_name: str | None,
|
||||
large_chunks_enabled: bool, # noqa: ARG002
|
||||
secondary_large_chunks_enabled: bool | None, # noqa: ARG002
|
||||
@@ -258,10 +280,6 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
index_name=index_name,
|
||||
secondary_index_name=secondary_index_name,
|
||||
)
|
||||
if multitenant:
|
||||
raise ValueError(
|
||||
"Bug: OpenSearch is not yet ready for multitenant environments but something tried to use it."
|
||||
)
|
||||
if multitenant != MULTI_TENANT:
|
||||
raise ValueError(
|
||||
"Bug: Multitenant mismatch when initializing an OpenSearchDocumentIndex. "
|
||||
@@ -269,8 +287,10 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
)
|
||||
tenant_id = get_current_tenant_id()
|
||||
self._real_index = OpenSearchDocumentIndex(
|
||||
index_name=index_name,
|
||||
tenant_state=TenantState(tenant_id=tenant_id, multitenant=multitenant),
|
||||
index_name=index_name,
|
||||
embedding_dim=embedding_dim,
|
||||
embedding_precision=embedding_precision,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -279,9 +299,8 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
embedding_dims: list[int],
|
||||
embedding_precisions: list[EmbeddingPrecision],
|
||||
) -> None:
|
||||
# TODO(andrei): Implement.
|
||||
raise NotImplementedError(
|
||||
"Multitenant index registration is not yet implemented for OpenSearch."
|
||||
"Bug: Multitenant index registration is not supported for OpenSearch."
|
||||
)
|
||||
|
||||
def ensure_indices_exist(
|
||||
@@ -471,19 +490,37 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
for an OpenSearch search engine instance. It handles the complete lifecycle
|
||||
of document chunks within a specific OpenSearch index/schema.
|
||||
|
||||
Although not yet used in this way in the codebase, each kind of embedding
|
||||
used should correspond to a different instance of this class, and therefore
|
||||
a different index in OpenSearch.
|
||||
Each kind of embedding used should correspond to a different instance of
|
||||
this class, and therefore a different index in OpenSearch.
|
||||
|
||||
If in a multitenant environment and
|
||||
VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT, will verify and create the index
|
||||
if necessary on initialization. This is because there is no logic which runs
|
||||
on cluster restart which scans through all search settings over all tenants
|
||||
and creates the relevant indices.
|
||||
|
||||
Args:
|
||||
tenant_state: The tenant state of the caller.
|
||||
index_name: The name of the index to interact with.
|
||||
embedding_dim: The dimensionality of the embeddings used for the index.
|
||||
embedding_precision: The precision of the embeddings used for the index.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
tenant_state: TenantState,
|
||||
index_name: str,
|
||||
embedding_dim: int,
|
||||
embedding_precision: EmbeddingPrecision,
|
||||
) -> None:
|
||||
self._index_name: str = index_name
|
||||
self._tenant_state: TenantState = tenant_state
|
||||
self._os_client = OpenSearchClient(index_name=self._index_name)
|
||||
self._client = OpenSearchIndexClient(index_name=self._index_name)
|
||||
|
||||
if self._tenant_state.multitenant and VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT:
|
||||
self.verify_and_create_index_if_necessary(
|
||||
embedding_dim=embedding_dim, embedding_precision=embedding_precision
|
||||
)
|
||||
|
||||
def verify_and_create_index_if_necessary(
|
||||
self,
|
||||
@@ -492,10 +529,15 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
) -> None:
|
||||
"""Verifies and creates the index if necessary.
|
||||
|
||||
Also puts the desired cluster settings.
|
||||
Also puts the desired cluster settings if not in a multitenant
|
||||
environment.
|
||||
|
||||
Also puts the desired search pipeline state, creating the pipelines if
|
||||
they do not exist and updating them otherwise.
|
||||
Also puts the desired search pipeline state if not in a multitenant
|
||||
environment, creating the pipelines if they do not exist and updating
|
||||
them otherwise.
|
||||
|
||||
In a multitenant environment, the above steps happen explicitly on
|
||||
setup.
|
||||
|
||||
Args:
|
||||
embedding_dim: Vector dimensionality for the vector similarity part
|
||||
@@ -508,47 +550,33 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
search pipelines.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Verifying and creating index {self._index_name} if necessary, "
|
||||
f"with embedding dimension {embedding_dim}."
|
||||
f"[OpenSearchDocumentIndex] Verifying and creating index {self._index_name} if "
|
||||
f"necessary, with embedding dimension {embedding_dim}."
|
||||
)
|
||||
|
||||
if not self._tenant_state.multitenant:
|
||||
set_cluster_state(self._client)
|
||||
|
||||
expected_mappings = DocumentSchema.get_document_schema(
|
||||
embedding_dim, self._tenant_state.multitenant
|
||||
)
|
||||
if not self._os_client.put_cluster_settings(
|
||||
settings=OPENSEARCH_CLUSTER_SETTINGS
|
||||
):
|
||||
logger.error(
|
||||
f"Failed to put cluster settings for index {self._index_name}. If the settings have never been set before this "
|
||||
"may cause unexpected index creation when indexing documents into an index that does not exist, or may cause "
|
||||
"expected logs to not appear. If this is not the first time running Onyx against this instance of OpenSearch, "
|
||||
"these settings have likely already been set. Not taking any further action..."
|
||||
)
|
||||
if not self._os_client.index_exists():
|
||||
if USING_AWS_MANAGED_OPENSEARCH:
|
||||
index_settings = (
|
||||
DocumentSchema.get_index_settings_for_aws_managed_opensearch()
|
||||
)
|
||||
else:
|
||||
index_settings = DocumentSchema.get_index_settings()
|
||||
self._os_client.create_index(
|
||||
|
||||
if not self._client.index_exists():
|
||||
index_settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
self._client.create_index(
|
||||
mappings=expected_mappings,
|
||||
settings=index_settings,
|
||||
)
|
||||
if not self._os_client.validate_index(
|
||||
expected_mappings=expected_mappings,
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"The index {self._index_name} is not valid. The expected mappings do not match the actual mappings."
|
||||
)
|
||||
|
||||
self._os_client.create_search_pipeline(
|
||||
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
|
||||
pipeline_body=MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
|
||||
)
|
||||
self._os_client.create_search_pipeline(
|
||||
pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
|
||||
pipeline_body=ZSCORE_NORMALIZATION_PIPELINE_CONFIG,
|
||||
)
|
||||
else:
|
||||
# Ensure schema is up to date by applying the current mappings.
|
||||
try:
|
||||
self._client.put_mapping(expected_mappings)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to update mappings for index {self._index_name}. This likely means a "
|
||||
f"field type was changed which requires reindexing. Error: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
def index(
|
||||
self,
|
||||
@@ -620,7 +648,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
)
|
||||
# Now index. This will raise if a chunk of the same ID exists, which
|
||||
# we do not expect because we should have deleted all chunks.
|
||||
self._os_client.bulk_index_documents(
|
||||
self._client.bulk_index_documents(
|
||||
documents=chunk_batch,
|
||||
tenant_state=self._tenant_state,
|
||||
)
|
||||
@@ -660,7 +688,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
tenant_state=self._tenant_state,
|
||||
)
|
||||
|
||||
return self._os_client.delete_by_query(query_body)
|
||||
return self._client.delete_by_query(query_body)
|
||||
|
||||
def update(
|
||||
self,
|
||||
@@ -760,7 +788,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
document_id=doc_id,
|
||||
chunk_index=chunk_index,
|
||||
)
|
||||
self._os_client.update_document(
|
||||
self._client.update_document(
|
||||
document_chunk_id=document_chunk_id,
|
||||
properties_to_update=properties_to_update,
|
||||
)
|
||||
@@ -799,7 +827,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
min_chunk_index=chunk_request.min_chunk_ind,
|
||||
max_chunk_index=chunk_request.max_chunk_ind,
|
||||
)
|
||||
search_hits = self._os_client.search(
|
||||
search_hits = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
)
|
||||
@@ -849,7 +877,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
# NOTE: Using z-score normalization here because it's better for hybrid search from a theoretical standpoint.
|
||||
# Empirically on a small dataset of up to 10K docs, it's not very different. Likely more impactful at scale.
|
||||
# https://opensearch.org/blog/introducing-the-z-score-normalization-technique-for-hybrid-search/
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._os_client.search(
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
|
||||
)
|
||||
@@ -881,7 +909,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
index_filters=filters,
|
||||
num_to_retrieve=num_to_retrieve,
|
||||
)
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._os_client.search(
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
)
|
||||
@@ -909,6 +937,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
# Do not raise if the document already exists, just update. This is
|
||||
# because the document may already have been indexed during the
|
||||
# OpenSearch transition period.
|
||||
self._os_client.bulk_index_documents(
|
||||
self._client.bulk_index_documents(
|
||||
documents=chunks, tenant_state=self._tenant_state, update_if_exists=True
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ from pydantic import model_validator
|
||||
from pydantic import SerializerFunctionWrapHandler
|
||||
|
||||
from onyx.configs.app_configs import OPENSEARCH_TEXT_ANALYZER
|
||||
from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
|
||||
from onyx.document_index.opensearch.constants import EF_CONSTRUCTION
|
||||
@@ -525,7 +526,7 @@ class DocumentSchema:
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_index_settings_for_aws_managed_opensearch() -> dict[str, Any]:
|
||||
def get_index_settings_for_aws_managed_opensearch_st_dev() -> dict[str, Any]:
|
||||
"""
|
||||
Settings for AWS-managed OpenSearch.
|
||||
|
||||
@@ -546,3 +547,41 @@ class DocumentSchema:
|
||||
"knn.algo_param.ef_search": EF_SEARCH,
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_index_settings_for_aws_managed_opensearch_mt_cloud() -> dict[str, Any]:
|
||||
"""
|
||||
Settings for AWS-managed OpenSearch in multi-tenant cloud.
|
||||
|
||||
324 shards very roughly targets a storage load of ~30Gb per shard, which
|
||||
according to AWS OpenSearch documentation is within a good target range.
|
||||
|
||||
As documented above we need 2 replicas for a total of 3 copies of the
|
||||
data because the cluster is configured with 3-AZ awareness.
|
||||
"""
|
||||
return {
|
||||
"index": {
|
||||
"number_of_shards": 324,
|
||||
"number_of_replicas": 2,
|
||||
# Required for vector search.
|
||||
"knn": True,
|
||||
"knn.algo_param.ef_search": EF_SEARCH,
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_index_settings_based_on_environment() -> dict[str, Any]:
|
||||
"""
|
||||
Returns the index settings based on the environment.
|
||||
"""
|
||||
if USING_AWS_MANAGED_OPENSEARCH:
|
||||
if MULTI_TENANT:
|
||||
return (
|
||||
DocumentSchema.get_index_settings_for_aws_managed_opensearch_mt_cloud()
|
||||
)
|
||||
else:
|
||||
return (
|
||||
DocumentSchema.get_index_settings_for_aws_managed_opensearch_st_dev()
|
||||
)
|
||||
else:
|
||||
return DocumentSchema.get_index_settings()
|
||||
|
||||
@@ -28,6 +28,7 @@ from onyx.document_index.opensearch.schema import HIDDEN_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import LAST_UPDATED_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import MAX_CHUNK_SIZE_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import METADATA_LIST_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import PERSONAS_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import PUBLIC_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import set_or_convert_timezone_to_utc
|
||||
from onyx.document_index.opensearch.schema import SOURCE_TYPE_FIELD_NAME
|
||||
@@ -144,6 +145,7 @@ class DocumentQuery:
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=min_chunk_index,
|
||||
max_chunk_index=max_chunk_index,
|
||||
@@ -202,6 +204,7 @@ class DocumentQuery:
|
||||
document_sets=[],
|
||||
user_file_ids=[],
|
||||
project_id=None,
|
||||
persona_id=None,
|
||||
time_cutoff=None,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -267,6 +270,7 @@ class DocumentQuery:
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -334,6 +338,7 @@ class DocumentQuery:
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -496,6 +501,7 @@ class DocumentQuery:
|
||||
document_sets: list[str],
|
||||
user_file_ids: list[UUID],
|
||||
project_id: int | None,
|
||||
persona_id: int | None,
|
||||
time_cutoff: datetime | None,
|
||||
min_chunk_index: int | None,
|
||||
max_chunk_index: int | None,
|
||||
@@ -530,6 +536,8 @@ class DocumentQuery:
|
||||
retrieved.
|
||||
project_id: If not None, only documents with this project ID in user
|
||||
projects will be retrieved.
|
||||
persona_id: If not None, only documents whose personas array
|
||||
contains this persona ID will be retrieved.
|
||||
time_cutoff: Time cutoff for the documents to retrieve. If not None,
|
||||
Documents which were last updated before this date will not be
|
||||
returned. For documents which do not have a value for their last
|
||||
@@ -627,6 +635,9 @@ class DocumentQuery:
|
||||
)
|
||||
return user_project_filter
|
||||
|
||||
def _get_persona_filter(persona_id: int) -> dict[str, Any]:
|
||||
return {"term": {PERSONAS_FIELD_NAME: {"value": persona_id}}}
|
||||
|
||||
def _get_time_cutoff_filter(time_cutoff: datetime) -> dict[str, Any]:
|
||||
# Convert to UTC if not already so the cutoff is comparable to the
|
||||
# document data.
|
||||
@@ -780,6 +791,9 @@ class DocumentQuery:
|
||||
# document's user projects list.
|
||||
filter_clauses.append(_get_user_project_filter(project_id))
|
||||
|
||||
if persona_id is not None:
|
||||
filter_clauses.append(_get_persona_filter(persona_id))
|
||||
|
||||
if time_cutoff is not None:
|
||||
# If a time cutoff is provided, the caller will only retrieve
|
||||
# documents where the document was last updated at or after the time
|
||||
|
||||
@@ -12,6 +12,7 @@ from onyx.document_index.vespa_constants import DOCUMENT_ID
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_SETS
|
||||
from onyx.document_index.vespa_constants import HIDDEN
|
||||
from onyx.document_index.vespa_constants import METADATA_LIST
|
||||
from onyx.document_index.vespa_constants import PERSONAS
|
||||
from onyx.document_index.vespa_constants import SOURCE_TYPE
|
||||
from onyx.document_index.vespa_constants import TENANT_ID
|
||||
from onyx.document_index.vespa_constants import USER_PROJECT
|
||||
@@ -149,6 +150,18 @@ def build_vespa_filters(
|
||||
# Vespa YQL 'contains' expects a string literal; quote the integer
|
||||
return f'({USER_PROJECT} contains "{pid}") and '
|
||||
|
||||
def _build_persona_filter(
|
||||
persona_id: int | None,
|
||||
) -> str:
|
||||
if persona_id is None:
|
||||
return ""
|
||||
try:
|
||||
pid = int(persona_id)
|
||||
except Exception:
|
||||
logger.warning(f"Invalid persona ID: {persona_id}")
|
||||
return ""
|
||||
return f'({PERSONAS} contains "{pid}") and '
|
||||
|
||||
# Start building the filter string
|
||||
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
|
||||
|
||||
@@ -192,6 +205,9 @@ def build_vespa_filters(
|
||||
# User project filter (array<int> attribute membership)
|
||||
filter_str += _build_user_project_filter(filters.project_id)
|
||||
|
||||
# Persona filter (array<int> attribute membership)
|
||||
filter_str += _build_persona_filter(filters.persona_id)
|
||||
|
||||
# Time filter
|
||||
filter_str += _build_time_filter(filters.time_cutoff)
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ from onyx.db.models import Persona
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.db.user_file import fetch_chunk_counts_for_user_files
|
||||
from onyx.db.user_file import fetch_persona_ids_for_user_files
|
||||
from onyx.db.user_file import fetch_user_project_ids_for_user_files
|
||||
from onyx.file_store.utils import store_user_file_plaintext
|
||||
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
|
||||
@@ -119,6 +120,10 @@ class UserFileIndexingAdapter:
|
||||
user_file_ids=updatable_ids,
|
||||
db_session=self.db_session,
|
||||
)
|
||||
user_file_id_to_persona_ids = fetch_persona_ids_for_user_files(
|
||||
user_file_ids=updatable_ids,
|
||||
db_session=self.db_session,
|
||||
)
|
||||
user_file_id_to_access: dict[str, DocumentAccess] = get_access_for_user_files(
|
||||
user_file_ids=updatable_ids,
|
||||
db_session=self.db_session,
|
||||
@@ -182,7 +187,7 @@ class UserFileIndexingAdapter:
|
||||
user_project=user_file_id_to_project_ids.get(
|
||||
chunk.source_document.id, []
|
||||
),
|
||||
personas=[],
|
||||
personas=user_file_id_to_persona_ids.get(chunk.source_document.id, []),
|
||||
boost=DEFAULT_BOOST,
|
||||
tenant_id=tenant_id,
|
||||
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],
|
||||
|
||||
@@ -67,6 +67,18 @@ Status checked against LiteLLM v1.81.6-nightly (2026-02-02):
|
||||
STATUS: STILL NEEDED - litellm_core_utils/litellm_logging.py lines 3185-3199 set
|
||||
usage as a dict with chat completion format instead of keeping it as
|
||||
ResponseAPIUsage. Our patch creates a deep copy before modification.
|
||||
|
||||
7. Responses API metadata=None TypeError (_patch_responses_metadata_none):
|
||||
- LiteLLM's @client decorator wrapper in utils.py uses kwargs.get("metadata", {})
|
||||
to check for router calls, but when metadata is explicitly None (key exists with
|
||||
value None), the default {} is not used
|
||||
- This causes "argument of type 'NoneType' is not iterable" TypeError which swallows
|
||||
the real exception (e.g. AuthenticationError for wrong API key)
|
||||
- Surfaces as: APIConnectionError: OpenAIException - argument of type 'NoneType' is
|
||||
not iterable
|
||||
STATUS: STILL NEEDED - litellm/utils.py wrapper function (line 1721) does not guard
|
||||
against metadata being explicitly None. Triggered when Responses API bridge
|
||||
passes **litellm_params containing metadata=None.
|
||||
"""
|
||||
|
||||
import time
|
||||
@@ -725,6 +737,44 @@ def _patch_logging_assembled_streaming_response() -> None:
|
||||
LiteLLMLoggingObj._get_assembled_streaming_response = _patched_get_assembled_streaming_response # type: ignore[method-assign]
|
||||
|
||||
|
||||
def _patch_responses_metadata_none() -> None:
|
||||
"""
|
||||
Patches litellm.responses to normalize metadata=None to metadata={} in kwargs.
|
||||
|
||||
LiteLLM's @client decorator wrapper in utils.py (line 1721) does:
|
||||
_is_litellm_router_call = "model_group" in kwargs.get("metadata", {})
|
||||
When metadata is explicitly None in kwargs, kwargs.get("metadata", {}) returns
|
||||
None (the key exists, so the default is not used), causing:
|
||||
TypeError: argument of type 'NoneType' is not iterable
|
||||
|
||||
This swallows the real exception (e.g. AuthenticationError) and surfaces as:
|
||||
APIConnectionError: OpenAIException - argument of type 'NoneType' is not iterable
|
||||
|
||||
This happens when the Responses API bridge calls litellm.responses() with
|
||||
**litellm_params which may contain metadata=None.
|
||||
|
||||
STATUS: STILL NEEDED - litellm/utils.py wrapper function uses kwargs.get("metadata", {})
|
||||
which does not guard against metadata being explicitly None. Same pattern exists
|
||||
on line 1407 for async path.
|
||||
"""
|
||||
import litellm as _litellm
|
||||
from functools import wraps
|
||||
|
||||
original_responses = _litellm.responses
|
||||
|
||||
if getattr(original_responses, "_metadata_patched", False):
|
||||
return
|
||||
|
||||
@wraps(original_responses)
|
||||
def _patched_responses(*args: Any, **kwargs: Any) -> Any:
|
||||
if kwargs.get("metadata") is None:
|
||||
kwargs["metadata"] = {}
|
||||
return original_responses(*args, **kwargs)
|
||||
|
||||
_patched_responses._metadata_patched = True # type: ignore[attr-defined]
|
||||
_litellm.responses = _patched_responses
|
||||
|
||||
|
||||
def apply_monkey_patches() -> None:
|
||||
"""
|
||||
Apply all necessary monkey patches to LiteLLM for compatibility.
|
||||
@@ -736,6 +786,7 @@ def apply_monkey_patches() -> None:
|
||||
- Patching AzureOpenAIResponsesAPIConfig.should_fake_stream to enable native streaming
|
||||
- Patching ResponsesAPIResponse.model_construct to fix usage format in all code paths
|
||||
- Patching LiteLLMLoggingObj._get_assembled_streaming_response to avoid mutating original response
|
||||
- Patching litellm.responses to fix metadata=None causing TypeError in error handling
|
||||
"""
|
||||
_patch_ollama_chunk_parser()
|
||||
_patch_openai_responses_parallel_tool_calls()
|
||||
@@ -743,3 +794,4 @@ def apply_monkey_patches() -> None:
|
||||
_patch_azure_responses_should_fake_stream()
|
||||
_patch_responses_api_usage_format()
|
||||
_patch_logging_assembled_streaming_response()
|
||||
_patch_responses_metadata_none()
|
||||
|
||||
@@ -32,11 +32,14 @@ from onyx.auth.schemas import UserUpdate
|
||||
from onyx.auth.users import auth_backend
|
||||
from onyx.auth.users import create_onyx_oauth_router
|
||||
from onyx.auth.users import fastapi_users
|
||||
from onyx.cache.interface import CacheBackendType
|
||||
from onyx.configs.app_configs import APP_API_PREFIX
|
||||
from onyx.configs.app_configs import APP_HOST
|
||||
from onyx.configs.app_configs import APP_PORT
|
||||
from onyx.configs.app_configs import AUTH_RATE_LIMITING_ENABLED
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import CACHE_BACKEND
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import LOG_ENDPOINT_LATENCY
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_ID
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
@@ -254,8 +257,53 @@ def include_auth_router_with_prefix(
|
||||
)
|
||||
|
||||
|
||||
def validate_cache_backend_settings() -> None:
|
||||
"""Validate that CACHE_BACKEND=postgres is only used with DISABLE_VECTOR_DB.
|
||||
|
||||
The Postgres cache backend eliminates the Redis dependency, but only works
|
||||
when Celery is not running (which requires DISABLE_VECTOR_DB=true).
|
||||
"""
|
||||
if CACHE_BACKEND == CacheBackendType.POSTGRES and not DISABLE_VECTOR_DB:
|
||||
raise RuntimeError(
|
||||
"CACHE_BACKEND=postgres requires DISABLE_VECTOR_DB=true. "
|
||||
"The Postgres cache backend is only supported in no-vector-DB "
|
||||
"deployments where Celery is replaced by the in-process task runner."
|
||||
)
|
||||
|
||||
|
||||
def validate_no_vector_db_settings() -> None:
|
||||
"""Validate that DISABLE_VECTOR_DB is not combined with incompatible settings.
|
||||
|
||||
Raises RuntimeError if DISABLE_VECTOR_DB is set alongside MULTI_TENANT or ENABLE_CRAFT,
|
||||
since these modes require infrastructure that is removed in no-vector-DB deployments.
|
||||
"""
|
||||
if not DISABLE_VECTOR_DB:
|
||||
return
|
||||
|
||||
if MULTI_TENANT:
|
||||
raise RuntimeError(
|
||||
"DISABLE_VECTOR_DB cannot be used with MULTI_TENANT. "
|
||||
"Multi-tenant deployments require the vector database for "
|
||||
"per-tenant document indexing and search. Run in single-tenant "
|
||||
"mode when disabling the vector database."
|
||||
)
|
||||
|
||||
from onyx.server.features.build.configs import ENABLE_CRAFT
|
||||
|
||||
if ENABLE_CRAFT:
|
||||
raise RuntimeError(
|
||||
"DISABLE_VECTOR_DB cannot be used with ENABLE_CRAFT. "
|
||||
"Onyx Craft requires background workers for sandbox lifecycle "
|
||||
"management, which are removed in no-vector-DB deployments. "
|
||||
"Disable Craft (ENABLE_CRAFT=false) when disabling the vector database."
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
|
||||
validate_no_vector_db_settings()
|
||||
validate_cache_backend_settings()
|
||||
|
||||
# Set recursion limit
|
||||
if SYSTEM_RECURSION_LIMIT is not None:
|
||||
sys.setrecursionlimit(SYSTEM_RECURSION_LIMIT)
|
||||
@@ -324,8 +372,20 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
|
||||
if AUTH_RATE_LIMITING_ENABLED:
|
||||
await setup_auth_limiter()
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
from onyx.background.periodic_poller import recover_stuck_user_files
|
||||
from onyx.background.periodic_poller import start_periodic_poller
|
||||
|
||||
recover_stuck_user_files(POSTGRES_DEFAULT_SCHEMA)
|
||||
start_periodic_poller(POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
yield
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
from onyx.background.periodic_poller import stop_periodic_poller
|
||||
|
||||
stop_periodic_poller()
|
||||
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
if AUTH_RATE_LIMITING_ENABLED:
|
||||
|
||||
@@ -592,11 +592,8 @@ def build_slack_response_blocks(
|
||||
)
|
||||
|
||||
citations_blocks = []
|
||||
document_blocks = []
|
||||
if answer.citation_info:
|
||||
citations_blocks = _build_citations_blocks(answer)
|
||||
else:
|
||||
document_blocks = _priority_ordered_documents_blocks(answer)
|
||||
|
||||
citations_divider = [DividerBlock()] if citations_blocks else []
|
||||
buttons_divider = [DividerBlock()] if web_follow_up_block or follow_up_block else []
|
||||
@@ -608,7 +605,6 @@ def build_slack_response_blocks(
|
||||
+ ai_feedback_block
|
||||
+ citations_divider
|
||||
+ citations_blocks
|
||||
+ document_blocks
|
||||
+ buttons_divider
|
||||
+ web_follow_up_block
|
||||
+ follow_up_block
|
||||
|
||||
@@ -1,10 +1,59 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from mistune import create_markdown
|
||||
from mistune import HTMLRenderer
|
||||
|
||||
_CITATION_LINK_PATTERN = re.compile(r"\[\[\d+\]\]\(")
|
||||
# Tags that should be replaced with a newline (line-break and block-level elements)
|
||||
_HTML_NEWLINE_TAG_PATTERN = re.compile(
|
||||
r"<br\s*/?>|</(?:p|div|li|h[1-6]|tr|blockquote|section|article)>",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Strips HTML tags but excludes autolinks like <https://...> and <mailto:...>
|
||||
_HTML_TAG_PATTERN = re.compile(
|
||||
r"<(?!https?://|mailto:)/?[a-zA-Z][^>]*>",
|
||||
)
|
||||
|
||||
# Matches fenced code blocks (``` ... ```) so we can skip sanitization inside them
|
||||
_FENCED_CODE_BLOCK_PATTERN = re.compile(r"```[\s\S]*?```")
|
||||
|
||||
# Matches the start of any markdown link: [text]( or [[n]](
|
||||
# The inner group handles nested brackets for citation links like [[1]](.
|
||||
_MARKDOWN_LINK_PATTERN = re.compile(r"\[(?:[^\[\]]|\[[^\]]*\])*\]\(")
|
||||
|
||||
# Matches Slack-style links <url|text> that LLMs sometimes output directly.
|
||||
# Mistune doesn't recognise this syntax, so text() would escape the angle
|
||||
# brackets and Slack would render them as literal text instead of links.
|
||||
_SLACK_LINK_PATTERN = re.compile(r"<(https?://[^|>]+)\|([^>]+)>")
|
||||
|
||||
|
||||
def _sanitize_html(text: str) -> str:
|
||||
"""Strip HTML tags from a text fragment.
|
||||
|
||||
Block-level closing tags and <br> are converted to newlines.
|
||||
All other HTML tags are removed. Autolinks (<https://...>) are preserved.
|
||||
"""
|
||||
text = _HTML_NEWLINE_TAG_PATTERN.sub("\n", text)
|
||||
text = _HTML_TAG_PATTERN.sub("", text)
|
||||
return text
|
||||
|
||||
|
||||
def _transform_outside_code_blocks(
|
||||
message: str, transform: Callable[[str], str]
|
||||
) -> str:
|
||||
"""Apply *transform* only to text outside fenced code blocks."""
|
||||
parts = _FENCED_CODE_BLOCK_PATTERN.split(message)
|
||||
code_blocks = _FENCED_CODE_BLOCK_PATTERN.findall(message)
|
||||
|
||||
result: list[str] = []
|
||||
for i, part in enumerate(parts):
|
||||
result.append(transform(part))
|
||||
if i < len(code_blocks):
|
||||
result.append(code_blocks[i])
|
||||
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def _extract_link_destination(message: str, start_idx: int) -> tuple[str, int | None]:
|
||||
@@ -29,15 +78,21 @@ def _extract_link_destination(message: str, start_idx: int) -> tuple[str, int |
|
||||
return message[start_idx:], None
|
||||
|
||||
|
||||
def _normalize_citation_link_destinations(message: str) -> str:
|
||||
"""Wrap citation URLs in angle brackets so markdown parsers handle parentheses safely."""
|
||||
if "[[" not in message:
|
||||
def _normalize_link_destinations(message: str) -> str:
|
||||
"""Wrap markdown link URLs in angle brackets so the parser handles special chars safely.
|
||||
|
||||
Markdown link syntax [text](url) breaks when the URL contains unescaped
|
||||
parentheses, spaces, or other special characters. Wrapping the URL in angle
|
||||
brackets — [text](<url>) — tells the parser to treat everything inside as
|
||||
a literal URL. This applies to all links, not just citations.
|
||||
"""
|
||||
if "](" not in message:
|
||||
return message
|
||||
|
||||
normalized_parts: list[str] = []
|
||||
cursor = 0
|
||||
|
||||
while match := _CITATION_LINK_PATTERN.search(message, cursor):
|
||||
while match := _MARKDOWN_LINK_PATTERN.search(message, cursor):
|
||||
normalized_parts.append(message[cursor : match.end()])
|
||||
destination_start = match.end()
|
||||
destination, end_idx = _extract_link_destination(message, destination_start)
|
||||
@@ -57,18 +112,38 @@ def _normalize_citation_link_destinations(message: str) -> str:
|
||||
return "".join(normalized_parts)
|
||||
|
||||
|
||||
def _convert_slack_links_to_markdown(message: str) -> str:
|
||||
"""Convert Slack-style <url|text> links to standard markdown [text](url).
|
||||
|
||||
LLMs sometimes emit Slack mrkdwn link syntax directly. Mistune doesn't
|
||||
recognise it, so the angle brackets would be escaped by text() and Slack
|
||||
would render the link as literal text instead of a clickable link.
|
||||
"""
|
||||
return _transform_outside_code_blocks(
|
||||
message, lambda text: _SLACK_LINK_PATTERN.sub(r"[\2](\1)", text)
|
||||
)
|
||||
|
||||
|
||||
def format_slack_message(message: str | None) -> str:
|
||||
if message is None:
|
||||
return ""
|
||||
message = _transform_outside_code_blocks(message, _sanitize_html)
|
||||
message = _convert_slack_links_to_markdown(message)
|
||||
normalized_message = _normalize_link_destinations(message)
|
||||
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough"])
|
||||
normalized_message = _normalize_citation_link_destinations(message)
|
||||
result = md(normalized_message)
|
||||
# With HTMLRenderer, result is always str (not AST list)
|
||||
assert isinstance(result, str)
|
||||
return result
|
||||
return result.rstrip("\n")
|
||||
|
||||
|
||||
class SlackRenderer(HTMLRenderer):
|
||||
"""Renders markdown as Slack mrkdwn format instead of HTML.
|
||||
|
||||
Overrides all HTMLRenderer methods that produce HTML tags to ensure
|
||||
no raw HTML ever appears in Slack messages.
|
||||
"""
|
||||
|
||||
SPECIALS: dict[str, str] = {"&": "&", "<": "<", ">": ">"}
|
||||
|
||||
def escape_special(self, text: str) -> str:
|
||||
@@ -77,7 +152,7 @@ class SlackRenderer(HTMLRenderer):
|
||||
return text
|
||||
|
||||
def heading(self, text: str, level: int, **attrs: Any) -> str: # noqa: ARG002
|
||||
return f"*{text}*\n"
|
||||
return f"*{text}*\n\n"
|
||||
|
||||
def emphasis(self, text: str) -> str:
|
||||
return f"_{text}_"
|
||||
@@ -96,7 +171,7 @@ class SlackRenderer(HTMLRenderer):
|
||||
count += 1
|
||||
prefix = f"{count}. " if ordered else "• "
|
||||
lines[i] = f"{prefix}{line[4:]}"
|
||||
return "\n".join(lines)
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
def list_item(self, text: str) -> str:
|
||||
return f"li: {text}\n"
|
||||
@@ -118,7 +193,30 @@ class SlackRenderer(HTMLRenderer):
|
||||
return f"`{text}`"
|
||||
|
||||
def block_code(self, code: str, info: str | None = None) -> str: # noqa: ARG002
|
||||
return f"```\n{code}\n```\n"
|
||||
return f"```\n{code.rstrip(chr(10))}\n```\n\n"
|
||||
|
||||
def linebreak(self) -> str:
|
||||
return "\n"
|
||||
|
||||
def thematic_break(self) -> str:
|
||||
return "---\n\n"
|
||||
|
||||
def block_quote(self, text: str) -> str:
|
||||
lines = text.strip().split("\n")
|
||||
quoted = "\n".join(f">{line}" for line in lines)
|
||||
return quoted + "\n\n"
|
||||
|
||||
def block_html(self, html: str) -> str:
|
||||
return _sanitize_html(html) + "\n\n"
|
||||
|
||||
def block_error(self, text: str) -> str:
|
||||
return f"```\n{text}\n```\n\n"
|
||||
|
||||
def text(self, text: str) -> str:
|
||||
# Only escape the three entities Slack recognizes: & < >
|
||||
# HTMLRenderer.text() also escapes " to " which Slack renders
|
||||
# as literal " text since Slack doesn't recognize that entity.
|
||||
return self.escape_special(text)
|
||||
|
||||
def paragraph(self, text: str) -> str:
|
||||
return f"{text}\n"
|
||||
return f"{text}\n\n"
|
||||
|
||||
@@ -3,10 +3,12 @@ import datetime
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_FEEDBACK_REMINDER
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.user_preferences import activate_user
|
||||
from onyx.db.users import add_slack_user_if_not_exists
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.onyxbot.slack.blocks import get_feedback_reminder_blocks
|
||||
@@ -243,6 +245,44 @@ def handle_message(
|
||||
)
|
||||
return False
|
||||
|
||||
elif (
|
||||
not existing_user.is_active
|
||||
and existing_user.role == UserRole.SLACK_USER
|
||||
):
|
||||
check_seat_fn = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license",
|
||||
"check_seat_availability",
|
||||
None,
|
||||
)
|
||||
seat_result = check_seat_fn(db_session=db_session)
|
||||
if seat_result is not None and not seat_result.available:
|
||||
logger.info(
|
||||
f"Blocked inactive Slack user {message_info.email}: "
|
||||
f"{seat_result.error_message}"
|
||||
)
|
||||
respond_in_thread_or_channel(
|
||||
client=client,
|
||||
channel=channel,
|
||||
thread_ts=message_info.msg_to_respond,
|
||||
text=(
|
||||
"We weren't able to respond because your organization "
|
||||
"has reached its user seat limit. Your account is "
|
||||
"currently deactivated and cannot be reactivated "
|
||||
"until more seats are available. Please contact "
|
||||
"your Onyx administrator."
|
||||
),
|
||||
)
|
||||
return False
|
||||
|
||||
activate_user(existing_user, db_session)
|
||||
invalidate_license_cache_fn = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license",
|
||||
"invalidate_license_cache",
|
||||
None,
|
||||
)
|
||||
invalidate_license_cache_fn()
|
||||
logger.info(f"Reactivated inactive Slack user {message_info.email}")
|
||||
|
||||
add_slack_user_if_not_exists(db_session, message_info.email)
|
||||
|
||||
# first check if we need to respond with a standard answer
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate AGENTS.md by scanning the files directory and populating the template.
|
||||
|
||||
This script runs at container startup, AFTER the init container has synced files
|
||||
from S3. It scans the /workspace/files directory to discover what knowledge sources
|
||||
are available and generates appropriate documentation.
|
||||
This script runs during session setup, AFTER files have been synced from S3
|
||||
and the files symlink has been created. It reads an existing AGENTS.md (which
|
||||
contains the {{KNOWLEDGE_SOURCES_SECTION}} placeholder), replaces the
|
||||
placeholder by scanning the knowledge source directory, and writes it back.
|
||||
|
||||
Environment variables:
|
||||
- AGENT_INSTRUCTIONS: The template content with placeholders to replace
|
||||
Usage:
|
||||
python3 generate_agents_md.py <agents_md_path> <files_path>
|
||||
|
||||
Arguments:
|
||||
agents_md_path: Path to the AGENTS.md file to update in place
|
||||
files_path: Path to the files directory to scan for knowledge sources
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
@@ -189,49 +193,39 @@ def build_knowledge_sources_section(files_path: Path) -> str:
|
||||
def main() -> None:
|
||||
"""Main entry point for container startup script.
|
||||
|
||||
Is called by the container startup script to scan /workspace/files and populate
|
||||
the knowledge sources section.
|
||||
Reads an existing AGENTS.md, replaces the {{KNOWLEDGE_SOURCES_SECTION}}
|
||||
placeholder by scanning the files directory, and writes it back.
|
||||
|
||||
Usage:
|
||||
python3 generate_agents_md.py <agents_md_path> <files_path>
|
||||
"""
|
||||
# Read template from environment variable
|
||||
template = os.environ.get("AGENT_INSTRUCTIONS", "")
|
||||
if not template:
|
||||
print("Warning: No AGENT_INSTRUCTIONS template provided", file=sys.stderr)
|
||||
template = "# Agent Instructions\n\nNo instructions provided."
|
||||
if len(sys.argv) != 3:
|
||||
print(
|
||||
f"Usage: {sys.argv[0]} <agents_md_path> <files_path>",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Scan files directory - check /workspace/files first, then /workspace/demo_data
|
||||
files_path = Path("/workspace/files")
|
||||
demo_data_path = Path("/workspace/demo_data")
|
||||
agents_md_path = Path(sys.argv[1])
|
||||
files_path = Path(sys.argv[2])
|
||||
|
||||
# Use demo_data if files doesn't exist or is empty
|
||||
if not files_path.exists() or not any(files_path.iterdir()):
|
||||
if demo_data_path.exists():
|
||||
files_path = demo_data_path
|
||||
if not agents_md_path.exists():
|
||||
print(f"Error: {agents_md_path} not found", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
knowledge_sources_section = build_knowledge_sources_section(files_path)
|
||||
template = agents_md_path.read_text()
|
||||
|
||||
# Replace placeholders
|
||||
content = template
|
||||
content = content.replace(
|
||||
# Resolve symlinks (handles both direct symlinks and dirs containing symlinks)
|
||||
resolved_files_path = files_path.resolve()
|
||||
|
||||
knowledge_sources_section = build_knowledge_sources_section(resolved_files_path)
|
||||
|
||||
# Replace placeholder and write back
|
||||
content = template.replace(
|
||||
"{{KNOWLEDGE_SOURCES_SECTION}}", knowledge_sources_section
|
||||
)
|
||||
|
||||
# Write AGENTS.md
|
||||
output_path = Path("/workspace/AGENTS.md")
|
||||
output_path.write_text(content)
|
||||
|
||||
# Log result
|
||||
source_count = 0
|
||||
if files_path.exists():
|
||||
source_count = len(
|
||||
[
|
||||
d
|
||||
for d in files_path.iterdir()
|
||||
if d.is_dir() and not d.name.startswith(".")
|
||||
]
|
||||
)
|
||||
print(
|
||||
f"Generated AGENTS.md with {source_count} knowledge sources from {files_path}"
|
||||
)
|
||||
agents_md_path.write_text(content)
|
||||
print(f"Populated knowledge sources in {agents_md_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1352,6 +1352,9 @@ fi
|
||||
echo "Writing AGENTS.md"
|
||||
printf '%s' '{agent_instructions_escaped}' > {session_path}/AGENTS.md
|
||||
|
||||
# Populate knowledge sources by scanning the files directory
|
||||
python3 /usr/local/bin/generate_agents_md.py {session_path}/AGENTS.md {session_path}/files || true
|
||||
|
||||
# Write opencode config
|
||||
echo "Writing opencode.json"
|
||||
printf '%s' '{opencode_json_escaped}' > {session_path}/opencode.json
|
||||
@@ -1780,6 +1783,9 @@ ln -sf {symlink_target} {session_path}/files
|
||||
echo "Writing AGENTS.md"
|
||||
printf '%s' '{agent_instructions_escaped}' > {session_path}/AGENTS.md
|
||||
|
||||
# Populate knowledge sources by scanning the files directory
|
||||
python3 /usr/local/bin/generate_agents_md.py {session_path}/AGENTS.md {session_path}/files || true
|
||||
|
||||
# Write opencode config
|
||||
echo "Writing opencode.json"
|
||||
printf '%s' '{opencode_json_escaped}' > {session_path}/opencode.json
|
||||
|
||||
@@ -32,7 +32,7 @@ from onyx.db.persona import get_persona_snapshots_for_user
|
||||
from onyx.db.persona import get_persona_snapshots_paginated
|
||||
from onyx.db.persona import mark_persona_as_deleted
|
||||
from onyx.db.persona import mark_persona_as_not_deleted
|
||||
from onyx.db.persona import update_persona_is_default
|
||||
from onyx.db.persona import update_persona_featured
|
||||
from onyx.db.persona import update_persona_label
|
||||
from onyx.db.persona import update_persona_public_status
|
||||
from onyx.db.persona import update_persona_shared
|
||||
@@ -130,8 +130,8 @@ class IsPublicRequest(BaseModel):
|
||||
is_public: bool
|
||||
|
||||
|
||||
class IsDefaultRequest(BaseModel):
|
||||
is_default_persona: bool
|
||||
class IsFeaturedRequest(BaseModel):
|
||||
featured: bool
|
||||
|
||||
|
||||
@admin_router.patch("/{persona_id}/visible")
|
||||
@@ -168,22 +168,22 @@ def patch_user_persona_public_status(
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.patch("/{persona_id}/default")
|
||||
def patch_persona_default_status(
|
||||
@admin_router.patch("/{persona_id}/featured")
|
||||
def patch_persona_featured_status(
|
||||
persona_id: int,
|
||||
is_default_request: IsDefaultRequest,
|
||||
is_featured_request: IsFeaturedRequest,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
update_persona_is_default(
|
||||
update_persona_featured(
|
||||
persona_id=persona_id,
|
||||
is_default=is_default_request.is_default_persona,
|
||||
featured=is_featured_request.featured,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to update persona default status")
|
||||
logger.exception("Failed to update persona featured status")
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
@@ -405,6 +405,7 @@ class PersonaShareRequest(BaseModel):
|
||||
user_ids: list[UUID] | None = None
|
||||
group_ids: list[int] | None = None
|
||||
is_public: bool | None = None
|
||||
label_ids: list[int] | None = None
|
||||
|
||||
|
||||
# We notify each user when a user is shared with them
|
||||
@@ -415,14 +416,22 @@ def share_persona(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
update_persona_shared(
|
||||
persona_id=persona_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
user_ids=persona_share_request.user_ids,
|
||||
group_ids=persona_share_request.group_ids,
|
||||
is_public=persona_share_request.is_public,
|
||||
)
|
||||
try:
|
||||
update_persona_shared(
|
||||
persona_id=persona_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
user_ids=persona_share_request.user_ids,
|
||||
group_ids=persona_share_request.group_ids,
|
||||
is_public=persona_share_request.is_public,
|
||||
label_ids=persona_share_request.label_ids,
|
||||
)
|
||||
except PermissionError as e:
|
||||
logger.exception("Failed to share persona")
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to share persona")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@basic_router.delete("/{persona_id}", tags=PUBLIC_API_TAGS)
|
||||
|
||||
@@ -5,7 +5,6 @@ from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import HierarchyNodeType
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import HierarchyNode
|
||||
@@ -108,11 +107,7 @@ class PersonaUpsertRequest(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
document_set_ids: list[int]
|
||||
num_chunks: float
|
||||
is_public: bool
|
||||
recency_bias: RecencyBiasSetting
|
||||
llm_filter_extraction: bool
|
||||
llm_relevance_filter: bool
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
starter_messages: list[StarterMessage] | None = None
|
||||
@@ -128,7 +123,7 @@ class PersonaUpsertRequest(BaseModel):
|
||||
)
|
||||
search_start_date: datetime | None = None
|
||||
label_ids: list[int] | None = None
|
||||
is_default_persona: bool = False
|
||||
featured: bool = False
|
||||
display_priority: int | None = None
|
||||
# Accept string UUIDs from frontend
|
||||
user_file_ids: list[str] | None = None
|
||||
@@ -155,9 +150,6 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
tools: list[ToolSnapshot]
|
||||
starter_messages: list[StarterMessage] | None
|
||||
|
||||
llm_relevance_filter: bool
|
||||
llm_filter_extraction: bool
|
||||
|
||||
# only show document sets in the UI that the assistant has access to
|
||||
document_sets: list[DocumentSetSummary]
|
||||
# Counts for knowledge sources (used to determine if search tool should be enabled)
|
||||
@@ -175,7 +167,7 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
is_public: bool
|
||||
is_visible: bool
|
||||
display_priority: int | None
|
||||
is_default_persona: bool
|
||||
featured: bool
|
||||
builtin_persona: bool
|
||||
|
||||
# Used for filtering
|
||||
@@ -214,8 +206,6 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
if should_expose_tool_to_fe(tool)
|
||||
],
|
||||
starter_messages=persona.starter_messages,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
document_sets=[
|
||||
DocumentSetSummary.from_model(document_set)
|
||||
for document_set in persona.document_sets
|
||||
@@ -230,7 +220,7 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
is_public=persona.is_public,
|
||||
is_visible=persona.is_visible,
|
||||
display_priority=persona.display_priority,
|
||||
is_default_persona=persona.is_default_persona,
|
||||
featured=persona.featured,
|
||||
builtin_persona=persona.builtin_persona,
|
||||
labels=[PersonaLabelSnapshot.from_model(label) for label in persona.labels],
|
||||
owner=(
|
||||
@@ -252,11 +242,9 @@ class PersonaSnapshot(BaseModel):
|
||||
# Return string UUIDs to frontend for consistency
|
||||
user_file_ids: list[str]
|
||||
display_priority: int | None
|
||||
is_default_persona: bool
|
||||
featured: bool
|
||||
builtin_persona: bool
|
||||
starter_messages: list[StarterMessage] | None
|
||||
llm_relevance_filter: bool
|
||||
llm_filter_extraction: bool
|
||||
tools: list[ToolSnapshot]
|
||||
labels: list["PersonaLabelSnapshot"]
|
||||
owner: MinimalUserSnapshot | None
|
||||
@@ -265,7 +253,6 @@ class PersonaSnapshot(BaseModel):
|
||||
document_sets: list[DocumentSetSummary]
|
||||
llm_model_provider_override: str | None
|
||||
llm_model_version_override: str | None
|
||||
num_chunks: float | None
|
||||
# Hierarchy nodes attached for scoped search
|
||||
hierarchy_nodes: list[HierarchyNodeSnapshot] = Field(default_factory=list)
|
||||
# Individual documents attached for scoped search
|
||||
@@ -289,11 +276,9 @@ class PersonaSnapshot(BaseModel):
|
||||
icon_name=persona.icon_name,
|
||||
user_file_ids=[str(file.id) for file in persona.user_files],
|
||||
display_priority=persona.display_priority,
|
||||
is_default_persona=persona.is_default_persona,
|
||||
featured=persona.featured,
|
||||
builtin_persona=persona.builtin_persona,
|
||||
starter_messages=persona.starter_messages,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
tools=[
|
||||
ToolSnapshot.from_model(tool)
|
||||
for tool in persona.tools
|
||||
@@ -324,7 +309,6 @@ class PersonaSnapshot(BaseModel):
|
||||
],
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
num_chunks=persona.num_chunks,
|
||||
system_prompt=persona.system_prompt,
|
||||
replace_base_system_prompt=persona.replace_base_system_prompt,
|
||||
task_prompt=persona.task_prompt,
|
||||
@@ -332,12 +316,10 @@ class PersonaSnapshot(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
# Model with full context on perona's internal settings
|
||||
# Model with full context on persona's internal settings
|
||||
# This is used for flows which need to know all settings
|
||||
class FullPersonaSnapshot(PersonaSnapshot):
|
||||
search_start_date: datetime | None = None
|
||||
llm_relevance_filter: bool = False
|
||||
llm_filter_extraction: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
@@ -360,7 +342,7 @@ class FullPersonaSnapshot(PersonaSnapshot):
|
||||
icon_name=persona.icon_name,
|
||||
user_file_ids=[str(file.id) for file in persona.user_files],
|
||||
display_priority=persona.display_priority,
|
||||
is_default_persona=persona.is_default_persona,
|
||||
featured=persona.featured,
|
||||
builtin_persona=persona.builtin_persona,
|
||||
starter_messages=persona.starter_messages,
|
||||
users=[
|
||||
@@ -391,10 +373,7 @@ class FullPersonaSnapshot(PersonaSnapshot):
|
||||
DocumentSetSummary.from_model(document_set_model)
|
||||
for document_set_model in persona.document_sets
|
||||
],
|
||||
num_chunks=persona.num_chunks,
|
||||
search_start_date=persona.search_start_date,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
system_prompt=persona.system_prompt,
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import BackgroundTasks
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import Form
|
||||
@@ -12,13 +13,7 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
enqueue_user_file_project_sync_task,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
get_user_file_project_sync_queue_depth,
|
||||
)
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -34,7 +29,6 @@ from onyx.db.models import UserProject
|
||||
from onyx.db.persona import get_personas_by_ids
|
||||
from onyx.db.projects import get_project_token_count
|
||||
from onyx.db.projects import upload_files_to_user_files_with_indexing
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.features.projects.models import CategorizedFilesSnapshot
|
||||
from onyx.server.features.projects.models import ChatSessionRequest
|
||||
from onyx.server.features.projects.models import TokenCountResponse
|
||||
@@ -55,7 +49,27 @@ class UserFileDeleteResult(BaseModel):
|
||||
assistant_names: list[str] = []
|
||||
|
||||
|
||||
def _trigger_user_file_project_sync(user_file_id: UUID, tenant_id: str) -> None:
|
||||
def _trigger_user_file_project_sync(
|
||||
user_file_id: UUID,
|
||||
tenant_id: str,
|
||||
background_tasks: BackgroundTasks | None = None,
|
||||
) -> None:
|
||||
if DISABLE_VECTOR_DB and background_tasks is not None:
|
||||
from onyx.background.task_utils import drain_project_sync_loop
|
||||
|
||||
background_tasks.add_task(drain_project_sync_loop, tenant_id)
|
||||
logger.info(f"Queued in-process project sync for user_file_id={user_file_id}")
|
||||
return
|
||||
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
enqueue_user_file_project_sync_task,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
get_user_file_project_sync_queue_depth,
|
||||
)
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
|
||||
queue_depth = get_user_file_project_sync_queue_depth(client_app)
|
||||
if queue_depth > USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH:
|
||||
logger.warning(
|
||||
@@ -111,6 +125,7 @@ def create_project(
|
||||
|
||||
@router.post("/file/upload", tags=PUBLIC_API_TAGS)
|
||||
def upload_user_files(
|
||||
bg_tasks: BackgroundTasks,
|
||||
files: list[UploadFile] = File(...),
|
||||
project_id: int | None = Form(None),
|
||||
temp_id_map: str | None = Form(None), # JSON string mapping hashed key -> temp_id
|
||||
@@ -137,12 +152,12 @@ def upload_user_files(
|
||||
user=user,
|
||||
temp_id_map=parsed_temp_id_map,
|
||||
db_session=db_session,
|
||||
background_tasks=bg_tasks if DISABLE_VECTOR_DB else None,
|
||||
)
|
||||
|
||||
return CategorizedFilesSnapshot.from_result(categorized_files_result)
|
||||
|
||||
except Exception as e:
|
||||
# Log error with type, message, and stack for easier debugging
|
||||
logger.exception(f"Error uploading files - {type(e).__name__}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
@@ -192,6 +207,7 @@ def get_files_in_project(
|
||||
def unlink_user_file_from_project(
|
||||
project_id: int,
|
||||
file_id: UUID,
|
||||
bg_tasks: BackgroundTasks,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
@@ -208,7 +224,6 @@ def unlink_user_file_from_project(
|
||||
if project is None:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
user_id = user.id
|
||||
user_file = (
|
||||
db_session.query(UserFile)
|
||||
.filter(UserFile.id == file_id, UserFile.user_id == user_id)
|
||||
@@ -224,7 +239,7 @@ def unlink_user_file_from_project(
|
||||
db_session.commit()
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
_trigger_user_file_project_sync(user_file.id, tenant_id)
|
||||
_trigger_user_file_project_sync(user_file.id, tenant_id, bg_tasks)
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
@@ -237,6 +252,7 @@ def unlink_user_file_from_project(
|
||||
def link_user_file_to_project(
|
||||
project_id: int,
|
||||
file_id: UUID,
|
||||
bg_tasks: BackgroundTasks,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserFileSnapshot:
|
||||
@@ -268,7 +284,7 @@ def link_user_file_to_project(
|
||||
db_session.commit()
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
_trigger_user_file_project_sync(user_file.id, tenant_id)
|
||||
_trigger_user_file_project_sync(user_file.id, tenant_id, bg_tasks)
|
||||
|
||||
return UserFileSnapshot.from_model(user_file)
|
||||
|
||||
@@ -335,7 +351,7 @@ def upsert_project_instructions(
|
||||
class ProjectPayload(BaseModel):
|
||||
project: UserProjectSnapshot
|
||||
files: list[UserFileSnapshot] | None = None
|
||||
persona_id_to_is_default: dict[int, bool] | None = None
|
||||
persona_id_to_featured: dict[int, bool] | None = None
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -354,13 +370,11 @@ def get_project_details(
|
||||
if session.persona_id is not None
|
||||
]
|
||||
personas = get_personas_by_ids(persona_ids, db_session)
|
||||
persona_id_to_is_default = {
|
||||
persona.id: persona.is_default_persona for persona in personas
|
||||
}
|
||||
persona_id_to_featured = {persona.id: persona.featured for persona in personas}
|
||||
return ProjectPayload(
|
||||
project=project,
|
||||
files=files,
|
||||
persona_id_to_is_default=persona_id_to_is_default,
|
||||
persona_id_to_featured=persona_id_to_featured,
|
||||
)
|
||||
|
||||
|
||||
@@ -426,6 +440,7 @@ def delete_project(
|
||||
@router.delete("/file/{file_id}", tags=PUBLIC_API_TAGS)
|
||||
def delete_user_file(
|
||||
file_id: UUID,
|
||||
bg_tasks: BackgroundTasks,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserFileDeleteResult:
|
||||
@@ -458,15 +473,25 @@ def delete_user_file(
|
||||
db_session.commit()
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file.id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_DELETE,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered delete for user_file_id={user_file.id} with task_id={task.id}"
|
||||
)
|
||||
if DISABLE_VECTOR_DB:
|
||||
from onyx.background.task_utils import drain_delete_loop
|
||||
|
||||
bg_tasks.add_task(drain_delete_loop, tenant_id)
|
||||
logger.info(f"Queued in-process delete for user_file_id={user_file.id}")
|
||||
else:
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file.id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_DELETE,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered delete for user_file_id={user_file.id} "
|
||||
f"with task_id={task.id}"
|
||||
)
|
||||
|
||||
return UserFileDeleteResult(
|
||||
has_associations=False, project_names=[], assistant_names=[]
|
||||
)
|
||||
|
||||
@@ -7,13 +7,14 @@ from PIL import UnidentifiedImageError
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.file_processing.password_validation import is_file_password_protected
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -116,7 +117,9 @@ def estimate_image_tokens_for_upload(
|
||||
pass
|
||||
|
||||
|
||||
def categorize_uploaded_files(files: list[UploadFile]) -> CategorizedFiles:
|
||||
def categorize_uploaded_files(
|
||||
files: list[UploadFile], db_session: Session
|
||||
) -> CategorizedFiles:
|
||||
"""
|
||||
Categorize uploaded files based on text extractability and tokenized length.
|
||||
|
||||
@@ -128,11 +131,11 @@ def categorize_uploaded_files(files: list[UploadFile]) -> CategorizedFiles:
|
||||
"""
|
||||
|
||||
results = CategorizedFiles()
|
||||
llm = get_default_llm()
|
||||
default_model = fetch_default_llm_model(db_session)
|
||||
|
||||
tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name, provider_type=llm.config.model_provider
|
||||
)
|
||||
model_name = default_model.name if default_model else None
|
||||
provider_type = default_model.llm_provider.provider if default_model else None
|
||||
tokenizer = get_tokenizer(model_name=model_name, provider_type=provider_type)
|
||||
|
||||
# Check if threshold checks should be skipped
|
||||
skip_threshold = False
|
||||
|
||||
@@ -5,7 +5,6 @@ from sqlalchemy.orm import Session
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
|
||||
from onyx.configs.kg_configs import KG_BETA_ASSISTANT_DESCRIPTION
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.entities import get_entity_stats_by_grounded_source_name
|
||||
from onyx.db.entity_type import get_configured_entity_types
|
||||
@@ -134,11 +133,7 @@ def enable_or_disable_kg(
|
||||
system_prompt=KG_BETA_ASSISTANT_SYSTEM_PROMPT,
|
||||
task_prompt=KG_BETA_ASSISTANT_TASK_PROMPT,
|
||||
datetime_aware=False,
|
||||
num_chunks=25,
|
||||
llm_relevance_filter=False,
|
||||
is_public=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
document_set_ids=[],
|
||||
tool_ids=[search_tool.id, kg_tool.id],
|
||||
llm_model_provider_override=None,
|
||||
@@ -147,7 +142,7 @@ def enable_or_disable_kg(
|
||||
users=[user.id],
|
||||
groups=[],
|
||||
label_ids=[],
|
||||
is_default_persona=False,
|
||||
featured=False,
|
||||
display_priority=0,
|
||||
user_file_ids=[],
|
||||
)
|
||||
|
||||
@@ -97,7 +97,6 @@ def _build_llm_provider_request(
|
||||
), # Only this from source
|
||||
api_base=api_base, # From request
|
||||
api_version=api_version, # From request
|
||||
default_model_name=model_name,
|
||||
deployment_name=deployment_name, # From request
|
||||
is_public=True,
|
||||
groups=[],
|
||||
@@ -136,7 +135,6 @@ def _build_llm_provider_request(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
default_model_name=model_name,
|
||||
deployment_name=deployment_name,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
@@ -168,7 +166,6 @@ def _create_image_gen_llm_provider__no_commit(
|
||||
api_key=provider_request.api_key,
|
||||
api_base=provider_request.api_base,
|
||||
api_version=provider_request.api_version,
|
||||
default_model_name=provider_request.default_model_name,
|
||||
deployment_name=provider_request.deployment_name,
|
||||
is_public=provider_request.is_public,
|
||||
custom_config=provider_request.custom_config,
|
||||
|
||||
@@ -22,7 +22,10 @@ from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.llm import can_user_access_llm_provider
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.db.llm import fetch_default_vision_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import fetch_existing_llm_provider_by_id
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_existing_models
|
||||
from onyx.db.llm import fetch_persona_with_groups
|
||||
@@ -52,11 +55,12 @@ from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
)
|
||||
from onyx.server.manage.llm.models import BedrockFinalModelResponse
|
||||
from onyx.server.manage.llm.models import BedrockModelsRequest
|
||||
from onyx.server.manage.llm.models import DefaultModel
|
||||
from onyx.server.manage.llm.models import LLMCost
|
||||
from onyx.server.manage.llm.models import LLMProviderDescriptor
|
||||
from onyx.server.manage.llm.models import LLMProviderResponse
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.manage.llm.models import OllamaFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OllamaModelDetails
|
||||
from onyx.server.manage.llm.models import OllamaModelsRequest
|
||||
@@ -233,12 +237,9 @@ def test_llm_configuration(
|
||||
|
||||
test_api_key = test_llm_request.api_key
|
||||
test_custom_config = test_llm_request.custom_config
|
||||
if test_llm_request.name:
|
||||
# NOTE: we are querying by name. we probably should be querying by an invariant id, but
|
||||
# as it turns out the name is not editable in the UI and other code also keys off name,
|
||||
# so we won't rock the boat just yet.
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=test_llm_request.name, db_session=db_session
|
||||
if test_llm_request.id:
|
||||
existing_provider = fetch_existing_llm_provider_by_id(
|
||||
id=test_llm_request.id, db_session=db_session
|
||||
)
|
||||
if existing_provider:
|
||||
test_custom_config = _restore_masked_custom_config_values(
|
||||
@@ -268,7 +269,7 @@ def test_llm_configuration(
|
||||
|
||||
llm = get_llm(
|
||||
provider=test_llm_request.provider,
|
||||
model=test_llm_request.default_model_name,
|
||||
model=test_llm_request.model,
|
||||
api_key=test_api_key,
|
||||
api_base=test_llm_request.api_base,
|
||||
api_version=test_llm_request.api_version,
|
||||
@@ -303,7 +304,7 @@ def list_llm_providers(
|
||||
include_image_gen: bool = Query(False),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LLMProviderView]:
|
||||
) -> LLMProviderResponse[LLMProviderView]:
|
||||
start_time = datetime.now(timezone.utc)
|
||||
logger.debug("Starting to fetch LLM providers")
|
||||
|
||||
@@ -328,7 +329,15 @@ def list_llm_providers(
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
logger.debug(f"Completed fetching LLM providers in {duration:.2f} seconds")
|
||||
|
||||
return llm_provider_list
|
||||
return LLMProviderResponse[LLMProviderView].from_models(
|
||||
providers=llm_provider_list,
|
||||
default_text=DefaultModel.from_model_config(
|
||||
fetch_default_llm_model(db_session)
|
||||
),
|
||||
default_vision=DefaultModel.from_model_config(
|
||||
fetch_default_vision_model(db_session)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@admin_router.put("/provider")
|
||||
@@ -344,18 +353,44 @@ def put_llm_provider(
|
||||
# validate request (e.g. if we're intending to create but the name already exists we should throw an error)
|
||||
# NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache
|
||||
# the result
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
existing_provider = None
|
||||
if llm_provider_upsert_request.id:
|
||||
existing_provider = fetch_existing_llm_provider_by_id(
|
||||
id=llm_provider_upsert_request.id, db_session=db_session
|
||||
)
|
||||
|
||||
# Check name constraints
|
||||
# TODO: Once port from name to id is complete, unique name will no longer be required
|
||||
if existing_provider and llm_provider_upsert_request.name != existing_provider.name:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Renaming providers is not currently supported",
|
||||
)
|
||||
|
||||
found_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
)
|
||||
if found_provider is not None and found_provider is not existing_provider:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Provider with name={llm_provider_upsert_request.name} already exists",
|
||||
)
|
||||
|
||||
if existing_provider and is_creation:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"LLM Provider with name {llm_provider_upsert_request.name} already exists",
|
||||
detail=(
|
||||
f"LLM Provider with name {llm_provider_upsert_request.name} and "
|
||||
f"id={llm_provider_upsert_request.id} already exists"
|
||||
),
|
||||
)
|
||||
elif not existing_provider and not is_creation:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"LLM Provider with name {llm_provider_upsert_request.name} does not exist",
|
||||
detail=(
|
||||
f"LLM Provider with name {llm_provider_upsert_request.name} and "
|
||||
f"id={llm_provider_upsert_request.id} does not exist"
|
||||
),
|
||||
)
|
||||
|
||||
# SSRF Protection: Validate api_base and custom_config match stored values
|
||||
@@ -393,22 +428,6 @@ def put_llm_provider(
|
||||
deduplicated_personas.append(persona_id)
|
||||
llm_provider_upsert_request.personas = deduplicated_personas
|
||||
|
||||
default_model_found = False
|
||||
|
||||
for model_configuration in llm_provider_upsert_request.model_configurations:
|
||||
if model_configuration.name == llm_provider_upsert_request.default_model_name:
|
||||
model_configuration.is_visible = True
|
||||
default_model_found = True
|
||||
|
||||
# TODO: Remove this logic on api change
|
||||
# Believed to be a dead pathway but we want to be safe for now
|
||||
if not default_model_found:
|
||||
llm_provider_upsert_request.model_configurations.append(
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=llm_provider_upsert_request.default_model_name, is_visible=True
|
||||
)
|
||||
)
|
||||
|
||||
# the llm api key is sanitized when returned to clients, so the only time we
|
||||
# should get a real key is when it is explicitly changed
|
||||
if existing_provider and not llm_provider_upsert_request.api_key_changed:
|
||||
@@ -438,8 +457,8 @@ def put_llm_provider(
|
||||
config = fetch_llm_recommendations_from_github()
|
||||
if config and llm_provider_upsert_request.provider in config.providers:
|
||||
# Refetch the provider to get the updated model
|
||||
updated_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
updated_provider = fetch_existing_llm_provider_by_id(
|
||||
id=result.id, db_session=db_session
|
||||
)
|
||||
if updated_provider:
|
||||
sync_auto_mode_models(
|
||||
@@ -469,28 +488,29 @@ def delete_llm_provider(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.post("/provider/{provider_id}/default")
|
||||
@admin_router.post("/default")
|
||||
def set_provider_as_default(
|
||||
provider_id: int,
|
||||
default_model_request: DefaultModel,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
update_default_provider(provider_id=provider_id, db_session=db_session)
|
||||
update_default_provider(
|
||||
provider_id=default_model_request.provider_id,
|
||||
model_name=default_model_request.model_name,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
@admin_router.post("/provider/{provider_id}/default-vision")
|
||||
@admin_router.post("/default-vision")
|
||||
def set_provider_as_default_vision(
|
||||
provider_id: int,
|
||||
vision_model: str | None = Query(
|
||||
None, description="The default vision model to use"
|
||||
),
|
||||
default_model: DefaultModel,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
if vision_model is None:
|
||||
raise HTTPException(status_code=404, detail="Vision model not provided")
|
||||
update_default_vision_provider(
|
||||
provider_id=provider_id, vision_model=vision_model, db_session=db_session
|
||||
provider_id=default_model.provider_id,
|
||||
vision_model=default_model.model_name,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
@@ -516,7 +536,7 @@ def get_auto_config(
|
||||
def get_vision_capable_providers(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[VisionProviderResponse]:
|
||||
) -> LLMProviderResponse[VisionProviderResponse]:
|
||||
"""Return a list of LLM providers and their models that support image input"""
|
||||
vision_models = fetch_existing_models(
|
||||
db_session=db_session, flow_types=[LLMModelFlowType.VISION]
|
||||
@@ -545,7 +565,13 @@ def get_vision_capable_providers(
|
||||
]
|
||||
|
||||
logger.debug(f"Found {len(vision_provider_response)} vision-capable providers")
|
||||
return vision_provider_response
|
||||
|
||||
return LLMProviderResponse[VisionProviderResponse].from_models(
|
||||
providers=vision_provider_response,
|
||||
default_vision=DefaultModel.from_model_config(
|
||||
fetch_default_vision_model(db_session)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
"""Endpoints for all"""
|
||||
@@ -555,7 +581,7 @@ def get_vision_capable_providers(
|
||||
def list_llm_provider_basics(
|
||||
user: User = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LLMProviderDescriptor]:
|
||||
) -> LLMProviderResponse[LLMProviderDescriptor]:
|
||||
"""Get LLM providers accessible to the current user.
|
||||
|
||||
Returns:
|
||||
@@ -577,9 +603,9 @@ def list_llm_provider_basics(
|
||||
for provider in all_providers:
|
||||
# Use centralized access control logic with persona=None since we're
|
||||
# listing providers without a specific persona context. This correctly:
|
||||
# - Includes all public providers
|
||||
# - Includes public providers WITHOUT persona restrictions
|
||||
# - Includes providers user can access via group membership
|
||||
# - Excludes persona-only restricted providers (requires specific persona)
|
||||
# - Excludes providers with persona restrictions (requires specific persona)
|
||||
# - Excludes non-public providers with no restrictions (admin-only)
|
||||
if can_user_access_llm_provider(
|
||||
provider, user_group_ids, persona=None, is_admin=is_admin
|
||||
@@ -592,7 +618,15 @@ def list_llm_provider_basics(
|
||||
f"Completed fetching {len(accessible_providers)} user-accessible providers in {duration:.2f} seconds"
|
||||
)
|
||||
|
||||
return accessible_providers
|
||||
return LLMProviderResponse[LLMProviderDescriptor].from_models(
|
||||
providers=accessible_providers,
|
||||
default_text=DefaultModel.from_model_config(
|
||||
fetch_default_llm_model(db_session)
|
||||
),
|
||||
default_vision=DefaultModel.from_model_config(
|
||||
fetch_default_vision_model(db_session)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_valid_model_names_for_persona(
|
||||
@@ -604,7 +638,7 @@ def get_valid_model_names_for_persona(
|
||||
|
||||
Returns a list of model names (e.g., ["gpt-4o", "claude-3-5-sonnet"]) that are
|
||||
available to the user when using this persona, respecting all RBAC restrictions.
|
||||
Public providers are always included.
|
||||
Public providers are included unless they have persona restrictions that exclude this persona.
|
||||
"""
|
||||
persona = fetch_persona_with_groups(db_session, persona_id)
|
||||
if not persona:
|
||||
@@ -618,7 +652,7 @@ def get_valid_model_names_for_persona(
|
||||
|
||||
valid_models = []
|
||||
for llm_provider_model in all_providers:
|
||||
# Public providers always included, restricted checked via RBAC
|
||||
# Check access with persona context — respects all RBAC restrictions
|
||||
if can_user_access_llm_provider(
|
||||
llm_provider_model, user_group_ids, persona, is_admin=is_admin
|
||||
):
|
||||
@@ -635,11 +669,11 @@ def list_llm_providers_for_persona(
|
||||
persona_id: int,
|
||||
user: User = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LLMProviderDescriptor]:
|
||||
) -> LLMProviderResponse[LLMProviderDescriptor]:
|
||||
"""Get LLM providers for a specific persona.
|
||||
|
||||
Returns providers that the user can access when using this persona:
|
||||
- All public providers (is_public=True) - ALWAYS included
|
||||
- Public providers (respecting persona restrictions if set)
|
||||
- Restricted providers user can access via group/persona restrictions
|
||||
|
||||
This endpoint is used for background fetching of restricted providers
|
||||
@@ -668,7 +702,7 @@ def list_llm_providers_for_persona(
|
||||
llm_provider_list: list[LLMProviderDescriptor] = []
|
||||
|
||||
for llm_provider_model in all_providers:
|
||||
# Use simplified access check - public providers always included
|
||||
# Check access with persona context — respects persona restrictions
|
||||
if can_user_access_llm_provider(
|
||||
llm_provider_model, user_group_ids, persona, is_admin=is_admin
|
||||
):
|
||||
@@ -682,7 +716,51 @@ def list_llm_providers_for_persona(
|
||||
f"Completed fetching {len(llm_provider_list)} LLM providers for persona {persona_id} in {duration:.2f} seconds"
|
||||
)
|
||||
|
||||
return llm_provider_list
|
||||
# Get the default model and vision model for the persona
|
||||
# TODO: Port persona's over to use ID
|
||||
persona_default_provider = persona.llm_model_provider_override
|
||||
persona_default_model = persona.llm_model_version_override
|
||||
|
||||
default_text_model = fetch_default_llm_model(db_session)
|
||||
default_vision_model = fetch_default_vision_model(db_session)
|
||||
|
||||
# Build default_text and default_vision using persona overrides when available,
|
||||
# falling back to the global defaults.
|
||||
default_text = DefaultModel.from_model_config(default_text_model)
|
||||
default_vision = DefaultModel.from_model_config(default_vision_model)
|
||||
|
||||
if persona_default_provider:
|
||||
provider = fetch_existing_llm_provider(persona_default_provider, db_session)
|
||||
if provider and can_user_access_llm_provider(
|
||||
provider, user_group_ids, persona, is_admin=is_admin
|
||||
):
|
||||
if persona_default_model:
|
||||
# Persona specifies both provider and model — use them directly
|
||||
default_text = DefaultModel(
|
||||
provider_id=provider.id,
|
||||
model_name=persona_default_model,
|
||||
)
|
||||
else:
|
||||
# Persona specifies only the provider — pick a visible (public) model,
|
||||
# falling back to any model on this provider
|
||||
visible_model = next(
|
||||
(mc for mc in provider.model_configurations if mc.is_visible),
|
||||
None,
|
||||
)
|
||||
fallback_model = visible_model or next(
|
||||
iter(provider.model_configurations), None
|
||||
)
|
||||
if fallback_model:
|
||||
default_text = DefaultModel(
|
||||
provider_id=provider.id,
|
||||
model_name=fallback_model.name,
|
||||
)
|
||||
|
||||
return LLMProviderResponse[LLMProviderDescriptor].from_models(
|
||||
providers=llm_provider_list,
|
||||
default_text=default_text,
|
||||
default_vision=default_vision,
|
||||
)
|
||||
|
||||
|
||||
@admin_router.get("/provider-contextual-cost")
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
@@ -21,50 +25,22 @@ if TYPE_CHECKING:
|
||||
ModelConfiguration as ModelConfigurationModel,
|
||||
)
|
||||
|
||||
|
||||
# TODO: Clear this up on api refactor
|
||||
# There is still logic that requires sending each providers default model name
|
||||
# There is no logic that requires sending the providers default vision model name
|
||||
# We only send for the one that is actually the default
|
||||
def get_default_llm_model_name(llm_provider_model: "LLMProviderModel") -> str:
|
||||
"""Find the default conversation model name for a provider.
|
||||
|
||||
Returns the model name if found, otherwise returns empty string.
|
||||
"""
|
||||
for model_config in llm_provider_model.model_configurations:
|
||||
for flow in model_config.llm_model_flows:
|
||||
if flow.is_default and flow.llm_model_flow_type == LLMModelFlowType.CHAT:
|
||||
return model_config.name
|
||||
return ""
|
||||
|
||||
|
||||
def get_default_vision_model_name(llm_provider_model: "LLMProviderModel") -> str | None:
|
||||
"""Find the default vision model name for a provider.
|
||||
|
||||
Returns the model name if found, otherwise returns None.
|
||||
"""
|
||||
for model_config in llm_provider_model.model_configurations:
|
||||
for flow in model_config.llm_model_flows:
|
||||
if flow.is_default and flow.llm_model_flow_type == LLMModelFlowType.VISION:
|
||||
return model_config.name
|
||||
return None
|
||||
T = TypeVar("T", "LLMProviderDescriptor", "LLMProviderView", "VisionProviderResponse")
|
||||
|
||||
|
||||
class TestLLMRequest(BaseModel):
|
||||
# provider level
|
||||
name: str | None = None
|
||||
id: int | None = None
|
||||
provider: str
|
||||
model: str
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
custom_config: dict[str, str] | None = None
|
||||
|
||||
# model level
|
||||
default_model_name: str
|
||||
deployment_name: str | None = None
|
||||
|
||||
model_configurations: list["ModelConfigurationUpsertRequest"]
|
||||
|
||||
# if try and use the existing API/custom config key
|
||||
api_key_changed: bool
|
||||
custom_config_changed: bool
|
||||
@@ -80,13 +56,10 @@ class LLMProviderDescriptor(BaseModel):
|
||||
"""A descriptor for an LLM provider that can be safely viewed by
|
||||
non-admin users. Used when giving a list of available LLMs."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
provider: str
|
||||
provider_display_name: str # Human-friendly name like "Claude (Anthropic)"
|
||||
default_model_name: str
|
||||
is_default_provider: bool | None
|
||||
is_default_vision_provider: bool | None
|
||||
default_vision_model: str | None
|
||||
model_configurations: list["ModelConfigurationView"]
|
||||
|
||||
@classmethod
|
||||
@@ -99,22 +72,12 @@ class LLMProviderDescriptor(BaseModel):
|
||||
)
|
||||
|
||||
provider = llm_provider_model.provider
|
||||
default_model_name = get_default_llm_model_name(llm_provider_model)
|
||||
default_vision_model = get_default_vision_model_name(llm_provider_model)
|
||||
|
||||
is_default_provider = bool(default_model_name)
|
||||
is_default_vision_provider = default_vision_model is not None
|
||||
|
||||
default_model_name = default_model_name or llm_provider_model.default_model_name
|
||||
|
||||
return cls(
|
||||
id=llm_provider_model.id,
|
||||
name=llm_provider_model.name,
|
||||
provider=provider,
|
||||
provider_display_name=get_provider_display_name(provider),
|
||||
default_model_name=default_model_name,
|
||||
is_default_provider=is_default_provider,
|
||||
is_default_vision_provider=is_default_vision_provider,
|
||||
default_vision_model=default_vision_model,
|
||||
model_configurations=filter_model_configurations(
|
||||
llm_provider_model.model_configurations, provider
|
||||
),
|
||||
@@ -128,18 +91,17 @@ class LLMProvider(BaseModel):
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
custom_config: dict[str, str] | None = None
|
||||
default_model_name: str
|
||||
is_public: bool = True
|
||||
is_auto_mode: bool = False
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
personas: list[int] = Field(default_factory=list)
|
||||
deployment_name: str | None = None
|
||||
default_vision_model: str | None = None
|
||||
|
||||
|
||||
class LLMProviderUpsertRequest(LLMProvider):
|
||||
# should only be used for a "custom" provider
|
||||
# for default providers, the built-in model names are used
|
||||
id: int | None = None
|
||||
api_key_changed: bool = False
|
||||
custom_config_changed: bool = False
|
||||
model_configurations: list["ModelConfigurationUpsertRequest"] = []
|
||||
@@ -155,8 +117,6 @@ class LLMProviderView(LLMProvider):
|
||||
"""Stripped down representation of LLMProvider for display / limited access info only"""
|
||||
|
||||
id: int
|
||||
is_default_provider: bool | None = None
|
||||
is_default_vision_provider: bool | None = None
|
||||
model_configurations: list["ModelConfigurationView"]
|
||||
|
||||
@classmethod
|
||||
@@ -178,14 +138,6 @@ class LLMProviderView(LLMProvider):
|
||||
|
||||
provider = llm_provider_model.provider
|
||||
|
||||
default_model_name = get_default_llm_model_name(llm_provider_model)
|
||||
default_vision_model = get_default_vision_model_name(llm_provider_model)
|
||||
|
||||
is_default_provider = bool(default_model_name)
|
||||
is_default_vision_provider = default_vision_model is not None
|
||||
|
||||
default_model_name = default_model_name or llm_provider_model.default_model_name
|
||||
|
||||
return cls(
|
||||
id=llm_provider_model.id,
|
||||
name=llm_provider_model.name,
|
||||
@@ -198,10 +150,6 @@ class LLMProviderView(LLMProvider):
|
||||
api_base=llm_provider_model.api_base,
|
||||
api_version=llm_provider_model.api_version,
|
||||
custom_config=llm_provider_model.custom_config,
|
||||
default_model_name=default_model_name,
|
||||
is_default_provider=is_default_provider,
|
||||
is_default_vision_provider=is_default_vision_provider,
|
||||
default_vision_model=default_vision_model,
|
||||
is_public=llm_provider_model.is_public,
|
||||
is_auto_mode=llm_provider_model.is_auto_mode,
|
||||
groups=groups,
|
||||
@@ -421,3 +369,38 @@ class OpenRouterFinalModelResponse(BaseModel):
|
||||
int | None
|
||||
) # From OpenRouter API context_length (may be missing for some models)
|
||||
supports_image_input: bool
|
||||
|
||||
|
||||
class DefaultModel(BaseModel):
|
||||
provider_id: int
|
||||
model_name: str
|
||||
|
||||
@classmethod
|
||||
def from_model_config(
|
||||
cls, model_config: ModelConfigurationModel | None
|
||||
) -> DefaultModel | None:
|
||||
if not model_config:
|
||||
return None
|
||||
return cls(
|
||||
provider_id=model_config.llm_provider_id,
|
||||
model_name=model_config.name,
|
||||
)
|
||||
|
||||
|
||||
class LLMProviderResponse(BaseModel, Generic[T]):
|
||||
providers: list[T]
|
||||
default_text: DefaultModel | None = None
|
||||
default_vision: DefaultModel | None = None
|
||||
|
||||
@classmethod
|
||||
def from_models(
|
||||
cls,
|
||||
providers: list[T],
|
||||
default_text: DefaultModel | None = None,
|
||||
default_vision: DefaultModel | None = None,
|
||||
) -> LLMProviderResponse[T]:
|
||||
return cls(
|
||||
providers=providers,
|
||||
default_text=default_text,
|
||||
default_vision=default_vision,
|
||||
)
|
||||
|
||||
@@ -198,7 +198,6 @@ def patch_slack_channel_config(
|
||||
channel_name=channel_config["channel_name"],
|
||||
document_set_ids=slack_channel_config_creation_request.document_sets,
|
||||
existing_persona_id=existing_persona_id,
|
||||
enable_auto_filters=slack_channel_config_creation_request.enable_auto_filters,
|
||||
).id
|
||||
|
||||
slack_channel_config_model = update_slack_channel_config(
|
||||
|
||||
27
backend/onyx/server/metrics/per_tenant.py
Normal file
27
backend/onyx/server/metrics/per_tenant.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Per-tenant request counter metric.
|
||||
|
||||
Increments a counter on every request, labelled by tenant, so Grafana can
|
||||
answer "which tenant is generating the most traffic?"
|
||||
"""
|
||||
|
||||
from prometheus_client import Counter
|
||||
from prometheus_fastapi_instrumentator.metrics import Info
|
||||
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
_requests_by_tenant = Counter(
|
||||
"onyx_api_requests_by_tenant_total",
|
||||
"Total API requests by tenant",
|
||||
["tenant_id", "method", "handler", "status"],
|
||||
)
|
||||
|
||||
|
||||
def per_tenant_request_callback(info: Info) -> None:
|
||||
"""Increment per-tenant request counter for every request."""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or "unknown"
|
||||
_requests_by_tenant.labels(
|
||||
tenant_id=tenant_id,
|
||||
method=info.method,
|
||||
handler=info.modified_handler,
|
||||
status=info.modified_status,
|
||||
).inc()
|
||||
@@ -32,6 +32,7 @@ from sqlalchemy.pool import QueuePool
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_ENDPOINT_CONTEXTVAR
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -72,7 +73,7 @@ _checkout_timeout_total = Counter(
|
||||
_connections_held = Gauge(
|
||||
"onyx_db_connections_held_by_endpoint",
|
||||
"Number of DB connections currently held, by endpoint and engine",
|
||||
["handler", "engine"],
|
||||
["handler", "engine", "tenant_id"],
|
||||
)
|
||||
|
||||
_hold_seconds = Histogram(
|
||||
@@ -163,10 +164,14 @@ def _register_pool_events(engine: Engine, label: str) -> None:
|
||||
conn_proxy: PoolProxiedConnection, # noqa: ARG001
|
||||
) -> None:
|
||||
handler = CURRENT_ENDPOINT_CONTEXTVAR.get() or "unknown"
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or "unknown"
|
||||
conn_record.info["_metrics_endpoint"] = handler
|
||||
conn_record.info["_metrics_tenant_id"] = tenant_id
|
||||
conn_record.info["_metrics_checkout_time"] = time.monotonic()
|
||||
_checkout_total.labels(engine=label).inc()
|
||||
_connections_held.labels(handler=handler, engine=label).inc()
|
||||
_connections_held.labels(
|
||||
handler=handler, engine=label, tenant_id=tenant_id
|
||||
).inc()
|
||||
|
||||
@event.listens_for(engine, "checkin")
|
||||
def on_checkin(
|
||||
@@ -174,9 +179,12 @@ def _register_pool_events(engine: Engine, label: str) -> None:
|
||||
conn_record: ConnectionPoolEntry,
|
||||
) -> None:
|
||||
handler = conn_record.info.pop("_metrics_endpoint", "unknown")
|
||||
tenant_id = conn_record.info.pop("_metrics_tenant_id", "unknown")
|
||||
start = conn_record.info.pop("_metrics_checkout_time", None)
|
||||
_checkin_total.labels(engine=label).inc()
|
||||
_connections_held.labels(handler=handler, engine=label).dec()
|
||||
_connections_held.labels(
|
||||
handler=handler, engine=label, tenant_id=tenant_id
|
||||
).dec()
|
||||
if start is not None:
|
||||
_hold_seconds.labels(handler=handler, engine=label).observe(
|
||||
time.monotonic() - start
|
||||
@@ -199,9 +207,12 @@ def _register_pool_events(engine: Engine, label: str) -> None:
|
||||
# Defensively clean up the held-connections gauge in case checkin
|
||||
# doesn't fire after invalidation (e.g. hard pool shutdown).
|
||||
handler = conn_record.info.pop("_metrics_endpoint", None)
|
||||
tenant_id = conn_record.info.pop("_metrics_tenant_id", "unknown")
|
||||
start = conn_record.info.pop("_metrics_checkout_time", None)
|
||||
if handler:
|
||||
_connections_held.labels(handler=handler, engine=label).dec()
|
||||
_connections_held.labels(
|
||||
handler=handler, engine=label, tenant_id=tenant_id
|
||||
).dec()
|
||||
if start is not None:
|
||||
_hold_seconds.labels(handler=handler or "unknown", engine=label).observe(
|
||||
time.monotonic() - start
|
||||
|
||||
@@ -11,9 +11,11 @@ SQLAlchemy connection pool metrics are registered separately via
|
||||
"""
|
||||
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from prometheus_fastapi_instrumentator.metrics import default as default_metrics
|
||||
from sqlalchemy.exc import TimeoutError as SATimeoutError
|
||||
from starlette.applications import Starlette
|
||||
|
||||
from onyx.server.metrics.per_tenant import per_tenant_request_callback
|
||||
from onyx.server.metrics.postgres_connection_pool import pool_timeout_handler
|
||||
from onyx.server.metrics.slow_requests import slow_request_callback
|
||||
|
||||
@@ -59,6 +61,15 @@ def setup_prometheus_metrics(app: Starlette) -> None:
|
||||
excluded_handlers=_EXCLUDED_HANDLERS,
|
||||
)
|
||||
|
||||
# Explicitly create the default metrics (http_requests_total,
|
||||
# http_request_duration_seconds, etc.) and add them first. The library
|
||||
# skips creating defaults when ANY custom instrumentations are registered
|
||||
# via .add(), so we must include them ourselves.
|
||||
default_callback = default_metrics(latency_lowr_buckets=_LATENCY_BUCKETS)
|
||||
if default_callback:
|
||||
instrumentator.add(default_callback)
|
||||
|
||||
instrumentator.add(slow_request_callback)
|
||||
instrumentator.add(per_tenant_request_callback)
|
||||
|
||||
instrumentator.instrument(app, latency_lowr_buckets=_LATENCY_BUCKETS).expose(app)
|
||||
|
||||
@@ -152,10 +152,20 @@ def get_user_chat_sessions(
|
||||
project_id: int | None = None,
|
||||
only_non_project_chats: bool = True,
|
||||
include_failed_chats: bool = False,
|
||||
page_size: int = Query(default=50, ge=1, le=100),
|
||||
before: str | None = Query(default=None),
|
||||
) -> ChatSessionsResponse:
|
||||
user_id = user.id
|
||||
|
||||
try:
|
||||
before_dt = (
|
||||
datetime.datetime.fromisoformat(before) if before is not None else None
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=422, detail="Invalid 'before' timestamp format")
|
||||
|
||||
try:
|
||||
# Fetch one extra to determine if there are more results
|
||||
chat_sessions = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
@@ -163,11 +173,16 @@ def get_user_chat_sessions(
|
||||
project_id=project_id,
|
||||
only_non_project_chats=only_non_project_chats,
|
||||
include_failed_chats=include_failed_chats,
|
||||
limit=page_size + 1,
|
||||
before=before_dt,
|
||||
)
|
||||
|
||||
except ValueError:
|
||||
raise ValueError("Chat session does not exist or has been deleted")
|
||||
|
||||
has_more = len(chat_sessions) > page_size
|
||||
chat_sessions = chat_sessions[:page_size]
|
||||
|
||||
return ChatSessionsResponse(
|
||||
sessions=[
|
||||
ChatSessionDetails(
|
||||
@@ -181,7 +196,8 @@ def get_user_chat_sessions(
|
||||
current_temperature_override=chat.temperature_override,
|
||||
)
|
||||
for chat in chat_sessions
|
||||
]
|
||||
],
|
||||
has_more=has_more,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -192,6 +192,7 @@ class ChatSessionDetails(BaseModel):
|
||||
|
||||
class ChatSessionsResponse(BaseModel):
|
||||
sessions: list[ChatSessionDetails]
|
||||
has_more: bool = False
|
||||
|
||||
|
||||
class ChatMessageDetail(BaseModel):
|
||||
|
||||
@@ -19,6 +19,7 @@ class ApplicationStatus(str, Enum):
|
||||
PAYMENT_REMINDER = "payment_reminder"
|
||||
GRACE_PERIOD = "grace_period"
|
||||
GATED_ACCESS = "gated_access"
|
||||
SEAT_LIMIT_EXCEEDED = "seat_limit_exceeded"
|
||||
|
||||
|
||||
class Notification(BaseModel):
|
||||
@@ -82,6 +83,10 @@ class Settings(BaseModel):
|
||||
# Default Assistant settings
|
||||
disable_default_assistant: bool | None = False
|
||||
|
||||
# Seat usage - populated by license enforcement when seat limit is exceeded
|
||||
seat_count: int | None = None
|
||||
used_seats: int | None = None
|
||||
|
||||
# OpenSearch migration
|
||||
opensearch_indexing_enabled: bool = False
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_NUM_ATTEMPTS_ON_STARTUP
|
||||
@@ -24,6 +25,7 @@ from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from onyx.db.index_attempt import expire_index_attempts
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
@@ -32,6 +34,9 @@ from onyx.db.search_settings import update_current_search_settings
|
||||
from onyx.db.swap_index import check_and_perform_index_swap
|
||||
from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
|
||||
from onyx.document_index.opensearch.opensearch_document_index import set_cluster_state
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
@@ -250,14 +255,18 @@ def setup_postgres(db_session: Session) -> None:
|
||||
logger.notice("Setting up default OpenAI LLM for dev.")
|
||||
|
||||
llm_model = GEN_AI_MODEL_VERSION or "gpt-4o-mini"
|
||||
provider_name = "DevEnvPresetOpenAI"
|
||||
existing = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
model_req = LLMProviderUpsertRequest(
|
||||
name="DevEnvPresetOpenAI",
|
||||
id=existing.id if existing else None,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=GEN_AI_API_KEY,
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
custom_config=None,
|
||||
default_model_name=llm_model,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
model_configurations=[
|
||||
@@ -269,7 +278,9 @@ def setup_postgres(db_session: Session) -> None:
|
||||
new_llm_provider = upsert_llm_provider(
|
||||
llm_provider_upsert_request=model_req, db_session=db_session
|
||||
)
|
||||
update_default_provider(provider_id=new_llm_provider.id, db_session=db_session)
|
||||
update_default_provider(
|
||||
provider_id=new_llm_provider.id, model_name=llm_model, db_session=db_session
|
||||
)
|
||||
|
||||
|
||||
def update_default_multipass_indexing(db_session: Session) -> None:
|
||||
@@ -311,7 +322,14 @@ def setup_multitenant_onyx() -> None:
|
||||
logger.notice("DISABLE_VECTOR_DB is set — skipping multitenant Vespa setup.")
|
||||
return
|
||||
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
opensearch_client = OpenSearchClient()
|
||||
if not wait_for_opensearch_with_timeout(client=opensearch_client):
|
||||
raise RuntimeError("Failed to connect to OpenSearch.")
|
||||
set_cluster_state(opensearch_client)
|
||||
|
||||
# For Managed Vespa, the schema is sent over via the Vespa Console manually.
|
||||
# NOTE: Pretty sure this code is never hit in any production environment.
|
||||
if not MANAGED_VESPA:
|
||||
setup_vespa_multitenant(SUPPORTED_EMBEDDING_MODELS)
|
||||
|
||||
|
||||
@@ -120,7 +120,7 @@ def generate_intermediate_report(
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=history,
|
||||
reminder_message=reminder_message,
|
||||
project_files=None,
|
||||
context_files=None,
|
||||
available_tokens=llm.config.max_input_tokens,
|
||||
)
|
||||
|
||||
@@ -325,7 +325,7 @@ def run_research_agent_call(
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=msg_history,
|
||||
reminder_message=reminder_message,
|
||||
project_files=None,
|
||||
context_files=None,
|
||||
available_tokens=llm.config.max_input_tokens,
|
||||
)
|
||||
|
||||
|
||||
@@ -54,6 +54,7 @@ logger = setup_logger()
|
||||
class SearchToolConfig(BaseModel):
|
||||
user_selected_filters: BaseFilters | None = None
|
||||
project_id: int | None = None
|
||||
persona_id: int | None = None
|
||||
bypass_acl: bool = False
|
||||
additional_context: str | None = None
|
||||
slack_context: SlackContext | None = None
|
||||
@@ -180,6 +181,7 @@ def construct_tools(
|
||||
document_index=document_index,
|
||||
user_selected_filters=search_tool_config.user_selected_filters,
|
||||
project_id=search_tool_config.project_id,
|
||||
persona_id=search_tool_config.persona_id,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
slack_context=search_tool_config.slack_context,
|
||||
enable_slack_search=search_tool_config.enable_slack_search,
|
||||
@@ -427,6 +429,7 @@ def construct_tools(
|
||||
document_index=document_index,
|
||||
user_selected_filters=search_tool_config.user_selected_filters,
|
||||
project_id=search_tool_config.project_id,
|
||||
persona_id=search_tool_config.persona_id,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
slack_context=search_tool_config.slack_context,
|
||||
enable_slack_search=search_tool_config.enable_slack_search,
|
||||
|
||||
@@ -247,6 +247,8 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
user_selected_filters: BaseFilters | None,
|
||||
# If the chat is part of a project
|
||||
project_id: int | None,
|
||||
# If set, search scopes to files attached to this persona
|
||||
persona_id: int | None = None,
|
||||
bypass_acl: bool = False,
|
||||
# Slack context for federated Slack search (tokens fetched internally)
|
||||
slack_context: SlackContext | None = None,
|
||||
@@ -261,6 +263,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
self.document_index = document_index
|
||||
self.user_selected_filters = user_selected_filters
|
||||
self.project_id = project_id
|
||||
self.persona_id = persona_id
|
||||
self.bypass_acl = bypass_acl
|
||||
self.slack_context = slack_context
|
||||
self.enable_slack_search = enable_slack_search
|
||||
@@ -456,6 +459,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
limit=num_hits,
|
||||
),
|
||||
project_id=self.project_id,
|
||||
persona_id=self.persona_id,
|
||||
document_index=self.document_index,
|
||||
user=self.user,
|
||||
persona=self.persona,
|
||||
|
||||
@@ -8,37 +8,3 @@ dependencies = [
|
||||
|
||||
[tool.uv.sources]
|
||||
onyx = { workspace = true }
|
||||
|
||||
[tool.mypy]
|
||||
plugins = "sqlalchemy.ext.mypy.plugin"
|
||||
mypy_path = "backend"
|
||||
explicit_package_bases = true
|
||||
disallow_untyped_defs = true
|
||||
warn_unused_ignores = true
|
||||
enable_error_code = ["possibly-undefined"]
|
||||
strict_equality = true
|
||||
# Patterns match paths whether mypy is run from backend/ (CI) or repo root (e.g. VS Code extension with target ./backend)
|
||||
exclude = [
|
||||
"(?:^|/)generated/",
|
||||
"(?:^|/)\\.venv/",
|
||||
"(?:^|/)onyx/server/features/build/sandbox/kubernetes/docker/skills/",
|
||||
"(?:^|/)onyx/server/features/build/sandbox/kubernetes/docker/templates/",
|
||||
]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "alembic.versions.*"
|
||||
disable_error_code = ["var-annotated"]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "alembic_tenants.versions.*"
|
||||
disable_error_code = ["var-annotated"]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "generated.*"
|
||||
follow_imports = "silent"
|
||||
ignore_errors = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "transformers.*"
|
||||
follow_imports = "skip"
|
||||
ignore_errors = true
|
||||
|
||||
@@ -257,7 +257,7 @@ exceptiongroup==1.3.0
|
||||
# via
|
||||
# braintrust
|
||||
# fastmcp
|
||||
fastapi==0.128.0
|
||||
fastapi==0.133.1
|
||||
# via
|
||||
# fastapi-limiter
|
||||
# fastapi-users
|
||||
@@ -528,7 +528,7 @@ lxml==5.3.0
|
||||
# unstructured
|
||||
# xmlsec
|
||||
# zeep
|
||||
lxml-html-clean==0.4.3
|
||||
lxml-html-clean==0.4.4
|
||||
# via lxml
|
||||
magika==0.6.3
|
||||
# via markitdown
|
||||
@@ -809,7 +809,7 @@ pypandoc-binary==1.16.2
|
||||
# via onyx
|
||||
pyparsing==3.2.5
|
||||
# via httplib2
|
||||
pypdf==6.6.2
|
||||
pypdf==6.7.5
|
||||
# via
|
||||
# onyx
|
||||
# unstructured-client
|
||||
@@ -1155,6 +1155,7 @@ typing-inspect==0.9.0
|
||||
# via dataclasses-json
|
||||
typing-inspection==0.4.2
|
||||
# via
|
||||
# fastapi
|
||||
# mcp
|
||||
# pydantic
|
||||
# pydantic-settings
|
||||
@@ -1216,7 +1217,7 @@ websockets==15.0.1
|
||||
# via
|
||||
# fastmcp
|
||||
# google-genai
|
||||
werkzeug==3.1.5
|
||||
werkzeug==3.1.6
|
||||
# via sendgrid
|
||||
wrapt==1.17.3
|
||||
# via
|
||||
|
||||
@@ -125,7 +125,7 @@ executing==2.2.1
|
||||
# via stack-data
|
||||
faker==40.1.2
|
||||
# via onyx
|
||||
fastapi==0.128.0
|
||||
fastapi==0.133.1
|
||||
# via
|
||||
# onyx
|
||||
# onyx-devtools
|
||||
@@ -317,7 +317,7 @@ oauthlib==3.2.2
|
||||
# via
|
||||
# kubernetes
|
||||
# requests-oauthlib
|
||||
onyx-devtools==0.6.1
|
||||
onyx-devtools==0.6.2
|
||||
# via onyx
|
||||
openai==2.14.0
|
||||
# via
|
||||
@@ -619,6 +619,7 @@ typing-extensions==4.15.0
|
||||
# typing-inspection
|
||||
typing-inspection==0.4.2
|
||||
# via
|
||||
# fastapi
|
||||
# mcp
|
||||
# pydantic
|
||||
# pydantic-settings
|
||||
|
||||
@@ -90,7 +90,7 @@ docstring-parser==0.17.0
|
||||
# via google-cloud-aiplatform
|
||||
durationpy==0.10
|
||||
# via kubernetes
|
||||
fastapi==0.128.0
|
||||
fastapi==0.133.1
|
||||
# via onyx
|
||||
fastavro==1.12.1
|
||||
# via cohere
|
||||
@@ -398,6 +398,7 @@ typing-extensions==4.15.0
|
||||
# typing-inspection
|
||||
typing-inspection==0.4.2
|
||||
# via
|
||||
# fastapi
|
||||
# mcp
|
||||
# pydantic
|
||||
# pydantic-settings
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user