mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-02 06:05:46 +00:00
Compare commits
43 Commits
litellm_pr
...
experiment
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bf3c98142d | ||
|
|
f023599618 | ||
|
|
c55cb899f7 | ||
|
|
9b8a6e60b7 | ||
|
|
dd9d201b51 | ||
|
|
c545819aa6 | ||
|
|
960ee228bf | ||
|
|
dea5be2185 | ||
|
|
d083973d4f | ||
|
|
df956888bf | ||
|
|
7c6062e7d5 | ||
|
|
89d2759021 | ||
|
|
d9feaf43a7 | ||
|
|
5bfffefa2f | ||
|
|
4d0b7e14d4 | ||
|
|
36c55d9e59 | ||
|
|
9f652108f9 | ||
|
|
d4e4c6b40e | ||
|
|
9c8deb5d0c | ||
|
|
58f57c43aa | ||
|
|
62106df753 | ||
|
|
45b3a5e945 | ||
|
|
e19a6b6789 | ||
|
|
2de7df4839 | ||
|
|
bd054bbad9 | ||
|
|
313e709d41 | ||
|
|
aeb1d6edac | ||
|
|
49a35f8aaa | ||
|
|
049e8ef0e2 | ||
|
|
3b61b495a3 | ||
|
|
5c5c9f0e1d | ||
|
|
f20d5c33b7 | ||
|
|
e898407f7b | ||
|
|
f802ff09a7 | ||
|
|
69ad712e09 | ||
|
|
98b69c0f2c | ||
|
|
1e5c87896f | ||
|
|
b6cc97a8c3 | ||
|
|
032fbf1058 | ||
|
|
fc32a9f92a | ||
|
|
9be13bbf63 | ||
|
|
9e7176eb82 | ||
|
|
c7faf8ce52 |
2
.github/workflows/deployment.yml
vendored
2
.github/workflows/deployment.yml
vendored
@@ -640,6 +640,7 @@ jobs:
|
||||
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }}
|
||||
NEXT_PUBLIC_RECAPTCHA_SITE_KEY=${{ vars.NEXT_PUBLIC_RECAPTCHA_SITE_KEY }}
|
||||
NEXT_PUBLIC_GTM_ENABLED=true
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
@@ -721,6 +722,7 @@ jobs:
|
||||
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }}
|
||||
NEXT_PUBLIC_RECAPTCHA_SITE_KEY=${{ vars.NEXT_PUBLIC_RECAPTCHA_SITE_KEY }}
|
||||
NEXT_PUBLIC_GTM_ENABLED=true
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
|
||||
38
.github/workflows/pr-integration-tests.yml
vendored
38
.github/workflows/pr-integration-tests.yml
vendored
@@ -46,6 +46,7 @@ jobs:
|
||||
timeout-minutes: 45
|
||||
outputs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
editions: ${{ steps.set-editions.outputs.editions }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
@@ -72,6 +73,16 @@ jobs:
|
||||
all_dirs="[${all_dirs%,}]"
|
||||
echo "test-dirs=$all_dirs" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Determine editions to test
|
||||
id: set-editions
|
||||
run: |
|
||||
# On PRs, only run EE tests. On merge_group and tags, run both EE and MIT.
|
||||
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||
echo 'editions=["ee"]' >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo 'editions=["ee","mit"]' >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
build-backend-image:
|
||||
runs-on:
|
||||
[
|
||||
@@ -267,7 +278,7 @@ jobs:
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- ${{ format('run-id={0}-integration-tests-job-{1}', github.run_id, strategy['job-index']) }}
|
||||
- ${{ format('run-id={0}-integration-tests-{1}-job-{2}', github.run_id, matrix.edition, strategy['job-index']) }}
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 45
|
||||
|
||||
@@ -275,6 +286,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
test-dir: ${{ fromJson(needs.discover-test-dirs.outputs.test-dirs) }}
|
||||
edition: ${{ fromJson(needs.discover-test-dirs.outputs.editions) }}
|
||||
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
@@ -298,12 +310,11 @@ jobs:
|
||||
env:
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
EDITION: ${{ matrix.edition }}
|
||||
run: |
|
||||
# Base config shared by both editions
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
@@ -312,11 +323,20 @@ jobs:
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
|
||||
MCP_SERVER_ENABLED=true
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
|
||||
EOF
|
||||
|
||||
# EE-only config
|
||||
if [ "$EDITION" = "ee" ]; then
|
||||
cat <<EOF >> deployment/docker_compose/.env
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
|
||||
USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
|
||||
EOF
|
||||
fi
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
@@ -379,14 +399,14 @@ jobs:
|
||||
docker compose -f docker-compose.mock-it-services.yml \
|
||||
-p mock-it-services-stack up -d
|
||||
|
||||
- name: Run Integration Tests for ${{ matrix.test-dir.name }}
|
||||
- name: Run Integration Tests (${{ matrix.edition }}) for ${{ matrix.test-dir.name }}
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
timeout_minutes: 20
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
echo "Running integration tests for ${{ matrix.test-dir.path }}..."
|
||||
echo "Running ${{ matrix.edition }} integration tests for ${{ matrix.test-dir.path }}..."
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
@@ -444,7 +464,7 @@ jobs:
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-${{ matrix.test-dir.name }}
|
||||
name: docker-all-logs-${{ matrix.edition }}-${{ matrix.test-dir.name }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
|
||||
|
||||
443
.github/workflows/pr-mit-integration-tests.yml
vendored
443
.github/workflows/pr-mit-integration-tests.yml
vendored
@@ -1,443 +0,0 @@
|
||||
name: Run MIT Integration Tests v2
|
||||
concurrency:
|
||||
group: Run-MIT-Integration-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
types: [checks_requested]
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
# Test Environment Variables
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
|
||||
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
|
||||
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
|
||||
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
|
||||
|
||||
jobs:
|
||||
discover-test-dirs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
outputs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Discover test directories
|
||||
id: set-matrix
|
||||
run: |
|
||||
# Find all leaf-level directories in both test directories
|
||||
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" ! -name "mcp" -exec basename {} \; | sort)
|
||||
connector_dirs=$(find backend/tests/integration/connector_job_tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
|
||||
|
||||
# Create JSON array with directory info
|
||||
all_dirs=""
|
||||
for dir in $tests_dirs; do
|
||||
all_dirs="$all_dirs{\"path\":\"tests/$dir\",\"name\":\"tests-$dir\"},"
|
||||
done
|
||||
for dir in $connector_dirs; do
|
||||
all_dirs="$all_dirs{\"path\":\"connector_job_tests/$dir\",\"name\":\"connector-$dir\"},"
|
||||
done
|
||||
|
||||
# Remove trailing comma and wrap in array
|
||||
all_dirs="[${all_dirs%,}]"
|
||||
echo "test-dirs=$all_dirs" >> $GITHUB_OUTPUT
|
||||
|
||||
build-backend-image:
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-backend-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push Backend Docker image
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
push: true
|
||||
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ github.event.pull_request.head.sha || github.sha }}
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }}
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache
|
||||
type=registry,ref=onyxdotapp/onyx-backend:latest
|
||||
cache-to: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache,mode=max
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-model-server-image:
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-model-server-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push Model Server Docker image
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
push: true
|
||||
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }}
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }}
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache
|
||||
type=registry,ref=onyxdotapp/onyx-model-server:latest
|
||||
cache-to: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
|
||||
|
||||
build-integration-image:
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=2cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-integration-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling openapitools/openapi-generator-cli
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push integration test image with Docker Bake
|
||||
env:
|
||||
INTEGRATION_REPOSITORY: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
TAG: integration-test-${{ github.run_id }}
|
||||
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
|
||||
HEAD_SHA: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
run: |
|
||||
docker buildx bake --push \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA} \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX} \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
|
||||
--set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA},mode=max \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX},mode=max \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA} \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX} \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA},mode=max \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX},mode=max \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \
|
||||
integration
|
||||
|
||||
integration-tests-mit:
|
||||
needs:
|
||||
[
|
||||
discover-test-dirs,
|
||||
build-backend-image,
|
||||
build-model-server-image,
|
||||
build-integration-image,
|
||||
]
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- ${{ format('run-id={0}-integration-tests-mit-job-{1}', github.run_id, strategy['job-index']) }}
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 45
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
test-dir: ${{ fromJson(needs.discover-test-dirs.outputs.test-dirs) }}
|
||||
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
# NOTE: don't need web server for integration tests
|
||||
- name: Create .env file for Docker Compose
|
||||
env:
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
MCP_SERVER_ENABLED=true
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
|
||||
EOF
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
|
||||
relational_db \
|
||||
index \
|
||||
cache \
|
||||
minio \
|
||||
api_server \
|
||||
inference_model_server \
|
||||
indexing_model_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for services to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
wait_for_service() {
|
||||
local url=$1
|
||||
local label=$2
|
||||
local timeout=${3:-300} # default 5 minutes
|
||||
local start_time
|
||||
start_time=$(date +%s)
|
||||
|
||||
while true; do
|
||||
local current_time
|
||||
current_time=$(date +%s)
|
||||
local elapsed_time=$((current_time - start_time))
|
||||
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. ${label} did not become ready in $timeout seconds."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
local response
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" "$url" || echo "curl_error")
|
||||
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "${label} is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error while checking ${label}. Retrying in 5 seconds..."
|
||||
else
|
||||
echo "${label} not ready yet (HTTP status $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
|
||||
sleep 5
|
||||
done
|
||||
}
|
||||
|
||||
wait_for_service "http://localhost:8080/health" "API server"
|
||||
echo "Finished waiting for services."
|
||||
|
||||
- name: Start Mock Services
|
||||
run: |
|
||||
cd backend/tests/integration/mock_services
|
||||
docker compose -f docker-compose.mock-it-services.yml \
|
||||
-p mock-it-services-stack up -d
|
||||
|
||||
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
|
||||
- name: Run Integration Tests for ${{ matrix.test-dir.name }}
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
timeout_minutes: 20
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
echo "Running integration tests for ${{ matrix.test-dir.path }}..."
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e EXA_API_KEY=${EXA_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN_SCOPED=${CONFLUENCE_ACCESS_TOKEN_SCOPED} \
|
||||
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
|
||||
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
|
||||
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
|
||||
-e JIRA_API_TOKEN_SCOPED=${JIRA_API_TOKEN_SCOPED} \
|
||||
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
|
||||
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
|
||||
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
|
||||
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
|
||||
/app/tests/integration/${{ matrix.test-dir.path }}
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Always gather logs BEFORE "down":
|
||||
- name: Dump API server logs
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
|
||||
- name: Dump all-container logs (optional)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-${{ matrix.test-dir.name }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
|
||||
required:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
needs: [integration-tests-mit]
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- name: Check job status
|
||||
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
|
||||
run: exit 1
|
||||
4
.vscode/launch.json
vendored
4
.vscode/launch.json
vendored
@@ -275,7 +275,7 @@
|
||||
"--loglevel=INFO",
|
||||
"--hostname=background@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,user_files_indexing,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration"
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
@@ -419,7 +419,7 @@
|
||||
"--loglevel=INFO",
|
||||
"--hostname=docfetching@%n",
|
||||
"-Q",
|
||||
"connector_doc_fetching,user_files_indexing"
|
||||
"connector_doc_fetching"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
|
||||
5
LICENSE
5
LICENSE
@@ -2,7 +2,10 @@ Copyright (c) 2023-present DanswerAI, Inc.
|
||||
|
||||
Portions of this software are licensed as follows:
|
||||
|
||||
- All content that resides under "ee" directories of this repository, if that directory exists, is licensed under the license defined in "backend/ee/LICENSE". Specifically all content under "backend/ee" and "web/src/app/ee" is licensed under the license defined in "backend/ee/LICENSE".
|
||||
- All content that resides under "ee" directories of this repository is licensed under the Onyx Enterprise License. Each ee directory contains an identical copy of this license at its root:
|
||||
- backend/ee/LICENSE
|
||||
- web/src/app/ee/LICENSE
|
||||
- web/src/ee/LICENSE
|
||||
- All third party components incorporated into the Onyx Software are licensed under the original license provided by the owner of the applicable component.
|
||||
- Content outside of the above mentioned directories or restrictions above is available under the "MIT Expat" license as defined below.
|
||||
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
"""Migrate to contextual rag model
|
||||
|
||||
Revision ID: 19c0ccb01687
|
||||
Revises: 9c54986124c6
|
||||
Create Date: 2026-02-12 11:21:41.798037
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "19c0ccb01687"
|
||||
down_revision = "9c54986124c6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Widen the column to fit 'CONTEXTUAL_RAG' (15 chars); was varchar(10)
|
||||
# when the table was created with only CHAT/VISION values.
|
||||
op.alter_column(
|
||||
"llm_model_flow",
|
||||
"llm_model_flow_type",
|
||||
type_=sa.String(length=20),
|
||||
existing_type=sa.String(length=10),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
# For every search_settings row that has contextual rag configured,
|
||||
# create an llm_model_flow entry. is_default is TRUE if the row
|
||||
# belongs to the PRESENT search settings, FALSE otherwise.
|
||||
op.execute(
|
||||
"""
|
||||
INSERT INTO llm_model_flow (llm_model_flow_type, model_configuration_id, is_default)
|
||||
SELECT DISTINCT
|
||||
'CONTEXTUAL_RAG',
|
||||
mc.id,
|
||||
(ss.status = 'PRESENT')
|
||||
FROM search_settings ss
|
||||
JOIN llm_provider lp
|
||||
ON lp.name = ss.contextual_rag_llm_provider
|
||||
JOIN model_configuration mc
|
||||
ON mc.llm_provider_id = lp.id
|
||||
AND mc.name = ss.contextual_rag_llm_name
|
||||
WHERE ss.enable_contextual_rag = TRUE
|
||||
AND ss.contextual_rag_llm_name IS NOT NULL
|
||||
AND ss.contextual_rag_llm_provider IS NOT NULL
|
||||
ON CONFLICT (llm_model_flow_type, model_configuration_id)
|
||||
DO UPDATE SET is_default = EXCLUDED.is_default
|
||||
WHERE EXCLUDED.is_default = TRUE
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM llm_model_flow
|
||||
WHERE llm_model_flow_type = 'CONTEXTUAL_RAG'
|
||||
"""
|
||||
)
|
||||
|
||||
op.alter_column(
|
||||
"llm_model_flow",
|
||||
"llm_model_flow_type",
|
||||
type_=sa.String(length=10),
|
||||
existing_type=sa.String(length=20),
|
||||
existing_nullable=False,
|
||||
)
|
||||
124
backend/alembic/versions/9c54986124c6_add_scim_tables.py
Normal file
124
backend/alembic/versions/9c54986124c6_add_scim_tables.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""add_scim_tables
|
||||
|
||||
Revision ID: 9c54986124c6
|
||||
Revises: b51c6844d1df
|
||||
Create Date: 2026-02-12 20:29:47.448614
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import fastapi_users_db_sqlalchemy
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9c54986124c6"
|
||||
down_revision = "b51c6844d1df"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"scim_token",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("hashed_token", sa.String(length=64), nullable=False),
|
||||
sa.Column("token_display", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"created_by_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"is_active",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("true"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(["created_by_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("hashed_token"),
|
||||
)
|
||||
op.create_table(
|
||||
"scim_group_mapping",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("external_id", sa.String(), nullable=False),
|
||||
sa.Column("user_group_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_group_id"], ["user_group.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("user_group_id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_scim_group_mapping_external_id"),
|
||||
"scim_group_mapping",
|
||||
["external_id"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_table(
|
||||
"scim_user_mapping",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("external_id", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("user_id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_scim_user_mapping_external_id"),
|
||||
"scim_user_mapping",
|
||||
["external_id"],
|
||||
unique=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(
|
||||
op.f("ix_scim_user_mapping_external_id"),
|
||||
table_name="scim_user_mapping",
|
||||
)
|
||||
op.drop_table("scim_user_mapping")
|
||||
op.drop_index(
|
||||
op.f("ix_scim_group_mapping_external_id"),
|
||||
table_name="scim_group_mapping",
|
||||
)
|
||||
op.drop_table("scim_group_mapping")
|
||||
op.drop_table("scim_token")
|
||||
@@ -1,20 +1,20 @@
|
||||
The DanswerAI Enterprise license (the “Enterprise License”)
|
||||
The Onyx Enterprise License (the "Enterprise License")
|
||||
Copyright (c) 2023-present DanswerAI, Inc.
|
||||
|
||||
With regard to the Onyx Software:
|
||||
|
||||
This software and associated documentation files (the "Software") may only be
|
||||
used in production, if you (and any entity that you represent) have agreed to,
|
||||
and are in compliance with, the DanswerAI Subscription Terms of Service, available
|
||||
at https://onyx.app/terms (the “Enterprise Terms”), or other
|
||||
and are in compliance with, the Onyx Subscription Terms of Service, available
|
||||
at https://www.onyx.app/legal/self-host (the "Enterprise Terms"), or other
|
||||
agreement governing the use of the Software, as agreed by you and DanswerAI,
|
||||
and otherwise have a valid Onyx Enterprise license for the
|
||||
and otherwise have a valid Onyx Enterprise License for the
|
||||
correct number of user seats. Subject to the foregoing sentence, you are free to
|
||||
modify this Software and publish patches to the Software. You agree that DanswerAI
|
||||
and/or its licensors (as applicable) retain all right, title and interest in and
|
||||
to all such modifications and/or patches, and all such modifications and/or
|
||||
patches may only be used, copied, modified, displayed, distributed, or otherwise
|
||||
exploited with a valid Onyx Enterprise license for the correct
|
||||
exploited with a valid Onyx Enterprise License for the correct
|
||||
number of user seats. Notwithstanding the foregoing, you may copy and modify
|
||||
the Software for development and testing purposes, without requiring a
|
||||
subscription. You agree that DanswerAI and/or its licensors (as applicable) retain
|
||||
|
||||
@@ -536,7 +536,9 @@ def connector_permission_sync_generator_task(
|
||||
)
|
||||
redis_connector.permissions.set_fence(new_payload)
|
||||
|
||||
callback = PermissionSyncCallback(redis_connector, lock, r)
|
||||
callback = PermissionSyncCallback(
|
||||
redis_connector, lock, r, timeout_seconds=JOB_TIMEOUT
|
||||
)
|
||||
|
||||
# pass in the capability to fetch all existing docs for the cc_pair
|
||||
# this is can be used to determine documents that are "missing" and thus
|
||||
@@ -576,6 +578,13 @@ def connector_permission_sync_generator_task(
|
||||
tasks_generated = 0
|
||||
docs_with_errors = 0
|
||||
for doc_external_access in document_external_accesses:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
f"Permission sync task timed out or stop signal detected: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
result = redis_connector.permissions.update_db(
|
||||
lock=lock,
|
||||
new_permissions=[doc_external_access],
|
||||
@@ -932,6 +941,7 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
|
||||
redis_connector: RedisConnector,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
timeout_seconds: int | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.redis_connector: RedisConnector = redis_connector
|
||||
@@ -944,11 +954,26 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
|
||||
self.last_tag: str = "PermissionSyncCallback.__init__"
|
||||
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
|
||||
self.last_lock_monotonic = time.monotonic()
|
||||
self.start_monotonic = time.monotonic()
|
||||
self.timeout_seconds = timeout_seconds
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
if self.redis_connector.stop.fenced:
|
||||
return True
|
||||
|
||||
# Check if the task has exceeded its timeout
|
||||
# NOTE: Celery's soft_time_limit does not work with thread pools,
|
||||
# so we must enforce timeouts internally.
|
||||
if self.timeout_seconds is not None:
|
||||
elapsed = time.monotonic() - self.start_monotonic
|
||||
if elapsed > self.timeout_seconds:
|
||||
logger.warning(
|
||||
f"PermissionSyncCallback - task timeout exceeded: "
|
||||
f"elapsed={elapsed:.0f}s timeout={self.timeout_seconds}s "
|
||||
f"cc_pair={self.redis_connector.cc_pair_id}"
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None: # noqa: ARG002
|
||||
|
||||
@@ -466,6 +466,7 @@ def connector_external_group_sync_generator_task(
|
||||
def _perform_external_group_sync(
|
||||
cc_pair_id: int,
|
||||
tenant_id: str,
|
||||
timeout_seconds: int = JOB_TIMEOUT,
|
||||
) -> None:
|
||||
# Create attempt record at the start
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -518,9 +519,23 @@ def _perform_external_group_sync(
|
||||
seen_users: set[str] = set() # Track unique users across all groups
|
||||
total_groups_processed = 0
|
||||
total_group_memberships_synced = 0
|
||||
start_time = time.monotonic()
|
||||
try:
|
||||
external_user_group_generator = ext_group_sync_func(tenant_id, cc_pair)
|
||||
for external_user_group in external_user_group_generator:
|
||||
# Check if the task has exceeded its timeout
|
||||
# NOTE: Celery's soft_time_limit does not work with thread pools,
|
||||
# so we must enforce timeouts internally.
|
||||
elapsed = time.monotonic() - start_time
|
||||
if elapsed > timeout_seconds:
|
||||
raise RuntimeError(
|
||||
f"External group sync task timed out: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"elapsed={elapsed:.0f}s "
|
||||
f"timeout={timeout_seconds}s "
|
||||
f"groups_processed={total_groups_processed}"
|
||||
)
|
||||
|
||||
external_user_group_batch.append(external_user_group)
|
||||
|
||||
# Track progress
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.sharepoint.permission_utils import (
|
||||
get_sharepoint_external_groups,
|
||||
)
|
||||
from onyx.connectors.sharepoint.connector import acquire_token_for_rest
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -46,16 +43,11 @@ def sharepoint_group_sync(
|
||||
|
||||
logger.info(f"Processing {len(site_descriptors)} sites for group sync")
|
||||
|
||||
msal_app = connector.msal_app
|
||||
sp_tenant_domain = connector.sp_tenant_domain
|
||||
# Process each site
|
||||
for site_descriptor in site_descriptors:
|
||||
logger.debug(f"Processing site: {site_descriptor.url}")
|
||||
|
||||
# Create client context for the site using connector's MSAL app
|
||||
ctx = ClientContext(site_descriptor.url).with_access_token(
|
||||
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
|
||||
)
|
||||
ctx = connector._create_rest_client_context(site_descriptor.url)
|
||||
|
||||
# Get external groups for this site
|
||||
external_groups = get_sharepoint_external_groups(ctx, connector.graph_client)
|
||||
|
||||
@@ -27,6 +27,8 @@ class SearchFlowClassificationResponse(BaseModel):
|
||||
is_search_flow: bool
|
||||
|
||||
|
||||
# NOTE: This model is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
|
||||
# experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
|
||||
class SendSearchQueryRequest(BaseModel):
|
||||
search_query: str
|
||||
filters: BaseFilters | None = None
|
||||
|
||||
@@ -67,6 +67,8 @@ def search_flow_classification(
|
||||
return SearchFlowClassificationResponse(is_search_flow=is_search_flow)
|
||||
|
||||
|
||||
# NOTE: This endpoint is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
|
||||
# experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
|
||||
@router.post(
|
||||
"/send-search-message",
|
||||
response_model=None,
|
||||
|
||||
0
backend/ee/onyx/server/scim/__init__.py
Normal file
0
backend/ee/onyx/server/scim/__init__.py
Normal file
96
backend/ee/onyx/server/scim/filtering.py
Normal file
96
backend/ee/onyx/server/scim/filtering.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""SCIM filter expression parser (RFC 7644 §3.4.2.2).
|
||||
|
||||
Identity providers (Okta, Azure AD, OneLogin, etc.) use filters to look up
|
||||
resources before deciding whether to create or update them. For example, when
|
||||
an admin assigns a user to the Onyx app, the IdP first checks whether that
|
||||
user already exists::
|
||||
|
||||
GET /scim/v2/Users?filter=userName eq "john@example.com"
|
||||
|
||||
If zero results come back the IdP creates the user (``POST``); if a match is
|
||||
found it links to the existing record and uses ``PUT``/``PATCH`` going forward.
|
||||
The same pattern applies to groups (``displayName eq "Engineering"``).
|
||||
|
||||
This module parses the subset of the SCIM filter grammar that identity
|
||||
providers actually send in practice:
|
||||
|
||||
attribute SP operator SP value
|
||||
|
||||
Supported operators: ``eq``, ``co`` (contains), ``sw`` (starts with).
|
||||
Compound filters (``and`` / ``or``) are not supported; if an IdP sends one
|
||||
the parser returns ``None`` and the caller falls back to an unfiltered list.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ScimFilterOperator(str, Enum):
|
||||
"""Supported SCIM filter operators."""
|
||||
|
||||
EQUAL = "eq"
|
||||
CONTAINS = "co"
|
||||
STARTS_WITH = "sw"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ScimFilter:
|
||||
"""Parsed SCIM filter expression."""
|
||||
|
||||
attribute: str
|
||||
operator: ScimFilterOperator
|
||||
value: str
|
||||
|
||||
|
||||
# Matches: attribute operator "value" (with or without quotes around value)
|
||||
# Groups: (attribute) (operator) ("quoted value" | unquoted_value)
|
||||
_FILTER_RE = re.compile(
|
||||
r"^(\S+)\s+(eq|co|sw)\s+" # attribute + operator
|
||||
r'(?:"([^"]*)"' # quoted value
|
||||
r"|'([^']*)')" # or single-quoted value
|
||||
r"$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def parse_scim_filter(filter_string: str | None) -> ScimFilter | None:
|
||||
"""Parse a simple SCIM filter expression.
|
||||
|
||||
Args:
|
||||
filter_string: Raw filter query parameter value, e.g.
|
||||
``'userName eq "john@example.com"'``
|
||||
|
||||
Returns:
|
||||
A ``ScimFilter`` if the expression is valid and uses a supported
|
||||
operator, or ``None`` if the input is empty / missing.
|
||||
|
||||
Raises:
|
||||
ValueError: If the filter string is present but malformed or uses
|
||||
an unsupported operator.
|
||||
"""
|
||||
if not filter_string or not filter_string.strip():
|
||||
return None
|
||||
|
||||
match = _FILTER_RE.match(filter_string.strip())
|
||||
if not match:
|
||||
raise ValueError(f"Unsupported or malformed SCIM filter: {filter_string}")
|
||||
|
||||
return _build_filter(match, filter_string)
|
||||
|
||||
|
||||
def _build_filter(match: re.Match[str], raw: str) -> ScimFilter:
|
||||
"""Extract fields from a regex match and construct a ScimFilter."""
|
||||
attribute = match.group(1)
|
||||
op_str = match.group(2).lower()
|
||||
# Value is in group 3 (double-quoted) or group 4 (single-quoted)
|
||||
value = match.group(3) if match.group(3) is not None else match.group(4)
|
||||
|
||||
if value is None:
|
||||
raise ValueError(f"Unsupported or malformed SCIM filter: {raw}")
|
||||
|
||||
operator = ScimFilterOperator(op_str)
|
||||
|
||||
return ScimFilter(attribute=attribute, operator=operator, value=value)
|
||||
255
backend/ee/onyx/server/scim/models.py
Normal file
255
backend/ee/onyx/server/scim/models.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""Pydantic schemas for SCIM 2.0 provisioning (RFC 7643 / RFC 7644).
|
||||
|
||||
SCIM protocol schemas follow the wire format defined in:
|
||||
- Core Schema: https://datatracker.ietf.org/doc/html/rfc7643
|
||||
- Protocol: https://datatracker.ietf.org/doc/html/rfc7644
|
||||
|
||||
Admin API schemas are internal to Onyx and used for SCIM token management.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SCIM Schema URIs (RFC 7643 §8)
|
||||
# Every SCIM JSON payload includes a "schemas" array identifying its type.
|
||||
# IdPs like Okta/Azure AD use these URIs to determine how to parse responses.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
|
||||
SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group"
|
||||
SCIM_LIST_RESPONSE_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:ListResponse"
|
||||
SCIM_PATCH_OP_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:PatchOp"
|
||||
SCIM_ERROR_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:Error"
|
||||
SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA = (
|
||||
"urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"
|
||||
)
|
||||
SCIM_RESOURCE_TYPE_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:ResourceType"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SCIM Protocol Schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ScimName(BaseModel):
|
||||
"""User name components (RFC 7643 §4.1.1)."""
|
||||
|
||||
givenName: str | None = None
|
||||
familyName: str | None = None
|
||||
formatted: str | None = None
|
||||
|
||||
|
||||
class ScimEmail(BaseModel):
|
||||
"""Email sub-attribute (RFC 7643 §4.1.2)."""
|
||||
|
||||
value: str
|
||||
type: str | None = None
|
||||
primary: bool = False
|
||||
|
||||
|
||||
class ScimMeta(BaseModel):
|
||||
"""Resource metadata (RFC 7643 §3.1)."""
|
||||
|
||||
resourceType: str | None = None
|
||||
created: datetime | None = None
|
||||
lastModified: datetime | None = None
|
||||
location: str | None = None
|
||||
|
||||
|
||||
class ScimUserResource(BaseModel):
|
||||
"""SCIM User resource representation (RFC 7643 §4.1).
|
||||
|
||||
This is the JSON shape that IdPs send when creating/updating a user via
|
||||
SCIM, and the shape we return in GET responses. Field names use camelCase
|
||||
to match the SCIM wire format (not Python convention).
|
||||
"""
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_USER_SCHEMA])
|
||||
id: str | None = None # Onyx's internal user ID, set on responses
|
||||
externalId: str | None = None # IdP's identifier for this user
|
||||
userName: str # Typically the user's email address
|
||||
name: ScimName | None = None
|
||||
emails: list[ScimEmail] = Field(default_factory=list)
|
||||
active: bool = True
|
||||
meta: ScimMeta | None = None
|
||||
|
||||
|
||||
class ScimGroupMember(BaseModel):
|
||||
"""Group member reference (RFC 7643 §4.2).
|
||||
|
||||
Represents a user within a SCIM group. The IdP sends these when adding
|
||||
or removing users from groups. ``value`` is the Onyx user ID.
|
||||
"""
|
||||
|
||||
value: str # User ID of the group member
|
||||
display: str | None = None
|
||||
|
||||
|
||||
class ScimGroupResource(BaseModel):
|
||||
"""SCIM Group resource representation (RFC 7643 §4.2)."""
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_GROUP_SCHEMA])
|
||||
id: str | None = None
|
||||
externalId: str | None = None
|
||||
displayName: str
|
||||
members: list[ScimGroupMember] = Field(default_factory=list)
|
||||
meta: ScimMeta | None = None
|
||||
|
||||
|
||||
class ScimListResponse(BaseModel):
|
||||
"""Paginated list response (RFC 7644 §3.4.2)."""
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_LIST_RESPONSE_SCHEMA])
|
||||
totalResults: int
|
||||
startIndex: int = 1
|
||||
itemsPerPage: int = 100
|
||||
Resources: list[ScimUserResource | ScimGroupResource] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ScimPatchOperationType(str, Enum):
|
||||
"""Supported PATCH operations (RFC 7644 §3.5.2)."""
|
||||
|
||||
ADD = "add"
|
||||
REPLACE = "replace"
|
||||
REMOVE = "remove"
|
||||
|
||||
|
||||
class ScimPatchOperation(BaseModel):
|
||||
"""Single PATCH operation (RFC 7644 §3.5.2)."""
|
||||
|
||||
op: ScimPatchOperationType
|
||||
path: str | None = None
|
||||
value: str | list[dict[str, str]] | dict[str, str | bool] | bool | None = None
|
||||
|
||||
|
||||
class ScimPatchRequest(BaseModel):
|
||||
"""PATCH request body (RFC 7644 §3.5.2).
|
||||
|
||||
IdPs use PATCH to make incremental changes — e.g. deactivating a user
|
||||
(replace active=false) or adding/removing group members — instead of
|
||||
replacing the entire resource with PUT.
|
||||
"""
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_PATCH_OP_SCHEMA])
|
||||
Operations: list[ScimPatchOperation]
|
||||
|
||||
|
||||
class ScimError(BaseModel):
|
||||
"""SCIM error response (RFC 7644 §3.12)."""
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_ERROR_SCHEMA])
|
||||
status: str
|
||||
detail: str | None = None
|
||||
scimType: str | None = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Service Provider Configuration (RFC 7643 §5)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ScimSupported(BaseModel):
|
||||
"""Generic supported/not-supported flag used in ServiceProviderConfig."""
|
||||
|
||||
supported: bool
|
||||
|
||||
|
||||
class ScimFilterConfig(BaseModel):
|
||||
"""Filter configuration within ServiceProviderConfig (RFC 7643 §5)."""
|
||||
|
||||
supported: bool
|
||||
maxResults: int = 100
|
||||
|
||||
|
||||
class ScimServiceProviderConfig(BaseModel):
|
||||
"""SCIM ServiceProviderConfig resource (RFC 7643 §5).
|
||||
|
||||
Served at GET /scim/v2/ServiceProviderConfig. IdPs fetch this during
|
||||
initial setup to discover which SCIM features our server supports
|
||||
(e.g. PATCH yes, bulk no, filtering yes).
|
||||
"""
|
||||
|
||||
schemas: list[str] = Field(
|
||||
default_factory=lambda: [SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA]
|
||||
)
|
||||
patch: ScimSupported = ScimSupported(supported=True)
|
||||
bulk: ScimSupported = ScimSupported(supported=False)
|
||||
filter: ScimFilterConfig = ScimFilterConfig(supported=True)
|
||||
changePassword: ScimSupported = ScimSupported(supported=False)
|
||||
sort: ScimSupported = ScimSupported(supported=False)
|
||||
etag: ScimSupported = ScimSupported(supported=False)
|
||||
authenticationSchemes: list[dict[str, str]] = Field(
|
||||
default_factory=lambda: [
|
||||
{
|
||||
"type": "oauthbearertoken",
|
||||
"name": "OAuth Bearer Token",
|
||||
"description": "Authentication scheme using a SCIM bearer token",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class ScimSchemaExtension(BaseModel):
|
||||
"""Schema extension reference within ResourceType (RFC 7643 §6)."""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
schema_: str = Field(alias="schema")
|
||||
required: bool
|
||||
|
||||
|
||||
class ScimResourceType(BaseModel):
|
||||
"""SCIM ResourceType resource (RFC 7643 §6).
|
||||
|
||||
Served at GET /scim/v2/ResourceTypes. Tells the IdP which resource
|
||||
types are available (Users, Groups) and their respective endpoints.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_RESOURCE_TYPE_SCHEMA])
|
||||
id: str
|
||||
name: str
|
||||
endpoint: str
|
||||
description: str | None = None
|
||||
schema_: str = Field(alias="schema")
|
||||
schemaExtensions: list[ScimSchemaExtension] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Admin API Schemas (Onyx-internal, for SCIM token management)
|
||||
# These are NOT part of the SCIM protocol. They power the Onyx admin UI
|
||||
# where admins create/revoke the bearer tokens that IdPs use to authenticate.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ScimTokenCreate(BaseModel):
|
||||
"""Request to create a new SCIM bearer token."""
|
||||
|
||||
name: str
|
||||
|
||||
|
||||
class ScimTokenResponse(BaseModel):
|
||||
"""SCIM token metadata returned in list/get responses."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
token_display: str
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
last_used_at: datetime | None = None
|
||||
|
||||
|
||||
class ScimTokenCreatedResponse(ScimTokenResponse):
|
||||
"""Response returned when a new SCIM token is created.
|
||||
|
||||
Includes the raw token value which is only available at creation time.
|
||||
"""
|
||||
|
||||
raw_token: str
|
||||
256
backend/ee/onyx/server/scim/patch.py
Normal file
256
backend/ee/onyx/server/scim/patch.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""SCIM PATCH operation handler (RFC 7644 §3.5.2).
|
||||
|
||||
Identity providers use PATCH to make incremental changes to SCIM resources
|
||||
instead of replacing the entire resource with PUT. Common operations include:
|
||||
|
||||
- Deactivating a user: ``replace`` ``active`` with ``false``
|
||||
- Adding group members: ``add`` to ``members``
|
||||
- Removing group members: ``remove`` from ``members[value eq "..."]``
|
||||
|
||||
This module applies PATCH operations to Pydantic SCIM resource objects and
|
||||
returns the modified result. It does NOT touch the database — the caller is
|
||||
responsible for persisting changes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimPatchOperation
|
||||
from ee.onyx.server.scim.models import ScimPatchOperationType
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
|
||||
|
||||
class ScimPatchError(Exception):
|
||||
"""Raised when a PATCH operation cannot be applied."""
|
||||
|
||||
def __init__(self, detail: str, status: int = 400) -> None:
|
||||
self.detail = detail
|
||||
self.status = status
|
||||
super().__init__(detail)
|
||||
|
||||
|
||||
# Pattern for member removal path: members[value eq "user-id"]
|
||||
_MEMBER_FILTER_RE = re.compile(
|
||||
r'^members\[value\s+eq\s+"([^"]+)"\]$',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def apply_user_patch(
|
||||
operations: list[ScimPatchOperation],
|
||||
current: ScimUserResource,
|
||||
) -> ScimUserResource:
|
||||
"""Apply SCIM PATCH operations to a user resource.
|
||||
|
||||
Returns a new ``ScimUserResource`` with the modifications applied.
|
||||
The original object is not mutated.
|
||||
|
||||
Raises:
|
||||
ScimPatchError: If an operation targets an unsupported path.
|
||||
"""
|
||||
data = current.model_dump()
|
||||
name_data = data.get("name") or {}
|
||||
|
||||
for op in operations:
|
||||
if op.op == ScimPatchOperationType.REPLACE:
|
||||
_apply_user_replace(op, data, name_data)
|
||||
elif op.op == ScimPatchOperationType.ADD:
|
||||
_apply_user_replace(op, data, name_data)
|
||||
else:
|
||||
raise ScimPatchError(
|
||||
f"Unsupported operation '{op.op.value}' on User resource"
|
||||
)
|
||||
|
||||
data["name"] = name_data
|
||||
return ScimUserResource.model_validate(data)
|
||||
|
||||
|
||||
def _apply_user_replace(
|
||||
op: ScimPatchOperation,
|
||||
data: dict,
|
||||
name_data: dict,
|
||||
) -> None:
|
||||
"""Apply a replace/add operation to user data."""
|
||||
path = (op.path or "").lower()
|
||||
|
||||
if not path:
|
||||
# No path — value is a dict of top-level attributes to set
|
||||
if isinstance(op.value, dict):
|
||||
for key, val in op.value.items():
|
||||
_set_user_field(key.lower(), val, data, name_data)
|
||||
else:
|
||||
raise ScimPatchError("Replace without path requires a dict value")
|
||||
return
|
||||
|
||||
_set_user_field(path, op.value, data, name_data)
|
||||
|
||||
|
||||
def _set_user_field(
|
||||
path: str,
|
||||
value: str | bool | dict | list | None,
|
||||
data: dict,
|
||||
name_data: dict,
|
||||
) -> None:
|
||||
"""Set a single field on user data by SCIM path."""
|
||||
if path == "active":
|
||||
data["active"] = value
|
||||
elif path == "username":
|
||||
data["userName"] = value
|
||||
elif path == "externalid":
|
||||
data["externalId"] = value
|
||||
elif path == "name.givenname":
|
||||
name_data["givenName"] = value
|
||||
elif path == "name.familyname":
|
||||
name_data["familyName"] = value
|
||||
elif path == "name.formatted":
|
||||
name_data["formatted"] = value
|
||||
elif path == "displayname":
|
||||
# Some IdPs send displayName on users; map to formatted name
|
||||
name_data["formatted"] = value
|
||||
else:
|
||||
raise ScimPatchError(f"Unsupported path '{path}' for User PATCH")
|
||||
|
||||
|
||||
def apply_group_patch(
|
||||
operations: list[ScimPatchOperation],
|
||||
current: ScimGroupResource,
|
||||
) -> tuple[ScimGroupResource, list[str], list[str]]:
|
||||
"""Apply SCIM PATCH operations to a group resource.
|
||||
|
||||
Returns:
|
||||
A tuple of (modified group, added member IDs, removed member IDs).
|
||||
The caller uses the member ID lists to update the database.
|
||||
|
||||
Raises:
|
||||
ScimPatchError: If an operation targets an unsupported path.
|
||||
"""
|
||||
data = current.model_dump()
|
||||
current_members: list[dict] = list(data.get("members") or [])
|
||||
added_ids: list[str] = []
|
||||
removed_ids: list[str] = []
|
||||
|
||||
for op in operations:
|
||||
if op.op == ScimPatchOperationType.REPLACE:
|
||||
_apply_group_replace(op, data, current_members, added_ids, removed_ids)
|
||||
elif op.op == ScimPatchOperationType.ADD:
|
||||
_apply_group_add(op, current_members, added_ids)
|
||||
elif op.op == ScimPatchOperationType.REMOVE:
|
||||
_apply_group_remove(op, current_members, removed_ids)
|
||||
else:
|
||||
raise ScimPatchError(
|
||||
f"Unsupported operation '{op.op.value}' on Group resource"
|
||||
)
|
||||
|
||||
data["members"] = current_members
|
||||
group = ScimGroupResource.model_validate(data)
|
||||
return group, added_ids, removed_ids
|
||||
|
||||
|
||||
def _apply_group_replace(
|
||||
op: ScimPatchOperation,
|
||||
data: dict,
|
||||
current_members: list[dict],
|
||||
added_ids: list[str],
|
||||
removed_ids: list[str],
|
||||
) -> None:
|
||||
"""Apply a replace operation to group data."""
|
||||
path = (op.path or "").lower()
|
||||
|
||||
if not path:
|
||||
if isinstance(op.value, dict):
|
||||
for key, val in op.value.items():
|
||||
if key.lower() == "members":
|
||||
_replace_members(val, current_members, added_ids, removed_ids)
|
||||
else:
|
||||
_set_group_field(key.lower(), val, data)
|
||||
else:
|
||||
raise ScimPatchError("Replace without path requires a dict value")
|
||||
return
|
||||
|
||||
if path == "members":
|
||||
_replace_members(op.value, current_members, added_ids, removed_ids)
|
||||
return
|
||||
|
||||
_set_group_field(path, op.value, data)
|
||||
|
||||
|
||||
def _replace_members(
|
||||
value: str | list | dict | bool | None,
|
||||
current_members: list[dict],
|
||||
added_ids: list[str],
|
||||
removed_ids: list[str],
|
||||
) -> None:
|
||||
"""Replace the entire group member list."""
|
||||
if not isinstance(value, list):
|
||||
raise ScimPatchError("Replace members requires a list value")
|
||||
|
||||
old_ids = {m["value"] for m in current_members}
|
||||
new_ids = {m.get("value", "") for m in value}
|
||||
|
||||
removed_ids.extend(old_ids - new_ids)
|
||||
added_ids.extend(new_ids - old_ids)
|
||||
|
||||
current_members[:] = value
|
||||
|
||||
|
||||
def _set_group_field(
|
||||
path: str,
|
||||
value: str | bool | dict | list | None,
|
||||
data: dict,
|
||||
) -> None:
|
||||
"""Set a single field on group data by SCIM path."""
|
||||
if path == "displayname":
|
||||
data["displayName"] = value
|
||||
elif path == "externalid":
|
||||
data["externalId"] = value
|
||||
else:
|
||||
raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH")
|
||||
|
||||
|
||||
def _apply_group_add(
|
||||
op: ScimPatchOperation,
|
||||
members: list[dict],
|
||||
added_ids: list[str],
|
||||
) -> None:
|
||||
"""Add members to a group."""
|
||||
path = (op.path or "").lower()
|
||||
|
||||
if path and path != "members":
|
||||
raise ScimPatchError(f"Unsupported add path '{op.path}' for Group")
|
||||
|
||||
if not isinstance(op.value, list):
|
||||
raise ScimPatchError("Add members requires a list value")
|
||||
|
||||
existing_ids = {m["value"] for m in members}
|
||||
for member_data in op.value:
|
||||
member_id = member_data.get("value", "")
|
||||
if member_id and member_id not in existing_ids:
|
||||
members.append(member_data)
|
||||
added_ids.append(member_id)
|
||||
existing_ids.add(member_id)
|
||||
|
||||
|
||||
def _apply_group_remove(
|
||||
op: ScimPatchOperation,
|
||||
members: list[dict],
|
||||
removed_ids: list[str],
|
||||
) -> None:
|
||||
"""Remove members from a group."""
|
||||
if not op.path:
|
||||
raise ScimPatchError("Remove operation requires a path")
|
||||
|
||||
match = _MEMBER_FILTER_RE.match(op.path)
|
||||
if not match:
|
||||
raise ScimPatchError(
|
||||
f"Unsupported remove path '{op.path}'. "
|
||||
'Expected: members[value eq "user-id"]'
|
||||
)
|
||||
|
||||
target_id = match.group(1)
|
||||
original_len = len(members)
|
||||
members[:] = [m for m in members if m.get("value") != target_id]
|
||||
|
||||
if len(members) < original_len:
|
||||
removed_ids.append(target_id)
|
||||
@@ -1,7 +1,9 @@
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from fastapi_users import schemas
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class UserRole(str, Enum):
|
||||
@@ -41,8 +43,21 @@ class UserCreate(schemas.BaseUserCreate):
|
||||
role: UserRole = UserRole.BASIC
|
||||
tenant_id: str | None = None
|
||||
# Captcha token for cloud signup protection (optional, only used when captcha is enabled)
|
||||
# Excluded from create_update_dict so it never reaches the DB layer
|
||||
captcha_token: str | None = None
|
||||
|
||||
@override
|
||||
def create_update_dict(self) -> dict[str, Any]:
|
||||
d = super().create_update_dict()
|
||||
d.pop("captcha_token", None)
|
||||
return d
|
||||
|
||||
@override
|
||||
def create_update_dict_superuser(self) -> dict[str, Any]:
|
||||
d = super().create_update_dict_superuser()
|
||||
d.pop("captcha_token", None)
|
||||
return d
|
||||
|
||||
|
||||
class UserUpdateWithRole(schemas.BaseUserUpdate):
|
||||
role: UserRole
|
||||
|
||||
@@ -37,6 +37,7 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
|
||||
redis_connector: RedisConnector,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
timeout_seconds: int | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.parent_pid = parent_pid
|
||||
@@ -51,11 +52,29 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
|
||||
self.last_lock_monotonic = time.monotonic()
|
||||
|
||||
self.last_parent_check = time.monotonic()
|
||||
self.start_monotonic = time.monotonic()
|
||||
self.timeout_seconds = timeout_seconds
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
# Check if the associated indexing attempt has been cancelled
|
||||
# TODO: Pass index_attempt_id to the callback and check cancellation using the db
|
||||
return bool(self.redis_connector.stop.fenced)
|
||||
if bool(self.redis_connector.stop.fenced):
|
||||
return True
|
||||
|
||||
# Check if the task has exceeded its timeout
|
||||
# NOTE: Celery's soft_time_limit does not work with thread pools,
|
||||
# so we must enforce timeouts internally.
|
||||
if self.timeout_seconds is not None:
|
||||
elapsed = time.monotonic() - self.start_monotonic
|
||||
if elapsed > self.timeout_seconds:
|
||||
logger.warning(
|
||||
f"IndexingCallback Docprocessing - task timeout exceeded: "
|
||||
f"elapsed={elapsed:.0f}s timeout={self.timeout_seconds}s "
|
||||
f"cc_pair={self.redis_connector.cc_pair_id}"
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None: # noqa: ARG002
|
||||
"""Amount isn't used yet."""
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
"""Celery tasks for hierarchy fetching."""
|
||||
|
||||
from onyx.background.celery.tasks.hierarchyfetching.tasks import ( # noqa: F401
|
||||
check_for_hierarchy_fetching,
|
||||
)
|
||||
from onyx.background.celery.tasks.hierarchyfetching.tasks import ( # noqa: F401
|
||||
connector_hierarchy_fetching_task,
|
||||
)
|
||||
|
||||
__all__ = ["check_for_hierarchy_fetching", "connector_hierarchy_fetching_task"]
|
||||
@@ -146,14 +146,26 @@ def _collect_queue_metrics(redis_celery: Redis) -> list[Metric]:
|
||||
"""Collect metrics about queue lengths for different Celery queues"""
|
||||
metrics = []
|
||||
queue_mappings = {
|
||||
"celery_queue_length": "celery",
|
||||
"docprocessing_queue_length": "docprocessing",
|
||||
"sync_queue_length": "sync",
|
||||
"deletion_queue_length": "deletion",
|
||||
"pruning_queue_length": "pruning",
|
||||
"celery_queue_length": OnyxCeleryQueues.PRIMARY,
|
||||
"docprocessing_queue_length": OnyxCeleryQueues.DOCPROCESSING,
|
||||
"docfetching_queue_length": OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,
|
||||
"sync_queue_length": OnyxCeleryQueues.VESPA_METADATA_SYNC,
|
||||
"deletion_queue_length": OnyxCeleryQueues.CONNECTOR_DELETION,
|
||||
"pruning_queue_length": OnyxCeleryQueues.CONNECTOR_PRUNING,
|
||||
"permissions_sync_queue_length": OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
|
||||
"external_group_sync_queue_length": OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC,
|
||||
"permissions_upsert_queue_length": OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT,
|
||||
"hierarchy_fetching_queue_length": OnyxCeleryQueues.CONNECTOR_HIERARCHY_FETCHING,
|
||||
"llm_model_update_queue_length": OnyxCeleryQueues.LLM_MODEL_UPDATE,
|
||||
"checkpoint_cleanup_queue_length": OnyxCeleryQueues.CHECKPOINT_CLEANUP,
|
||||
"index_attempt_cleanup_queue_length": OnyxCeleryQueues.INDEX_ATTEMPT_CLEANUP,
|
||||
"csv_generation_queue_length": OnyxCeleryQueues.CSV_GENERATION,
|
||||
"user_file_processing_queue_length": OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
"user_file_project_sync_queue_length": OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
|
||||
"user_file_delete_queue_length": OnyxCeleryQueues.USER_FILE_DELETE,
|
||||
"monitoring_queue_length": OnyxCeleryQueues.MONITORING,
|
||||
"sandbox_queue_length": OnyxCeleryQueues.SANDBOX,
|
||||
"opensearch_migration_queue_length": OnyxCeleryQueues.OPENSEARCH_MIGRATION,
|
||||
}
|
||||
|
||||
for name, queue in queue_mappings.items():
|
||||
@@ -881,7 +893,7 @@ def monitor_celery_queues_helper(
|
||||
"""A task to monitor all celery queue lengths."""
|
||||
|
||||
r_celery = task.app.broker_connection().channel().client # type: ignore
|
||||
n_celery = celery_get_queue_length("celery", r_celery)
|
||||
n_celery = celery_get_queue_length(OnyxCeleryQueues.PRIMARY, r_celery)
|
||||
n_docfetching = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
|
||||
)
|
||||
@@ -908,6 +920,26 @@ def monitor_celery_queues_helper(
|
||||
n_permissions_upsert = celery_get_queue_length(
|
||||
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
|
||||
)
|
||||
n_hierarchy_fetching = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_HIERARCHY_FETCHING, r_celery
|
||||
)
|
||||
n_llm_model_update = celery_get_queue_length(
|
||||
OnyxCeleryQueues.LLM_MODEL_UPDATE, r_celery
|
||||
)
|
||||
n_checkpoint_cleanup = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CHECKPOINT_CLEANUP, r_celery
|
||||
)
|
||||
n_index_attempt_cleanup = celery_get_queue_length(
|
||||
OnyxCeleryQueues.INDEX_ATTEMPT_CLEANUP, r_celery
|
||||
)
|
||||
n_csv_generation = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CSV_GENERATION, r_celery
|
||||
)
|
||||
n_monitoring = celery_get_queue_length(OnyxCeleryQueues.MONITORING, r_celery)
|
||||
n_sandbox = celery_get_queue_length(OnyxCeleryQueues.SANDBOX, r_celery)
|
||||
n_opensearch_migration = celery_get_queue_length(
|
||||
OnyxCeleryQueues.OPENSEARCH_MIGRATION, r_celery
|
||||
)
|
||||
|
||||
n_docfetching_prefetched = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
|
||||
@@ -931,6 +963,14 @@ def monitor_celery_queues_helper(
|
||||
f"permissions_sync={n_permissions_sync} "
|
||||
f"external_group_sync={n_external_group_sync} "
|
||||
f"permissions_upsert={n_permissions_upsert} "
|
||||
f"hierarchy_fetching={n_hierarchy_fetching} "
|
||||
f"llm_model_update={n_llm_model_update} "
|
||||
f"checkpoint_cleanup={n_checkpoint_cleanup} "
|
||||
f"index_attempt_cleanup={n_index_attempt_cleanup} "
|
||||
f"csv_generation={n_csv_generation} "
|
||||
f"monitoring={n_monitoring} "
|
||||
f"sandbox={n_sandbox} "
|
||||
f"opensearch_migration={n_opensearch_migration} "
|
||||
)
|
||||
|
||||
|
||||
|
||||
8
backend/onyx/background/celery/tasks/pruning/__init__.py
Normal file
8
backend/onyx/background/celery/tasks/pruning/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Celery tasks for connector pruning."""
|
||||
|
||||
from onyx.background.celery.tasks.pruning.tasks import check_for_pruning # noqa: F401
|
||||
from onyx.background.celery.tasks.pruning.tasks import ( # noqa: F401
|
||||
connector_pruning_generator_task,
|
||||
)
|
||||
|
||||
__all__ = ["check_for_pruning", "connector_pruning_generator_task"]
|
||||
@@ -523,6 +523,7 @@ def connector_pruning_generator_task(
|
||||
redis_connector,
|
||||
lock,
|
||||
r,
|
||||
timeout_seconds=JOB_TIMEOUT,
|
||||
)
|
||||
|
||||
# a list of docs in the source
|
||||
|
||||
@@ -3,34 +3,26 @@ from collections.abc import Callable
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.datastructures import Headers
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import is_user_admin
|
||||
from onyx.chat.models import ChatHistoryResult
|
||||
from onyx.chat.models import ChatLoadedFile
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.chat import get_chat_messages_by_session
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.kg_config import get_kg_config_settings
|
||||
from onyx.db.kg_config import is_kg_config_settings_enabled_valid
|
||||
from onyx.db.llm import fetch_existing_doc_sets
|
||||
from onyx.db.llm import fetch_existing_tools
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.db.models import Tool
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.projects import check_project_ownership
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
@@ -47,9 +39,6 @@ from onyx.prompts.tool_prompts import TOOL_CALL_FAILURE_PROMPT
|
||||
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.timing import log_function_time
|
||||
@@ -278,70 +267,6 @@ def extract_headers(
|
||||
return extracted_headers
|
||||
|
||||
|
||||
def create_temporary_persona(
|
||||
persona_config: PersonaOverrideConfig, db_session: Session, user: User
|
||||
) -> Persona:
|
||||
if not is_user_admin(user):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="User is not authorized to create a persona in one shot queries",
|
||||
)
|
||||
|
||||
"""Create a temporary Persona object from the provided configuration."""
|
||||
persona = Persona(
|
||||
name=persona_config.name,
|
||||
description=persona_config.description,
|
||||
num_chunks=persona_config.num_chunks,
|
||||
llm_relevance_filter=persona_config.llm_relevance_filter,
|
||||
llm_filter_extraction=persona_config.llm_filter_extraction,
|
||||
recency_bias=RecencyBiasSetting.BASE_DECAY,
|
||||
llm_model_provider_override=persona_config.llm_model_provider_override,
|
||||
llm_model_version_override=persona_config.llm_model_version_override,
|
||||
)
|
||||
|
||||
if persona_config.prompts:
|
||||
# Use the first prompt from the override config for embedded prompt fields
|
||||
first_prompt = persona_config.prompts[0]
|
||||
persona.system_prompt = first_prompt.system_prompt
|
||||
persona.task_prompt = first_prompt.task_prompt
|
||||
persona.datetime_aware = first_prompt.datetime_aware
|
||||
|
||||
persona.tools = []
|
||||
if persona_config.custom_tools_openapi:
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
|
||||
for schema in persona_config.custom_tools_openapi:
|
||||
tools = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema_and_headers(
|
||||
tool_id=0, # dummy tool id
|
||||
openapi_schema=schema,
|
||||
emitter=get_default_emitter(),
|
||||
),
|
||||
)
|
||||
persona.tools.extend(tools)
|
||||
|
||||
if persona_config.tools:
|
||||
tool_ids = [tool.id for tool in persona_config.tools]
|
||||
persona.tools.extend(
|
||||
fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
|
||||
)
|
||||
|
||||
if persona_config.tool_ids:
|
||||
persona.tools.extend(
|
||||
fetch_existing_tools(
|
||||
db_session=db_session, tool_ids=persona_config.tool_ids
|
||||
)
|
||||
)
|
||||
|
||||
fetched_docs = fetch_existing_doc_sets(
|
||||
db_session=db_session, doc_ids=persona_config.document_set_ids
|
||||
)
|
||||
persona.document_sets = fetched_docs
|
||||
|
||||
return persona
|
||||
|
||||
|
||||
def process_kg_commands(
|
||||
message: str, persona_name: str, tenant_id: str, db_session: Session # noqa: ARG001
|
||||
) -> None:
|
||||
@@ -688,28 +613,34 @@ def convert_chat_history(
|
||||
|
||||
|
||||
def get_custom_agent_prompt(persona: Persona, chat_session: ChatSession) -> str | None:
|
||||
"""Get the custom agent prompt from persona or project instructions.
|
||||
"""Get the custom agent prompt from persona or project instructions. If it's replacing the base system prompt,
|
||||
it does not count as a custom agent prompt (logic exists later also to drop it in this case).
|
||||
|
||||
Chat Sessions in Projects that are using a custom agent will retain the custom agent prompt.
|
||||
Priority: persona.system_prompt > chat_session.project.instructions > None
|
||||
Priority: persona.system_prompt (if not default Agent) > chat_session.project.instructions
|
||||
|
||||
# NOTE: Logic elsewhere allows saving empty strings for potentially other purposes but for constructing the prompts
|
||||
# we never want to return an empty string for a prompt so it's translated into an explicit None.
|
||||
|
||||
Args:
|
||||
persona: The Persona object
|
||||
chat_session: The ChatSession object
|
||||
|
||||
Returns:
|
||||
The custom agent prompt string, or None if neither persona nor project has one
|
||||
The prompt to use for the custom Agent part of the prompt.
|
||||
"""
|
||||
# Not considered a custom agent if it's the default behavior persona
|
||||
if persona.id == DEFAULT_PERSONA_ID:
|
||||
return None
|
||||
# If using a custom Agent, always respect its prompt, even if in a Project, and even if it's an empty custom prompt.
|
||||
if persona.id != DEFAULT_PERSONA_ID:
|
||||
# Logic exists later also to drop it in this case but this is strictly correct anyhow.
|
||||
if persona.replace_base_system_prompt:
|
||||
return None
|
||||
return persona.system_prompt or None
|
||||
|
||||
if persona.system_prompt:
|
||||
return persona.system_prompt
|
||||
elif chat_session.project and chat_session.project.instructions:
|
||||
# If in a project and using the default Agent, respect the project instructions.
|
||||
if chat_session.project and chat_session.project.instructions:
|
||||
return chat_session.project.instructions
|
||||
else:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def is_last_assistant_message_clarification(chat_history: list[ChatMessage]) -> bool:
|
||||
|
||||
@@ -38,7 +38,6 @@ from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.llm.utils import model_needs_formatting_reenabled
|
||||
from onyx.prompts.chat_prompts import IMAGE_GEN_REMINDER
|
||||
from onyx.prompts.chat_prompts import OPEN_URL_REMINDER
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
@@ -594,6 +593,7 @@ def run_llm_loop(
|
||||
|
||||
reasoning_cycles = 0
|
||||
for llm_cycle_count in range(MAX_LLM_CYCLES):
|
||||
# Handling tool calls based on cycle count and past cycle conditions
|
||||
out_of_cycles = llm_cycle_count == MAX_LLM_CYCLES - 1
|
||||
if forced_tool_id:
|
||||
# Needs to be just the single one because the "required" currently doesn't have a specified tool, just a binary
|
||||
@@ -615,6 +615,7 @@ def run_llm_loop(
|
||||
tool_choice = ToolChoiceOptions.AUTO
|
||||
final_tools = tools
|
||||
|
||||
# Handling the system prompt and custom agent prompt
|
||||
# The section below calculates the available tokens for history a bit more accurately
|
||||
# now that project files are loaded in.
|
||||
if persona and persona.replace_base_system_prompt:
|
||||
@@ -632,12 +633,14 @@ def run_llm_loop(
|
||||
else:
|
||||
# If it's an empty string, we assume the user does not want to include it as an empty System message
|
||||
if default_base_system_prompt:
|
||||
open_ai_formatting_enabled = model_needs_formatting_reenabled(
|
||||
llm.config.model_name
|
||||
)
|
||||
|
||||
prompt_memory_context = (
|
||||
user_memory_context if inject_memories_in_prompt else None
|
||||
user_memory_context
|
||||
if inject_memories_in_prompt
|
||||
else (
|
||||
user_memory_context.without_memories()
|
||||
if user_memory_context
|
||||
else None
|
||||
)
|
||||
)
|
||||
system_prompt_str = build_system_prompt(
|
||||
base_system_prompt=default_base_system_prompt,
|
||||
@@ -646,7 +649,6 @@ def run_llm_loop(
|
||||
tools=tools,
|
||||
should_cite_documents=should_cite_documents
|
||||
or always_cite_documents,
|
||||
open_ai_formatting_enabled=open_ai_formatting_enabled,
|
||||
)
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=system_prompt_str,
|
||||
|
||||
@@ -36,6 +36,8 @@ from onyx.llm.models import ToolCall
|
||||
from onyx.llm.models import ToolMessage
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.llm.prompt_cache.processor import process_with_prompt_cache
|
||||
from onyx.llm.utils import model_needs_formatting_reenabled
|
||||
from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
|
||||
from onyx.prompts.constants import SYSTEM_REMINDER_TAG_CLOSE
|
||||
from onyx.prompts.constants import SYSTEM_REMINDER_TAG_OPEN
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
@@ -623,6 +625,17 @@ def translate_history_to_llm_format(
|
||||
f"Unknown message type {msg.message_type} in history. Skipping message."
|
||||
)
|
||||
|
||||
# Apply model-specific formatting when translating to LLM format (e.g. OpenAI
|
||||
# reasoning models need CODE_BLOCK_MARKDOWN prefix for correct markdown generation)
|
||||
if model_needs_formatting_reenabled(llm_config.model_name):
|
||||
for i, m in enumerate(messages):
|
||||
if isinstance(m, SystemMessage):
|
||||
messages[i] = SystemMessage(
|
||||
role="system",
|
||||
content=CODE_BLOCK_MARKDOWN + m.content,
|
||||
)
|
||||
break
|
||||
|
||||
# prompt caching: rely on should_cache in ChatMessageSimple to
|
||||
# pick the split point for the cacheable prefix and suffix
|
||||
if last_cacheable_msg_idx != -1:
|
||||
|
||||
@@ -1,17 +1,13 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import GeneratedImage
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
@@ -20,54 +16,6 @@ from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.tool_implementations.custom.base_tool_types import ToolResultType
|
||||
|
||||
|
||||
class StreamStopReason(Enum):
|
||||
CONTEXT_LENGTH = "context_length"
|
||||
CANCELLED = "cancelled"
|
||||
FINISHED = "finished"
|
||||
|
||||
|
||||
class StreamType(Enum):
|
||||
SUB_QUESTIONS = "sub_questions"
|
||||
SUB_ANSWER = "sub_answer"
|
||||
MAIN_ANSWER = "main_answer"
|
||||
|
||||
|
||||
class StreamStopInfo(BaseModel):
|
||||
stop_reason: StreamStopReason
|
||||
|
||||
stream_type: StreamType = StreamType.MAIN_ANSWER
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
data = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
data["stop_reason"] = self.stop_reason.name
|
||||
return data
|
||||
|
||||
|
||||
class UserKnowledgeFilePacket(BaseModel):
|
||||
user_files: list[FileDescriptor]
|
||||
|
||||
|
||||
class RelevanceAnalysis(BaseModel):
|
||||
relevant: bool
|
||||
content: str | None = None
|
||||
|
||||
|
||||
class DocumentRelevance(BaseModel):
|
||||
"""Contains all relevance information for a given search"""
|
||||
|
||||
relevance_summaries: dict[str, RelevanceAnalysis]
|
||||
|
||||
|
||||
class OnyxAnswerPiece(BaseModel):
|
||||
# A small piece of a complete answer. Used for streaming back answers.
|
||||
answer_piece: str | None # if None, specifies the end of an Answer
|
||||
|
||||
|
||||
class MessageResponseIDInfo(BaseModel):
|
||||
user_message_id: int | None
|
||||
reserved_assistant_message_id: int
|
||||
|
||||
|
||||
class StreamingError(BaseModel):
|
||||
error: str
|
||||
stack_trace: str | None = None
|
||||
@@ -78,23 +26,11 @@ class StreamingError(BaseModel):
|
||||
details: dict | None = None # Additional context (tool name, model name, etc.)
|
||||
|
||||
|
||||
class OnyxAnswer(BaseModel):
|
||||
answer: str | None
|
||||
|
||||
|
||||
class FileChatDisplay(BaseModel):
|
||||
file_ids: list[str]
|
||||
|
||||
|
||||
class CustomToolResponse(BaseModel):
|
||||
response: ToolResultType
|
||||
tool_name: str
|
||||
|
||||
|
||||
class ToolConfig(BaseModel):
|
||||
id: int
|
||||
|
||||
|
||||
class ProjectSearchConfig(BaseModel):
|
||||
"""Configuration for search tool availability in project context."""
|
||||
|
||||
@@ -102,83 +38,15 @@ class ProjectSearchConfig(BaseModel):
|
||||
disable_forced_tool: bool
|
||||
|
||||
|
||||
class PromptOverrideConfig(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
system_prompt: str
|
||||
task_prompt: str = ""
|
||||
datetime_aware: bool = True
|
||||
include_citations: bool = True
|
||||
|
||||
|
||||
class PersonaOverrideConfig(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
search_type: SearchType = SearchType.SEMANTIC
|
||||
num_chunks: float | None = None
|
||||
llm_relevance_filter: bool = False
|
||||
llm_filter_extraction: bool = False
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
|
||||
prompts: list[PromptOverrideConfig] = Field(default_factory=list)
|
||||
# Note: prompt_ids removed - prompts are now embedded in personas
|
||||
|
||||
document_set_ids: list[int] = Field(default_factory=list)
|
||||
tools: list[ToolConfig] = Field(default_factory=list)
|
||||
tool_ids: list[int] = Field(default_factory=list)
|
||||
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
AnswerQuestionPossibleReturn = (
|
||||
OnyxAnswerPiece
|
||||
| CitationInfo
|
||||
| FileChatDisplay
|
||||
| CustomToolResponse
|
||||
| StreamingError
|
||||
| StreamStopInfo
|
||||
)
|
||||
|
||||
|
||||
class CreateChatSessionID(BaseModel):
|
||||
chat_session_id: UUID
|
||||
|
||||
|
||||
AnswerQuestionStreamReturn = Iterator[AnswerQuestionPossibleReturn]
|
||||
|
||||
|
||||
class LLMMetricsContainer(BaseModel):
|
||||
prompt_tokens: int
|
||||
response_tokens: int
|
||||
|
||||
|
||||
StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn]
|
||||
|
||||
|
||||
AnswerStreamPart = (
|
||||
Packet
|
||||
| StreamStopInfo
|
||||
| MessageResponseIDInfo
|
||||
| StreamingError
|
||||
| UserKnowledgeFilePacket
|
||||
| CreateChatSessionID
|
||||
)
|
||||
AnswerStreamPart = Packet | MessageResponseIDInfo | StreamingError | CreateChatSessionID
|
||||
|
||||
AnswerStream = Iterator[AnswerStreamPart]
|
||||
|
||||
|
||||
class ChatBasicResponse(BaseModel):
|
||||
# This is built piece by piece, any of these can be None as the flow could break
|
||||
answer: str
|
||||
answer_citationless: str
|
||||
|
||||
top_documents: list[SearchDoc]
|
||||
|
||||
error_msg: str | None
|
||||
message_id: int
|
||||
citation_info: list[CitationInfo]
|
||||
|
||||
|
||||
class ToolCallResponse(BaseModel):
|
||||
"""Tool call with full details for non-streaming response."""
|
||||
|
||||
@@ -191,8 +59,23 @@ class ToolCallResponse(BaseModel):
|
||||
pre_reasoning: str | None = None
|
||||
|
||||
|
||||
class ChatBasicResponse(BaseModel):
|
||||
# This is built piece by piece, any of these can be None as the flow could break
|
||||
answer: str
|
||||
answer_citationless: str
|
||||
|
||||
top_documents: list[SearchDoc]
|
||||
|
||||
error_msg: str | None
|
||||
message_id: int
|
||||
citation_info: list[CitationInfo]
|
||||
|
||||
|
||||
class ChatFullResponse(BaseModel):
|
||||
"""Complete non-streaming response with all available data."""
|
||||
"""Complete non-streaming response with all available data.
|
||||
NOTE: This model is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
|
||||
experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
|
||||
"""
|
||||
|
||||
# Core response fields
|
||||
answer: str
|
||||
|
||||
@@ -37,7 +37,6 @@ from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import CreateChatSessionID
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import MessageResponseIDInfo
|
||||
from onyx.chat.models import ProjectFileMetadata
|
||||
from onyx.chat.models import ProjectSearchConfig
|
||||
from onyx.chat.models import StreamingError
|
||||
@@ -81,8 +80,7 @@ from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.models import OptionalSearchSetting
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
@@ -615,16 +613,27 @@ def handle_stream_message_objects(
|
||||
|
||||
user_memory_context = get_memories(user, db_session)
|
||||
|
||||
# This is the custom prompt which may come from the Agent or Project. We fetch it earlier because the inner loop
|
||||
# (run_llm_loop and run_deep_research_llm_loop) should not need to be aware of the Chat History in the DB form processed
|
||||
# here, however we need this early for token reservation.
|
||||
custom_agent_prompt = get_custom_agent_prompt(persona, chat_session)
|
||||
|
||||
# When use_memories is disabled, don't inject memories into the prompt
|
||||
# or count them in token reservation, but still pass the full context
|
||||
# When use_memories is disabled, strip memories from the prompt context
|
||||
# but keep user info/preferences. The full context is still passed
|
||||
# to the LLM loop for memory tool persistence.
|
||||
prompt_memory_context = user_memory_context if user.use_memories else None
|
||||
prompt_memory_context = (
|
||||
user_memory_context
|
||||
if user.use_memories
|
||||
else user_memory_context.without_memories()
|
||||
)
|
||||
|
||||
max_reserved_system_prompt_tokens_str = (persona.system_prompt or "") + (
|
||||
custom_agent_prompt or ""
|
||||
)
|
||||
|
||||
reserved_token_count = calculate_reserved_tokens(
|
||||
db_session=db_session,
|
||||
persona_system_prompt=custom_agent_prompt or "",
|
||||
persona_system_prompt=max_reserved_system_prompt_tokens_str,
|
||||
token_counter=token_counter,
|
||||
files=new_msg_req.file_descriptors,
|
||||
user_memory_context=prompt_memory_context,
|
||||
@@ -1016,68 +1025,6 @@ def llm_loop_completion_handle(
|
||||
)
|
||||
|
||||
|
||||
def stream_chat_message_objects(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
# if specified, uses the last user message and does not create a new user message based
|
||||
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
bypass_acl: bool = False,
|
||||
# Additional context that should be included in the chat history, for example:
|
||||
# Slack threads where the conversation cannot be represented by a chain of User/Assistant
|
||||
# messages. Both of the below are used for Slack
|
||||
# NOTE: is not stored in the database, only passed in to the LLM as context
|
||||
additional_context: str | None = None,
|
||||
# Slack context for federated Slack search
|
||||
slack_context: SlackContext | None = None,
|
||||
) -> AnswerStream:
|
||||
forced_tool_id = (
|
||||
new_msg_req.forced_tool_ids[0] if new_msg_req.forced_tool_ids else None
|
||||
)
|
||||
if (
|
||||
new_msg_req.retrieval_options
|
||||
and new_msg_req.retrieval_options.run_search == OptionalSearchSetting.ALWAYS
|
||||
):
|
||||
all_tools = get_tools(db_session)
|
||||
|
||||
search_tool_id = next(
|
||||
(tool.id for tool in all_tools if tool.in_code_tool_id == SEARCH_TOOL_ID),
|
||||
None,
|
||||
)
|
||||
forced_tool_id = search_tool_id
|
||||
|
||||
translated_new_msg_req = SendMessageRequest(
|
||||
message=new_msg_req.message,
|
||||
llm_override=new_msg_req.llm_override,
|
||||
mock_llm_response=new_msg_req.mock_llm_response,
|
||||
allowed_tool_ids=new_msg_req.allowed_tool_ids,
|
||||
forced_tool_id=forced_tool_id,
|
||||
file_descriptors=new_msg_req.file_descriptors,
|
||||
internal_search_filters=(
|
||||
new_msg_req.retrieval_options.filters
|
||||
if new_msg_req.retrieval_options
|
||||
else None
|
||||
),
|
||||
deep_research=new_msg_req.deep_research,
|
||||
parent_message_id=new_msg_req.parent_message_id,
|
||||
chat_session_id=new_msg_req.chat_session_id,
|
||||
origin=new_msg_req.origin,
|
||||
include_citations=new_msg_req.include_citations,
|
||||
)
|
||||
return handle_stream_message_objects(
|
||||
new_msg_req=translated_new_msg_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
litellm_additional_headers=litellm_additional_headers,
|
||||
custom_tool_additional_headers=custom_tool_additional_headers,
|
||||
bypass_acl=bypass_acl,
|
||||
additional_context=additional_context,
|
||||
slack_context=slack_context,
|
||||
)
|
||||
|
||||
|
||||
def remove_answer_citations(answer: str) -> str:
|
||||
pattern = r"\s*\[\[\d+\]\]\(http[s]?://[^\s]+\)"
|
||||
|
||||
|
||||
@@ -9,13 +9,13 @@ from onyx.db.persona import get_default_behavior_persona
|
||||
from onyx.db.user_file import calculate_user_files_token_count
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.prompts.chat_prompts import CITATION_REMINDER
|
||||
from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
|
||||
from onyx.prompts.chat_prompts import DEFAULT_SYSTEM_PROMPT
|
||||
from onyx.prompts.chat_prompts import LAST_CYCLE_CITATION_REMINDER
|
||||
from onyx.prompts.chat_prompts import REQUIRE_CITATION_GUIDANCE
|
||||
from onyx.prompts.prompt_utils import get_company_context
|
||||
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
from onyx.prompts.prompt_utils import replace_citation_guidance_tag
|
||||
from onyx.prompts.prompt_utils import replace_reminder_tag
|
||||
from onyx.prompts.tool_prompts import GENERATE_IMAGE_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import INTERNAL_SEARCH_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import MEMORY_GUIDANCE
|
||||
@@ -25,7 +25,12 @@ from onyx.prompts.tool_prompts import TOOL_DESCRIPTION_SEARCH_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import TOOL_SECTION_HEADER
|
||||
from onyx.prompts.tool_prompts import WEB_SEARCH_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import WEB_SEARCH_SITE_DISABLED_GUIDANCE
|
||||
from onyx.prompts.user_info import BASIC_INFORMATION_PROMPT
|
||||
from onyx.prompts.user_info import TEAM_INFORMATION_PROMPT
|
||||
from onyx.prompts.user_info import USER_INFORMATION_HEADER
|
||||
from onyx.prompts.user_info import USER_MEMORIES_PROMPT
|
||||
from onyx.prompts.user_info import USER_PREFERENCES_PROMPT
|
||||
from onyx.prompts.user_info import USER_ROLE_PROMPT
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
@@ -131,6 +136,59 @@ def build_reminder_message(
|
||||
return reminder if reminder else None
|
||||
|
||||
|
||||
def _build_user_information_section(
|
||||
user_memory_context: UserMemoryContext | None,
|
||||
company_context: str | None,
|
||||
) -> str:
|
||||
"""Build the complete '# User Information' section with all sub-sections
|
||||
in the correct order: Basic Info → Team Info → Preferences → Memories."""
|
||||
sections: list[str] = []
|
||||
|
||||
if user_memory_context:
|
||||
ctx = user_memory_context
|
||||
has_basic_info = ctx.user_info.name or ctx.user_info.email or ctx.user_info.role
|
||||
|
||||
if has_basic_info:
|
||||
role_line = (
|
||||
USER_ROLE_PROMPT.format(user_role=ctx.user_info.role).strip()
|
||||
if ctx.user_info.role
|
||||
else ""
|
||||
)
|
||||
if role_line:
|
||||
role_line = "\n" + role_line
|
||||
sections.append(
|
||||
BASIC_INFORMATION_PROMPT.format(
|
||||
user_name=ctx.user_info.name or "",
|
||||
user_email=ctx.user_info.email or "",
|
||||
user_role=role_line,
|
||||
)
|
||||
)
|
||||
|
||||
if company_context:
|
||||
sections.append(
|
||||
TEAM_INFORMATION_PROMPT.format(team_information=company_context.strip())
|
||||
)
|
||||
|
||||
if user_memory_context:
|
||||
ctx = user_memory_context
|
||||
|
||||
if ctx.user_preferences:
|
||||
sections.append(
|
||||
USER_PREFERENCES_PROMPT.format(user_preferences=ctx.user_preferences)
|
||||
)
|
||||
|
||||
if ctx.memories:
|
||||
formatted_memories = "\n".join(f"- {memory}" for memory in ctx.memories)
|
||||
sections.append(
|
||||
USER_MEMORIES_PROMPT.format(user_memories=formatted_memories)
|
||||
)
|
||||
|
||||
if not sections:
|
||||
return ""
|
||||
|
||||
return USER_INFORMATION_HEADER + "".join(sections)
|
||||
|
||||
|
||||
def build_system_prompt(
|
||||
base_system_prompt: str,
|
||||
datetime_aware: bool = False,
|
||||
@@ -138,18 +196,12 @@ def build_system_prompt(
|
||||
tools: Sequence[Tool] | None = None,
|
||||
should_cite_documents: bool = False,
|
||||
include_all_guidance: bool = False,
|
||||
open_ai_formatting_enabled: bool = False,
|
||||
) -> str:
|
||||
"""Should only be called with the default behavior system prompt.
|
||||
If the user has replaced the default behavior prompt with their custom agent prompt, do not call this function.
|
||||
"""
|
||||
system_prompt = handle_onyx_date_awareness(base_system_prompt, datetime_aware)
|
||||
|
||||
# See https://simonwillison.net/tags/markdown/ for context on why this is needed
|
||||
# for OpenAI reasoning models to have correct markdown generation
|
||||
if open_ai_formatting_enabled:
|
||||
system_prompt = CODE_BLOCK_MARKDOWN + system_prompt
|
||||
|
||||
# Replace citation guidance placeholder if present
|
||||
system_prompt, should_append_citation_guidance = replace_citation_guidance_tag(
|
||||
system_prompt,
|
||||
@@ -157,16 +209,14 @@ def build_system_prompt(
|
||||
include_all_guidance=include_all_guidance,
|
||||
)
|
||||
|
||||
# Replace reminder tag placeholder if present
|
||||
system_prompt = replace_reminder_tag(system_prompt)
|
||||
|
||||
company_context = get_company_context()
|
||||
formatted_user_context = (
|
||||
user_memory_context.as_formatted_prompt() if user_memory_context else ""
|
||||
user_info_section = _build_user_information_section(
|
||||
user_memory_context, company_context
|
||||
)
|
||||
if company_context or formatted_user_context:
|
||||
system_prompt += USER_INFORMATION_HEADER
|
||||
if company_context:
|
||||
system_prompt += company_context
|
||||
if formatted_user_context:
|
||||
system_prompt += formatted_user_context
|
||||
system_prompt += user_info_section
|
||||
|
||||
# Append citation guidance after company context if placeholder was not present
|
||||
# This maintains backward compatibility and ensures citations are always enforced when needed
|
||||
|
||||
@@ -977,6 +977,7 @@ API_KEY_HASH_ROUNDS = (
|
||||
# MCP Server Configs
|
||||
#####
|
||||
MCP_SERVER_ENABLED = os.environ.get("MCP_SERVER_ENABLED", "").lower() == "true"
|
||||
MCP_SERVER_HOST = os.environ.get("MCP_SERVER_HOST", "0.0.0.0")
|
||||
MCP_SERVER_PORT = int(os.environ.get("MCP_SERVER_PORT") or 8090)
|
||||
|
||||
# CORS origins for MCP clients (comma-separated)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import contextvars
|
||||
import re
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import Future
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
@@ -14,6 +15,7 @@ from retry import retry
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import Document
|
||||
@@ -62,11 +64,44 @@ class AirtableClientNotSetUpError(PermissionError):
|
||||
super().__init__("Airtable Client is not set up, was load_credentials called?")
|
||||
|
||||
|
||||
# Matches URLs like https://airtable.com/appXXX/tblYYY/viwZZZ?blocks=hide
|
||||
# Captures: base_id (appXXX), table_id (tblYYY), and optionally view_id (viwZZZ)
|
||||
_AIRTABLE_URL_PATTERN = re.compile(
|
||||
r"https?://airtable\.com/(app[A-Za-z0-9]+)/(tbl[A-Za-z0-9]+)(?:/(viw[A-Za-z0-9]+))?",
|
||||
)
|
||||
|
||||
|
||||
def parse_airtable_url(
|
||||
url: str,
|
||||
) -> tuple[str, str, str | None]:
|
||||
"""Parse an Airtable URL into (base_id, table_id, view_id).
|
||||
|
||||
Accepts URLs like:
|
||||
https://airtable.com/appXXX/tblYYY
|
||||
https://airtable.com/appXXX/tblYYY/viwZZZ
|
||||
https://airtable.com/appXXX/tblYYY/viwZZZ?blocks=hide
|
||||
|
||||
Returns:
|
||||
(base_id, table_id, view_id or None)
|
||||
|
||||
Raises:
|
||||
ValueError if the URL doesn't match the expected format.
|
||||
"""
|
||||
match = _AIRTABLE_URL_PATTERN.search(url.strip())
|
||||
if not match:
|
||||
raise ValueError(
|
||||
f"Could not parse Airtable URL: '{url}'. "
|
||||
"Expected format: https://airtable.com/appXXX/tblYYY[/viwZZZ]"
|
||||
)
|
||||
return match.group(1), match.group(2), match.group(3)
|
||||
|
||||
|
||||
class AirtableConnector(LoadConnector):
|
||||
def __init__(
|
||||
self,
|
||||
base_id: str,
|
||||
table_name_or_id: str,
|
||||
base_id: str = "",
|
||||
table_name_or_id: str = "",
|
||||
airtable_url: str = "",
|
||||
treat_all_non_attachment_fields_as_metadata: bool = False,
|
||||
view_id: str | None = None,
|
||||
share_id: str | None = None,
|
||||
@@ -75,16 +110,33 @@ class AirtableConnector(LoadConnector):
|
||||
"""Initialize an AirtableConnector.
|
||||
|
||||
Args:
|
||||
base_id: The ID of the Airtable base to connect to
|
||||
table_name_or_id: The name or ID of the table to index
|
||||
base_id: The ID of the Airtable base (not required when airtable_url is set)
|
||||
table_name_or_id: The name or ID of the table (not required when airtable_url is set)
|
||||
airtable_url: An Airtable URL to parse base_id, table_id, and view_id from.
|
||||
Overrides base_id, table_name_or_id, and view_id if provided.
|
||||
treat_all_non_attachment_fields_as_metadata: If True, all fields except attachments will be treated as metadata.
|
||||
If False, only fields with types in DEFAULT_METADATA_FIELD_TYPES will be treated as metadata.
|
||||
view_id: Optional ID of a specific view to use
|
||||
share_id: Optional ID of a "share" to use for generating record URLs (https://airtable.com/developers/web/api/list-shares)
|
||||
share_id: Optional ID of a "share" to use for generating record URLs
|
||||
batch_size: Number of records to process in each batch
|
||||
|
||||
Mode is auto-detected: if a specific table is identified (via URL or
|
||||
base_id + table_name_or_id), the connector indexes that single table.
|
||||
Otherwise, it discovers and indexes all accessible bases and tables.
|
||||
"""
|
||||
# If a URL is provided, parse it to extract base_id, table_id, and view_id
|
||||
if airtable_url:
|
||||
parsed_base_id, parsed_table_id, parsed_view_id = parse_airtable_url(
|
||||
airtable_url
|
||||
)
|
||||
base_id = parsed_base_id
|
||||
table_name_or_id = parsed_table_id
|
||||
if parsed_view_id:
|
||||
view_id = parsed_view_id
|
||||
|
||||
self.base_id = base_id
|
||||
self.table_name_or_id = table_name_or_id
|
||||
self.index_all = not (base_id and table_name_or_id)
|
||||
self.view_id = view_id
|
||||
self.share_id = share_id
|
||||
self.batch_size = batch_size
|
||||
@@ -103,6 +155,33 @@ class AirtableConnector(LoadConnector):
|
||||
raise AirtableClientNotSetUpError()
|
||||
return self._airtable_client
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self.index_all:
|
||||
try:
|
||||
bases = self.airtable_client.bases()
|
||||
if not bases:
|
||||
raise ConnectorValidationError(
|
||||
"No bases found. Ensure your API token has access to at least one base."
|
||||
)
|
||||
except ConnectorValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise ConnectorValidationError(f"Failed to list Airtable bases: {e}")
|
||||
else:
|
||||
if not self.base_id or not self.table_name_or_id:
|
||||
raise ConnectorValidationError(
|
||||
"A valid Airtable URL or base_id and table_name_or_id are required "
|
||||
"when not using index_all mode."
|
||||
)
|
||||
try:
|
||||
table = self.airtable_client.table(self.base_id, self.table_name_or_id)
|
||||
table.schema()
|
||||
except Exception as e:
|
||||
raise ConnectorValidationError(
|
||||
f"Failed to access table '{self.table_name_or_id}' "
|
||||
f"in base '{self.base_id}': {e}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_record_url(
|
||||
cls,
|
||||
@@ -267,6 +346,7 @@ class AirtableConnector(LoadConnector):
|
||||
field_name: str,
|
||||
field_info: Any,
|
||||
field_type: str,
|
||||
base_id: str,
|
||||
table_id: str,
|
||||
view_id: str | None,
|
||||
record_id: str,
|
||||
@@ -291,7 +371,7 @@ class AirtableConnector(LoadConnector):
|
||||
field_name=field_name,
|
||||
field_info=field_info,
|
||||
field_type=field_type,
|
||||
base_id=self.base_id,
|
||||
base_id=base_id,
|
||||
table_id=table_id,
|
||||
view_id=view_id,
|
||||
record_id=record_id,
|
||||
@@ -326,15 +406,17 @@ class AirtableConnector(LoadConnector):
|
||||
record: RecordDict,
|
||||
table_schema: TableSchema,
|
||||
primary_field_name: str | None,
|
||||
base_id: str,
|
||||
base_name: str | None = None,
|
||||
) -> Document | None:
|
||||
"""Process a single Airtable record into a Document.
|
||||
|
||||
Args:
|
||||
record: The Airtable record to process
|
||||
table_schema: Schema information for the table
|
||||
table_name: Name of the table
|
||||
table_id: ID of the table
|
||||
primary_field_name: Name of the primary field, if any
|
||||
base_id: The ID of the base this record belongs to
|
||||
base_name: The name of the base (used in semantic ID for index_all mode)
|
||||
|
||||
Returns:
|
||||
Document object representing the record
|
||||
@@ -367,6 +449,7 @@ class AirtableConnector(LoadConnector):
|
||||
field_name=field_name,
|
||||
field_info=field_val,
|
||||
field_type=field_type,
|
||||
base_id=base_id,
|
||||
table_id=table_id,
|
||||
view_id=view_id,
|
||||
record_id=record_id,
|
||||
@@ -379,11 +462,26 @@ class AirtableConnector(LoadConnector):
|
||||
logger.warning(f"No sections found for record {record_id}")
|
||||
return None
|
||||
|
||||
semantic_id = (
|
||||
f"{table_name}: {primary_field_value}"
|
||||
if primary_field_value
|
||||
else table_name
|
||||
)
|
||||
# Include base name in semantic ID only in index_all mode
|
||||
if self.index_all and base_name:
|
||||
semantic_id = (
|
||||
f"{base_name} > {table_name}: {primary_field_value}"
|
||||
if primary_field_value
|
||||
else f"{base_name} > {table_name}"
|
||||
)
|
||||
else:
|
||||
semantic_id = (
|
||||
f"{table_name}: {primary_field_value}"
|
||||
if primary_field_value
|
||||
else table_name
|
||||
)
|
||||
|
||||
# Build hierarchy source_path for Craft file system subdirectory structure.
|
||||
# This creates: airtable/{base_name}/{table_name}/record.json
|
||||
source_path: list[str] = []
|
||||
if base_name:
|
||||
source_path.append(base_name)
|
||||
source_path.append(table_name)
|
||||
|
||||
return Document(
|
||||
id=f"airtable__{record_id}",
|
||||
@@ -391,19 +489,39 @@ class AirtableConnector(LoadConnector):
|
||||
source=DocumentSource.AIRTABLE,
|
||||
semantic_identifier=semantic_id,
|
||||
metadata=metadata,
|
||||
doc_metadata={
|
||||
"hierarchy": {
|
||||
"source_path": source_path,
|
||||
"base_id": base_id,
|
||||
"table_id": table_id,
|
||||
"table_name": table_name,
|
||||
**({"base_name": base_name} if base_name else {}),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
"""
|
||||
Fetch all records from the table.
|
||||
def _resolve_base_name(self, base_id: str) -> str | None:
|
||||
"""Try to resolve a human-readable base name from the API."""
|
||||
try:
|
||||
for base_info in self.airtable_client.bases():
|
||||
if base_info.id == base_id:
|
||||
return base_info.name
|
||||
except Exception:
|
||||
logger.debug(f"Could not resolve base name for {base_id}")
|
||||
return None
|
||||
|
||||
NOTE: Airtable does not support filtering by time updated, so
|
||||
we have to fetch all records every time.
|
||||
"""
|
||||
if not self.airtable_client:
|
||||
raise AirtableClientNotSetUpError()
|
||||
def _index_table(
|
||||
self,
|
||||
base_id: str,
|
||||
table_name_or_id: str,
|
||||
base_name: str | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""Index all records from a single table. Yields batches of Documents."""
|
||||
# Resolve base name for hierarchy if not provided
|
||||
if base_name is None:
|
||||
base_name = self._resolve_base_name(base_id)
|
||||
|
||||
table = self.airtable_client.table(self.base_id, self.table_name_or_id)
|
||||
table = self.airtable_client.table(base_id, table_name_or_id)
|
||||
records = table.all()
|
||||
|
||||
table_schema = table.schema()
|
||||
@@ -415,21 +533,25 @@ class AirtableConnector(LoadConnector):
|
||||
primary_field_name = field.name
|
||||
break
|
||||
|
||||
logger.info(f"Starting to process Airtable records for {table.name}.")
|
||||
logger.info(
|
||||
f"Processing {len(records)} records from table "
|
||||
f"'{table_schema.name}' in base '{base_name or base_id}'."
|
||||
)
|
||||
|
||||
if not records:
|
||||
return
|
||||
|
||||
# Process records in parallel batches using ThreadPoolExecutor
|
||||
PARALLEL_BATCH_SIZE = 8
|
||||
max_workers = min(PARALLEL_BATCH_SIZE, len(records))
|
||||
record_documents: list[Document | HierarchyNode] = []
|
||||
|
||||
# Process records in batches
|
||||
for i in range(0, len(records), PARALLEL_BATCH_SIZE):
|
||||
batch_records = records[i : i + PARALLEL_BATCH_SIZE]
|
||||
record_documents = []
|
||||
record_documents: list[Document | HierarchyNode] = []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit batch tasks
|
||||
future_to_record: dict[Future, RecordDict] = {}
|
||||
future_to_record: dict[Future[Document | None], RecordDict] = {}
|
||||
for record in batch_records:
|
||||
# Capture the current context so that the thread gets the current tenant ID
|
||||
current_context = contextvars.copy_context()
|
||||
@@ -440,6 +562,8 @@ class AirtableConnector(LoadConnector):
|
||||
record=record,
|
||||
table_schema=table_schema,
|
||||
primary_field_name=primary_field_name,
|
||||
base_id=base_id,
|
||||
base_name=base_name,
|
||||
)
|
||||
] = record
|
||||
|
||||
@@ -454,9 +578,58 @@ class AirtableConnector(LoadConnector):
|
||||
logger.exception(f"Failed to process record {record['id']}")
|
||||
raise e
|
||||
|
||||
yield record_documents
|
||||
record_documents = []
|
||||
if record_documents:
|
||||
yield record_documents
|
||||
|
||||
# Yield any remaining records
|
||||
if record_documents:
|
||||
yield record_documents
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
"""
|
||||
Fetch all records from one or all tables.
|
||||
|
||||
NOTE: Airtable does not support filtering by time updated, so
|
||||
we have to fetch all records every time.
|
||||
"""
|
||||
if not self.airtable_client:
|
||||
raise AirtableClientNotSetUpError()
|
||||
|
||||
if self.index_all:
|
||||
yield from self._load_all()
|
||||
else:
|
||||
yield from self._index_table(
|
||||
base_id=self.base_id,
|
||||
table_name_or_id=self.table_name_or_id,
|
||||
)
|
||||
|
||||
def _load_all(self) -> GenerateDocumentsOutput:
|
||||
"""Discover all bases and tables, then index everything."""
|
||||
bases = self.airtable_client.bases()
|
||||
logger.info(f"Discovered {len(bases)} Airtable base(s).")
|
||||
|
||||
for base_info in bases:
|
||||
base_id = base_info.id
|
||||
base_name = base_info.name
|
||||
logger.info(f"Listing tables for base '{base_name}' ({base_id}).")
|
||||
|
||||
try:
|
||||
base = self.airtable_client.base(base_id)
|
||||
tables = base.tables()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to list tables for base '{base_name}' ({base_id}), skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(f"Found {len(tables)} table(s) in base '{base_name}'.")
|
||||
|
||||
for table in tables:
|
||||
try:
|
||||
yield from self._index_table(
|
||||
base_id=base_id,
|
||||
table_name_or_id=table.id,
|
||||
base_name=base_name,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to index table '{table.name}' ({table.id}) "
|
||||
f"in base '{base_name}' ({base_id}), skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -79,6 +79,13 @@ SHARED_DOCUMENTS_MAP_REVERSE = {v: k for k, v in SHARED_DOCUMENTS_MAP.items()}
|
||||
|
||||
ASPX_EXTENSION = ".aspx"
|
||||
|
||||
# The office365 library's ClientContext caches the access token from
|
||||
# The office365 library's ClientContext caches the access token from its
|
||||
# first request and never re-invokes the token callback. Microsoft access
|
||||
# tokens live ~60-75 minutes, so we recreate the cached ClientContext every
|
||||
# 30 minutes to let MSAL transparently handle token refresh.
|
||||
_REST_CTX_MAX_AGE_S = 30 * 60
|
||||
|
||||
|
||||
class SiteDescriptor(BaseModel):
|
||||
"""Data class for storing SharePoint site information.
|
||||
@@ -104,30 +111,11 @@ class CertificateData(BaseModel):
|
||||
thumbprint: str
|
||||
|
||||
|
||||
# TODO(Evan): Remove this once we have a proper token refresh mechanism.
|
||||
def _clear_cached_token(query_obj: ClientQuery) -> bool:
|
||||
"""Clear the cached access token on the query object's ClientContext so
|
||||
the next request re-invokes the token callback and gets a fresh token.
|
||||
|
||||
The office365 library's AuthenticationContext.with_access_token() caches
|
||||
the token in ``_cached_token`` and never refreshes it. Setting it to
|
||||
``None`` forces re-acquisition on the next request.
|
||||
|
||||
Returns True if the token was successfully cleared."""
|
||||
ctx = getattr(query_obj, "context", query_obj)
|
||||
auth_ctx = getattr(ctx, "authentication_context", None)
|
||||
if auth_ctx is not None and hasattr(auth_ctx, "_cached_token"):
|
||||
auth_ctx._cached_token = None
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def sleep_and_retry(
|
||||
query_obj: ClientQuery, method_name: str, max_retries: int = 3
|
||||
) -> Any:
|
||||
"""
|
||||
Execute a SharePoint query with retry logic for rate limiting
|
||||
and automatic token refresh on 401 Unauthorized.
|
||||
Execute a SharePoint query with retry logic for rate limiting.
|
||||
"""
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
@@ -135,15 +123,6 @@ def sleep_and_retry(
|
||||
except ClientRequestException as e:
|
||||
status = e.response.status_code if e.response is not None else None
|
||||
|
||||
# 401 — token expired. Clear the cached token and retry immediately.
|
||||
if status == 401 and attempt < max_retries:
|
||||
cleared = _clear_cached_token(query_obj)
|
||||
logger.warning(
|
||||
f"Token expired on {method_name}, attempt {attempt + 1}/{max_retries + 1}, "
|
||||
f"cleared cached token={cleared}, retrying"
|
||||
)
|
||||
continue
|
||||
|
||||
# 429 / 503 — rate limit or transient error. Back off and retry.
|
||||
if status in (429, 503) and attempt < max_retries:
|
||||
logger.warning(
|
||||
@@ -742,6 +721,10 @@ class SharepointConnector(
|
||||
self.include_site_pages = include_site_pages
|
||||
self.include_site_documents = include_site_documents
|
||||
self.sp_tenant_domain: str | None = None
|
||||
self._credential_json: dict[str, Any] | None = None
|
||||
self._cached_rest_ctx: ClientContext | None = None
|
||||
self._cached_rest_ctx_url: str | None = None
|
||||
self._cached_rest_ctx_created_at: float = 0.0
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
# Validate that at least one content type is enabled
|
||||
@@ -767,6 +750,44 @@ class SharepointConnector(
|
||||
|
||||
return self._graph_client
|
||||
|
||||
def _create_rest_client_context(self, site_url: str) -> ClientContext:
|
||||
"""Return a ClientContext for SharePoint REST API calls, with caching.
|
||||
|
||||
The office365 library's ClientContext caches the access token from its
|
||||
first request and never re-invokes the token callback. We cache the
|
||||
context and recreate it when the site URL changes or after
|
||||
``_REST_CTX_MAX_AGE_S``. On recreation we also call
|
||||
``load_credentials`` to build a fresh MSAL app with an empty token
|
||||
cache, guaranteeing a brand-new token from Azure AD."""
|
||||
elapsed = time.monotonic() - self._cached_rest_ctx_created_at
|
||||
if (
|
||||
self._cached_rest_ctx is not None
|
||||
and self._cached_rest_ctx_url == site_url
|
||||
and elapsed <= _REST_CTX_MAX_AGE_S
|
||||
):
|
||||
return self._cached_rest_ctx
|
||||
|
||||
if self._credential_json:
|
||||
logger.info(
|
||||
"Rebuilding SharePoint REST client context "
|
||||
"(elapsed=%.0fs, site_changed=%s)",
|
||||
elapsed,
|
||||
self._cached_rest_ctx_url != site_url,
|
||||
)
|
||||
self.load_credentials(self._credential_json)
|
||||
|
||||
if not self.msal_app or not self.sp_tenant_domain:
|
||||
raise RuntimeError("MSAL app or tenant domain is not set")
|
||||
|
||||
msal_app = self.msal_app
|
||||
sp_tenant_domain = self.sp_tenant_domain
|
||||
self._cached_rest_ctx = ClientContext(site_url).with_access_token(
|
||||
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
|
||||
)
|
||||
self._cached_rest_ctx_url = site_url
|
||||
self._cached_rest_ctx_created_at = time.monotonic()
|
||||
return self._cached_rest_ctx
|
||||
|
||||
@staticmethod
|
||||
def _strip_share_link_tokens(path: str) -> list[str]:
|
||||
# Share links often include a token prefix like /:f:/r/ or /:x:/r/.
|
||||
@@ -1206,21 +1227,6 @@ class SharepointConnector(
|
||||
# goes over all urls, converts them into SlimDocument objects and then yields them in batches
|
||||
doc_batch: list[SlimDocument | HierarchyNode] = []
|
||||
for site_descriptor in site_descriptors:
|
||||
ctx: ClientContext | None = None
|
||||
|
||||
if self.msal_app and self.sp_tenant_domain:
|
||||
msal_app = self.msal_app
|
||||
sp_tenant_domain = self.sp_tenant_domain
|
||||
ctx = ClientContext(site_descriptor.url).with_access_token(
|
||||
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("MSAL app or tenant domain is not set")
|
||||
|
||||
if ctx is None:
|
||||
logger.warning("ClientContext is not set, skipping permissions")
|
||||
continue
|
||||
|
||||
site_url = site_descriptor.url
|
||||
|
||||
# Yield site hierarchy node using helper
|
||||
@@ -1259,6 +1265,7 @@ class SharepointConnector(
|
||||
|
||||
try:
|
||||
logger.debug(f"Processing: {driveitem.web_url}")
|
||||
ctx = self._create_rest_client_context(site_descriptor.url)
|
||||
doc_batch.append(
|
||||
_convert_driveitem_to_slim_document(
|
||||
driveitem, drive_name, ctx, self.graph_client
|
||||
@@ -1278,6 +1285,7 @@ class SharepointConnector(
|
||||
logger.debug(
|
||||
f"Processing site page: {site_page.get('webUrl', site_page.get('name', 'Unknown'))}"
|
||||
)
|
||||
ctx = self._create_rest_client_context(site_descriptor.url)
|
||||
doc_batch.append(
|
||||
_convert_sitepage_to_slim_document(
|
||||
site_page, ctx, self.graph_client
|
||||
@@ -1289,6 +1297,7 @@ class SharepointConnector(
|
||||
yield doc_batch
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self._credential_json = credentials
|
||||
auth_method = credentials.get(
|
||||
"authentication_method", SharepointAuthMethod.CLIENT_SECRET.value
|
||||
)
|
||||
@@ -1705,17 +1714,6 @@ class SharepointConnector(
|
||||
)
|
||||
logger.debug(f"Time range: {start_dt} to {end_dt}")
|
||||
|
||||
ctx: ClientContext | None = None
|
||||
if include_permissions:
|
||||
if self.msal_app and self.sp_tenant_domain:
|
||||
msal_app = self.msal_app
|
||||
sp_tenant_domain = self.sp_tenant_domain
|
||||
ctx = ClientContext(site_descriptor.url).with_access_token(
|
||||
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("MSAL app or tenant domain is not set")
|
||||
|
||||
# At this point current_drive_name should be set from popleft()
|
||||
current_drive_name = checkpoint.current_drive_name
|
||||
if current_drive_name is None:
|
||||
@@ -1810,6 +1808,10 @@ class SharepointConnector(
|
||||
)
|
||||
|
||||
try:
|
||||
ctx: ClientContext | None = None
|
||||
if include_permissions:
|
||||
ctx = self._create_rest_client_context(site_descriptor.url)
|
||||
|
||||
doc = _convert_driveitem_to_document_with_permissions(
|
||||
driveitem,
|
||||
current_drive_name,
|
||||
@@ -1875,20 +1877,13 @@ class SharepointConnector(
|
||||
site_pages = self._fetch_site_pages(
|
||||
site_descriptor, start=start_dt, end=end_dt
|
||||
)
|
||||
client_ctx: ClientContext | None = None
|
||||
if include_permissions:
|
||||
if self.msal_app and self.sp_tenant_domain:
|
||||
msal_app = self.msal_app
|
||||
sp_tenant_domain = self.sp_tenant_domain
|
||||
client_ctx = ClientContext(site_descriptor.url).with_access_token(
|
||||
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("MSAL app or tenant domain is not set")
|
||||
for site_page in site_pages:
|
||||
logger.debug(
|
||||
f"Processing site page: {site_page.get('webUrl', site_page.get('name', 'Unknown'))}"
|
||||
)
|
||||
client_ctx: ClientContext | None = None
|
||||
if include_permissions:
|
||||
client_ctx = self._create_rest_client_context(site_descriptor.url)
|
||||
yield (
|
||||
_convert_sitepage_to_document(
|
||||
site_page,
|
||||
|
||||
@@ -6,7 +6,6 @@ from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import SearchSettings
|
||||
@@ -97,21 +96,6 @@ class IndexFilters(BaseFilters, UserFileFilters, AssistantKnowledgeFilters):
|
||||
tenant_id: str | None = None
|
||||
|
||||
|
||||
class ChunkContext(BaseModel):
|
||||
# If not specified (None), picked up from Persona settings if there is space
|
||||
# if specified (even if 0), it always uses the specified number of chunks above and below
|
||||
chunks_above: int | None = None
|
||||
chunks_below: int | None = None
|
||||
full_doc: bool = False
|
||||
|
||||
@field_validator("chunks_above", "chunks_below")
|
||||
@classmethod
|
||||
def check_non_negative(cls, value: int, field: Any) -> int:
|
||||
if value is not None and value < 0:
|
||||
raise ValueError(f"{field.name} must be non-negative")
|
||||
return value
|
||||
|
||||
|
||||
class BasicChunkRequest(BaseModel):
|
||||
query: str
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ from sqlalchemy.exc import MultipleResultsFound
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.models import DocumentRelevance
|
||||
from onyx.configs.chat_configs import HARD_DELETE_CHATS
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import InferenceSection
|
||||
@@ -672,27 +671,6 @@ def set_as_latest_chat_message(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_search_docs_table_with_relevance(
|
||||
db_session: Session,
|
||||
reference_db_search_docs: list[DBSearchDoc],
|
||||
relevance_summary: DocumentRelevance,
|
||||
) -> None:
|
||||
for search_doc in reference_db_search_docs:
|
||||
relevance_data = relevance_summary.relevance_summaries.get(
|
||||
search_doc.document_id
|
||||
)
|
||||
if relevance_data is not None:
|
||||
db_session.execute(
|
||||
update(DBSearchDoc)
|
||||
.where(DBSearchDoc.id == search_doc.id)
|
||||
.values(
|
||||
is_relevant=relevance_data.relevant,
|
||||
relevance_explanation=relevance_data.content,
|
||||
)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _sanitize_for_postgres(value: str) -> str:
|
||||
"""Remove NUL (0x00) characters from strings as PostgreSQL doesn't allow them."""
|
||||
sanitized = value.replace("\x00", "")
|
||||
|
||||
@@ -296,4 +296,4 @@ class HierarchyNodeType(str, PyEnum):
|
||||
class LLMModelFlowType(str, PyEnum):
|
||||
CHAT = "chat"
|
||||
VISION = "vision"
|
||||
EMBEDDINGS = "embeddings"
|
||||
CONTEXTUAL_RAG = "contextual_rag"
|
||||
|
||||
@@ -509,6 +509,12 @@ def fetch_default_vision_model(db_session: Session) -> ModelConfiguration | None
|
||||
return fetch_default_model(db_session, LLMModelFlowType.VISION)
|
||||
|
||||
|
||||
def fetch_default_contextual_rag_model(
|
||||
db_session: Session,
|
||||
) -> ModelConfiguration | None:
|
||||
return fetch_default_model(db_session, LLMModelFlowType.CONTEXTUAL_RAG)
|
||||
|
||||
|
||||
def fetch_default_model(
|
||||
db_session: Session,
|
||||
flow_type: LLMModelFlowType,
|
||||
@@ -646,6 +652,73 @@ def update_default_vision_provider(
|
||||
)
|
||||
|
||||
|
||||
def update_no_default_contextual_rag_provider(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
db_session.execute(
|
||||
update(LLMModelFlow)
|
||||
.where(
|
||||
LLMModelFlow.llm_model_flow_type == LLMModelFlowType.CONTEXTUAL_RAG,
|
||||
LLMModelFlow.is_default == True, # noqa: E712
|
||||
)
|
||||
.values(is_default=False)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_default_contextual_model(
|
||||
db_session: Session,
|
||||
enable_contextual_rag: bool,
|
||||
contextual_rag_llm_provider: str | None,
|
||||
contextual_rag_llm_name: str | None,
|
||||
) -> None:
|
||||
"""Sets or clears the default contextual RAG model.
|
||||
|
||||
Should be called whenever the PRESENT search settings change
|
||||
(e.g. inline update or FUTURE → PRESENT swap).
|
||||
"""
|
||||
if (
|
||||
not enable_contextual_rag
|
||||
or not contextual_rag_llm_name
|
||||
or not contextual_rag_llm_provider
|
||||
):
|
||||
update_no_default_contextual_rag_provider(db_session=db_session)
|
||||
return
|
||||
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=contextual_rag_llm_provider, db_session=db_session
|
||||
)
|
||||
if not provider:
|
||||
raise ValueError(f"Provider '{contextual_rag_llm_provider}' not found")
|
||||
|
||||
model_config = next(
|
||||
(
|
||||
mc
|
||||
for mc in provider.model_configurations
|
||||
if mc.name == contextual_rag_llm_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not model_config:
|
||||
raise ValueError(
|
||||
f"Model '{contextual_rag_llm_name}' not found for provider '{contextual_rag_llm_provider}'"
|
||||
)
|
||||
|
||||
add_model_to_flow(
|
||||
db_session=db_session,
|
||||
model_configuration_id=model_config.id,
|
||||
flow_type=LLMModelFlowType.CONTEXTUAL_RAG,
|
||||
)
|
||||
_update_default_model(
|
||||
db_session=db_session,
|
||||
provider_id=provider.id,
|
||||
model=contextual_rag_llm_name,
|
||||
flow_type=LLMModelFlowType.CONTEXTUAL_RAG,
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def fetch_auto_mode_providers(db_session: Session) -> list[LLMProviderModel]:
|
||||
"""Fetch all LLM providers that are in Auto mode."""
|
||||
query = (
|
||||
@@ -760,9 +833,18 @@ def create_new_flow_mapping__no_commit(
|
||||
)
|
||||
|
||||
flow = result.scalar()
|
||||
if not flow:
|
||||
# Row already exists — fetch it
|
||||
flow = db_session.scalar(
|
||||
select(LLMModelFlow).where(
|
||||
LLMModelFlow.model_configuration_id == model_configuration_id,
|
||||
LLMModelFlow.llm_model_flow_type == flow_type,
|
||||
)
|
||||
)
|
||||
if not flow:
|
||||
raise ValueError(
|
||||
f"Failed to create new flow mapping for model_configuration_id={model_configuration_id} and flow_type={flow_type}"
|
||||
f"Failed to create or find flow mapping for "
|
||||
f"model_configuration_id={model_configuration_id} and flow_type={flow_type}"
|
||||
)
|
||||
|
||||
return flow
|
||||
@@ -900,3 +982,18 @@ def _update_default_model(
|
||||
model_config.is_visible = True
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def add_model_to_flow(
|
||||
db_session: Session,
|
||||
model_configuration_id: int,
|
||||
flow_type: LLMModelFlowType,
|
||||
) -> None:
|
||||
# Function does nothing on conflict
|
||||
create_new_flow_mapping__no_commit(
|
||||
db_session=db_session,
|
||||
model_configuration_id=model_configuration_id,
|
||||
flow_type=flow_type,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
@@ -7,10 +7,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import Memory
|
||||
from onyx.db.models import User
|
||||
from onyx.prompts.user_info import BASIC_INFORMATION_PROMPT
|
||||
from onyx.prompts.user_info import USER_MEMORIES_PROMPT
|
||||
from onyx.prompts.user_info import USER_PREFERENCES_PROMPT
|
||||
from onyx.prompts.user_info import USER_ROLE_PROMPT
|
||||
|
||||
MAX_MEMORIES_PER_USER = 10
|
||||
|
||||
@@ -36,6 +32,15 @@ class UserMemoryContext(BaseModel):
|
||||
user_preferences: str | None = None
|
||||
memories: tuple[str, ...] = ()
|
||||
|
||||
def without_memories(self) -> "UserMemoryContext":
|
||||
"""Return a copy with memories cleared but user info/preferences intact."""
|
||||
return UserMemoryContext(
|
||||
user_id=self.user_id,
|
||||
user_info=self.user_info,
|
||||
user_preferences=self.user_preferences,
|
||||
memories=(),
|
||||
)
|
||||
|
||||
def as_formatted_list(self) -> list[str]:
|
||||
"""Returns combined list of user info, preferences, and memories."""
|
||||
result = []
|
||||
@@ -50,45 +55,6 @@ class UserMemoryContext(BaseModel):
|
||||
result.extend(self.memories)
|
||||
return result
|
||||
|
||||
def as_formatted_prompt(self) -> str:
|
||||
"""Returns structured prompt sections for the system prompt."""
|
||||
has_basic_info = (
|
||||
self.user_info.name or self.user_info.email or self.user_info.role
|
||||
)
|
||||
if not has_basic_info and not self.user_preferences and not self.memories:
|
||||
return ""
|
||||
|
||||
sections: list[str] = []
|
||||
|
||||
if has_basic_info:
|
||||
role_line = (
|
||||
USER_ROLE_PROMPT.format(user_role=self.user_info.role).strip()
|
||||
if self.user_info.role
|
||||
else ""
|
||||
)
|
||||
if role_line:
|
||||
role_line = "\n" + role_line
|
||||
sections.append(
|
||||
BASIC_INFORMATION_PROMPT.format(
|
||||
user_name=self.user_info.name or "",
|
||||
user_email=self.user_info.email or "",
|
||||
user_role=role_line,
|
||||
)
|
||||
)
|
||||
|
||||
if self.user_preferences:
|
||||
sections.append(
|
||||
USER_PREFERENCES_PROMPT.format(user_preferences=self.user_preferences)
|
||||
)
|
||||
|
||||
if self.memories:
|
||||
formatted_memories = "\n".join(f"- {memory}" for memory in self.memories)
|
||||
sections.append(
|
||||
USER_MEMORIES_PROMPT.format(user_memories=formatted_memories)
|
||||
)
|
||||
|
||||
return "".join(sections)
|
||||
|
||||
|
||||
def get_memories(user: User, db_session: Session) -> UserMemoryContext:
|
||||
user_info = UserInfo(
|
||||
|
||||
@@ -4877,3 +4877,90 @@ class BuildMessage(Base):
|
||||
"ix_build_message_session_turn", "session_id", "turn_index", "created_at"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
SCIM 2.0 Provisioning Models (Enterprise Edition only)
|
||||
Used for automated user/group provisioning from identity providers (Okta, Azure AD).
|
||||
"""
|
||||
|
||||
|
||||
class ScimToken(Base):
|
||||
"""Bearer tokens for IdP SCIM authentication."""
|
||||
|
||||
__tablename__ = "scim_token"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, nullable=False)
|
||||
hashed_token: Mapped[str] = mapped_column(
|
||||
String(64), unique=True, nullable=False
|
||||
) # SHA256 = 64 hex chars
|
||||
token_display: Mapped[str] = mapped_column(
|
||||
String, nullable=False
|
||||
) # Last 4 chars for UI identification
|
||||
|
||||
created_by_id: Mapped[UUID] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
is_active: Mapped[bool] = mapped_column(
|
||||
Boolean, server_default=text("true"), nullable=False
|
||||
)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
last_used_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
created_by: Mapped[User] = relationship("User", foreign_keys=[created_by_id])
|
||||
|
||||
|
||||
class ScimUserMapping(Base):
|
||||
"""Maps SCIM externalId from the IdP to an Onyx User."""
|
||||
|
||||
__tablename__ = "scim_user_mapping"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
external_id: Mapped[str] = mapped_column(String, unique=True, index=True)
|
||||
user_id: Mapped[UUID] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), unique=True, nullable=False
|
||||
)
|
||||
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
updated_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
user: Mapped[User] = relationship("User", foreign_keys=[user_id])
|
||||
|
||||
|
||||
class ScimGroupMapping(Base):
|
||||
"""Maps SCIM externalId from the IdP to an Onyx UserGroup."""
|
||||
|
||||
__tablename__ = "scim_group_mapping"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
external_id: Mapped[str] = mapped_column(String, unique=True, index=True)
|
||||
user_group_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("user_group.id", ondelete="CASCADE"), unique=True, nullable=False
|
||||
)
|
||||
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
updated_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
user_group: Mapped[UserGroup] = relationship(
|
||||
"UserGroup", foreign_keys=[user_group_id]
|
||||
)
|
||||
|
||||
@@ -15,6 +15,8 @@ from onyx.db.index_attempt import (
|
||||
count_unique_active_cc_pairs_with_successful_index_attempts,
|
||||
)
|
||||
from onyx.db.index_attempt import count_unique_cc_pairs_with_successful_index_attempts
|
||||
from onyx.db.llm import update_default_contextual_model
|
||||
from onyx.db.llm import update_no_default_contextual_rag_provider
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
@@ -80,6 +82,24 @@ def _perform_index_swap(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Update the default contextual model to match the newly promoted settings
|
||||
try:
|
||||
update_default_contextual_model(
|
||||
db_session=db_session,
|
||||
enable_contextual_rag=new_search_settings.enable_contextual_rag,
|
||||
contextual_rag_llm_provider=new_search_settings.contextual_rag_llm_provider,
|
||||
contextual_rag_llm_name=new_search_settings.contextual_rag_llm_name,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Model not found, defaulting to no contextual model: {e}")
|
||||
update_no_default_contextual_rag_provider(
|
||||
db_session=db_session,
|
||||
)
|
||||
new_search_settings.enable_contextual_rag = False
|
||||
new_search_settings.contextual_rag_llm_provider = None
|
||||
new_search_settings.contextual_rag_llm_name = None
|
||||
db_session.commit()
|
||||
|
||||
# This flow is for checking and possibly creating an index so we get all
|
||||
# indices.
|
||||
document_indices = get_all_document_indices(new_search_settings, None, None)
|
||||
|
||||
@@ -4,23 +4,21 @@ from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm.session import SessionTransaction
|
||||
|
||||
from onyx.chat.models import MessageResponseIDInfo
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.process_message import AnswerStream
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.models import ChatFullResponse
|
||||
from onyx.chat.process_message import gather_stream_full
|
||||
from onyx.chat.process_message import handle_stream_message_objects
|
||||
from onyx.chat.process_message import remove_answer_citations
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.evals.models import ChatFullEvalResult
|
||||
from onyx.evals.models import EvalationAck
|
||||
from onyx.evals.models import EvalConfigurationOptions
|
||||
from onyx.evals.models import EvalMessage
|
||||
@@ -33,18 +31,7 @@ from onyx.evals.provider import get_provider
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
|
||||
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.models import RetrievalDetails
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import OpenUrlStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PythonToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
@@ -87,193 +74,29 @@ def isolated_ephemeral_session_factory(
|
||||
conn.close()
|
||||
|
||||
|
||||
class GatherStreamResult(BaseModel):
|
||||
"""Result of gathering a stream with tool call information."""
|
||||
|
||||
answer: str
|
||||
answer_citationless: str
|
||||
tools_called: list[str]
|
||||
tool_call_details: list[dict[str, Any]]
|
||||
message_id: int
|
||||
error_msg: str | None = None
|
||||
citations: list[CitationInfo] = []
|
||||
timings: EvalTimings | None = None
|
||||
|
||||
|
||||
def gather_stream_with_tools(packets: AnswerStream) -> GatherStreamResult:
|
||||
"""
|
||||
Gather streaming packets and extract both answer content and tool call information.
|
||||
|
||||
Returns a GatherStreamResult containing the answer and all tools that were called.
|
||||
"""
|
||||
stream_start_time = time.time()
|
||||
|
||||
answer: str | None = None
|
||||
citations: list[CitationInfo] = []
|
||||
error_msg: str | None = None
|
||||
message_id: int | None = None
|
||||
tools_called: list[str] = []
|
||||
tool_call_details: list[dict[str, Any]] = []
|
||||
|
||||
# Timing tracking
|
||||
first_token_time: float | None = None
|
||||
tool_start_times: dict[str, float] = {} # tool_name -> start time
|
||||
tool_execution_ms: dict[str, float] = {} # tool_name -> duration in ms
|
||||
current_tool: str | None = None
|
||||
|
||||
def _finalize_tool_timing(tool_name: str) -> None:
|
||||
"""Record the duration for a tool that just finished."""
|
||||
if tool_name in tool_start_times:
|
||||
duration_ms = (time.time() - tool_start_times[tool_name]) * 1000
|
||||
tool_execution_ms[tool_name] = duration_ms
|
||||
|
||||
for packet in packets:
|
||||
if isinstance(packet, Packet):
|
||||
obj = packet.obj
|
||||
|
||||
# Handle answer content
|
||||
if isinstance(obj, AgentResponseStart):
|
||||
# When answer starts, finalize any in-progress tool
|
||||
if current_tool:
|
||||
_finalize_tool_timing(current_tool)
|
||||
current_tool = None
|
||||
elif isinstance(obj, AgentResponseDelta):
|
||||
if answer is None:
|
||||
answer = ""
|
||||
first_token_time = time.time()
|
||||
if obj.content:
|
||||
answer += obj.content
|
||||
elif isinstance(obj, CitationInfo):
|
||||
citations.append(obj)
|
||||
|
||||
# Track tool calls with timing
|
||||
elif isinstance(obj, SearchToolStart):
|
||||
# Finalize any previous tool
|
||||
if current_tool:
|
||||
_finalize_tool_timing(current_tool)
|
||||
|
||||
tool_name = "WebSearchTool" if obj.is_internet_search else "SearchTool"
|
||||
current_tool = tool_name
|
||||
tool_start_times[tool_name] = time.time()
|
||||
tools_called.append(tool_name)
|
||||
tool_call_details.append(
|
||||
{
|
||||
"tool_name": tool_name,
|
||||
"tool_type": "search",
|
||||
"is_internet_search": obj.is_internet_search,
|
||||
}
|
||||
)
|
||||
elif isinstance(obj, ImageGenerationToolStart):
|
||||
if current_tool:
|
||||
_finalize_tool_timing(current_tool)
|
||||
|
||||
tool_name = "ImageGenerationTool"
|
||||
current_tool = tool_name
|
||||
tool_start_times[tool_name] = time.time()
|
||||
tools_called.append(tool_name)
|
||||
tool_call_details.append(
|
||||
{
|
||||
"tool_name": tool_name,
|
||||
"tool_type": "image_generation",
|
||||
}
|
||||
)
|
||||
elif isinstance(obj, PythonToolStart):
|
||||
if current_tool:
|
||||
_finalize_tool_timing(current_tool)
|
||||
|
||||
tool_name = "PythonTool"
|
||||
current_tool = tool_name
|
||||
tool_start_times[tool_name] = time.time()
|
||||
tools_called.append(tool_name)
|
||||
tool_call_details.append(
|
||||
{
|
||||
"tool_name": tool_name,
|
||||
"tool_type": "python",
|
||||
"code": obj.code,
|
||||
}
|
||||
)
|
||||
elif isinstance(obj, OpenUrlStart):
|
||||
if current_tool:
|
||||
_finalize_tool_timing(current_tool)
|
||||
|
||||
tool_name = "OpenURLTool"
|
||||
current_tool = tool_name
|
||||
tool_start_times[tool_name] = time.time()
|
||||
tools_called.append(tool_name)
|
||||
tool_call_details.append(
|
||||
{
|
||||
"tool_name": tool_name,
|
||||
"tool_type": "open_url",
|
||||
}
|
||||
)
|
||||
elif isinstance(obj, CustomToolStart):
|
||||
if current_tool:
|
||||
_finalize_tool_timing(current_tool)
|
||||
|
||||
tool_name = obj.tool_name
|
||||
current_tool = tool_name
|
||||
tool_start_times[tool_name] = time.time()
|
||||
tools_called.append(tool_name)
|
||||
tool_call_details.append(
|
||||
{
|
||||
"tool_name": tool_name,
|
||||
"tool_type": "custom",
|
||||
}
|
||||
)
|
||||
|
||||
elif isinstance(packet, StreamingError):
|
||||
logger.warning(f"Streaming error during eval: {packet.error}")
|
||||
error_msg = packet.error
|
||||
elif isinstance(packet, MessageResponseIDInfo):
|
||||
message_id = packet.reserved_assistant_message_id
|
||||
|
||||
# Finalize any remaining tool timing
|
||||
if current_tool:
|
||||
_finalize_tool_timing(current_tool)
|
||||
|
||||
def _chat_full_response_to_eval_result(
|
||||
full: ChatFullResponse,
|
||||
stream_start_time: float,
|
||||
) -> ChatFullEvalResult:
|
||||
"""Map ChatFullResponse from gather_stream_full to eval result components."""
|
||||
tools_called = [tc.tool_name for tc in full.tool_calls]
|
||||
tool_call_details: list[dict[str, Any]] = [
|
||||
{"tool_name": tc.tool_name, "tool_arguments": tc.tool_arguments}
|
||||
for tc in full.tool_calls
|
||||
]
|
||||
stream_end_time = time.time()
|
||||
|
||||
if message_id is None:
|
||||
# If we got a streaming error, include it in the exception
|
||||
if error_msg:
|
||||
raise ValueError(f"Message ID is required. Stream error: {error_msg}")
|
||||
raise ValueError(
|
||||
f"Message ID is required. No MessageResponseIDInfo received. "
|
||||
f"Tools called: {tools_called}"
|
||||
)
|
||||
|
||||
# Allow empty answers for tool-only turns (e.g., in multi-turn evals)
|
||||
# Some turns may only execute tools without generating a text response
|
||||
if answer is None:
|
||||
logger.warning(
|
||||
"No answer content generated. Tools called: %s. "
|
||||
"This may be expected for tool-only turns.",
|
||||
tools_called,
|
||||
)
|
||||
answer = ""
|
||||
|
||||
# Calculate timings
|
||||
total_ms = (stream_end_time - stream_start_time) * 1000
|
||||
first_token_ms = (
|
||||
(first_token_time - stream_start_time) * 1000 if first_token_time else None
|
||||
)
|
||||
stream_processing_ms = (stream_end_time - stream_start_time) * 1000
|
||||
|
||||
timings = EvalTimings(
|
||||
total_ms=total_ms,
|
||||
llm_first_token_ms=first_token_ms,
|
||||
tool_execution_ms=tool_execution_ms,
|
||||
stream_processing_ms=stream_processing_ms,
|
||||
llm_first_token_ms=None,
|
||||
tool_execution_ms={},
|
||||
stream_processing_ms=total_ms,
|
||||
)
|
||||
|
||||
return GatherStreamResult(
|
||||
answer=answer,
|
||||
answer_citationless=remove_answer_citations(answer),
|
||||
return ChatFullEvalResult(
|
||||
answer=full.answer,
|
||||
tools_called=tools_called,
|
||||
tool_call_details=tool_call_details,
|
||||
message_id=message_id,
|
||||
error_msg=error_msg,
|
||||
citations=citations,
|
||||
citations=full.citation_info,
|
||||
timings=timings,
|
||||
)
|
||||
|
||||
@@ -413,14 +236,17 @@ def _get_answer_with_tools(
|
||||
),
|
||||
)
|
||||
|
||||
stream_start_time = time.time()
|
||||
state_container = ChatStateContainer()
|
||||
packets = handle_stream_message_objects(
|
||||
new_msg_req=request,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
external_state_container=state_container,
|
||||
)
|
||||
full = gather_stream_full(packets, state_container)
|
||||
|
||||
# Gather stream with tool call tracking
|
||||
result = gather_stream_with_tools(packets)
|
||||
result = _chat_full_response_to_eval_result(full, stream_start_time)
|
||||
|
||||
# Evaluate tool assertions
|
||||
assertion_passed, assertion_details = evaluate_tool_assertions(
|
||||
@@ -551,30 +377,30 @@ def _get_multi_turn_answer_with_tools(
|
||||
),
|
||||
)
|
||||
|
||||
# Create request for this turn
|
||||
# Create request for this turn using SendMessageRequest (same API as handle_stream_message_objects)
|
||||
# Use AUTO_PLACE_AFTER_LATEST_MESSAGE to chain messages
|
||||
request = CreateChatMessageRequest(
|
||||
forced_tool_id = forced_tool_ids[0] if forced_tool_ids else None
|
||||
request = SendMessageRequest(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=AUTO_PLACE_AFTER_LATEST_MESSAGE,
|
||||
message=msg.message,
|
||||
file_descriptors=[],
|
||||
search_doc_ids=None,
|
||||
retrieval_options=RetrievalDetails(),
|
||||
llm_override=llm_override,
|
||||
persona_override_config=full_configuration.persona_override_config,
|
||||
skip_gen_ai_answer_generation=False,
|
||||
allowed_tool_ids=full_configuration.allowed_tool_ids,
|
||||
forced_tool_ids=forced_tool_ids or None,
|
||||
forced_tool_id=forced_tool_id,
|
||||
)
|
||||
|
||||
# Stream and gather results for this turn
|
||||
packets = stream_chat_message_objects(
|
||||
# Stream and gather results for this turn via handle_stream_message_objects + gather_stream_full
|
||||
stream_start_time = time.time()
|
||||
state_container = ChatStateContainer()
|
||||
packets = handle_stream_message_objects(
|
||||
new_msg_req=request,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
external_state_container=state_container,
|
||||
)
|
||||
full = gather_stream_full(packets, state_container)
|
||||
|
||||
result = gather_stream_with_tools(packets)
|
||||
result = _chat_full_response_to_eval_result(full, stream_start_time)
|
||||
|
||||
# Evaluate tool assertions for this turn
|
||||
assertion_passed, assertion_details = evaluate_tool_assertions(
|
||||
|
||||
@@ -7,9 +7,6 @@ from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import PromptOverrideConfig
|
||||
from onyx.chat.models import ToolConfig
|
||||
from onyx.db.tools import get_builtin_tool
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
@@ -34,6 +31,16 @@ class EvalTimings(BaseModel):
|
||||
stream_processing_ms: float | None = None # Time to process the stream
|
||||
|
||||
|
||||
class ChatFullEvalResult(BaseModel):
|
||||
"""Raw eval components from ChatFullResponse (before tool assertions)."""
|
||||
|
||||
answer: str
|
||||
tools_called: list[str]
|
||||
tool_call_details: list[dict[str, Any]]
|
||||
citations: list[CitationInfo]
|
||||
timings: EvalTimings
|
||||
|
||||
|
||||
class EvalToolResult(BaseModel):
|
||||
"""Result of a single eval with tool call information."""
|
||||
|
||||
@@ -72,8 +79,6 @@ class MultiTurnEvalResult(BaseModel):
|
||||
|
||||
|
||||
class EvalConfiguration(BaseModel):
|
||||
builtin_tool_types: list[str] = Field(default_factory=list)
|
||||
persona_override_config: PersonaOverrideConfig | None = None
|
||||
llm: LLMOverride = Field(default_factory=LLMOverride)
|
||||
search_permissions_email: str
|
||||
allowed_tool_ids: list[int]
|
||||
@@ -81,7 +86,6 @@ class EvalConfiguration(BaseModel):
|
||||
|
||||
class EvalConfigurationOptions(BaseModel):
|
||||
builtin_tool_types: list[str] = list(BUILT_IN_TOOL_MAP.keys())
|
||||
persona_override_config: PersonaOverrideConfig | None = None
|
||||
llm: LLMOverride = LLMOverride(
|
||||
model_provider=None,
|
||||
model_version="gpt-4o",
|
||||
@@ -96,26 +100,7 @@ class EvalConfigurationOptions(BaseModel):
|
||||
experiment_name: str | None = None
|
||||
|
||||
def get_configuration(self, db_session: Session) -> EvalConfiguration:
|
||||
persona_override_config = self.persona_override_config or PersonaOverrideConfig(
|
||||
name="Eval",
|
||||
description="A persona for evaluation",
|
||||
tools=[
|
||||
ToolConfig(id=get_builtin_tool(db_session, BUILT_IN_TOOL_MAP[tool]).id)
|
||||
for tool in self.builtin_tool_types
|
||||
],
|
||||
prompts=[
|
||||
PromptOverrideConfig(
|
||||
name="Default",
|
||||
description="Default prompt for evaluation",
|
||||
system_prompt="You are a helpful assistant.",
|
||||
task_prompt="",
|
||||
datetime_aware=True,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
return EvalConfiguration(
|
||||
persona_override_config=persona_override_config,
|
||||
llm=self.llm,
|
||||
search_permissions_email=self.search_permissions_email,
|
||||
allowed_tool_ids=[
|
||||
|
||||
@@ -2,7 +2,6 @@ from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
@@ -77,7 +76,7 @@ def _build_model_kwargs(
|
||||
|
||||
|
||||
def get_llm_for_persona(
|
||||
persona: Persona | PersonaOverrideConfig | None,
|
||||
persona: Persona | None,
|
||||
user: User,
|
||||
llm_override: LLMOverride | None = None,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
@@ -102,20 +101,16 @@ def get_llm_for_persona(
|
||||
if not provider_model:
|
||||
raise ValueError("No LLM provider found")
|
||||
|
||||
# Only check access control for database Persona entities, not PersonaOverrideConfig
|
||||
# PersonaOverrideConfig is used for temporary overrides and doesn't have access restrictions
|
||||
persona_model = persona if isinstance(persona, Persona) else None
|
||||
|
||||
# Fetch user group IDs for access control check
|
||||
user_group_ids = fetch_user_group_ids(db_session, user)
|
||||
|
||||
if not can_user_access_llm_provider(
|
||||
provider_model, user_group_ids, persona_model, user.role == UserRole.ADMIN
|
||||
provider_model, user_group_ids, persona, user.role == UserRole.ADMIN
|
||||
):
|
||||
logger.warning(
|
||||
"User %s with persona %s cannot access provider %s. Falling back to default provider.",
|
||||
user.id,
|
||||
getattr(persona_model, "id", None),
|
||||
persona.id,
|
||||
provider_model.name,
|
||||
)
|
||||
return get_default_llm(
|
||||
|
||||
@@ -92,7 +92,7 @@ class CacheableMessage(BaseModel):
|
||||
|
||||
class SystemMessage(CacheableMessage):
|
||||
role: Literal["system"] = "system"
|
||||
content: str | list[ContentPart]
|
||||
content: str
|
||||
|
||||
|
||||
class UserMessage(CacheableMessage):
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
import os
|
||||
import threading
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from contextlib import nullcontext
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -49,6 +53,8 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_env_lock = threading.Lock()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import CustomStreamWrapper
|
||||
from litellm import HTTPHandler
|
||||
@@ -378,23 +384,30 @@ class LitellmLLM(LLM):
|
||||
if "api_key" not in passthrough_kwargs:
|
||||
passthrough_kwargs["api_key"] = self._api_key or None
|
||||
|
||||
response = litellm.completion(
|
||||
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
|
||||
model=model,
|
||||
base_url=self._api_base or None,
|
||||
api_version=self._api_version or None,
|
||||
custom_llm_provider=self._custom_llm_provider or None,
|
||||
messages=_prompt_to_dicts(prompt),
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
timeout=timeout_override or self._timeout,
|
||||
max_tokens=max_tokens,
|
||||
client=client,
|
||||
**optional_kwargs,
|
||||
**passthrough_kwargs,
|
||||
# We only need to set environment variables if custom config is set
|
||||
env_ctx = (
|
||||
temporary_env_and_lock(self._custom_config)
|
||||
if self._custom_config
|
||||
else nullcontext()
|
||||
)
|
||||
with env_ctx:
|
||||
response = litellm.completion(
|
||||
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
|
||||
model=model,
|
||||
base_url=self._api_base or None,
|
||||
api_version=self._api_version or None,
|
||||
custom_llm_provider=self._custom_llm_provider or None,
|
||||
messages=_prompt_to_dicts(prompt),
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
timeout=timeout_override or self._timeout,
|
||||
max_tokens=max_tokens,
|
||||
client=client,
|
||||
**optional_kwargs,
|
||||
**passthrough_kwargs,
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
# for break pointing
|
||||
@@ -475,13 +488,21 @@ class LitellmLLM(LLM):
|
||||
client = HTTPHandler(timeout=timeout_override or self._timeout)
|
||||
|
||||
try:
|
||||
response = cast(
|
||||
LiteLLMModelResponse,
|
||||
# When custom_config is set, env vars are temporarily injected
|
||||
# under a global lock. Using stream=True here means the lock is
|
||||
# only held during connection setup (not the full inference).
|
||||
# The chunks are then collected outside the lock and reassembled
|
||||
# into a single ModelResponse via stream_chunk_builder.
|
||||
from litellm import stream_chunk_builder
|
||||
from litellm import CustomStreamWrapper as LiteLLMCustomStreamWrapper
|
||||
|
||||
stream_response = cast(
|
||||
LiteLLMCustomStreamWrapper,
|
||||
self._completion(
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=False,
|
||||
stream=True,
|
||||
structured_response_format=structured_response_format,
|
||||
timeout_override=timeout_override,
|
||||
max_tokens=max_tokens,
|
||||
@@ -491,6 +512,11 @@ class LitellmLLM(LLM):
|
||||
client=client,
|
||||
),
|
||||
)
|
||||
chunks = list(stream_response)
|
||||
response = cast(
|
||||
LiteLLMModelResponse,
|
||||
stream_chunk_builder(chunks),
|
||||
)
|
||||
|
||||
model_response = from_litellm_model_response(response)
|
||||
|
||||
@@ -581,3 +607,29 @@ class LitellmLLM(LLM):
|
||||
finally:
|
||||
if client is not None:
|
||||
client.close()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def temporary_env_and_lock(env_variables: dict[str, str]) -> Iterator[None]:
|
||||
"""
|
||||
Temporarily sets the environment variables to the given values.
|
||||
Code path is locked while the environment variables are set.
|
||||
Then cleans up the environment and frees the lock.
|
||||
"""
|
||||
with _env_lock:
|
||||
logger.debug("Acquired lock in temporary_env_and_lock")
|
||||
# Store original values (None if key didn't exist)
|
||||
original_values: dict[str, str | None] = {
|
||||
key: os.environ.get(key) for key in env_variables
|
||||
}
|
||||
try:
|
||||
os.environ.update(env_variables)
|
||||
yield
|
||||
finally:
|
||||
for key, original_value in original_values.items():
|
||||
if original_value is None:
|
||||
os.environ.pop(key, None) # Remove if it didn't exist before
|
||||
else:
|
||||
os.environ[key] = original_value # Restore original value
|
||||
|
||||
logger.debug("Released lock in temporary_env_and_lock")
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import uvicorn
|
||||
|
||||
from onyx.configs.app_configs import MCP_SERVER_ENABLED
|
||||
from onyx.configs.app_configs import MCP_SERVER_HOST
|
||||
from onyx.configs.app_configs import MCP_SERVER_PORT
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -15,13 +16,13 @@ def main() -> None:
|
||||
logger.info("MCP server is disabled (MCP_SERVER_ENABLED=false)")
|
||||
return
|
||||
|
||||
logger.info(f"Starting MCP server on 0.0.0.0:{MCP_SERVER_PORT}")
|
||||
logger.info(f"Starting MCP server on {MCP_SERVER_HOST}:{MCP_SERVER_PORT}")
|
||||
|
||||
from onyx.mcp_server.api import mcp_app
|
||||
|
||||
uvicorn.run(
|
||||
mcp_app,
|
||||
host="0.0.0.0",
|
||||
host=MCP_SERVER_HOST,
|
||||
port=MCP_SERVER_PORT,
|
||||
log_config=None,
|
||||
)
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
# ruff: noqa: E501, W605 start
|
||||
|
||||
from onyx.prompts.constants import REMINDER_TAG_DESCRIPTION
|
||||
from onyx.prompts.constants import REMINDER_TAG_NO_HEADER
|
||||
|
||||
|
||||
DATETIME_REPLACEMENT_PAT = "{{CURRENT_DATETIME}}"
|
||||
CITATION_GUIDANCE_REPLACEMENT_PAT = "{{CITATION_GUIDANCE}}"
|
||||
ALT_DATETIME_REPLACEMENT_PAT = "[[CURRENT_DATETIME]]"
|
||||
ALT_CITATION_GUIDANCE_REPLACEMENT_PAT = "[[CITATION_GUIDANCE]]"
|
||||
REMINDER_TAG_REPLACEMENT_PAT = "{{REMINDER_TAG_DESCRIPTION}}"
|
||||
|
||||
|
||||
# Note this uses a string pattern replacement so the user can also include it in their custom prompts. Keeps the replacement logic simple
|
||||
@@ -27,7 +25,7 @@ For code you prefer to use Markdown and specify the language.
|
||||
You can use horizontal rules (---) to separate sections of your responses.
|
||||
You can use Markdown tables to format your responses for data, lists, and other structured information.
|
||||
|
||||
{REMINDER_TAG_DESCRIPTION}
|
||||
{REMINDER_TAG_REPLACEMENT_PAT}
|
||||
""".lstrip()
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# ruff: noqa: E501, W605 start
|
||||
CODE_BLOCK_PAT = "```\n{}\n```"
|
||||
TRIPLE_BACKTICK = "```"
|
||||
SYSTEM_REMINDER_TAG_OPEN = "<system-reminder>"
|
||||
@@ -5,13 +6,12 @@ SYSTEM_REMINDER_TAG_CLOSE = "</system-reminder>"
|
||||
|
||||
# Tags format inspired by Anthropic and OpenCode
|
||||
REMINDER_TAG_NO_HEADER = f"""
|
||||
User messages may include {SYSTEM_REMINDER_TAG_OPEN} and {SYSTEM_REMINDER_TAG_CLOSE} tags.
|
||||
These {SYSTEM_REMINDER_TAG_OPEN} tags contain useful information and reminders. \
|
||||
They are automatically added by the system and are not actual user inputs.
|
||||
Behave in accordance to these instructions if relevant, and continue normally if they are not.
|
||||
User messages may include {SYSTEM_REMINDER_TAG_OPEN} and {SYSTEM_REMINDER_TAG_CLOSE} tags. These {SYSTEM_REMINDER_TAG_OPEN} tags contain useful information and reminders. \
|
||||
They are automatically added by the system and are not actual user inputs. Behave in accordance to these instructions if relevant, and continue normally if they are not.
|
||||
""".strip()
|
||||
|
||||
REMINDER_TAG_DESCRIPTION = f"""
|
||||
# System Reminders
|
||||
{REMINDER_TAG_NO_HEADER}
|
||||
""".strip()
|
||||
# ruff: noqa: E501, W605 end
|
||||
|
||||
@@ -5,14 +5,14 @@ from langchain_core.messages import BaseMessage
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.prompts.chat_prompts import ADDITIONAL_INFO
|
||||
from onyx.prompts.chat_prompts import ALT_CITATION_GUIDANCE_REPLACEMENT_PAT
|
||||
from onyx.prompts.chat_prompts import ALT_DATETIME_REPLACEMENT_PAT
|
||||
from onyx.prompts.chat_prompts import CITATION_GUIDANCE_REPLACEMENT_PAT
|
||||
from onyx.prompts.chat_prompts import COMPANY_DESCRIPTION_BLOCK
|
||||
from onyx.prompts.chat_prompts import COMPANY_NAME_BLOCK
|
||||
from onyx.prompts.chat_prompts import DATETIME_REPLACEMENT_PAT
|
||||
from onyx.prompts.chat_prompts import REMINDER_TAG_REPLACEMENT_PAT
|
||||
from onyx.prompts.chat_prompts import REQUIRE_CITATION_GUIDANCE
|
||||
from onyx.prompts.constants import CODE_BLOCK_PAT
|
||||
from onyx.prompts.constants import REMINDER_TAG_DESCRIPTION
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -54,11 +54,8 @@ def replace_current_datetime_tag(
|
||||
include_day_of_week=include_day_of_week,
|
||||
)
|
||||
|
||||
# Check and replace both patterns: {{CURRENT_DATETIME}} and [[CURRENT_DATETIME]]
|
||||
if DATETIME_REPLACEMENT_PAT in prompt_str:
|
||||
prompt_str = prompt_str.replace(DATETIME_REPLACEMENT_PAT, datetime_str)
|
||||
if ALT_DATETIME_REPLACEMENT_PAT in prompt_str:
|
||||
prompt_str = prompt_str.replace(ALT_DATETIME_REPLACEMENT_PAT, datetime_str)
|
||||
|
||||
return prompt_str
|
||||
|
||||
@@ -70,7 +67,7 @@ def replace_citation_guidance_tag(
|
||||
include_all_guidance: bool = False,
|
||||
) -> tuple[str, bool]:
|
||||
"""
|
||||
Replace {{CITATION_GUIDANCE}} or [[CITATION_GUIDANCE]] placeholder with citation guidance if needed.
|
||||
Replace {{CITATION_GUIDANCE}} placeholder with citation guidance if needed.
|
||||
|
||||
Returns:
|
||||
tuple[str, bool]: (prompt_with_replacement, should_append_fallback)
|
||||
@@ -78,10 +75,7 @@ def replace_citation_guidance_tag(
|
||||
- should_append_fallback: True if citation guidance should be appended
|
||||
(placeholder is not present and citations are needed)
|
||||
"""
|
||||
# Check for both patterns: {{CITATION_GUIDANCE}} and [[CITATION_GUIDANCE]]
|
||||
has_primary_pattern = CITATION_GUIDANCE_REPLACEMENT_PAT in prompt_str
|
||||
has_alt_pattern = ALT_CITATION_GUIDANCE_REPLACEMENT_PAT in prompt_str
|
||||
placeholder_was_present = has_primary_pattern or has_alt_pattern
|
||||
placeholder_was_present = CITATION_GUIDANCE_REPLACEMENT_PAT in prompt_str
|
||||
|
||||
if not placeholder_was_present:
|
||||
# Placeholder not present - caller should append if citations are needed
|
||||
@@ -96,30 +90,32 @@ def replace_citation_guidance_tag(
|
||||
else ""
|
||||
)
|
||||
|
||||
# Replace both patterns if present
|
||||
if has_primary_pattern:
|
||||
prompt_str = prompt_str.replace(
|
||||
CITATION_GUIDANCE_REPLACEMENT_PAT,
|
||||
citation_guidance,
|
||||
)
|
||||
if has_alt_pattern:
|
||||
prompt_str = prompt_str.replace(
|
||||
ALT_CITATION_GUIDANCE_REPLACEMENT_PAT,
|
||||
citation_guidance,
|
||||
)
|
||||
prompt_str = prompt_str.replace(
|
||||
CITATION_GUIDANCE_REPLACEMENT_PAT,
|
||||
citation_guidance,
|
||||
)
|
||||
|
||||
return prompt_str, False
|
||||
|
||||
|
||||
def replace_reminder_tag(prompt_str: str) -> str:
|
||||
"""Replace {{REMINDER_TAG_DESCRIPTION}} with the reminder tag content."""
|
||||
if REMINDER_TAG_REPLACEMENT_PAT in prompt_str:
|
||||
prompt_str = prompt_str.replace(
|
||||
REMINDER_TAG_REPLACEMENT_PAT, REMINDER_TAG_DESCRIPTION
|
||||
)
|
||||
|
||||
return prompt_str
|
||||
|
||||
|
||||
def handle_onyx_date_awareness(
|
||||
prompt_str: str,
|
||||
# We always replace the pattern {{CURRENT_DATETIME}} or [[CURRENT_DATETIME]] if it shows up
|
||||
# We always replace the pattern {{CURRENT_DATETIME}} if it shows up
|
||||
# but if it doesn't show up and the prompt is datetime aware, add it to the prompt at the end.
|
||||
datetime_aware: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
If there is a {{CURRENT_DATETIME}} or [[CURRENT_DATETIME]] tag, replace it with the current
|
||||
date and time no matter what.
|
||||
If there is a {{CURRENT_DATETIME}} tag, replace it with the current date and time no matter what.
|
||||
If the prompt is datetime aware, and there are no datetime tags, add it to the prompt.
|
||||
Do nothing otherwise.
|
||||
This can later be expanded to support other tags.
|
||||
|
||||
@@ -85,7 +85,7 @@ def send_message(
|
||||
Enforces rate limiting before executing the agent (via dependency).
|
||||
Returns a Server-Sent Events (SSE) stream with the agent's response.
|
||||
|
||||
Follows the same pattern as /chat/send-message for consistency.
|
||||
Follows the same pattern as /chat/send-chat-message for consistency.
|
||||
"""
|
||||
|
||||
def stream_generator() -> Generator[str, None, None]:
|
||||
|
||||
@@ -4,8 +4,9 @@ This client runs `opencode acp` directly in the sandbox pod via kubernetes exec,
|
||||
using stdin/stdout for JSON-RPC communication. This bypasses the HTTP server
|
||||
and uses the native ACP subprocess protocol.
|
||||
|
||||
This module includes comprehensive logging for debugging ACP communication.
|
||||
Enable logging by setting LOG_LEVEL=DEBUG or BUILD_PACKET_LOGGING=true.
|
||||
When multiple API server replicas share the same sandbox pod, this client
|
||||
uses ACP session resumption (session/list + session/resume) to maintain
|
||||
conversation context across replicas.
|
||||
|
||||
Usage:
|
||||
client = ACPExecClient(
|
||||
@@ -100,7 +101,7 @@ class ACPClientState:
|
||||
"""Internal state for the ACP client."""
|
||||
|
||||
initialized: bool = False
|
||||
current_session: ACPSession | None = None
|
||||
sessions: dict[str, ACPSession] = field(default_factory=dict)
|
||||
next_request_id: int = 0
|
||||
agent_capabilities: dict[str, Any] = field(default_factory=dict)
|
||||
agent_info: dict[str, Any] = field(default_factory=dict)
|
||||
@@ -144,6 +145,7 @@ class ACPExecClient:
|
||||
self._reader_thread: threading.Thread | None = None
|
||||
self._stop_reader = threading.Event()
|
||||
self._k8s_client: client.CoreV1Api | None = None
|
||||
self._prompt_count: int = 0 # Track how many prompts sent on this client
|
||||
|
||||
def _get_k8s_client(self) -> client.CoreV1Api:
|
||||
"""Get or create kubernetes client."""
|
||||
@@ -155,16 +157,16 @@ class ACPExecClient:
|
||||
self._k8s_client = client.CoreV1Api()
|
||||
return self._k8s_client
|
||||
|
||||
def start(self, cwd: str = "/workspace", timeout: float = 30.0) -> str:
|
||||
"""Start the agent process via exec and initialize a session.
|
||||
def start(self, cwd: str = "/workspace", timeout: float = 30.0) -> None:
|
||||
"""Start the agent process via exec and initialize the ACP connection.
|
||||
|
||||
Only performs the ACP `initialize` handshake. Sessions are created
|
||||
separately via `create_session()` or `resume_session()`.
|
||||
|
||||
Args:
|
||||
cwd: Working directory for the agent
|
||||
cwd: Working directory for the `opencode acp` process
|
||||
timeout: Timeout for initialization
|
||||
|
||||
Returns:
|
||||
The session ID
|
||||
|
||||
Raises:
|
||||
RuntimeError: If startup fails
|
||||
"""
|
||||
@@ -176,6 +178,8 @@ class ACPExecClient:
|
||||
# Start opencode acp via exec
|
||||
exec_command = ["opencode", "acp", "--cwd", cwd]
|
||||
|
||||
logger.info(f"[ACP] Starting client: pod={self._pod_name} cwd={cwd}")
|
||||
|
||||
try:
|
||||
self._ws_client = k8s_stream(
|
||||
k8s.connect_get_namespaced_pod_exec,
|
||||
@@ -201,15 +205,13 @@ class ACPExecClient:
|
||||
# Give process a moment to start
|
||||
time.sleep(0.5)
|
||||
|
||||
# Initialize ACP connection
|
||||
# Initialize ACP connection (no session creation)
|
||||
self._initialize(timeout=timeout)
|
||||
|
||||
# Create session
|
||||
session_id = self._create_session(cwd=cwd, timeout=timeout)
|
||||
|
||||
return session_id
|
||||
logger.info(f"[ACP] Client started: pod={self._pod_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[ACP] Client start failed: pod={self._pod_name} error={e}")
|
||||
self.stop()
|
||||
raise RuntimeError(f"Failed to start ACP exec client: {e}") from e
|
||||
|
||||
@@ -217,63 +219,153 @@ class ACPExecClient:
|
||||
"""Background thread to read responses from the exec stream."""
|
||||
buffer = ""
|
||||
packet_logger = get_packet_logger()
|
||||
messages_read = 0
|
||||
# Track how many consecutive read cycles the buffer has had
|
||||
# unterminated data (no trailing newline) with no new data arriving.
|
||||
buffer_stale_cycles = 0
|
||||
# Track empty read cycles for periodic buffer state logging
|
||||
empty_read_cycles = 0
|
||||
|
||||
while not self._stop_reader.is_set():
|
||||
if self._ws_client is None:
|
||||
break
|
||||
logger.debug(f"[ACP] Reader thread started for pod={self._pod_name}")
|
||||
|
||||
try:
|
||||
if self._ws_client.is_open():
|
||||
# Read available data
|
||||
self._ws_client.update(timeout=0.1)
|
||||
|
||||
# Read stdout (channel 1)
|
||||
data = self._ws_client.read_stdout(timeout=0.1)
|
||||
if data:
|
||||
buffer += data
|
||||
|
||||
# Process complete lines
|
||||
while "\n" in buffer:
|
||||
line, buffer = buffer.split("\n", 1)
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
message = json.loads(line)
|
||||
# Log the raw incoming message
|
||||
packet_logger.log_jsonrpc_raw_message(
|
||||
"IN", message, context="k8s"
|
||||
)
|
||||
self._response_queue.put(message)
|
||||
except json.JSONDecodeError:
|
||||
packet_logger.log_raw(
|
||||
"JSONRPC-PARSE-ERROR-K8S",
|
||||
{
|
||||
"raw_line": line[:500],
|
||||
"error": "JSON decode failed",
|
||||
},
|
||||
)
|
||||
logger.warning(
|
||||
f"Invalid JSON from agent: {line[:100]}"
|
||||
)
|
||||
|
||||
else:
|
||||
packet_logger.log_raw(
|
||||
"K8S-WEBSOCKET-CLOSED",
|
||||
{"pod": self._pod_name, "namespace": self._namespace},
|
||||
)
|
||||
try:
|
||||
while not self._stop_reader.is_set():
|
||||
if self._ws_client is None:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
if not self._stop_reader.is_set():
|
||||
packet_logger.log_raw(
|
||||
"K8S-READER-ERROR",
|
||||
{"error": str(e), "pod": self._pod_name},
|
||||
try:
|
||||
if self._ws_client.is_open():
|
||||
self._ws_client.update(timeout=0.1)
|
||||
|
||||
# Read stderr - log any agent errors
|
||||
stderr_data = self._ws_client.read_stderr(timeout=0.01)
|
||||
if stderr_data:
|
||||
logger.warning(
|
||||
f"[ACP] stderr pod={self._pod_name}: "
|
||||
f"{stderr_data.strip()[:500]}"
|
||||
)
|
||||
|
||||
# Read stdout
|
||||
data = self._ws_client.read_stdout(timeout=0.1)
|
||||
if data:
|
||||
buffer += data
|
||||
buffer_stale_cycles = 0
|
||||
empty_read_cycles = 0
|
||||
|
||||
while "\n" in buffer:
|
||||
line, buffer = buffer.split("\n", 1)
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
message = json.loads(line)
|
||||
messages_read += 1
|
||||
packet_logger.log_jsonrpc_raw_message(
|
||||
"IN", message, context="k8s"
|
||||
)
|
||||
self._response_queue.put(message)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"[ACP] Invalid JSON from agent: "
|
||||
f"{line[:100]}"
|
||||
)
|
||||
else:
|
||||
empty_read_cycles += 1
|
||||
|
||||
# No new data arrived this cycle. If the buffer
|
||||
# has unterminated content, track how long it's
|
||||
# been sitting there. After a few cycles (~0.5s)
|
||||
# try to parse it — the agent may have sent the
|
||||
# last message without a trailing newline.
|
||||
if buffer.strip():
|
||||
buffer_stale_cycles += 1
|
||||
if buffer_stale_cycles == 1:
|
||||
logger.info(
|
||||
f"[ACP] Buffer has unterminated data: "
|
||||
f"{len(buffer)} bytes, "
|
||||
f"preview={buffer.strip()[:200]}"
|
||||
)
|
||||
if buffer_stale_cycles >= 3:
|
||||
logger.info(
|
||||
f"[ACP] Attempting stale buffer parse: "
|
||||
f"{len(buffer)} bytes, "
|
||||
f"cycles={buffer_stale_cycles}"
|
||||
)
|
||||
try:
|
||||
message = json.loads(buffer.strip())
|
||||
messages_read += 1
|
||||
packet_logger.log_jsonrpc_raw_message(
|
||||
"IN",
|
||||
message,
|
||||
context="k8s-unterminated",
|
||||
)
|
||||
self._response_queue.put(message)
|
||||
buffer = ""
|
||||
buffer_stale_cycles = 0
|
||||
logger.info(
|
||||
"[ACP] Stale buffer parsed successfully"
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# Not valid JSON yet, keep waiting
|
||||
logger.debug(
|
||||
f"[ACP] Stale buffer not valid JSON: "
|
||||
f"{buffer.strip()[:100]}"
|
||||
)
|
||||
|
||||
# Periodic log: every ~5s (50 cycles at 0.1s each)
|
||||
# when we're idle with an empty buffer — helps
|
||||
# confirm the reader is alive and waiting.
|
||||
if empty_read_cycles % 50 == 0:
|
||||
logger.info(
|
||||
f"[ACP] Reader idle: "
|
||||
f"empty_cycles={empty_read_cycles} "
|
||||
f"buffer={len(buffer)} bytes "
|
||||
f"messages_read={messages_read} "
|
||||
f"pod={self._pod_name}"
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
f"[ACP] WebSocket closed: pod={self._pod_name}, "
|
||||
f"messages_read={messages_read}"
|
||||
)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
if not self._stop_reader.is_set():
|
||||
logger.warning(f"[ACP] Reader error: {e}, pod={self._pod_name}")
|
||||
break
|
||||
finally:
|
||||
# Flush any remaining data in buffer
|
||||
remaining = buffer.strip()
|
||||
if remaining:
|
||||
logger.info(
|
||||
f"[ACP] Flushing buffer on exit: {len(remaining)} bytes, "
|
||||
f"preview={remaining[:200]}"
|
||||
)
|
||||
try:
|
||||
message = json.loads(remaining)
|
||||
packet_logger.log_jsonrpc_raw_message(
|
||||
"IN", message, context="k8s-flush"
|
||||
)
|
||||
logger.debug(f"Reader error: {e}")
|
||||
break
|
||||
self._response_queue.put(message)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"[ACP] Buffer flush failed (not JSON): " f"{remaining[:200]}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[ACP] Reader thread exiting: pod={self._pod_name}, "
|
||||
f"messages_read={messages_read}, "
|
||||
f"empty_read_cycles={empty_read_cycles}"
|
||||
)
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the exec session and clean up."""
|
||||
session_ids = list(self._state.sessions.keys())
|
||||
logger.info(
|
||||
f"[ACP] Stopping client: pod={self._pod_name} "
|
||||
f"sessions={session_ids} prompts_sent={self._prompt_count}"
|
||||
)
|
||||
self._stop_reader.set()
|
||||
|
||||
if self._ws_client is not None:
|
||||
@@ -400,44 +492,215 @@ class ACPExecClient:
|
||||
if not session_id:
|
||||
raise RuntimeError("No session ID returned from session/new")
|
||||
|
||||
self._state.current_session = ACPSession(session_id=session_id, cwd=cwd)
|
||||
self._state.sessions[session_id] = ACPSession(session_id=session_id, cwd=cwd)
|
||||
logger.info(f"[ACP] Created session: acp_session={session_id} cwd={cwd}")
|
||||
|
||||
return session_id
|
||||
|
||||
def _list_sessions(self, cwd: str, timeout: float = 10.0) -> list[dict[str, Any]]:
|
||||
"""List available ACP sessions, filtered by working directory.
|
||||
|
||||
Returns:
|
||||
List of session info dicts with keys like 'sessionId', 'cwd', 'title'.
|
||||
Empty list if session/list is not supported or fails.
|
||||
"""
|
||||
try:
|
||||
request_id = self._send_request("session/list", {"cwd": cwd})
|
||||
result = self._wait_for_response(request_id, timeout)
|
||||
sessions = result.get("sessions", [])
|
||||
logger.info(f"[ACP] session/list: {len(sessions)} sessions for cwd={cwd}")
|
||||
return sessions
|
||||
except Exception as e:
|
||||
logger.info(f"[ACP] session/list unavailable: {e}")
|
||||
return []
|
||||
|
||||
def _resume_session(self, session_id: str, cwd: str, timeout: float = 30.0) -> str:
|
||||
"""Resume an existing ACP session.
|
||||
|
||||
Args:
|
||||
session_id: The ACP session ID to resume
|
||||
cwd: Working directory for the session
|
||||
timeout: Timeout for the resume request
|
||||
|
||||
Returns:
|
||||
The session ID
|
||||
|
||||
Raises:
|
||||
RuntimeError: If resume fails
|
||||
"""
|
||||
params = {
|
||||
"sessionId": session_id,
|
||||
"cwd": cwd,
|
||||
"mcpServers": [],
|
||||
}
|
||||
|
||||
request_id = self._send_request("session/resume", params)
|
||||
result = self._wait_for_response(request_id, timeout)
|
||||
|
||||
# The response should contain the session ID
|
||||
resumed_id = result.get("sessionId", session_id)
|
||||
self._state.sessions[resumed_id] = ACPSession(session_id=resumed_id, cwd=cwd)
|
||||
|
||||
logger.info(f"[ACP] Resumed session: acp_session={resumed_id} cwd={cwd}")
|
||||
return resumed_id
|
||||
|
||||
def _try_resume_existing_session(self, cwd: str, timeout: float) -> str | None:
|
||||
"""Try to find and resume an existing session for this workspace.
|
||||
|
||||
When multiple API server replicas connect to the same sandbox pod,
|
||||
a previous replica may have already created an ACP session for this
|
||||
workspace. This method discovers and resumes that session so the
|
||||
agent retains conversation context.
|
||||
|
||||
Args:
|
||||
cwd: Working directory to search for sessions
|
||||
timeout: Timeout for ACP requests
|
||||
|
||||
Returns:
|
||||
The resumed session ID, or None if no session could be resumed
|
||||
"""
|
||||
# Check if the agent supports session/list + session/resume
|
||||
session_caps = self._state.agent_capabilities.get("sessionCapabilities", {})
|
||||
supports_list = session_caps.get("list") is not None
|
||||
supports_resume = session_caps.get("resume") is not None
|
||||
|
||||
if not supports_list or not supports_resume:
|
||||
logger.debug("[ACP] Agent does not support session resume")
|
||||
return None
|
||||
|
||||
# List sessions for this workspace directory
|
||||
sessions = self._list_sessions(cwd, timeout=min(timeout, 10.0))
|
||||
if not sessions:
|
||||
return None
|
||||
|
||||
# Pick the most recent session (first in list, assuming sorted)
|
||||
target = sessions[0]
|
||||
target_id = target.get("sessionId")
|
||||
if not target_id:
|
||||
logger.warning(
|
||||
"[ACP-LIFECYCLE] session/list returned session without sessionId"
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"[ACP] Resuming existing session: acp_session={target_id} "
|
||||
f"(found {len(sessions)})"
|
||||
)
|
||||
|
||||
try:
|
||||
return self._resume_session(target_id, cwd, timeout)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[ACP] session/resume failed for {target_id}: {e}, "
|
||||
f"falling back to session/new"
|
||||
)
|
||||
return None
|
||||
|
||||
def create_session(self, cwd: str, timeout: float = 30.0) -> str:
|
||||
"""Create a new ACP session on this connection.
|
||||
|
||||
Args:
|
||||
cwd: Working directory for the session
|
||||
timeout: Timeout for the request
|
||||
|
||||
Returns:
|
||||
The ACP session ID
|
||||
"""
|
||||
if not self._state.initialized:
|
||||
raise RuntimeError("Client not initialized. Call start() first.")
|
||||
return self._create_session(cwd=cwd, timeout=timeout)
|
||||
|
||||
def resume_session(self, session_id: str, cwd: str, timeout: float = 30.0) -> str:
|
||||
"""Resume an existing ACP session on this connection.
|
||||
|
||||
Args:
|
||||
session_id: The ACP session ID to resume
|
||||
cwd: Working directory for the session
|
||||
timeout: Timeout for the request
|
||||
|
||||
Returns:
|
||||
The ACP session ID
|
||||
"""
|
||||
if not self._state.initialized:
|
||||
raise RuntimeError("Client not initialized. Call start() first.")
|
||||
return self._resume_session(session_id=session_id, cwd=cwd, timeout=timeout)
|
||||
|
||||
def get_or_create_session(self, cwd: str, timeout: float = 30.0) -> str:
|
||||
"""Get an existing session for this cwd, or create/resume one.
|
||||
|
||||
Tries in order:
|
||||
1. Return an already-tracked session for this cwd
|
||||
2. Resume an existing session from opencode's storage (multi-replica)
|
||||
3. Create a new session
|
||||
|
||||
Args:
|
||||
cwd: Working directory for the session
|
||||
timeout: Timeout for ACP requests
|
||||
|
||||
Returns:
|
||||
The ACP session ID
|
||||
"""
|
||||
if not self._state.initialized:
|
||||
raise RuntimeError("Client not initialized. Call start() first.")
|
||||
|
||||
# Check if we already have a session for this cwd
|
||||
for sid, session in self._state.sessions.items():
|
||||
if session.cwd == cwd:
|
||||
logger.info(
|
||||
f"[ACP] Reusing existing session: " f"acp_session={sid} cwd={cwd}"
|
||||
)
|
||||
return sid
|
||||
|
||||
# Try to resume from opencode's persisted storage
|
||||
resumed_id = self._try_resume_existing_session(cwd, timeout)
|
||||
if resumed_id:
|
||||
return resumed_id
|
||||
|
||||
# Create a new session
|
||||
return self._create_session(cwd=cwd, timeout=timeout)
|
||||
|
||||
def send_message(
|
||||
self,
|
||||
message: str,
|
||||
session_id: str,
|
||||
timeout: float = ACP_MESSAGE_TIMEOUT,
|
||||
) -> Generator[ACPEvent, None, None]:
|
||||
"""Send a message and stream response events.
|
||||
"""Send a message to a specific session and stream response events.
|
||||
|
||||
Args:
|
||||
message: The message content to send
|
||||
session_id: The ACP session ID to send the message to
|
||||
timeout: Maximum time to wait for complete response (defaults to ACP_MESSAGE_TIMEOUT env var)
|
||||
|
||||
Yields:
|
||||
Typed ACP schema event objects
|
||||
"""
|
||||
if self._state.current_session is None:
|
||||
raise RuntimeError("No active session. Call start() first.")
|
||||
|
||||
session_id = self._state.current_session.session_id
|
||||
if session_id not in self._state.sessions:
|
||||
raise RuntimeError(
|
||||
f"Unknown session {session_id}. "
|
||||
f"Known sessions: {list(self._state.sessions.keys())}"
|
||||
)
|
||||
packet_logger = get_packet_logger()
|
||||
self._prompt_count += 1
|
||||
prompt_num = self._prompt_count
|
||||
|
||||
# Log the start of message processing
|
||||
packet_logger.log_raw(
|
||||
"ACP-SEND-MESSAGE-START-K8S",
|
||||
{
|
||||
"session_id": session_id,
|
||||
"pod": self._pod_name,
|
||||
"namespace": self._namespace,
|
||||
"message_preview": (
|
||||
message[:200] + "..." if len(message) > 200 else message
|
||||
),
|
||||
"timeout": timeout,
|
||||
},
|
||||
logger.info(
|
||||
f"[ACP] Prompt #{prompt_num} start: "
|
||||
f"acp_session={session_id} pod={self._pod_name}"
|
||||
)
|
||||
|
||||
# Drain leftover messages from the queue (e.g., session_info_update
|
||||
# that arrived between prompts).
|
||||
drained_count = 0
|
||||
while not self._response_queue.empty():
|
||||
try:
|
||||
self._response_queue.get_nowait()
|
||||
drained_count += 1
|
||||
except Empty:
|
||||
break
|
||||
if drained_count > 0:
|
||||
logger.debug(f"[ACP] Drained {drained_count} stale messages")
|
||||
|
||||
prompt_content = [{"type": "text", "text": message}]
|
||||
params = {
|
||||
"sessionId": session_id,
|
||||
@@ -446,44 +709,109 @@ class ACPExecClient:
|
||||
|
||||
request_id = self._send_request("session/prompt", params)
|
||||
start_time = time.time()
|
||||
last_event_time = time.time() # Track time since last event for keepalive
|
||||
last_event_time = time.time()
|
||||
events_yielded = 0
|
||||
messages_processed = 0
|
||||
keepalive_count = 0
|
||||
completion_reason = "unknown"
|
||||
|
||||
while True:
|
||||
remaining = timeout - (time.time() - start_time)
|
||||
if remaining <= 0:
|
||||
packet_logger.log_raw(
|
||||
"ACP-TIMEOUT-K8S",
|
||||
{
|
||||
"session_id": session_id,
|
||||
"elapsed_ms": (time.time() - start_time) * 1000,
|
||||
},
|
||||
completion_reason = "timeout"
|
||||
logger.warning(
|
||||
f"[ACP] Prompt #{prompt_num} timeout: "
|
||||
f"acp_session={session_id} events={events_yielded}"
|
||||
)
|
||||
yield Error(code=-1, message="Timeout waiting for response")
|
||||
break
|
||||
|
||||
try:
|
||||
message_data = self._response_queue.get(timeout=min(remaining, 1.0))
|
||||
last_event_time = time.time() # Reset keepalive timer on event
|
||||
last_event_time = time.time()
|
||||
messages_processed += 1
|
||||
|
||||
# Log every dequeued message for prompt #2+ to diagnose
|
||||
# why the response isn't being matched.
|
||||
if prompt_num >= 2:
|
||||
msg_id = message_data.get("id")
|
||||
logger.info(
|
||||
f"[ACP] Prompt #{prompt_num} dequeued: "
|
||||
f"id={msg_id} type(id)={type(msg_id).__name__} "
|
||||
f"method={message_data.get('method')} "
|
||||
f"keys={list(message_data.keys())} "
|
||||
f"request_id={request_id}"
|
||||
)
|
||||
except Empty:
|
||||
# Check if we need to send an SSE keepalive
|
||||
# Check if reader thread is still alive
|
||||
if (
|
||||
self._reader_thread is not None
|
||||
and not self._reader_thread.is_alive()
|
||||
):
|
||||
completion_reason = "reader_thread_dead"
|
||||
# Drain any final messages the reader flushed before dying
|
||||
while not self._response_queue.empty():
|
||||
try:
|
||||
final_msg = self._response_queue.get_nowait()
|
||||
if final_msg.get("id") == request_id:
|
||||
if "error" in final_msg:
|
||||
error_data = final_msg["error"]
|
||||
yield Error(
|
||||
code=error_data.get("code", -1),
|
||||
message=error_data.get(
|
||||
"message", "Unknown error"
|
||||
),
|
||||
)
|
||||
else:
|
||||
result = final_msg.get("result", {})
|
||||
try:
|
||||
yield PromptResponse.model_validate(result)
|
||||
except ValidationError:
|
||||
pass
|
||||
break
|
||||
except Empty:
|
||||
break
|
||||
|
||||
logger.warning(
|
||||
f"[ACP] Reader thread dead: prompt #{prompt_num} "
|
||||
f"acp_session={session_id} events={events_yielded}"
|
||||
)
|
||||
break
|
||||
|
||||
# Send SSE keepalive if idle
|
||||
idle_time = time.time() - last_event_time
|
||||
if idle_time >= SSE_KEEPALIVE_INTERVAL:
|
||||
packet_logger.log_raw(
|
||||
"SSE-KEEPALIVE-YIELD",
|
||||
{
|
||||
"session_id": session_id,
|
||||
"idle_seconds": idle_time,
|
||||
},
|
||||
)
|
||||
keepalive_count += 1
|
||||
if keepalive_count % 3 == 0:
|
||||
reader_alive = (
|
||||
self._reader_thread is not None
|
||||
and self._reader_thread.is_alive()
|
||||
)
|
||||
elapsed_s = time.time() - start_time
|
||||
logger.info(
|
||||
f"[ACP] Prompt #{prompt_num} waiting: "
|
||||
f"keepalives={keepalive_count} "
|
||||
f"elapsed={elapsed_s:.0f}s "
|
||||
f"events={events_yielded} "
|
||||
f"reader_alive={reader_alive} "
|
||||
f"queue_size={self._response_queue.qsize()}"
|
||||
)
|
||||
yield SSEKeepalive()
|
||||
last_event_time = time.time() # Reset after yielding keepalive
|
||||
last_event_time = time.time()
|
||||
continue
|
||||
|
||||
# Check for response to our prompt request
|
||||
if message_data.get("id") == request_id:
|
||||
# Check for JSON-RPC response to our prompt request.
|
||||
msg_id = message_data.get("id")
|
||||
is_response = "method" not in message_data and (
|
||||
msg_id == request_id
|
||||
or (msg_id is not None and str(msg_id) == str(request_id))
|
||||
)
|
||||
if is_response:
|
||||
completion_reason = "jsonrpc_response"
|
||||
if "error" in message_data:
|
||||
error_data = message_data["error"]
|
||||
completion_reason = "jsonrpc_error"
|
||||
logger.warning(f"[ACP] Prompt #{prompt_num} error: {error_data}")
|
||||
packet_logger.log_jsonrpc_response(
|
||||
request_id, error=error_data, context="k8s"
|
||||
)
|
||||
@@ -498,26 +826,16 @@ class ACPExecClient:
|
||||
)
|
||||
try:
|
||||
prompt_response = PromptResponse.model_validate(result)
|
||||
packet_logger.log_acp_event_yielded(
|
||||
"prompt_response", prompt_response
|
||||
)
|
||||
events_yielded += 1
|
||||
yield prompt_response
|
||||
except ValidationError as e:
|
||||
packet_logger.log_raw(
|
||||
"ACP-VALIDATION-ERROR-K8S",
|
||||
{"type": "prompt_response", "error": str(e)},
|
||||
)
|
||||
logger.error(f"[ACP] PromptResponse validation failed: {e}")
|
||||
|
||||
# Log completion summary
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
packet_logger.log_raw(
|
||||
"ACP-SEND-MESSAGE-COMPLETE-K8S",
|
||||
{
|
||||
"session_id": session_id,
|
||||
"events_yielded": events_yielded,
|
||||
"elapsed_ms": elapsed_ms,
|
||||
},
|
||||
logger.info(
|
||||
f"[ACP] Prompt #{prompt_num} complete: "
|
||||
f"reason={completion_reason} acp_session={session_id} "
|
||||
f"events={events_yielded} elapsed={elapsed_ms:.0f}ms"
|
||||
)
|
||||
break
|
||||
|
||||
@@ -526,25 +844,29 @@ class ACPExecClient:
|
||||
params_data = message_data.get("params", {})
|
||||
update = params_data.get("update", {})
|
||||
|
||||
# Log the notification
|
||||
packet_logger.log_jsonrpc_notification(
|
||||
"session/update",
|
||||
{"update_type": update.get("sessionUpdate")},
|
||||
context="k8s",
|
||||
)
|
||||
|
||||
prompt_complete = False
|
||||
for event in self._process_session_update(update):
|
||||
events_yielded += 1
|
||||
# Log each yielded event
|
||||
event_type = self._get_event_type_name(event)
|
||||
packet_logger.log_acp_event_yielded(event_type, event)
|
||||
yield event
|
||||
if isinstance(event, PromptResponse):
|
||||
prompt_complete = True
|
||||
break
|
||||
|
||||
if prompt_complete:
|
||||
completion_reason = "prompt_response_via_notification"
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
logger.info(
|
||||
f"[ACP] Prompt #{prompt_num} complete: "
|
||||
f"reason={completion_reason} acp_session={session_id} "
|
||||
f"events={events_yielded} elapsed={elapsed_ms:.0f}ms"
|
||||
)
|
||||
break
|
||||
|
||||
# Handle requests from agent - send error response
|
||||
elif "method" in message_data and "id" in message_data:
|
||||
packet_logger.log_raw(
|
||||
"ACP-UNSUPPORTED-REQUEST-K8S",
|
||||
{"method": message_data["method"], "id": message_data["id"]},
|
||||
logger.debug(
|
||||
f"[ACP] Unsupported agent request: "
|
||||
f"method={message_data['method']}"
|
||||
)
|
||||
self._send_error_response(
|
||||
message_data["id"],
|
||||
@@ -552,113 +874,50 @@ class ACPExecClient:
|
||||
f"Method not supported: {message_data['method']}",
|
||||
)
|
||||
|
||||
def _get_event_type_name(self, event: ACPEvent) -> str:
|
||||
"""Get the type name for an ACP event."""
|
||||
if isinstance(event, AgentMessageChunk):
|
||||
return "agent_message_chunk"
|
||||
elif isinstance(event, AgentThoughtChunk):
|
||||
return "agent_thought_chunk"
|
||||
elif isinstance(event, ToolCallStart):
|
||||
return "tool_call_start"
|
||||
elif isinstance(event, ToolCallProgress):
|
||||
return "tool_call_progress"
|
||||
elif isinstance(event, AgentPlanUpdate):
|
||||
return "agent_plan_update"
|
||||
elif isinstance(event, CurrentModeUpdate):
|
||||
return "current_mode_update"
|
||||
elif isinstance(event, PromptResponse):
|
||||
return "prompt_response"
|
||||
elif isinstance(event, Error):
|
||||
return "error"
|
||||
elif isinstance(event, SSEKeepalive):
|
||||
return "sse_keepalive"
|
||||
return "unknown"
|
||||
else:
|
||||
# Elevate to INFO — if the JSON-RPC response is arriving
|
||||
# but failing the is_response check, this will reveal it.
|
||||
logger.info(
|
||||
f"[ACP] Unhandled message: "
|
||||
f"id={message_data.get('id')} "
|
||||
f"type(id)={type(message_data.get('id')).__name__} "
|
||||
f"method={message_data.get('method')} "
|
||||
f"keys={list(message_data.keys())} "
|
||||
f"request_id={request_id} "
|
||||
f"has_result={'result' in message_data} "
|
||||
f"has_error={'error' in message_data}"
|
||||
)
|
||||
|
||||
def _process_session_update(
|
||||
self, update: dict[str, Any]
|
||||
) -> Generator[ACPEvent, None, None]:
|
||||
"""Process a session/update notification and yield typed ACP schema objects."""
|
||||
update_type = update.get("sessionUpdate")
|
||||
packet_logger = get_packet_logger()
|
||||
|
||||
if update_type == "agent_message_chunk":
|
||||
# Map update types to their ACP schema classes
|
||||
type_map: dict[str, type] = {
|
||||
"agent_message_chunk": AgentMessageChunk,
|
||||
"agent_thought_chunk": AgentThoughtChunk,
|
||||
"tool_call": ToolCallStart,
|
||||
"tool_call_update": ToolCallProgress,
|
||||
"plan": AgentPlanUpdate,
|
||||
"current_mode_update": CurrentModeUpdate,
|
||||
"prompt_response": PromptResponse,
|
||||
}
|
||||
|
||||
model_class = type_map.get(update_type) # type: ignore[arg-type]
|
||||
if model_class is not None:
|
||||
try:
|
||||
yield AgentMessageChunk.model_validate(update)
|
||||
yield model_class.model_validate(update)
|
||||
except ValidationError as e:
|
||||
packet_logger.log_raw(
|
||||
"ACP-VALIDATION-ERROR-K8S",
|
||||
{"update_type": update_type, "error": str(e), "update": update},
|
||||
)
|
||||
|
||||
elif update_type == "agent_thought_chunk":
|
||||
try:
|
||||
yield AgentThoughtChunk.model_validate(update)
|
||||
except ValidationError as e:
|
||||
packet_logger.log_raw(
|
||||
"ACP-VALIDATION-ERROR-K8S",
|
||||
{"update_type": update_type, "error": str(e), "update": update},
|
||||
)
|
||||
|
||||
elif update_type == "user_message_chunk":
|
||||
# Echo of user message - skip but log
|
||||
packet_logger.log_raw(
|
||||
"ACP-SKIPPED-UPDATE-K8S", {"type": "user_message_chunk"}
|
||||
)
|
||||
|
||||
elif update_type == "tool_call":
|
||||
try:
|
||||
yield ToolCallStart.model_validate(update)
|
||||
except ValidationError as e:
|
||||
packet_logger.log_raw(
|
||||
"ACP-VALIDATION-ERROR-K8S",
|
||||
{"update_type": update_type, "error": str(e), "update": update},
|
||||
)
|
||||
|
||||
elif update_type == "tool_call_update":
|
||||
try:
|
||||
yield ToolCallProgress.model_validate(update)
|
||||
except ValidationError as e:
|
||||
packet_logger.log_raw(
|
||||
"ACP-VALIDATION-ERROR-K8S",
|
||||
{"update_type": update_type, "error": str(e), "update": update},
|
||||
)
|
||||
|
||||
elif update_type == "plan":
|
||||
try:
|
||||
yield AgentPlanUpdate.model_validate(update)
|
||||
except ValidationError as e:
|
||||
packet_logger.log_raw(
|
||||
"ACP-VALIDATION-ERROR-K8S",
|
||||
{"update_type": update_type, "error": str(e), "update": update},
|
||||
)
|
||||
|
||||
elif update_type == "current_mode_update":
|
||||
try:
|
||||
yield CurrentModeUpdate.model_validate(update)
|
||||
except ValidationError as e:
|
||||
packet_logger.log_raw(
|
||||
"ACP-VALIDATION-ERROR-K8S",
|
||||
{"update_type": update_type, "error": str(e), "update": update},
|
||||
)
|
||||
|
||||
elif update_type == "available_commands_update":
|
||||
# Skip command updates
|
||||
packet_logger.log_raw(
|
||||
"ACP-SKIPPED-UPDATE-K8S", {"type": "available_commands_update"}
|
||||
)
|
||||
|
||||
elif update_type == "session_info_update":
|
||||
# Skip session info updates
|
||||
packet_logger.log_raw(
|
||||
"ACP-SKIPPED-UPDATE-K8S", {"type": "session_info_update"}
|
||||
)
|
||||
|
||||
else:
|
||||
# Unknown update types are logged
|
||||
packet_logger.log_raw(
|
||||
"ACP-UNKNOWN-UPDATE-TYPE-K8S",
|
||||
{"update_type": update_type, "update": update},
|
||||
)
|
||||
logger.warning(f"[ACP] Validation error for {update_type}: {e}")
|
||||
elif update_type not in (
|
||||
"user_message_chunk",
|
||||
"available_commands_update",
|
||||
"session_info_update",
|
||||
"usage_update",
|
||||
):
|
||||
logger.debug(f"[ACP] Unknown update type: {update_type}")
|
||||
|
||||
def _send_error_response(self, request_id: int, code: int, message: str) -> None:
|
||||
"""Send an error response to an agent request."""
|
||||
@@ -673,15 +932,24 @@ class ACPExecClient:
|
||||
|
||||
self._ws_client.write_stdin(json.dumps(response) + "\n")
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Cancel the current operation."""
|
||||
if self._state.current_session is None:
|
||||
return
|
||||
def cancel(self, session_id: str | None = None) -> None:
|
||||
"""Cancel the current operation on a session.
|
||||
|
||||
self._send_notification(
|
||||
"session/cancel",
|
||||
{"sessionId": self._state.current_session.session_id},
|
||||
)
|
||||
Args:
|
||||
session_id: The ACP session ID to cancel. If None, cancels all sessions.
|
||||
"""
|
||||
if session_id:
|
||||
if session_id in self._state.sessions:
|
||||
self._send_notification(
|
||||
"session/cancel",
|
||||
{"sessionId": session_id},
|
||||
)
|
||||
else:
|
||||
for sid in self._state.sessions:
|
||||
self._send_notification(
|
||||
"session/cancel",
|
||||
{"sessionId": sid},
|
||||
)
|
||||
|
||||
def health_check(self, timeout: float = 5.0) -> bool: # noqa: ARG002
|
||||
"""Check if we can exec into the pod."""
|
||||
@@ -708,11 +976,9 @@ class ACPExecClient:
|
||||
return self._ws_client is not None and self._ws_client.is_open()
|
||||
|
||||
@property
|
||||
def session_id(self) -> str | None:
|
||||
"""Get the current session ID, if any."""
|
||||
if self._state.current_session:
|
||||
return self._state.current_session.session_id
|
||||
return None
|
||||
def session_ids(self) -> list[str]:
|
||||
"""Get all tracked session IDs."""
|
||||
return list(self._state.sessions.keys())
|
||||
|
||||
def __enter__(self) -> "ACPExecClient":
|
||||
"""Context manager entry."""
|
||||
|
||||
@@ -50,6 +50,7 @@ from pathlib import Path
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
from acp.schema import PromptResponse
|
||||
from kubernetes import client # type: ignore
|
||||
from kubernetes import config
|
||||
from kubernetes.client.rest import ApiException # type: ignore
|
||||
@@ -97,6 +98,10 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# API server pod hostname — used to identify which replica is handling a request.
|
||||
# In K8s, HOSTNAME is set to the pod name (e.g., "api-server-dpgg7").
|
||||
_API_SERVER_HOSTNAME = os.environ.get("HOSTNAME", "unknown")
|
||||
|
||||
# Constants for pod configuration
|
||||
# Note: Next.js ports are dynamically allocated from SANDBOX_NEXTJS_PORT_START to
|
||||
# SANDBOX_NEXTJS_PORT_END range, with one port per session.
|
||||
@@ -348,6 +353,14 @@ class KubernetesSandboxManager(SandboxManager):
|
||||
self._service_account = SANDBOX_SERVICE_ACCOUNT_NAME
|
||||
self._file_sync_service_account = SANDBOX_FILE_SYNC_SERVICE_ACCOUNT
|
||||
|
||||
# One long-lived ACP client per sandbox (Zed-style architecture).
|
||||
# Multiple craft sessions share one `opencode acp` process per sandbox.
|
||||
self._acp_clients: dict[UUID, ACPExecClient] = {}
|
||||
|
||||
# Maps (sandbox_id, craft_session_id) → ACP session ID.
|
||||
# Each craft session has its own ACP session on the shared client.
|
||||
self._acp_session_ids: dict[tuple[UUID, UUID], str] = {}
|
||||
|
||||
# Load AGENTS.md template path
|
||||
build_dir = Path(__file__).parent.parent.parent # /onyx/server/features/build/
|
||||
self._agent_instructions_template_path = build_dir / "AGENTS.template.md"
|
||||
@@ -532,7 +545,7 @@ done
|
||||
],
|
||||
resources=client.V1ResourceRequirements(
|
||||
requests={"cpu": "1000m", "memory": "2Gi"},
|
||||
limits={"cpu": "4000m", "memory": "8Gi"},
|
||||
limits={"cpu": "2000m", "memory": "10Gi"},
|
||||
),
|
||||
# TODO: Re-enable probes when sandbox container runs actual services.
|
||||
# Note: Next.js ports are now per-session (dynamic), so container-level
|
||||
@@ -1156,11 +1169,28 @@ done
|
||||
def terminate(self, sandbox_id: UUID) -> None:
|
||||
"""Terminate a sandbox and clean up Kubernetes resources.
|
||||
|
||||
Deletes the Service and Pod for the sandbox.
|
||||
Stops the shared ACP client and removes all session mappings for this
|
||||
sandbox, then deletes the Service and Pod.
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID to terminate
|
||||
"""
|
||||
# Stop the shared ACP client for this sandbox
|
||||
acp_client = self._acp_clients.pop(sandbox_id, None)
|
||||
if acp_client:
|
||||
try:
|
||||
acp_client.stop()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[SANDBOX-ACP] Failed to stop ACP client for "
|
||||
f"sandbox {sandbox_id}: {e}"
|
||||
)
|
||||
|
||||
# Remove all session mappings for this sandbox
|
||||
keys_to_remove = [key for key in self._acp_session_ids if key[0] == sandbox_id]
|
||||
for key in keys_to_remove:
|
||||
del self._acp_session_ids[key]
|
||||
|
||||
# Clean up Kubernetes resources (needs string for pod/service names)
|
||||
self._cleanup_kubernetes_resources(str(sandbox_id))
|
||||
|
||||
@@ -1395,7 +1425,8 @@ echo "Session workspace setup complete"
|
||||
) -> None:
|
||||
"""Clean up a session workspace (on session delete).
|
||||
|
||||
Executes kubectl exec to remove the session directory.
|
||||
Removes the ACP session mapping and executes kubectl exec to remove
|
||||
the session directory. The shared ACP client persists for other sessions.
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID
|
||||
@@ -1403,6 +1434,15 @@ echo "Session workspace setup complete"
|
||||
nextjs_port: Optional port where Next.js server is running (unused in K8s,
|
||||
we use PID file instead)
|
||||
"""
|
||||
# Remove the ACP session mapping (shared client persists)
|
||||
session_key = (sandbox_id, session_id)
|
||||
acp_session_id = self._acp_session_ids.pop(session_key, None)
|
||||
if acp_session_id:
|
||||
logger.info(
|
||||
f"[SANDBOX-ACP] Removed ACP session mapping: "
|
||||
f"session={session_id} acp_session={acp_session_id}"
|
||||
)
|
||||
|
||||
pod_name = self._get_pod_name(str(sandbox_id))
|
||||
session_path = f"/workspace/sessions/{session_id}"
|
||||
|
||||
@@ -1807,6 +1847,94 @@ echo "Session config regeneration complete"
|
||||
)
|
||||
return exec_client.health_check(timeout=timeout)
|
||||
|
||||
def _get_or_create_acp_client(self, sandbox_id: UUID) -> ACPExecClient:
|
||||
"""Get the shared ACP client for a sandbox, creating one if needed.
|
||||
|
||||
One long-lived `opencode acp` process per sandbox (Zed-style).
|
||||
If the existing client's WebSocket has died, replaces it with a new one.
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID
|
||||
|
||||
Returns:
|
||||
A running ACPExecClient for this sandbox
|
||||
"""
|
||||
acp_client = self._acp_clients.get(sandbox_id)
|
||||
|
||||
if acp_client is not None and acp_client.is_running:
|
||||
return acp_client
|
||||
|
||||
# Client is dead or doesn't exist — clean up stale one
|
||||
if acp_client is not None:
|
||||
logger.warning(
|
||||
f"[SANDBOX-ACP] Stale ACP client for sandbox {sandbox_id}, "
|
||||
f"replacing"
|
||||
)
|
||||
try:
|
||||
acp_client.stop()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Clear session mappings — they're invalid on a new process
|
||||
keys_to_remove = [
|
||||
key for key in self._acp_session_ids if key[0] == sandbox_id
|
||||
]
|
||||
for key in keys_to_remove:
|
||||
del self._acp_session_ids[key]
|
||||
|
||||
pod_name = self._get_pod_name(str(sandbox_id))
|
||||
new_client = ACPExecClient(
|
||||
pod_name=pod_name,
|
||||
namespace=self._namespace,
|
||||
container="sandbox",
|
||||
)
|
||||
new_client.start(cwd="/workspace")
|
||||
self._acp_clients[sandbox_id] = new_client
|
||||
|
||||
logger.info(
|
||||
f"[SANDBOX-ACP] Created shared ACP client: "
|
||||
f"sandbox={sandbox_id} pod={pod_name} "
|
||||
f"api_pod={_API_SERVER_HOSTNAME}"
|
||||
)
|
||||
return new_client
|
||||
|
||||
def _get_or_create_acp_session(
|
||||
self,
|
||||
sandbox_id: UUID,
|
||||
session_id: UUID,
|
||||
acp_client: ACPExecClient,
|
||||
) -> str:
|
||||
"""Get the ACP session ID for a craft session, creating one if needed.
|
||||
|
||||
Uses the session mapping cache first, then falls back to
|
||||
`get_or_create_session()` which handles resume from opencode's
|
||||
persisted storage (multi-replica support).
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID
|
||||
session_id: The craft session ID
|
||||
acp_client: The shared ACP client for this sandbox
|
||||
|
||||
Returns:
|
||||
The ACP session ID
|
||||
"""
|
||||
session_key = (sandbox_id, session_id)
|
||||
acp_session_id = self._acp_session_ids.get(session_key)
|
||||
|
||||
if acp_session_id and acp_session_id in acp_client.session_ids:
|
||||
return acp_session_id
|
||||
|
||||
# Session not tracked or was lost — get or create it
|
||||
session_path = f"/workspace/sessions/{session_id}"
|
||||
acp_session_id = acp_client.get_or_create_session(cwd=session_path)
|
||||
self._acp_session_ids[session_key] = acp_session_id
|
||||
|
||||
logger.info(
|
||||
f"[SANDBOX-ACP] Session mapped: "
|
||||
f"craft_session={session_id} acp_session={acp_session_id}"
|
||||
)
|
||||
return acp_session_id
|
||||
|
||||
def send_message(
|
||||
self,
|
||||
sandbox_id: UUID,
|
||||
@@ -1815,8 +1943,9 @@ echo "Session config regeneration complete"
|
||||
) -> Generator[ACPEvent, None, None]:
|
||||
"""Send a message to the CLI agent and stream ACP events.
|
||||
|
||||
Runs `opencode acp` via kubectl exec in the sandbox pod.
|
||||
The agent runs in the session-specific workspace.
|
||||
Uses a shared ACP client per sandbox (one `opencode acp` process).
|
||||
Each craft session has its own ACP session ID on that shared process.
|
||||
Switching between sessions is client-side — just use the right sessionId.
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID
|
||||
@@ -1827,37 +1956,46 @@ echo "Session config regeneration complete"
|
||||
Typed ACP schema event objects
|
||||
"""
|
||||
packet_logger = get_packet_logger()
|
||||
pod_name = self._get_pod_name(str(sandbox_id))
|
||||
session_path = f"/workspace/sessions/{session_id}"
|
||||
|
||||
# Log ACP client creation
|
||||
packet_logger.log_acp_client_start(
|
||||
sandbox_id, session_id, session_path, context="k8s"
|
||||
# Get or create the shared ACP client for this sandbox
|
||||
acp_client = self._get_or_create_acp_client(sandbox_id)
|
||||
|
||||
# Get or create the ACP session for this craft session
|
||||
acp_session_id = self._get_or_create_acp_session(
|
||||
sandbox_id, session_id, acp_client
|
||||
)
|
||||
|
||||
exec_client = ACPExecClient(
|
||||
pod_name=pod_name,
|
||||
namespace=self._namespace,
|
||||
container="sandbox",
|
||||
logger.info(
|
||||
f"[SANDBOX-ACP] Sending message: "
|
||||
f"session={session_id} acp_session={acp_session_id} "
|
||||
f"api_pod={_API_SERVER_HOSTNAME}"
|
||||
)
|
||||
|
||||
# Log the send_message call at sandbox manager level
|
||||
packet_logger.log_session_start(session_id, sandbox_id, message)
|
||||
|
||||
events_count = 0
|
||||
got_prompt_response = False
|
||||
try:
|
||||
exec_client.start(cwd=session_path)
|
||||
for event in exec_client.send_message(message):
|
||||
for event in acp_client.send_message(message, session_id=acp_session_id):
|
||||
events_count += 1
|
||||
if isinstance(event, PromptResponse):
|
||||
got_prompt_response = True
|
||||
yield event
|
||||
|
||||
# Log successful completion
|
||||
logger.info(
|
||||
f"[SANDBOX-ACP] send_message completed: "
|
||||
f"session={session_id} events={events_count} "
|
||||
f"got_prompt_response={got_prompt_response}"
|
||||
)
|
||||
packet_logger.log_session_end(
|
||||
session_id, success=True, events_count=events_count
|
||||
)
|
||||
except GeneratorExit:
|
||||
# Generator was closed by consumer (client disconnect, timeout, broken pipe)
|
||||
# This is the most common failure mode for SSE streaming
|
||||
logger.warning(
|
||||
f"[SANDBOX-ACP] GeneratorExit: session={session_id} "
|
||||
f"events={events_count}"
|
||||
)
|
||||
packet_logger.log_session_end(
|
||||
session_id,
|
||||
success=False,
|
||||
@@ -1866,7 +2004,10 @@ echo "Session config regeneration complete"
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
# Log failure from normal exceptions
|
||||
logger.error(
|
||||
f"[SANDBOX-ACP] Exception: session={session_id} "
|
||||
f"events={events_count} error={e}"
|
||||
)
|
||||
packet_logger.log_session_end(
|
||||
session_id,
|
||||
success=False,
|
||||
@@ -1875,19 +2016,16 @@ echo "Session config regeneration complete"
|
||||
)
|
||||
raise
|
||||
except BaseException as e:
|
||||
# Log failure from other base exceptions (SystemExit, KeyboardInterrupt, etc.)
|
||||
exception_type = type(e).__name__
|
||||
logger.error(
|
||||
f"[SANDBOX-ACP] {type(e).__name__}: session={session_id} " f"error={e}"
|
||||
)
|
||||
packet_logger.log_session_end(
|
||||
session_id,
|
||||
success=False,
|
||||
error=f"{exception_type}: {str(e) if str(e) else 'System-level interruption'}",
|
||||
error=f"{type(e).__name__}: {str(e) if str(e) else 'System-level interruption'}",
|
||||
events_count=events_count,
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
exec_client.stop()
|
||||
# Log client stop
|
||||
packet_logger.log_acp_client_stop(sandbox_id, session_id, context="k8s")
|
||||
|
||||
def list_directory(
|
||||
self, sandbox_id: UUID, session_id: UUID, path: str
|
||||
|
||||
@@ -1 +1,10 @@
|
||||
"""Celery tasks for sandbox management."""
|
||||
|
||||
from onyx.server.features.build.sandbox.tasks.tasks import (
|
||||
cleanup_idle_sandboxes_task,
|
||||
) # noqa: F401
|
||||
from onyx.server.features.build.sandbox.tasks.tasks import (
|
||||
sync_sandbox_files,
|
||||
) # noqa: F401
|
||||
|
||||
__all__ = ["cleanup_idle_sandboxes_task", "sync_sandbox_files"]
|
||||
|
||||
@@ -11,6 +11,8 @@ from onyx.context.search.models import SearchSettingsCreationRequest
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.index_attempt import expire_index_attempts
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import update_default_contextual_model
|
||||
from onyx.db.llm import update_no_default_contextual_rag_provider
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.db.models import User
|
||||
from onyx.db.search_settings import delete_search_settings
|
||||
@@ -118,7 +120,9 @@ def set_new_search_settings(
|
||||
# # Ensure Vespa has the new index immediately
|
||||
# get_multipass_config(search_settings)
|
||||
# get_multipass_config(new_search_settings)
|
||||
# document_index = get_default_document_index(search_settings, new_search_settings)
|
||||
# document_index = get_default_document_index(
|
||||
# search_settings, new_search_settings, db_session
|
||||
# )
|
||||
|
||||
# document_index.ensure_indices_exist(
|
||||
# primary_embedding_dim=search_settings.final_embedding_dim,
|
||||
@@ -252,6 +256,13 @@ def update_saved_search_settings(
|
||||
search_settings=search_settings, db_session=db_session
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Updated current search settings to {search_settings.model_dump_json()}"
|
||||
)
|
||||
|
||||
# Re-sync default to match PRESENT search settings
|
||||
_sync_default_contextual_model(db_session)
|
||||
|
||||
|
||||
@router.get("/unstructured-api-key-set")
|
||||
def unstructured_api_key_set(
|
||||
@@ -309,3 +320,23 @@ def _validate_contextual_rag_model(
|
||||
return f"Model {model_name} not found in provider {provider_name}"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _sync_default_contextual_model(db_session: Session) -> None:
|
||||
"""Syncs the default CONTEXTUAL_RAG flow to match the PRESENT search settings."""
|
||||
primary = get_current_search_settings(db_session)
|
||||
|
||||
try:
|
||||
update_default_contextual_model(
|
||||
db_session=db_session,
|
||||
enable_contextual_rag=primary.enable_contextual_rag,
|
||||
contextual_rag_llm_provider=primary.contextual_rag_llm_provider,
|
||||
contextual_rag_llm_name=primary.contextual_rag_llm_name,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"Error syncing default contextual model, defaulting to no contextual model: {e}"
|
||||
)
|
||||
update_no_default_contextual_rag_provider(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
@@ -30,7 +30,6 @@ from onyx.chat.models import ChatFullResponse
|
||||
from onyx.chat.models import CreateChatSessionID
|
||||
from onyx.chat.process_message import gather_stream_full
|
||||
from onyx.chat.process_message import handle_stream_message_objects
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.chat.prompt_utils import get_default_base_system_prompt
|
||||
from onyx.chat.stop_signal_checker import set_fence
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
@@ -40,8 +39,6 @@ from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
|
||||
from onyx.db.chat import add_chats_to_session_from_slack_thread
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import delete_all_chat_sessions_for_user
|
||||
from onyx.db.chat import delete_chat_session
|
||||
from onyx.db.chat import duplicate_chat_session_for_user_from_slack
|
||||
@@ -49,7 +46,6 @@ from onyx.db.chat import get_chat_message
|
||||
from onyx.db.chat import get_chat_messages_by_session
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.chat import set_as_latest_chat_message
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import update_chat_session
|
||||
@@ -71,7 +67,6 @@ from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.llm.factory import get_llm_token_counter
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.secondary_llm_flows.chat_session_naming import generate_chat_session_name
|
||||
from onyx.server.api_key_usage import check_api_key_usage
|
||||
@@ -86,10 +81,7 @@ from onyx.server.query_and_chat.models import ChatSessionGroup
|
||||
from onyx.server.query_and_chat.models import ChatSessionsResponse
|
||||
from onyx.server.query_and_chat.models import ChatSessionSummary
|
||||
from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.models import LLMOverride
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
from onyx.server.query_and_chat.models import PromptOverride
|
||||
from onyx.server.query_and_chat.models import RenameChatSessionResponse
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
|
||||
@@ -503,71 +495,8 @@ def delete_chat_session_by_id(
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
# WARNING: this endpoint is deprecated and will be removed soon. Use the new send-chat-message endpoint instead.
|
||||
@router.post("/send-message")
|
||||
def handle_new_chat_message(
|
||||
chat_message_req: CreateChatMessageRequest,
|
||||
request: Request,
|
||||
user: User = Depends(current_chat_accessible_user),
|
||||
_rate_limit_check: None = Depends(check_token_rate_limits),
|
||||
_api_key_usage_check: None = Depends(check_api_key_usage),
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
This endpoint is both used for all the following purposes:
|
||||
- Sending a new message in the session
|
||||
- Regenerating a message in the session (just send the same one again)
|
||||
- Editing a message (similar to regenerating but sending a different message)
|
||||
- Kicking off a seeded chat session (set `use_existing_user_message`)
|
||||
|
||||
Assumes that previous messages have been set as the latest to minimize overhead.
|
||||
|
||||
Args:
|
||||
chat_message_req (CreateChatMessageRequest): Details about the new chat message.
|
||||
request (Request): The current HTTP request context.
|
||||
user (User): The current user, obtained via dependency injection.
|
||||
_ (None): Rate limit check is run if user/group/global rate limits are enabled.
|
||||
|
||||
Returns:
|
||||
StreamingResponse: Streams the response to the new chat message.
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
logger.debug(f"Received new chat message: {chat_message_req.message}")
|
||||
|
||||
if not chat_message_req.message and not chat_message_req.use_existing_user_message:
|
||||
raise HTTPException(status_code=400, detail="Empty chat message is invalid")
|
||||
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=tenant_id if user.is_anonymous else user.email,
|
||||
event=MilestoneRecordType.RAN_QUERY,
|
||||
)
|
||||
|
||||
def stream_generator() -> Generator[str, None, None]:
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for obj in stream_chat_message_objects(
|
||||
new_msg_req=chat_message_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
litellm_additional_headers=extract_headers(
|
||||
request.headers, LITELLM_PASS_THROUGH_HEADERS
|
||||
),
|
||||
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
|
||||
request.headers
|
||||
),
|
||||
):
|
||||
yield get_json_line(obj.model_dump())
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error in chat message streaming")
|
||||
yield json.dumps({"error": str(e)})
|
||||
|
||||
finally:
|
||||
logger.debug("Stream generator finished")
|
||||
|
||||
return StreamingResponse(stream_generator(), media_type="text/event-stream")
|
||||
|
||||
|
||||
# NOTE: This endpoint is extremely central to the application, any changes to it should be reviewed and approved by an experienced
|
||||
# team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
|
||||
@router.post(
|
||||
"/send-chat-message",
|
||||
response_model=ChatFullResponse,
|
||||
@@ -815,77 +744,6 @@ def get_available_context_tokens_for_session(
|
||||
"""Endpoints for chat seeding"""
|
||||
|
||||
|
||||
class ChatSeedRequest(BaseModel):
|
||||
# standard chat session stuff
|
||||
persona_id: int
|
||||
|
||||
# overrides / seeding
|
||||
llm_override: LLMOverride | None = None
|
||||
prompt_override: PromptOverride | None = None
|
||||
description: str | None = None
|
||||
message: str | None = None
|
||||
|
||||
# TODO: support this
|
||||
# initial_message_retrieval_options: RetrievalDetails | None = None
|
||||
|
||||
|
||||
class ChatSeedResponse(BaseModel):
|
||||
redirect_url: str
|
||||
|
||||
|
||||
@router.post("/seed-chat-session", tags=PUBLIC_API_TAGS)
|
||||
def seed_chat(
|
||||
chat_seed_request: ChatSeedRequest,
|
||||
# NOTE: This endpoint is designed for programmatic access (API keys, external services)
|
||||
# rather than authenticated user sessions. The user parameter is used for access control
|
||||
# but the created chat session is "unassigned" (user_id=None) until a user visits the web UI.
|
||||
# This allows external systems to pre-seed chat sessions that users can then access.
|
||||
user: User = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSeedResponse:
|
||||
|
||||
try:
|
||||
new_chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description=chat_seed_request.description or "",
|
||||
user_id=None, # this chat session is "unassigned" until a user visits the web UI
|
||||
persona_id=chat_seed_request.persona_id,
|
||||
llm_override=chat_seed_request.llm_override,
|
||||
prompt_override=chat_seed_request.prompt_override,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
raise HTTPException(status_code=400, detail="Invalid Persona provided.")
|
||||
|
||||
if chat_seed_request.message is not None:
|
||||
root_message = get_or_create_root_message(
|
||||
chat_session_id=new_chat_session.id, db_session=db_session
|
||||
)
|
||||
llm = get_llm_for_persona(
|
||||
persona=new_chat_session.persona,
|
||||
user=user,
|
||||
)
|
||||
|
||||
tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
provider_type=llm.config.model_provider,
|
||||
)
|
||||
token_count = len(tokenizer.encode(chat_seed_request.message))
|
||||
|
||||
create_new_chat_message(
|
||||
chat_session_id=new_chat_session.id,
|
||||
parent_message=root_message,
|
||||
message=chat_seed_request.message,
|
||||
token_count=token_count,
|
||||
message_type=MessageType.USER,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return ChatSeedResponse(
|
||||
redirect_url=f"{WEB_DOMAIN}/chat?chatId={new_chat_session.id}&seeded=true"
|
||||
)
|
||||
|
||||
|
||||
class SeedChatFromSlackRequest(BaseModel):
|
||||
chat_session_id: UUID
|
||||
|
||||
|
||||
@@ -1,18 +1,15 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import model_validator
|
||||
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import SessionType
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import ChunkContext
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.context.search.models import Tag
|
||||
@@ -20,7 +17,6 @@ from onyx.db.enums import ChatSessionSharedStatus
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.override_models import PromptOverride
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
|
||||
|
||||
@@ -40,8 +36,9 @@ class MessageOrigin(str, Enum):
|
||||
UNSET = "unset"
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
class MessageResponseIDInfo(BaseModel):
|
||||
user_message_id: int | None
|
||||
reserved_assistant_message_id: int
|
||||
|
||||
|
||||
class SourceTag(Tag):
|
||||
@@ -83,6 +80,8 @@ class ChatFeedbackRequest(BaseModel):
|
||||
return self
|
||||
|
||||
|
||||
# NOTE: This model is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
|
||||
# experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
|
||||
class SendMessageRequest(BaseModel):
|
||||
message: str
|
||||
|
||||
@@ -141,115 +140,6 @@ class SendMessageRequest(BaseModel):
|
||||
return self
|
||||
|
||||
|
||||
class OptionalSearchSetting(str, Enum):
|
||||
ALWAYS = "always"
|
||||
NEVER = "never"
|
||||
# Determine whether to run search based on history and latest query
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
class RetrievalDetails(ChunkContext):
|
||||
# Use LLM to determine whether to do a retrieval or only rely on existing history
|
||||
# If the Persona is configured to not run search (0 chunks), this is bypassed
|
||||
# If no Prompt is configured, the only search results are shown, this is bypassed
|
||||
run_search: OptionalSearchSetting = OptionalSearchSetting.AUTO
|
||||
# Is this a real-time/streaming call or a question where Onyx can take more time?
|
||||
# Used to determine reranking flow
|
||||
real_time: bool = True
|
||||
# The following have defaults in the Persona settings which can be overridden via
|
||||
# the query, if None, then use Persona settings
|
||||
filters: BaseFilters | None = None
|
||||
enable_auto_detect_filters: bool | None = None
|
||||
# if None, no offset / limit
|
||||
offset: int | None = None
|
||||
limit: int | None = None
|
||||
|
||||
# If this is set, only the highest matching chunk (or merged chunks) is returned
|
||||
dedupe_docs: bool = False
|
||||
|
||||
|
||||
class CreateChatMessageRequest(ChunkContext):
|
||||
"""Before creating messages, be sure to create a chat_session and get an id"""
|
||||
|
||||
chat_session_id: UUID
|
||||
# This is the primary-key (unique identifier) for the previous message of the tree
|
||||
parent_message_id: int | None
|
||||
|
||||
# New message contents
|
||||
message: str
|
||||
# Files that we should attach to this message
|
||||
file_descriptors: list[FileDescriptor] = []
|
||||
# Prompts are embedded in personas, so no separate prompt_id needed
|
||||
# If search_doc_ids provided, it should use those docs explicitly
|
||||
search_doc_ids: list[int] | None
|
||||
retrieval_options: RetrievalDetails | None
|
||||
# allows the caller to specify the exact search query they want to use
|
||||
# will disable Query Rewording if specified
|
||||
query_override: str | None = None
|
||||
|
||||
# enables additional handling to ensure that we regenerate with a given user message ID
|
||||
regenerate: bool | None = None
|
||||
|
||||
# allows the caller to override the Persona / Prompt
|
||||
# these do not persist in the chat thread details
|
||||
llm_override: LLMOverride | None = None
|
||||
# Test-only override for deterministic LiteLLM mock responses.
|
||||
mock_llm_response: str | None = None
|
||||
prompt_override: PromptOverride | None = None
|
||||
|
||||
# Allows the caller to override the temperature for the chat session
|
||||
# this does persist in the chat thread details
|
||||
temperature_override: float | None = None
|
||||
|
||||
# allow user to specify an alternate assistant
|
||||
alternate_assistant_id: int | None = None
|
||||
|
||||
# This takes the priority over the prompt_override
|
||||
# This won't be a type that's passed in directly from the API
|
||||
persona_override_config: PersonaOverrideConfig | None = None
|
||||
|
||||
# used for seeded chats to kick off the generation of an AI answer
|
||||
use_existing_user_message: bool = False
|
||||
|
||||
# used for "OpenAI Assistants API"
|
||||
existing_assistant_message_id: int | None = None
|
||||
|
||||
# forces the LLM to return a structured response, see
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
# List of allowed tool IDs to restrict tool usage. If not provided, all tools available to the persona will be used.
|
||||
allowed_tool_ids: list[int] | None = None
|
||||
|
||||
# List of tool IDs we MUST use.
|
||||
# TODO: make this a single one since unclear how to force this for multiple at a time.
|
||||
forced_tool_ids: list[int] | None = None
|
||||
|
||||
deep_research: bool = False
|
||||
|
||||
# When True (default), enables citation generation with markers and CitationInfo packets
|
||||
# When False, disables citations: removes markers like [1], [2] and skips CitationInfo packets
|
||||
include_citations: bool = True
|
||||
|
||||
# Origin of the message for telemetry tracking
|
||||
origin: MessageOrigin = MessageOrigin.UNKNOWN
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_search_doc_ids_or_retrieval_options(self) -> "CreateChatMessageRequest":
|
||||
if self.search_doc_ids is None and self.retrieval_options is None:
|
||||
raise ValueError(
|
||||
"Either search_doc_ids or retrieval_options must be provided, but not both or neither."
|
||||
)
|
||||
return self
|
||||
|
||||
def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
|
||||
data = super().model_dump(*args, **kwargs)
|
||||
data["chat_session_id"] = str(data["chat_session_id"])
|
||||
return data
|
||||
|
||||
|
||||
class ChatMessageIdentifier(BaseModel):
|
||||
message_id: int
|
||||
|
||||
@@ -365,13 +255,3 @@ class ChatSearchResponse(BaseModel):
|
||||
groups: list[ChatSessionGroup]
|
||||
has_more: bool
|
||||
next_page: int | None = None
|
||||
|
||||
|
||||
class ChatSearchRequest(BaseModel):
|
||||
query: str | None = None
|
||||
page: int = 1
|
||||
page_size: int = 10
|
||||
|
||||
|
||||
class CreateChatResponse(BaseModel):
|
||||
chat_session_id: str
|
||||
|
||||
@@ -343,7 +343,13 @@ def run_tool_calls(
|
||||
raise ValueError("No user message found in message history")
|
||||
|
||||
search_memory_context = (
|
||||
user_memory_context if inject_memories_in_prompt else None
|
||||
user_memory_context
|
||||
if inject_memories_in_prompt
|
||||
else (
|
||||
user_memory_context.without_memories()
|
||||
if user_memory_context
|
||||
else None
|
||||
)
|
||||
)
|
||||
override_kwargs = SearchToolOverrideKwargs(
|
||||
starting_citation_num=starting_citation_num,
|
||||
|
||||
@@ -17,11 +17,12 @@ disallow_untyped_defs = true
|
||||
warn_unused_ignores = true
|
||||
enable_error_code = ["possibly-undefined"]
|
||||
strict_equality = true
|
||||
# Patterns match paths whether mypy is run from backend/ (CI) or repo root (e.g. VS Code extension with target ./backend)
|
||||
exclude = [
|
||||
"^generated/.*",
|
||||
"^\\.venv/",
|
||||
"^onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/",
|
||||
"^onyx/server/features/build/sandbox/kubernetes/docker/templates/venv/",
|
||||
"(?:^|/)generated/",
|
||||
"(?:^|/)\\.venv/",
|
||||
"(?:^|/)onyx/server/features/build/sandbox/kubernetes/docker/skills/",
|
||||
"(?:^|/)onyx/server/features/build/sandbox/kubernetes/docker/templates/",
|
||||
]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
|
||||
@@ -23,7 +23,7 @@ def create_new_chat_session(onyx_url: str, api_key: str | None) -> int:
|
||||
|
||||
|
||||
def process_question(onyx_url: str, question: str, api_key: str | None) -> None:
|
||||
message_endpoint = onyx_url + "/api/chat/send-message"
|
||||
message_endpoint = onyx_url + "/api/chat/send-chat-message"
|
||||
|
||||
chat_session_id = create_new_chat_session(onyx_url, api_key)
|
||||
|
||||
|
||||
@@ -88,7 +88,7 @@ class ChatLoadTester:
|
||||
token_count = 0
|
||||
|
||||
async with session.post(
|
||||
f"{self.base_url}/chat/send-message",
|
||||
f"{self.base_url}/chat/send-chat-message",
|
||||
headers=self.headers,
|
||||
json={
|
||||
"chat_session_id": chat_session_id,
|
||||
|
||||
@@ -259,6 +259,145 @@ def test_airtable_connector_basic(
|
||||
compare_documents(doc_batch, expected_docs)
|
||||
|
||||
|
||||
def test_airtable_connector_url(
|
||||
mock_get_unstructured_api_key: MagicMock, # noqa: ARG001
|
||||
airtable_config: AirtableConfig,
|
||||
) -> None:
|
||||
"""Test that passing an Airtable URL produces the same results as base_id + table_id."""
|
||||
if not airtable_config.table_identifier.startswith("tbl"):
|
||||
pytest.skip("URL test requires table ID, not table name")
|
||||
|
||||
url = f"https://airtable.com/{airtable_config.base_id}/{airtable_config.table_identifier}/{BASE_VIEW_ID}"
|
||||
connector = AirtableConnector(
|
||||
airtable_url=url,
|
||||
treat_all_non_attachment_fields_as_metadata=False,
|
||||
)
|
||||
connector.load_credentials({"airtable_access_token": airtable_config.access_token})
|
||||
|
||||
doc_batch_generator = connector.load_from_state()
|
||||
doc_batch = [
|
||||
doc for doc in next(doc_batch_generator) if not isinstance(doc, HierarchyNode)
|
||||
]
|
||||
with pytest.raises(StopIteration):
|
||||
next(doc_batch_generator)
|
||||
|
||||
assert len(doc_batch) == 2
|
||||
|
||||
expected_docs = [
|
||||
create_test_document(
|
||||
id="rec8BnxDLyWeegOuO",
|
||||
title="Slow Internet",
|
||||
description="The internet connection is very slow.",
|
||||
priority="Medium",
|
||||
status="In Progress",
|
||||
ticket_id="2",
|
||||
created_time="2024-12-24T21:02:49.000Z",
|
||||
status_last_changed="2024-12-24T21:02:49.000Z",
|
||||
days_since_status_change=0,
|
||||
assignee="Chris Weaver (chris@onyx.app)",
|
||||
submitted_by="Chris Weaver (chris@onyx.app)",
|
||||
all_fields_as_metadata=False,
|
||||
view_id=BASE_VIEW_ID,
|
||||
),
|
||||
create_test_document(
|
||||
id="reccSlIA4pZEFxPBg",
|
||||
title="Printer Issue",
|
||||
description="The office printer is not working.",
|
||||
priority="High",
|
||||
status="Open",
|
||||
ticket_id="1",
|
||||
created_time="2024-12-24T21:02:49.000Z",
|
||||
status_last_changed="2024-12-24T21:02:49.000Z",
|
||||
days_since_status_change=0,
|
||||
assignee="Chris Weaver (chris@onyx.app)",
|
||||
submitted_by="Chris Weaver (chris@onyx.app)",
|
||||
attachments=[
|
||||
(
|
||||
"Test.pdf:\ntesting!!!",
|
||||
f"https://airtable.com/{airtable_config.base_id}/{airtable_config.table_identifier}/{BASE_VIEW_ID}/reccSlIA4pZEFxPBg/fld1u21zkJACIvAEF/attlj2UBWNEDZngCc?blocks=hide",
|
||||
)
|
||||
],
|
||||
all_fields_as_metadata=False,
|
||||
view_id=BASE_VIEW_ID,
|
||||
),
|
||||
]
|
||||
|
||||
compare_documents(doc_batch, expected_docs)
|
||||
|
||||
|
||||
def test_airtable_connector_index_all(
|
||||
mock_get_unstructured_api_key: MagicMock, # noqa: ARG001
|
||||
airtable_config: AirtableConfig,
|
||||
) -> None:
|
||||
"""Test index_all mode discovers all bases/tables and returns documents.
|
||||
|
||||
The test token has access to one base ("Onyx") with three tables:
|
||||
- Tickets: 3 records, 2 with content (1 empty record is skipped)
|
||||
- Support Categories: 4 records, all with Category Name field
|
||||
- Table 3: 3 records, 1 with content (2 empty records are skipped)
|
||||
Total expected: 7 documents
|
||||
"""
|
||||
connector = AirtableConnector()
|
||||
connector.load_credentials({"airtable_access_token": airtable_config.access_token})
|
||||
|
||||
all_docs: list[Document] = []
|
||||
for batch in connector.load_from_state():
|
||||
for item in batch:
|
||||
if isinstance(item, Document):
|
||||
all_docs.append(item)
|
||||
|
||||
# 2 from Tickets + 4 from Support Categories + 1 from Table 3 = 7
|
||||
assert len(all_docs) == 7
|
||||
|
||||
docs_by_id = {d.id: d for d in all_docs}
|
||||
|
||||
# Verify all expected document IDs are present
|
||||
expected_ids = {
|
||||
# Tickets
|
||||
"airtable__rec8BnxDLyWeegOuO",
|
||||
"airtable__reccSlIA4pZEFxPBg",
|
||||
# Support Categories
|
||||
"airtable__rec5SgUDcHXcBc8kS",
|
||||
"airtable__recD3DQHc0BQkDaqX",
|
||||
"airtable__recPHdnWu1Q9ZxyTg",
|
||||
"airtable__recWbIElUDz9HjgMd",
|
||||
# Table 3
|
||||
"airtable__recNalBz02QU1LhbM",
|
||||
}
|
||||
assert docs_by_id.keys() == expected_ids
|
||||
|
||||
# In index_all mode, semantic identifiers include "Base Name > Table Name: Primary Field"
|
||||
assert (
|
||||
docs_by_id["airtable__rec8BnxDLyWeegOuO"].semantic_identifier
|
||||
== "Onyx > Tickets: Slow Internet"
|
||||
)
|
||||
assert (
|
||||
docs_by_id["airtable__rec5SgUDcHXcBc8kS"].semantic_identifier
|
||||
== "Onyx > Support Categories: Software Development"
|
||||
)
|
||||
assert (
|
||||
docs_by_id["airtable__recNalBz02QU1LhbM"].semantic_identifier
|
||||
== "Onyx > Table 3: A"
|
||||
)
|
||||
|
||||
# Verify hierarchy metadata on a Tickets doc
|
||||
tickets_doc = docs_by_id["airtable__rec8BnxDLyWeegOuO"]
|
||||
assert tickets_doc.doc_metadata is not None
|
||||
hierarchy = tickets_doc.doc_metadata["hierarchy"]
|
||||
assert hierarchy["source_path"] == ["Onyx", "Tickets"]
|
||||
assert hierarchy["base_id"] == airtable_config.base_id
|
||||
assert hierarchy["base_name"] == "Onyx"
|
||||
assert hierarchy["table_name"] == "Tickets"
|
||||
|
||||
# Verify hierarchy on a Support Categories doc
|
||||
cat_doc = docs_by_id["airtable__rec5SgUDcHXcBc8kS"]
|
||||
assert cat_doc.doc_metadata is not None
|
||||
assert cat_doc.doc_metadata["hierarchy"]["source_path"] == [
|
||||
"Onyx",
|
||||
"Support Categories",
|
||||
]
|
||||
|
||||
|
||||
def test_airtable_connector_all_metadata(
|
||||
mock_get_unstructured_api_key: MagicMock, # noqa: ARG001
|
||||
airtable_config: AirtableConfig,
|
||||
|
||||
@@ -4,8 +4,8 @@ from typing import cast
|
||||
|
||||
from onyx.chat.models import AnswerStreamPart
|
||||
from onyx.chat.models import CreateChatSessionID
|
||||
from onyx.chat.models import MessageResponseIDInfo
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationFinal
|
||||
from onyx.server.query_and_chat.streaming_models import OpenUrlDocuments
|
||||
|
||||
@@ -6,9 +6,8 @@ from uuid import uuid4
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.models import AnswerStreamPart
|
||||
from onyx.chat.models import MessageResponseIDInfo
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.chat.process_message import handle_stream_message_objects
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
@@ -18,8 +17,8 @@ from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.models import RetrievalDetails
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
@@ -70,17 +69,13 @@ def test_answer_with_only_anthropic_provider(
|
||||
persona_id=0,
|
||||
)
|
||||
|
||||
chat_request = CreateChatMessageRequest(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message_id=None,
|
||||
chat_request = SendMessageRequest(
|
||||
message="hello",
|
||||
file_descriptors=[],
|
||||
search_doc_ids=None,
|
||||
retrieval_options=RetrievalDetails(),
|
||||
chat_session_id=chat_session.id,
|
||||
)
|
||||
|
||||
response_stream: list[AnswerStreamPart] = []
|
||||
for packet in stream_chat_message_objects(
|
||||
for packet in handle_stream_message_objects(
|
||||
new_msg_req=chat_request,
|
||||
user=test_user,
|
||||
db_session=db_session,
|
||||
|
||||
@@ -4,14 +4,13 @@ from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.models import AnswerStreamPart
|
||||
from onyx.chat.models import MessageResponseIDInfo
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.chat.process_message import handle_stream_message_objects
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.models import RetrievalDetails
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from tests.external_dependency_unit.answer.conftest import ensure_default_llm_provider
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
@@ -42,18 +41,12 @@ def test_stream_chat_current_date_response(
|
||||
persona_id=default_persona.id,
|
||||
)
|
||||
|
||||
chat_request = CreateChatMessageRequest(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message_id=None,
|
||||
chat_request = SendMessageRequest(
|
||||
message="Please respond only with the current date in the format 'Weekday Month DD, YYYY'.",
|
||||
file_descriptors=[],
|
||||
prompt_override=None,
|
||||
search_doc_ids=None,
|
||||
retrieval_options=RetrievalDetails(),
|
||||
query_override=None,
|
||||
chat_session_id=chat_session.id,
|
||||
)
|
||||
|
||||
gen = stream_chat_message_objects(
|
||||
gen = handle_stream_message_objects(
|
||||
new_msg_req=chat_request,
|
||||
user=test_user,
|
||||
db_session=db_session,
|
||||
|
||||
@@ -7,8 +7,8 @@ import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.models import CreateChatSessionID
|
||||
from onyx.chat.models import MessageResponseIDInfo
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import GeneratedImage
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationFinal
|
||||
|
||||
@@ -6,15 +6,14 @@ import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.models import AnswerStreamPart
|
||||
from onyx.chat.models import MessageResponseIDInfo
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.chat.process_message import handle_stream_message_objects
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.models import RecencyBiasSetting
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.models import RetrievalDetails
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from tests.external_dependency_unit.answer.conftest import ensure_default_llm_provider
|
||||
@@ -100,18 +99,12 @@ def test_stream_chat_message_objects_without_web_search(
|
||||
persona_id=test_persona.id,
|
||||
)
|
||||
# Create the chat message request with a query that attempts to force web search
|
||||
chat_request = CreateChatMessageRequest(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message_id=None,
|
||||
chat_request = SendMessageRequest(
|
||||
message="run a web search for 'Onyx'",
|
||||
file_descriptors=[],
|
||||
prompt_override=None,
|
||||
search_doc_ids=None,
|
||||
retrieval_options=RetrievalDetails(),
|
||||
query_override=None,
|
||||
chat_session_id=chat_session.id,
|
||||
)
|
||||
# Call stream_chat_message_objects
|
||||
response_generator = stream_chat_message_objects(
|
||||
# Call handle_stream_message_objects
|
||||
response_generator = handle_stream_message_objects(
|
||||
new_msg_req=chat_request,
|
||||
user=test_user,
|
||||
db_session=db_session,
|
||||
|
||||
@@ -5,6 +5,7 @@ These tests verify that:
|
||||
1. USER_REMINDER messages are wrapped with <system-reminder> tags
|
||||
2. The wrapped messages are converted to UserMessage type for the LLM
|
||||
3. The tags are properly applied around the message content
|
||||
4. CODE_BLOCK_MARKDOWN is prepended to system messages for models that need it
|
||||
"""
|
||||
|
||||
import pytest
|
||||
@@ -14,7 +15,9 @@ from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.llm.models import ChatCompletionMessage
|
||||
from onyx.llm.models import SystemMessage
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
|
||||
from onyx.prompts.constants import SYSTEM_REMINDER_TAG_CLOSE
|
||||
from onyx.prompts.constants import SYSTEM_REMINDER_TAG_OPEN
|
||||
|
||||
@@ -175,3 +178,161 @@ class TestUserReminderMessageType:
|
||||
assert SYSTEM_REMINDER_TAG_OPEN not in msg.content
|
||||
assert SYSTEM_REMINDER_TAG_CLOSE not in msg.content
|
||||
assert msg.content == "This is a normal user message."
|
||||
|
||||
|
||||
def _create_llm_config(model_name: str) -> LLMConfig:
|
||||
"""Create a LLMConfig with the specified model name."""
|
||||
return LLMConfig(
|
||||
model_provider="openai",
|
||||
model_name=model_name,
|
||||
temperature=0.7,
|
||||
api_key="test-key",
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
max_input_tokens=128000,
|
||||
)
|
||||
|
||||
|
||||
class TestCodeBlockMarkdownFormatting:
|
||||
"""Tests for CODE_BLOCK_MARKDOWN prefix handling in translate_history_to_llm_format.
|
||||
|
||||
OpenAI reasoning models (o1, o3, gpt-5) need a "Formatting re-enabled. " prefix
|
||||
in their system messages for correct markdown generation.
|
||||
"""
|
||||
|
||||
def test_o1_model_prepends_markdown_to_string(self) -> None:
|
||||
"""Test that o1 model prepends CODE_BLOCK_MARKDOWN to string system message."""
|
||||
llm_config = _create_llm_config("o1")
|
||||
history = [
|
||||
ChatMessageSimple(
|
||||
message="You are a helpful assistant.",
|
||||
token_count=10,
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
]
|
||||
|
||||
raw_result = translate_history_to_llm_format(history, llm_config)
|
||||
result = _ensure_list(raw_result)
|
||||
|
||||
assert len(result) == 1
|
||||
msg = result[0]
|
||||
assert isinstance(msg, SystemMessage)
|
||||
assert isinstance(msg.content, str)
|
||||
assert msg.content == CODE_BLOCK_MARKDOWN + "You are a helpful assistant."
|
||||
|
||||
def test_o3_model_prepends_markdown(self) -> None:
|
||||
"""Test that o3 model prepends CODE_BLOCK_MARKDOWN to system message."""
|
||||
llm_config = _create_llm_config("o3-mini")
|
||||
history = [
|
||||
ChatMessageSimple(
|
||||
message="System prompt here.",
|
||||
token_count=10,
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
]
|
||||
|
||||
raw_result = translate_history_to_llm_format(history, llm_config)
|
||||
result = _ensure_list(raw_result)
|
||||
|
||||
assert len(result) == 1
|
||||
msg = result[0]
|
||||
assert isinstance(msg, SystemMessage)
|
||||
assert isinstance(msg.content, str)
|
||||
assert msg.content.startswith(CODE_BLOCK_MARKDOWN)
|
||||
|
||||
def test_gpt5_model_prepends_markdown(self) -> None:
|
||||
"""Test that gpt-5 model prepends CODE_BLOCK_MARKDOWN to system message."""
|
||||
llm_config = _create_llm_config("gpt-5")
|
||||
history = [
|
||||
ChatMessageSimple(
|
||||
message="System prompt here.",
|
||||
token_count=10,
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
]
|
||||
|
||||
raw_result = translate_history_to_llm_format(history, llm_config)
|
||||
result = _ensure_list(raw_result)
|
||||
|
||||
assert len(result) == 1
|
||||
msg = result[0]
|
||||
assert isinstance(msg, SystemMessage)
|
||||
assert isinstance(msg.content, str)
|
||||
assert msg.content.startswith(CODE_BLOCK_MARKDOWN)
|
||||
|
||||
def test_gpt4o_does_not_prepend(self) -> None:
|
||||
"""Test that gpt-4o model does NOT prepend CODE_BLOCK_MARKDOWN."""
|
||||
llm_config = _create_llm_config("gpt-4o")
|
||||
history = [
|
||||
ChatMessageSimple(
|
||||
message="You are a helpful assistant.",
|
||||
token_count=10,
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
]
|
||||
|
||||
raw_result = translate_history_to_llm_format(history, llm_config)
|
||||
result = _ensure_list(raw_result)
|
||||
|
||||
assert len(result) == 1
|
||||
msg = result[0]
|
||||
assert isinstance(msg, SystemMessage)
|
||||
assert isinstance(msg.content, str)
|
||||
# Should NOT have the prefix
|
||||
assert msg.content == "You are a helpful assistant."
|
||||
assert not msg.content.startswith(CODE_BLOCK_MARKDOWN)
|
||||
|
||||
def test_no_system_message_no_crash(self) -> None:
|
||||
"""Test that history without system message doesn't crash."""
|
||||
llm_config = _create_llm_config("o1")
|
||||
history = [
|
||||
ChatMessageSimple(
|
||||
message="Hello!",
|
||||
token_count=5,
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
]
|
||||
|
||||
raw_result = translate_history_to_llm_format(history, llm_config)
|
||||
result = _ensure_list(raw_result)
|
||||
|
||||
assert len(result) == 1
|
||||
msg = result[0]
|
||||
assert isinstance(msg, UserMessage)
|
||||
assert msg.content == "Hello!"
|
||||
|
||||
def test_only_first_system_message_modified(self) -> None:
|
||||
"""Test that only the first system message gets the prefix."""
|
||||
llm_config = _create_llm_config("o1")
|
||||
history = [
|
||||
ChatMessageSimple(
|
||||
message="First system prompt.",
|
||||
token_count=10,
|
||||
message_type=MessageType.SYSTEM,
|
||||
),
|
||||
ChatMessageSimple(
|
||||
message="Hello!",
|
||||
token_count=5,
|
||||
message_type=MessageType.USER,
|
||||
),
|
||||
ChatMessageSimple(
|
||||
message="Second system prompt.",
|
||||
token_count=10,
|
||||
message_type=MessageType.SYSTEM,
|
||||
),
|
||||
]
|
||||
|
||||
raw_result = translate_history_to_llm_format(history, llm_config)
|
||||
result = _ensure_list(raw_result)
|
||||
|
||||
assert len(result) == 3
|
||||
# First system message should have prefix
|
||||
first_sys = result[0]
|
||||
assert isinstance(first_sys, SystemMessage)
|
||||
assert isinstance(first_sys.content, str)
|
||||
assert first_sys.content.startswith(CODE_BLOCK_MARKDOWN)
|
||||
# Second system message should NOT have prefix (only first one is modified)
|
||||
second_sys = result[2]
|
||||
assert isinstance(second_sys, SystemMessage)
|
||||
assert isinstance(second_sys.content, str)
|
||||
assert not second_sys.content.startswith(CODE_BLOCK_MARKDOWN)
|
||||
|
||||
@@ -8,7 +8,6 @@ import pytest
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.models import MessageResponseIDInfo
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
@@ -21,6 +20,7 @@ from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.query_and_chat.chat_backend import create_new_chat_session
|
||||
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
from tests.external_dependency_unit.answer.stream_test_assertions import (
|
||||
assert_answer_stream_part_correct,
|
||||
)
|
||||
|
||||
@@ -10,11 +10,12 @@ from sqlalchemy.orm import Session
|
||||
from onyx.context.search.models import SavedSearchSettings
|
||||
from onyx.context.search.models import SearchSettingsCreationRequest
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.db.llm import fetch_default_contextual_rag_model
|
||||
from onyx.db.llm import update_default_contextual_model
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.db.search_settings import create_search_settings
|
||||
from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.db.search_settings import update_search_settings_status
|
||||
from onyx.db.swap_index import check_and_perform_index_swap
|
||||
from onyx.indexing.indexing_pipeline import IndexingPipelineResult
|
||||
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
@@ -131,26 +132,37 @@ def baseline_search_settings(
|
||||
) -> None:
|
||||
"""Ensure a baseline PRESENT search settings row exists in the DB,
|
||||
which is required before set_new_search_settings can be called."""
|
||||
baseline = _make_saved_search_settings(enable_contextual_rag=False)
|
||||
create_search_settings(
|
||||
search_settings=_make_saved_search_settings(enable_contextual_rag=False),
|
||||
search_settings=baseline,
|
||||
db_session=db_session,
|
||||
status=IndexModelStatus.PRESENT,
|
||||
)
|
||||
# Sync default contextual model to match PRESENT (clears any leftover state)
|
||||
update_default_contextual_model(
|
||||
db_session=db_session,
|
||||
enable_contextual_rag=baseline.enable_contextual_rag,
|
||||
contextual_rag_llm_provider=baseline.contextual_rag_llm_provider,
|
||||
contextual_rag_llm_name=baseline.contextual_rag_llm_name,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
|
||||
@patch("onyx.db.swap_index.get_all_document_indices")
|
||||
@patch("onyx.server.manage.search_settings.get_default_document_index")
|
||||
@patch("onyx.indexing.indexing_pipeline.get_llm_for_contextual_rag")
|
||||
@patch("onyx.indexing.indexing_pipeline.index_doc_batch_with_handler")
|
||||
def test_indexing_pipeline_uses_contextual_rag_settings_from_create(
|
||||
mock_index_handler: MagicMock,
|
||||
mock_get_llm: MagicMock,
|
||||
mock_get_doc_index: MagicMock, # noqa: ARG001
|
||||
mock_get_all_doc_indices: MagicMock,
|
||||
baseline_search_settings: None, # noqa: ARG001
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""After creating search settings via set_new_search_settings with
|
||||
contextual RAG enabled, run_indexing_pipeline should call
|
||||
get_llm_for_contextual_rag with the LLM names from those settings."""
|
||||
"""After creating FUTURE settings and swapping to PRESENT,
|
||||
fetch_default_contextual_rag_model should match the PRESENT settings
|
||||
and run_indexing_pipeline should call get_llm_for_contextual_rag."""
|
||||
_create_llm_provider_and_model(
|
||||
db_session=db_session,
|
||||
provider_name=TEST_CONTEXTUAL_RAG_LLM_PROVIDER,
|
||||
@@ -163,6 +175,20 @@ def test_indexing_pipeline_uses_contextual_rag_settings_from_create(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# PRESENT still has contextual RAG disabled, so default should be None
|
||||
default_model = fetch_default_contextual_rag_model(db_session)
|
||||
assert default_model is None
|
||||
|
||||
# Swap FUTURE → PRESENT (with 0 cc-pairs, REINDEX swaps immediately)
|
||||
mock_get_all_doc_indices.return_value = []
|
||||
old_settings = check_and_perform_index_swap(db_session)
|
||||
assert old_settings is not None, "Swap should have occurred"
|
||||
|
||||
# Now PRESENT has contextual RAG enabled, default should match
|
||||
default_model = fetch_default_contextual_rag_model(db_session)
|
||||
assert default_model is not None
|
||||
assert default_model.name == TEST_CONTEXTUAL_RAG_LLM_NAME
|
||||
|
||||
_run_indexing_pipeline_with_mocks(mock_get_llm, mock_index_handler, db_session)
|
||||
|
||||
mock_get_llm.assert_called_once_with(
|
||||
@@ -172,16 +198,21 @@ def test_indexing_pipeline_uses_contextual_rag_settings_from_create(
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
|
||||
@patch("onyx.db.swap_index.get_all_document_indices")
|
||||
@patch("onyx.server.manage.search_settings.get_default_document_index")
|
||||
@patch("onyx.indexing.indexing_pipeline.get_llm_for_contextual_rag")
|
||||
@patch("onyx.indexing.indexing_pipeline.index_doc_batch_with_handler")
|
||||
def test_indexing_pipeline_uses_updated_contextual_rag_settings(
|
||||
mock_index_handler: MagicMock,
|
||||
mock_get_llm: MagicMock,
|
||||
tenant_context: None, # noqa: ARG001
|
||||
mock_get_doc_index: MagicMock, # noqa: ARG001
|
||||
mock_get_all_doc_indices: MagicMock,
|
||||
baseline_search_settings: None, # noqa: ARG001
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""After updating search settings via update_saved_search_settings,
|
||||
run_indexing_pipeline should use the updated LLM names."""
|
||||
"""After creating FUTURE settings, swapping to PRESENT, then updating
|
||||
via update_saved_search_settings, run_indexing_pipeline should use
|
||||
the updated LLM names."""
|
||||
_create_llm_provider_and_model(
|
||||
db_session=db_session,
|
||||
provider_name=TEST_CONTEXTUAL_RAG_LLM_PROVIDER,
|
||||
@@ -193,20 +224,28 @@ def test_indexing_pipeline_uses_updated_contextual_rag_settings(
|
||||
model_name=UPDATED_CONTEXTUAL_RAG_LLM_NAME,
|
||||
)
|
||||
|
||||
# Create baseline PRESENT settings with contextual RAG already enabled
|
||||
create_search_settings(
|
||||
search_settings=_make_saved_search_settings(),
|
||||
# Create FUTURE settings with contextual RAG enabled
|
||||
set_new_search_settings(
|
||||
search_settings_new=_make_creation_request(),
|
||||
_=MagicMock(),
|
||||
db_session=db_session,
|
||||
status=IndexModelStatus.PRESENT,
|
||||
)
|
||||
|
||||
# Retire any FUTURE settings left over from other tests so the
|
||||
# pipeline uses the PRESENT (primary) settings we just created.
|
||||
secondary = get_secondary_search_settings(db_session)
|
||||
if secondary:
|
||||
update_search_settings_status(secondary, IndexModelStatus.PAST, db_session)
|
||||
# PRESENT still has contextual RAG disabled, so default should be None
|
||||
default_model = fetch_default_contextual_rag_model(db_session)
|
||||
assert default_model is None
|
||||
|
||||
# Update LLM names via the endpoint function
|
||||
# Swap FUTURE → PRESENT (with 0 cc-pairs, REINDEX swaps immediately)
|
||||
mock_get_all_doc_indices.return_value = []
|
||||
old_settings = check_and_perform_index_swap(db_session)
|
||||
assert old_settings is not None, "Swap should have occurred"
|
||||
|
||||
# Now PRESENT has contextual RAG enabled, default should match
|
||||
default_model = fetch_default_contextual_rag_model(db_session)
|
||||
assert default_model is not None
|
||||
assert default_model.name == TEST_CONTEXTUAL_RAG_LLM_NAME
|
||||
|
||||
# Update the PRESENT LLM names
|
||||
update_saved_search_settings(
|
||||
search_settings=_make_saved_search_settings(
|
||||
llm_name=UPDATED_CONTEXTUAL_RAG_LLM_NAME,
|
||||
@@ -216,6 +255,10 @@ def test_indexing_pipeline_uses_updated_contextual_rag_settings(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
default_model = fetch_default_contextual_rag_model(db_session)
|
||||
assert default_model is not None
|
||||
assert default_model.name == UPDATED_CONTEXTUAL_RAG_LLM_NAME
|
||||
|
||||
_run_indexing_pipeline_with_mocks(mock_get_llm, mock_index_handler, db_session)
|
||||
|
||||
mock_get_llm.assert_called_once_with(
|
||||
@@ -231,6 +274,7 @@ def test_indexing_pipeline_uses_updated_contextual_rag_settings(
|
||||
def test_indexing_pipeline_skips_llm_when_contextual_rag_disabled(
|
||||
mock_index_handler: MagicMock,
|
||||
mock_get_llm: MagicMock,
|
||||
mock_get_doc_index: MagicMock, # noqa: ARG001
|
||||
baseline_search_settings: None, # noqa: ARG001
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
@@ -248,6 +292,10 @@ def test_indexing_pipeline_skips_llm_when_contextual_rag_disabled(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# PRESENT has contextual RAG disabled, so default should be None
|
||||
default_model = fetch_default_contextual_rag_model(db_session)
|
||||
assert default_model is None
|
||||
|
||||
_run_indexing_pipeline_with_mocks(mock_get_llm, mock_index_handler, db_session)
|
||||
|
||||
mock_get_llm.assert_not_called()
|
||||
|
||||
@@ -29,7 +29,7 @@ def test_create_chat_session_and_send_messages() -> None:
|
||||
# Send first message
|
||||
first_message = "Hello, this is a test message."
|
||||
send_message_response = requests.post(
|
||||
f"{base_url}/chat/send-message",
|
||||
f"{base_url}/chat/send-chat-message",
|
||||
json={
|
||||
"chat_session_id": chat_session_id,
|
||||
"message": first_message,
|
||||
@@ -43,7 +43,7 @@ def test_create_chat_session_and_send_messages() -> None:
|
||||
# Send second message
|
||||
second_message = "Can you provide more information?"
|
||||
send_message_response = requests.post(
|
||||
f"{base_url}/chat/send-message",
|
||||
f"{base_url}/chat/send-chat-message",
|
||||
json={
|
||||
"chat_session_id": chat_session_id,
|
||||
"message": second_message,
|
||||
|
||||
@@ -12,10 +12,9 @@ from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.override_models import PromptOverride
|
||||
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
|
||||
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.models import RetrievalDetails
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.streaming_models import StreamingType
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
@@ -104,37 +103,27 @@ class ChatSessionManager:
|
||||
parent_message_id: int | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
file_descriptors: list[FileDescriptor] | None = None,
|
||||
search_doc_ids: list[int] | None = None,
|
||||
retrieval_options: RetrievalDetails | None = None,
|
||||
query_override: str | None = None,
|
||||
regenerate: bool | None = None,
|
||||
llm_override: LLMOverride | None = None,
|
||||
prompt_override: PromptOverride | None = None,
|
||||
alternate_assistant_id: int | None = None,
|
||||
use_existing_user_message: bool = False,
|
||||
allowed_tool_ids: list[int] | None = None,
|
||||
forced_tool_ids: list[int] | None = None,
|
||||
chat_session: DATestChatSession | None = None,
|
||||
mock_llm_response: str | None = None,
|
||||
deep_research: bool = False,
|
||||
llm_override: LLMOverride | None = None,
|
||||
) -> StreamedResponse:
|
||||
chat_message_req = CreateChatMessageRequest(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=parent_message_id,
|
||||
chat_message_req = SendMessageRequest(
|
||||
message=message,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=(
|
||||
parent_message_id
|
||||
if parent_message_id is not None
|
||||
else AUTO_PLACE_AFTER_LATEST_MESSAGE
|
||||
),
|
||||
file_descriptors=file_descriptors or [],
|
||||
search_doc_ids=search_doc_ids or [],
|
||||
retrieval_options=retrieval_options,
|
||||
query_override=query_override,
|
||||
regenerate=regenerate,
|
||||
llm_override=llm_override,
|
||||
mock_llm_response=mock_llm_response,
|
||||
prompt_override=prompt_override,
|
||||
alternate_assistant_id=alternate_assistant_id,
|
||||
use_existing_user_message=use_existing_user_message,
|
||||
allowed_tool_ids=allowed_tool_ids,
|
||||
forced_tool_ids=forced_tool_ids,
|
||||
forced_tool_id=forced_tool_ids[0] if forced_tool_ids else None,
|
||||
mock_llm_response=mock_llm_response,
|
||||
deep_research=deep_research,
|
||||
llm_override=llm_override,
|
||||
)
|
||||
|
||||
headers = (
|
||||
@@ -145,8 +134,8 @@ class ChatSessionManager:
|
||||
cookies = user_performing_action.cookies if user_performing_action else None
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/send-message",
|
||||
json=chat_message_req.model_dump(),
|
||||
f"{API_SERVER_URL}/chat/send-chat-message",
|
||||
json=chat_message_req.model_dump(mode="json"),
|
||||
headers=headers,
|
||||
stream=True,
|
||||
cookies=cookies,
|
||||
@@ -182,17 +171,11 @@ class ChatSessionManager:
|
||||
parent_message_id: int | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
file_descriptors: list[FileDescriptor] | None = None,
|
||||
search_doc_ids: list[int] | None = None,
|
||||
query_override: str | None = None,
|
||||
regenerate: bool | None = None,
|
||||
llm_override: LLMOverride | None = None,
|
||||
prompt_override: PromptOverride | None = None,
|
||||
alternate_assistant_id: int | None = None,
|
||||
use_existing_user_message: bool = False,
|
||||
allowed_tool_ids: list[int] | None = None,
|
||||
forced_tool_ids: list[int] | None = None,
|
||||
mock_llm_response: str | None = None,
|
||||
deep_research: bool = False,
|
||||
llm_override: LLMOverride | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Send a message and simulate client disconnect before stream completes.
|
||||
@@ -204,33 +187,25 @@ class ChatSessionManager:
|
||||
chat_session_id: The chat session ID
|
||||
message: The message to send
|
||||
disconnect_after_packets: Disconnect after receiving this many packets.
|
||||
If None, disconnect_after_type must be specified.
|
||||
disconnect_after_type: Disconnect after receiving a packet of this type
|
||||
(e.g., "message_start", "search_tool_start"). If None,
|
||||
disconnect_after_packets must be specified.
|
||||
... (other standard message parameters)
|
||||
|
||||
Returns:
|
||||
StreamedResponse containing data received before disconnect,
|
||||
with is_disconnected=True flag set.
|
||||
None. Caller can verify server-side cleanup via get_chat_history etc.
|
||||
"""
|
||||
chat_message_req = CreateChatMessageRequest(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=parent_message_id,
|
||||
chat_message_req = SendMessageRequest(
|
||||
message=message,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=(
|
||||
parent_message_id
|
||||
if parent_message_id is not None
|
||||
else AUTO_PLACE_AFTER_LATEST_MESSAGE
|
||||
),
|
||||
file_descriptors=file_descriptors or [],
|
||||
search_doc_ids=search_doc_ids or [],
|
||||
retrieval_options=RetrievalDetails(), # This will be deprecated soon anyway
|
||||
query_override=query_override,
|
||||
regenerate=regenerate,
|
||||
llm_override=llm_override,
|
||||
mock_llm_response=mock_llm_response,
|
||||
prompt_override=prompt_override,
|
||||
alternate_assistant_id=alternate_assistant_id,
|
||||
use_existing_user_message=use_existing_user_message,
|
||||
allowed_tool_ids=allowed_tool_ids,
|
||||
forced_tool_ids=forced_tool_ids,
|
||||
forced_tool_id=forced_tool_ids[0] if forced_tool_ids else None,
|
||||
mock_llm_response=mock_llm_response,
|
||||
deep_research=deep_research,
|
||||
llm_override=llm_override,
|
||||
)
|
||||
|
||||
headers = (
|
||||
@@ -243,8 +218,8 @@ class ChatSessionManager:
|
||||
packets_received = 0
|
||||
|
||||
with requests.post(
|
||||
f"{API_SERVER_URL}/chat/send-message",
|
||||
json=chat_message_req.model_dump(),
|
||||
f"{API_SERVER_URL}/chat/send-chat-message",
|
||||
json=chat_message_req.model_dump(mode="json"),
|
||||
headers=headers,
|
||||
stream=True,
|
||||
cookies=cookies,
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from onyx.configs import app_configs
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.server.query_and_chat.models import OptionalSearchSetting
|
||||
from onyx.server.query_and_chat.models import RetrievalDetails
|
||||
from onyx.tools.constants import SEARCH_TOOL_ID
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
@@ -172,7 +170,7 @@ def test_run_search_always_maps_to_forced_search_tool(admin_user: DATestUser) ->
|
||||
chat_session_id=chat_session.id,
|
||||
message="always run search",
|
||||
user_performing_action=admin_user,
|
||||
retrieval_options=RetrievalDetails(run_search=OptionalSearchSetting.ALWAYS),
|
||||
forced_tool_ids=[search_tool_id],
|
||||
mock_llm_response='{"name":"internal_search","arguments":{"queries":["gamma"]}}',
|
||||
)
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ class TestOnyxWebCrawler:
|
||||
content from public websites correctly.
|
||||
"""
|
||||
|
||||
@pytest.mark.skip(reason="Temporarily disabled")
|
||||
def test_fetches_public_url_successfully(self, admin_user: DATestUser) -> None:
|
||||
"""Test that the crawler can fetch content from a public URL."""
|
||||
response = requests.post(
|
||||
@@ -40,6 +41,7 @@ class TestOnyxWebCrawler:
|
||||
assert "This domain is for use in" in content
|
||||
assert "documentation" in content or "illustrative" in content
|
||||
|
||||
@pytest.mark.skip(reason="Temporarily disabled")
|
||||
def test_fetches_multiple_urls(self, admin_user: DATestUser) -> None:
|
||||
"""Test that the crawler can fetch multiple URLs in one request."""
|
||||
response = requests.post(
|
||||
@@ -263,6 +265,7 @@ def _activate_exa_provider(admin_user: DATestUser) -> int:
|
||||
|
||||
|
||||
@pytestmark_exa
|
||||
@pytest.mark.skip(reason="Temporarily disabled")
|
||||
def test_web_search_endpoints_with_exa(
|
||||
reset: None, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
|
||||
152
backend/tests/unit/onyx/chat/test_chat_utils.py
Normal file
152
backend/tests/unit/onyx/chat/test_chat_utils.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""Tests for chat_utils.py, specifically get_custom_agent_prompt."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from onyx.chat.chat_utils import get_custom_agent_prompt
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
|
||||
|
||||
class TestGetCustomAgentPrompt:
|
||||
"""Tests for the get_custom_agent_prompt function."""
|
||||
|
||||
def _create_mock_persona(
|
||||
self,
|
||||
persona_id: int = 1,
|
||||
system_prompt: str | None = None,
|
||||
replace_base_system_prompt: bool = False,
|
||||
) -> MagicMock:
|
||||
"""Create a mock Persona with the specified attributes."""
|
||||
persona = MagicMock()
|
||||
persona.id = persona_id
|
||||
persona.system_prompt = system_prompt
|
||||
persona.replace_base_system_prompt = replace_base_system_prompt
|
||||
return persona
|
||||
|
||||
def _create_mock_chat_session(
|
||||
self,
|
||||
project: MagicMock | None = None,
|
||||
) -> MagicMock:
|
||||
"""Create a mock ChatSession with the specified attributes."""
|
||||
chat_session = MagicMock()
|
||||
chat_session.project = project
|
||||
return chat_session
|
||||
|
||||
def _create_mock_project(
|
||||
self,
|
||||
instructions: str = "",
|
||||
) -> MagicMock:
|
||||
"""Create a mock UserProject with the specified attributes."""
|
||||
project = MagicMock()
|
||||
project.instructions = instructions
|
||||
return project
|
||||
|
||||
def test_default_persona_no_project(self) -> None:
|
||||
"""Test that default persona without a project returns None."""
|
||||
persona = self._create_mock_persona(persona_id=DEFAULT_PERSONA_ID)
|
||||
chat_session = self._create_mock_chat_session(project=None)
|
||||
|
||||
result = get_custom_agent_prompt(persona, chat_session)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_default_persona_with_project_instructions(self) -> None:
|
||||
"""Test that default persona in a project returns project instructions."""
|
||||
persona = self._create_mock_persona(persona_id=DEFAULT_PERSONA_ID)
|
||||
project = self._create_mock_project(instructions="Do X and Y")
|
||||
chat_session = self._create_mock_chat_session(project=project)
|
||||
|
||||
result = get_custom_agent_prompt(persona, chat_session)
|
||||
|
||||
assert result == "Do X and Y"
|
||||
|
||||
def test_default_persona_with_empty_project_instructions(self) -> None:
|
||||
"""Test that default persona in a project with empty instructions returns None."""
|
||||
persona = self._create_mock_persona(persona_id=DEFAULT_PERSONA_ID)
|
||||
project = self._create_mock_project(instructions="")
|
||||
chat_session = self._create_mock_chat_session(project=project)
|
||||
|
||||
result = get_custom_agent_prompt(persona, chat_session)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_custom_persona_replace_base_prompt_true(self) -> None:
|
||||
"""Test that custom persona with replace_base_system_prompt=True returns None."""
|
||||
persona = self._create_mock_persona(
|
||||
persona_id=1,
|
||||
system_prompt="Custom system prompt",
|
||||
replace_base_system_prompt=True,
|
||||
)
|
||||
chat_session = self._create_mock_chat_session(project=None)
|
||||
|
||||
result = get_custom_agent_prompt(persona, chat_session)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_custom_persona_with_system_prompt(self) -> None:
|
||||
"""Test that custom persona with system_prompt returns the system_prompt."""
|
||||
persona = self._create_mock_persona(
|
||||
persona_id=1,
|
||||
system_prompt="Custom system prompt",
|
||||
replace_base_system_prompt=False,
|
||||
)
|
||||
chat_session = self._create_mock_chat_session(project=None)
|
||||
|
||||
result = get_custom_agent_prompt(persona, chat_session)
|
||||
|
||||
assert result == "Custom system prompt"
|
||||
|
||||
def test_custom_persona_empty_string_system_prompt(self) -> None:
|
||||
"""Test that custom persona with empty string system_prompt returns None."""
|
||||
persona = self._create_mock_persona(
|
||||
persona_id=1,
|
||||
system_prompt="",
|
||||
replace_base_system_prompt=False,
|
||||
)
|
||||
chat_session = self._create_mock_chat_session(project=None)
|
||||
|
||||
result = get_custom_agent_prompt(persona, chat_session)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_custom_persona_none_system_prompt(self) -> None:
|
||||
"""Test that custom persona with None system_prompt returns None."""
|
||||
persona = self._create_mock_persona(
|
||||
persona_id=1,
|
||||
system_prompt=None,
|
||||
replace_base_system_prompt=False,
|
||||
)
|
||||
chat_session = self._create_mock_chat_session(project=None)
|
||||
|
||||
result = get_custom_agent_prompt(persona, chat_session)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_custom_persona_in_project_uses_persona_prompt(self) -> None:
|
||||
"""Test that custom persona in a project uses persona's system_prompt, not project instructions."""
|
||||
persona = self._create_mock_persona(
|
||||
persona_id=1,
|
||||
system_prompt="Custom system prompt",
|
||||
replace_base_system_prompt=False,
|
||||
)
|
||||
project = self._create_mock_project(instructions="Project instructions")
|
||||
chat_session = self._create_mock_chat_session(project=project)
|
||||
|
||||
result = get_custom_agent_prompt(persona, chat_session)
|
||||
|
||||
# Should use persona's system_prompt, NOT project instructions
|
||||
assert result == "Custom system prompt"
|
||||
|
||||
def test_custom_persona_replace_base_in_project(self) -> None:
|
||||
"""Test that custom persona with replace_base_system_prompt=True in a project still returns None."""
|
||||
persona = self._create_mock_persona(
|
||||
persona_id=1,
|
||||
system_prompt="Custom system prompt",
|
||||
replace_base_system_prompt=True,
|
||||
)
|
||||
project = self._create_mock_project(instructions="Project instructions")
|
||||
chat_session = self._create_mock_chat_session(project=project)
|
||||
|
||||
result = get_custom_agent_prompt(persona, chat_session)
|
||||
|
||||
# Should return None because replace_base_system_prompt=True
|
||||
assert result is None
|
||||
@@ -0,0 +1,618 @@
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.connectors.airtable.airtable_connector import AirtableConnector
|
||||
from onyx.connectors.airtable.airtable_connector import parse_airtable_url
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.models import Document
|
||||
|
||||
|
||||
def _make_field_schema(field_id: str, name: str, field_type: str) -> MagicMock:
|
||||
field = MagicMock()
|
||||
field.id = field_id
|
||||
field.name = name
|
||||
field.type = field_type
|
||||
return field
|
||||
|
||||
|
||||
def _make_table_schema(
|
||||
table_id: str,
|
||||
table_name: str,
|
||||
primary_field_id: str,
|
||||
fields: list[MagicMock],
|
||||
) -> MagicMock:
|
||||
schema = MagicMock()
|
||||
schema.id = table_id
|
||||
schema.name = table_name
|
||||
schema.primary_field_id = primary_field_id
|
||||
schema.fields = fields
|
||||
schema.views = []
|
||||
return schema
|
||||
|
||||
|
||||
def _make_record(record_id: str, fields: dict[str, Any]) -> dict[str, Any]:
|
||||
return {"id": record_id, "fields": fields}
|
||||
|
||||
|
||||
def _make_base_info(base_id: str, name: str) -> MagicMock:
|
||||
info = MagicMock()
|
||||
info.id = base_id
|
||||
info.name = name
|
||||
return info
|
||||
|
||||
|
||||
def _make_table_obj(table_id: str, name: str) -> MagicMock:
|
||||
obj = MagicMock()
|
||||
obj.id = table_id
|
||||
obj.name = name
|
||||
return obj
|
||||
|
||||
|
||||
def _setup_mock_api(
|
||||
bases: list[dict[str, Any]],
|
||||
) -> MagicMock:
|
||||
"""Set up a mock AirtableApi with bases, tables, records, and schemas.
|
||||
|
||||
Args:
|
||||
bases: List of dicts with keys: id, name, tables.
|
||||
Each table is a dict with: id, name, primary_field_id, fields, records.
|
||||
Each field is a dict with: id, name, type.
|
||||
Each record is a dict with: id, fields.
|
||||
"""
|
||||
mock_api = MagicMock()
|
||||
|
||||
base_infos = [_make_base_info(b["id"], b["name"]) for b in bases]
|
||||
mock_api.bases.return_value = base_infos
|
||||
|
||||
def base_side_effect(base_id: str) -> MagicMock:
|
||||
mock_base = MagicMock()
|
||||
base_data = next((b for b in bases if b["id"] == base_id), None)
|
||||
if not base_data:
|
||||
raise ValueError(f"Unknown base: {base_id}")
|
||||
|
||||
table_objs = [_make_table_obj(t["id"], t["name"]) for t in base_data["tables"]]
|
||||
mock_base.tables.return_value = table_objs
|
||||
return mock_base
|
||||
|
||||
mock_api.base.side_effect = base_side_effect
|
||||
|
||||
def table_side_effect(base_id: str, table_name_or_id: str) -> MagicMock:
|
||||
base_data = next((b for b in bases if b["id"] == base_id), None)
|
||||
if not base_data:
|
||||
raise ValueError(f"Unknown base: {base_id}")
|
||||
|
||||
table_data = next(
|
||||
(
|
||||
t
|
||||
for t in base_data["tables"]
|
||||
if t["id"] == table_name_or_id or t["name"] == table_name_or_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not table_data:
|
||||
raise ValueError(f"Unknown table: {table_name_or_id}")
|
||||
|
||||
mock_table = MagicMock()
|
||||
mock_table.name = table_data["name"]
|
||||
mock_table.all.return_value = [
|
||||
_make_record(r["id"], r["fields"]) for r in table_data["records"]
|
||||
]
|
||||
|
||||
field_schemas = [
|
||||
_make_field_schema(f["id"], f["name"], f["type"])
|
||||
for f in table_data["fields"]
|
||||
]
|
||||
schema = _make_table_schema(
|
||||
table_data["id"],
|
||||
table_data["name"],
|
||||
table_data["primary_field_id"],
|
||||
field_schemas,
|
||||
)
|
||||
mock_table.schema.return_value = schema
|
||||
return mock_table
|
||||
|
||||
mock_api.table.side_effect = table_side_effect
|
||||
return mock_api
|
||||
|
||||
|
||||
SAMPLE_BASES = [
|
||||
{
|
||||
"id": "appBASE1",
|
||||
"name": "Base One",
|
||||
"tables": [
|
||||
{
|
||||
"id": "tblTABLE1",
|
||||
"name": "Table A",
|
||||
"primary_field_id": "fld1",
|
||||
"fields": [
|
||||
{"id": "fld1", "name": "Name", "type": "singleLineText"},
|
||||
{"id": "fld2", "name": "Notes", "type": "multilineText"},
|
||||
],
|
||||
"records": [
|
||||
{"id": "recA1", "fields": {"Name": "Alice", "Notes": "Note A"}},
|
||||
{"id": "recA2", "fields": {"Name": "Bob", "Notes": "Note B"}},
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "tblTABLE2",
|
||||
"name": "Table B",
|
||||
"primary_field_id": "fld3",
|
||||
"fields": [
|
||||
{"id": "fld3", "name": "Title", "type": "singleLineText"},
|
||||
{"id": "fld4", "name": "Status", "type": "singleSelect"},
|
||||
],
|
||||
"records": [
|
||||
{"id": "recB1", "fields": {"Title": "Task 1", "Status": "Done"}},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "appBASE2",
|
||||
"name": "Base Two",
|
||||
"tables": [
|
||||
{
|
||||
"id": "tblTABLE3",
|
||||
"name": "Table C",
|
||||
"primary_field_id": "fld5",
|
||||
"fields": [
|
||||
{"id": "fld5", "name": "Item", "type": "singleLineText"},
|
||||
],
|
||||
"records": [
|
||||
{"id": "recC1", "fields": {"Item": "Widget"}},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _collect_docs(connector: AirtableConnector) -> list[Document]:
|
||||
docs: list[Document] = []
|
||||
for batch in connector.load_from_state():
|
||||
for item in batch:
|
||||
if isinstance(item, Document):
|
||||
docs.append(item)
|
||||
return docs
|
||||
|
||||
|
||||
class TestIndexAll:
|
||||
@patch("time.sleep")
|
||||
def test_index_all_discovers_all_bases_and_tables(
|
||||
self, mock_sleep: MagicMock # noqa: ARG002
|
||||
) -> None:
|
||||
connector = AirtableConnector()
|
||||
mock_api = _setup_mock_api(SAMPLE_BASES)
|
||||
connector._airtable_client = mock_api
|
||||
|
||||
docs = _collect_docs(connector)
|
||||
|
||||
# 2 records from Table A + 1 from Table B + 1 from Table C = 4
|
||||
assert len(docs) == 4
|
||||
doc_ids = {d.id for d in docs}
|
||||
assert doc_ids == {
|
||||
"airtable__recA1",
|
||||
"airtable__recA2",
|
||||
"airtable__recB1",
|
||||
"airtable__recC1",
|
||||
}
|
||||
|
||||
@patch("time.sleep")
|
||||
def test_index_all_semantic_id_includes_base_name(
|
||||
self, mock_sleep: MagicMock # noqa: ARG002
|
||||
) -> None:
|
||||
connector = AirtableConnector()
|
||||
mock_api = _setup_mock_api(SAMPLE_BASES)
|
||||
connector._airtable_client = mock_api
|
||||
|
||||
docs = _collect_docs(connector)
|
||||
docs_by_id = {d.id: d for d in docs}
|
||||
|
||||
assert (
|
||||
docs_by_id["airtable__recA1"].semantic_identifier
|
||||
== "Base One > Table A: Alice"
|
||||
)
|
||||
assert (
|
||||
docs_by_id["airtable__recB1"].semantic_identifier
|
||||
== "Base One > Table B: Task 1"
|
||||
)
|
||||
assert (
|
||||
docs_by_id["airtable__recC1"].semantic_identifier
|
||||
== "Base Two > Table C: Widget"
|
||||
)
|
||||
|
||||
@patch("time.sleep")
|
||||
def test_index_all_hierarchy_source_path(
|
||||
self, mock_sleep: MagicMock # noqa: ARG002
|
||||
) -> None:
|
||||
"""Verify doc_metadata hierarchy source_path is [base_name, table_name]."""
|
||||
connector = AirtableConnector()
|
||||
mock_api = _setup_mock_api(SAMPLE_BASES)
|
||||
connector._airtable_client = mock_api
|
||||
|
||||
docs = _collect_docs(connector)
|
||||
docs_by_id = {d.id: d for d in docs}
|
||||
|
||||
doc_a1 = docs_by_id["airtable__recA1"]
|
||||
assert doc_a1.doc_metadata is not None
|
||||
assert doc_a1.doc_metadata["hierarchy"]["source_path"] == [
|
||||
"Base One",
|
||||
"Table A",
|
||||
]
|
||||
assert doc_a1.doc_metadata["hierarchy"]["base_name"] == "Base One"
|
||||
assert doc_a1.doc_metadata["hierarchy"]["table_name"] == "Table A"
|
||||
|
||||
doc_c1 = docs_by_id["airtable__recC1"]
|
||||
assert doc_c1.doc_metadata is not None
|
||||
assert doc_c1.doc_metadata["hierarchy"]["source_path"] == [
|
||||
"Base Two",
|
||||
"Table C",
|
||||
]
|
||||
|
||||
@patch("time.sleep")
|
||||
def test_index_all_empty_account(
|
||||
self, mock_sleep: MagicMock # noqa: ARG002
|
||||
) -> None:
|
||||
connector = AirtableConnector()
|
||||
mock_api = MagicMock()
|
||||
mock_api.bases.return_value = []
|
||||
connector._airtable_client = mock_api
|
||||
|
||||
docs = _collect_docs(connector)
|
||||
assert len(docs) == 0
|
||||
|
||||
@patch("time.sleep")
|
||||
def test_index_all_skips_failing_table(
|
||||
self, mock_sleep: MagicMock # noqa: ARG002
|
||||
) -> None:
|
||||
"""If one table fails, other tables should still be indexed."""
|
||||
bases = [
|
||||
{
|
||||
"id": "appBASE1",
|
||||
"name": "Base One",
|
||||
"tables": [
|
||||
{
|
||||
"id": "tblGOOD",
|
||||
"name": "Good Table",
|
||||
"primary_field_id": "fld1",
|
||||
"fields": [
|
||||
{"id": "fld1", "name": "Name", "type": "singleLineText"},
|
||||
],
|
||||
"records": [
|
||||
{"id": "recOK", "fields": {"Name": "Works"}},
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "tblBAD",
|
||||
"name": "Bad Table",
|
||||
"primary_field_id": "fldX",
|
||||
"fields": [],
|
||||
"records": [],
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
mock_api = _setup_mock_api(bases)
|
||||
|
||||
# Make the bad table raise an error when fetching records
|
||||
original_table_side_effect = mock_api.table.side_effect
|
||||
|
||||
def table_with_failure(base_id: str, table_name_or_id: str) -> MagicMock:
|
||||
if table_name_or_id == "tblBAD":
|
||||
mock_table = MagicMock()
|
||||
mock_table.all.side_effect = Exception("API Error")
|
||||
mock_table.schema.side_effect = Exception("API Error")
|
||||
return mock_table
|
||||
return original_table_side_effect(base_id, table_name_or_id)
|
||||
|
||||
mock_api.table.side_effect = table_with_failure
|
||||
connector = AirtableConnector()
|
||||
connector._airtable_client = mock_api
|
||||
|
||||
docs = _collect_docs(connector)
|
||||
|
||||
# Only the good table's records should come through
|
||||
assert len(docs) == 1
|
||||
assert docs[0].id == "airtable__recOK"
|
||||
|
||||
@patch("time.sleep")
|
||||
def test_index_all_skips_failing_base(
|
||||
self, mock_sleep: MagicMock # noqa: ARG002
|
||||
) -> None:
|
||||
"""If listing tables for a base fails, other bases should still be indexed."""
|
||||
bases_data = [
|
||||
{
|
||||
"id": "appGOOD",
|
||||
"name": "Good Base",
|
||||
"tables": [
|
||||
{
|
||||
"id": "tblOK",
|
||||
"name": "OK Table",
|
||||
"primary_field_id": "fld1",
|
||||
"fields": [
|
||||
{"id": "fld1", "name": "Name", "type": "singleLineText"},
|
||||
],
|
||||
"records": [
|
||||
{"id": "recOK", "fields": {"Name": "Works"}},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
mock_api = _setup_mock_api(bases_data)
|
||||
|
||||
# Add a bad base that fails on tables()
|
||||
bad_base_info = _make_base_info("appBAD", "Bad Base")
|
||||
mock_api.bases.return_value = [
|
||||
bad_base_info,
|
||||
*mock_api.bases.return_value,
|
||||
]
|
||||
|
||||
original_base_side_effect = mock_api.base.side_effect
|
||||
|
||||
def base_with_failure(base_id: str) -> MagicMock:
|
||||
if base_id == "appBAD":
|
||||
mock_base = MagicMock()
|
||||
mock_base.tables.side_effect = Exception("Permission denied")
|
||||
return mock_base
|
||||
return original_base_side_effect(base_id)
|
||||
|
||||
mock_api.base.side_effect = base_with_failure
|
||||
|
||||
connector = AirtableConnector()
|
||||
connector._airtable_client = mock_api
|
||||
|
||||
docs = _collect_docs(connector)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].id == "airtable__recOK"
|
||||
|
||||
|
||||
class TestSpecificTableMode:
|
||||
def test_specific_table_unchanged(self) -> None:
|
||||
"""Verify the original single-table behavior still works."""
|
||||
bases = [
|
||||
{
|
||||
"id": "appBASE1",
|
||||
"name": "Base One",
|
||||
"tables": [
|
||||
{
|
||||
"id": "tblTABLE1",
|
||||
"name": "Table A",
|
||||
"primary_field_id": "fld1",
|
||||
"fields": [
|
||||
{"id": "fld1", "name": "Name", "type": "singleLineText"},
|
||||
{"id": "fld2", "name": "Notes", "type": "multilineText"},
|
||||
],
|
||||
"records": [
|
||||
{
|
||||
"id": "recA1",
|
||||
"fields": {"Name": "Alice", "Notes": "Note"},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
mock_api = _setup_mock_api(bases)
|
||||
|
||||
connector = AirtableConnector(
|
||||
base_id="appBASE1",
|
||||
table_name_or_id="tblTABLE1",
|
||||
)
|
||||
connector._airtable_client = mock_api
|
||||
|
||||
docs = _collect_docs(connector)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].id == "airtable__recA1"
|
||||
# No base name prefix in specific mode
|
||||
assert docs[0].semantic_identifier == "Table A: Alice"
|
||||
|
||||
def test_specific_table_resolves_base_name_for_hierarchy(self) -> None:
|
||||
"""In specific mode, bases() is called to resolve the base name for hierarchy."""
|
||||
bases = [
|
||||
{
|
||||
"id": "appBASE1",
|
||||
"name": "Base One",
|
||||
"tables": [
|
||||
{
|
||||
"id": "tblTABLE1",
|
||||
"name": "Table A",
|
||||
"primary_field_id": "fld1",
|
||||
"fields": [
|
||||
{"id": "fld1", "name": "Name", "type": "singleLineText"},
|
||||
],
|
||||
"records": [
|
||||
{"id": "recA1", "fields": {"Name": "Test"}},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
mock_api = _setup_mock_api(bases)
|
||||
|
||||
connector = AirtableConnector(
|
||||
base_id="appBASE1",
|
||||
table_name_or_id="tblTABLE1",
|
||||
)
|
||||
connector._airtable_client = mock_api
|
||||
|
||||
docs = _collect_docs(connector)
|
||||
|
||||
# bases() is called to resolve the base name for hierarchy source_path
|
||||
mock_api.bases.assert_called_once()
|
||||
# But base().tables() should NOT be called (no discovery)
|
||||
mock_api.base.assert_not_called()
|
||||
# Semantic identifier should NOT include base name in specific mode
|
||||
assert docs[0].semantic_identifier == "Table A: Test"
|
||||
# Hierarchy should include base name for Craft file system
|
||||
assert docs[0].doc_metadata is not None
|
||||
assert docs[0].doc_metadata["hierarchy"]["source_path"] == [
|
||||
"Base One",
|
||||
"Table A",
|
||||
]
|
||||
|
||||
|
||||
class TestValidateConnectorSettings:
|
||||
def test_validate_index_all_success(self) -> None:
|
||||
connector = AirtableConnector()
|
||||
mock_api = _setup_mock_api(SAMPLE_BASES)
|
||||
connector._airtable_client = mock_api
|
||||
|
||||
# Should not raise
|
||||
connector.validate_connector_settings()
|
||||
|
||||
def test_validate_index_all_no_bases(self) -> None:
|
||||
connector = AirtableConnector()
|
||||
mock_api = MagicMock()
|
||||
mock_api.bases.return_value = []
|
||||
connector._airtable_client = mock_api
|
||||
|
||||
with pytest.raises(ConnectorValidationError, match="No bases found"):
|
||||
connector.validate_connector_settings()
|
||||
|
||||
def test_validate_specific_table_success(self) -> None:
|
||||
connector = AirtableConnector(
|
||||
base_id="appBASE1",
|
||||
table_name_or_id="tblTABLE1",
|
||||
)
|
||||
mock_api = _setup_mock_api(SAMPLE_BASES)
|
||||
connector._airtable_client = mock_api
|
||||
|
||||
# Should not raise
|
||||
connector.validate_connector_settings()
|
||||
|
||||
def test_validate_empty_fields_auto_detects_index_all(self) -> None:
|
||||
"""Empty base_id + table_name_or_id auto-detects as index_all mode."""
|
||||
connector = AirtableConnector(
|
||||
base_id="",
|
||||
table_name_or_id="",
|
||||
)
|
||||
assert connector.index_all is True
|
||||
|
||||
# Validation should go through the index_all path
|
||||
mock_api = _setup_mock_api(SAMPLE_BASES)
|
||||
connector._airtable_client = mock_api
|
||||
connector.validate_connector_settings()
|
||||
|
||||
def test_validate_specific_table_api_error(self) -> None:
|
||||
connector = AirtableConnector(
|
||||
base_id="appBAD",
|
||||
table_name_or_id="tblBAD",
|
||||
)
|
||||
mock_api = MagicMock()
|
||||
mock_table = MagicMock()
|
||||
mock_table.schema.side_effect = Exception("Not found")
|
||||
mock_api.table.return_value = mock_table
|
||||
connector._airtable_client = mock_api
|
||||
|
||||
with pytest.raises(ConnectorValidationError, match="Failed to access table"):
|
||||
connector.validate_connector_settings()
|
||||
|
||||
|
||||
class TestParseAirtableUrl:
|
||||
def test_full_url_with_view(self) -> None:
|
||||
base_id, table_id, view_id = parse_airtable_url(
|
||||
"https://airtable.com/appZqBgQFQ6kWyeZK/tblc9prNLypy7olTV/viwa3yxZvqWnyXftm?blocks=hide"
|
||||
)
|
||||
assert base_id == "appZqBgQFQ6kWyeZK"
|
||||
assert table_id == "tblc9prNLypy7olTV"
|
||||
assert view_id == "viwa3yxZvqWnyXftm"
|
||||
|
||||
def test_url_without_view(self) -> None:
|
||||
base_id, table_id, view_id = parse_airtable_url(
|
||||
"https://airtable.com/appZqBgQFQ6kWyeZK/tblc9prNLypy7olTV"
|
||||
)
|
||||
assert base_id == "appZqBgQFQ6kWyeZK"
|
||||
assert table_id == "tblc9prNLypy7olTV"
|
||||
assert view_id is None
|
||||
|
||||
def test_url_without_query_params(self) -> None:
|
||||
base_id, table_id, view_id = parse_airtable_url(
|
||||
"https://airtable.com/appABC123/tblDEF456/viwGHI789"
|
||||
)
|
||||
assert base_id == "appABC123"
|
||||
assert table_id == "tblDEF456"
|
||||
assert view_id == "viwGHI789"
|
||||
|
||||
def test_url_with_trailing_whitespace(self) -> None:
|
||||
base_id, table_id, view_id = parse_airtable_url(
|
||||
" https://airtable.com/appABC123/tblDEF456 "
|
||||
)
|
||||
assert base_id == "appABC123"
|
||||
assert table_id == "tblDEF456"
|
||||
|
||||
def test_invalid_url_raises(self) -> None:
|
||||
with pytest.raises(ValueError, match="Could not parse"):
|
||||
parse_airtable_url("https://google.com/something")
|
||||
|
||||
def test_missing_table_raises(self) -> None:
|
||||
with pytest.raises(ValueError, match="Could not parse"):
|
||||
parse_airtable_url("https://airtable.com/appABC123")
|
||||
|
||||
def test_empty_string_raises(self) -> None:
|
||||
with pytest.raises(ValueError, match="Could not parse"):
|
||||
parse_airtable_url("")
|
||||
|
||||
|
||||
class TestAirtableUrlConnector:
|
||||
def test_url_sets_base_and_table_ids(self) -> None:
|
||||
connector = AirtableConnector(
|
||||
airtable_url="https://airtable.com/appZqBgQFQ6kWyeZK/tblc9prNLypy7olTV/viwa3yxZvqWnyXftm?blocks=hide"
|
||||
)
|
||||
assert connector.base_id == "appZqBgQFQ6kWyeZK"
|
||||
assert connector.table_name_or_id == "tblc9prNLypy7olTV"
|
||||
assert connector.view_id == "viwa3yxZvqWnyXftm"
|
||||
|
||||
def test_url_without_view_leaves_view_none(self) -> None:
|
||||
connector = AirtableConnector(airtable_url="https://airtable.com/appABC/tblDEF")
|
||||
assert connector.base_id == "appABC"
|
||||
assert connector.table_name_or_id == "tblDEF"
|
||||
assert connector.view_id is None
|
||||
|
||||
def test_url_overrides_explicit_base_and_table(self) -> None:
|
||||
connector = AirtableConnector(
|
||||
base_id="appOLD",
|
||||
table_name_or_id="tblOLD",
|
||||
airtable_url="https://airtable.com/appNEW/tblNEW",
|
||||
)
|
||||
assert connector.base_id == "appNEW"
|
||||
assert connector.table_name_or_id == "tblNEW"
|
||||
|
||||
def test_url_indexes_correctly(self) -> None:
|
||||
"""End-to-end: URL-configured connector fetches from the right table."""
|
||||
bases = [
|
||||
{
|
||||
"id": "appFromUrl",
|
||||
"name": "URL Base",
|
||||
"tables": [
|
||||
{
|
||||
"id": "tblFromUrl",
|
||||
"name": "URL Table",
|
||||
"primary_field_id": "fld1",
|
||||
"fields": [
|
||||
{"id": "fld1", "name": "Name", "type": "singleLineText"},
|
||||
],
|
||||
"records": [
|
||||
{"id": "recURL1", "fields": {"Name": "From URL"}},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
mock_api = _setup_mock_api(bases)
|
||||
|
||||
connector = AirtableConnector(
|
||||
airtable_url="https://airtable.com/appFromUrl/tblFromUrl/viwABC"
|
||||
)
|
||||
connector._airtable_client = mock_api
|
||||
|
||||
docs = _collect_docs(connector)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].id == "airtable__recURL1"
|
||||
assert docs[0].semantic_identifier == "URL Table: From URL"
|
||||
@@ -0,0 +1,136 @@
|
||||
"""Unit tests for SharepointConnector._create_rest_client_context caching."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.connectors.sharepoint.connector import _REST_CTX_MAX_AGE_S
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
|
||||
SITE_A = "https://tenant.sharepoint.com/sites/SiteA"
|
||||
SITE_B = "https://tenant.sharepoint.com/sites/SiteB"
|
||||
|
||||
FAKE_CREDS = {"sp_client_id": "x", "sp_directory_id": "y"}
|
||||
|
||||
|
||||
def _make_connector() -> SharepointConnector:
|
||||
"""Return a SharepointConnector with minimal credentials wired up."""
|
||||
connector = SharepointConnector(sites=[SITE_A])
|
||||
connector.msal_app = MagicMock()
|
||||
connector.sp_tenant_domain = "tenant"
|
||||
connector._credential_json = FAKE_CREDS
|
||||
return connector
|
||||
|
||||
|
||||
def _noop_load_credentials(connector: SharepointConnector) -> MagicMock:
|
||||
"""Patch load_credentials to just swap in a fresh MagicMock for msal_app."""
|
||||
|
||||
def _fake_load(creds: dict) -> None: # noqa: ARG001, ARG002
|
||||
connector.msal_app = MagicMock()
|
||||
|
||||
mock = MagicMock(side_effect=_fake_load)
|
||||
connector.load_credentials = mock # type: ignore[method-assign]
|
||||
return mock
|
||||
|
||||
|
||||
def _fresh_client_context() -> MagicMock:
|
||||
"""Return a MagicMock for ClientContext that produces a distinct object per call."""
|
||||
mock_cls = MagicMock()
|
||||
# Each ClientContext(url).with_access_token(cb) returns a unique sentinel
|
||||
mock_cls.side_effect = lambda url: MagicMock() # noqa: ARG005
|
||||
return mock_cls
|
||||
|
||||
|
||||
@patch("onyx.connectors.sharepoint.connector.acquire_token_for_rest")
|
||||
@patch("onyx.connectors.sharepoint.connector.ClientContext")
|
||||
def test_returns_cached_context_within_max_age(
|
||||
mock_client_ctx_cls: MagicMock,
|
||||
_mock_acquire: MagicMock,
|
||||
) -> None:
|
||||
"""Repeated calls with the same site_url within the TTL return the same object."""
|
||||
mock_client_ctx_cls.side_effect = lambda url: MagicMock() # noqa: ARG005
|
||||
connector = _make_connector()
|
||||
_noop_load_credentials(connector)
|
||||
|
||||
ctx1 = connector._create_rest_client_context(SITE_A)
|
||||
ctx2 = connector._create_rest_client_context(SITE_A)
|
||||
|
||||
assert ctx1 is ctx2
|
||||
assert mock_client_ctx_cls.call_count == 1
|
||||
|
||||
|
||||
@patch("onyx.connectors.sharepoint.connector.time")
|
||||
@patch("onyx.connectors.sharepoint.connector.acquire_token_for_rest")
|
||||
@patch("onyx.connectors.sharepoint.connector.ClientContext")
|
||||
def test_rebuilds_context_after_max_age(
|
||||
mock_client_ctx_cls: MagicMock,
|
||||
_mock_acquire: MagicMock,
|
||||
mock_time: MagicMock,
|
||||
) -> None:
|
||||
"""After _REST_CTX_MAX_AGE_S the cached context is replaced."""
|
||||
mock_client_ctx_cls.side_effect = lambda url: MagicMock() # noqa: ARG005
|
||||
connector = _make_connector()
|
||||
_noop_load_credentials(connector)
|
||||
|
||||
mock_time.monotonic.return_value = 0.0
|
||||
ctx1 = connector._create_rest_client_context(SITE_A)
|
||||
|
||||
# Just past the boundary — should rebuild
|
||||
mock_time.monotonic.return_value = _REST_CTX_MAX_AGE_S + 1
|
||||
ctx2 = connector._create_rest_client_context(SITE_A)
|
||||
|
||||
assert ctx1 is not ctx2
|
||||
assert mock_client_ctx_cls.call_count == 2
|
||||
|
||||
|
||||
@patch("onyx.connectors.sharepoint.connector.acquire_token_for_rest")
|
||||
@patch("onyx.connectors.sharepoint.connector.ClientContext")
|
||||
def test_rebuilds_context_on_site_change(
|
||||
mock_client_ctx_cls: MagicMock,
|
||||
_mock_acquire: MagicMock,
|
||||
) -> None:
|
||||
"""Switching to a different site_url forces a new context."""
|
||||
mock_client_ctx_cls.side_effect = lambda url: MagicMock() # noqa: ARG005
|
||||
connector = _make_connector()
|
||||
_noop_load_credentials(connector)
|
||||
|
||||
ctx_a = connector._create_rest_client_context(SITE_A)
|
||||
ctx_b = connector._create_rest_client_context(SITE_B)
|
||||
|
||||
assert ctx_a is not ctx_b
|
||||
assert mock_client_ctx_cls.call_count == 2
|
||||
|
||||
|
||||
@patch("onyx.connectors.sharepoint.connector.time")
|
||||
@patch("onyx.connectors.sharepoint.connector.acquire_token_for_rest")
|
||||
@patch("onyx.connectors.sharepoint.connector.ClientContext")
|
||||
def test_load_credentials_called_on_rebuild(
|
||||
_mock_client_ctx_cls: MagicMock,
|
||||
_mock_acquire: MagicMock,
|
||||
mock_time: MagicMock,
|
||||
) -> None:
|
||||
"""load_credentials is called every time the context is rebuilt."""
|
||||
_mock_client_ctx_cls.side_effect = lambda url: MagicMock() # noqa: ARG005
|
||||
connector = _make_connector()
|
||||
mock_load = _noop_load_credentials(connector)
|
||||
|
||||
# First call — rebuild (no cache yet)
|
||||
mock_time.monotonic.return_value = 0.0
|
||||
connector._create_rest_client_context(SITE_A)
|
||||
assert mock_load.call_count == 1
|
||||
|
||||
# Second call — cache hit, no rebuild
|
||||
mock_time.monotonic.return_value = 100.0
|
||||
connector._create_rest_client_context(SITE_A)
|
||||
assert mock_load.call_count == 1
|
||||
|
||||
# Third call — expired, rebuild
|
||||
mock_time.monotonic.return_value = _REST_CTX_MAX_AGE_S + 1
|
||||
connector._create_rest_client_context(SITE_A)
|
||||
assert mock_load.call_count == 2
|
||||
|
||||
# Fourth call — site change, rebuild
|
||||
mock_time.monotonic.return_value = _REST_CTX_MAX_AGE_S + 2
|
||||
connector._create_rest_client_context(SITE_B)
|
||||
assert mock_load.call_count == 3
|
||||
@@ -1,3 +1,7 @@
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import ANY
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -137,42 +141,44 @@ def default_multi_llm() -> LitellmLLM:
|
||||
def test_multiple_tool_calls(default_multi_llm: LitellmLLM) -> None:
|
||||
# Mock the litellm.completion function
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
# Create a mock response with multiple tool calls using litellm objects
|
||||
mock_response = litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
finish_reason="tool_calls",
|
||||
index=0,
|
||||
message=litellm.Message(
|
||||
content=None,
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
litellm.ChatCompletionMessageToolCall(
|
||||
id="call_1",
|
||||
function=LiteLLMFunction(
|
||||
name="get_weather",
|
||||
arguments='{"location": "New York"}',
|
||||
# invoke() internally uses stream=True and reassembles via
|
||||
# stream_chunk_builder, so the mock must return stream chunks.
|
||||
mock_stream_chunks = [
|
||||
litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
delta=_create_delta(
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
ChatCompletionDeltaToolCall(
|
||||
id="call_1",
|
||||
function=LiteLLMFunction(
|
||||
name="get_weather",
|
||||
arguments='{"location": "New York"}',
|
||||
),
|
||||
type="function",
|
||||
index=0,
|
||||
),
|
||||
type="function",
|
||||
),
|
||||
litellm.ChatCompletionMessageToolCall(
|
||||
id="call_2",
|
||||
function=LiteLLMFunction(
|
||||
name="get_time", arguments='{"timezone": "EST"}'
|
||||
ChatCompletionDeltaToolCall(
|
||||
id="call_2",
|
||||
function=LiteLLMFunction(
|
||||
name="get_time",
|
||||
arguments='{"timezone": "EST"}',
|
||||
),
|
||||
type="function",
|
||||
index=1,
|
||||
),
|
||||
type="function",
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
model="gpt-3.5-turbo",
|
||||
usage=litellm.Usage(
|
||||
prompt_tokens=50, completion_tokens=30, total_tokens=80
|
||||
],
|
||||
),
|
||||
finish_reason="tool_calls",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model="gpt-3.5-turbo",
|
||||
),
|
||||
)
|
||||
mock_completion.return_value = mock_response
|
||||
]
|
||||
mock_completion.return_value = mock_stream_chunks
|
||||
|
||||
# Define input messages
|
||||
messages: LanguageModelInput = [
|
||||
@@ -246,11 +252,12 @@ def test_multiple_tool_calls(default_multi_llm: LitellmLLM) -> None:
|
||||
],
|
||||
tools=tools,
|
||||
tool_choice=None,
|
||||
stream=False,
|
||||
stream=True,
|
||||
temperature=0.0, # Default value from GEN_AI_TEMPERATURE
|
||||
timeout=30,
|
||||
max_tokens=None,
|
||||
client=ANY, # HTTPHandler instance created per-request
|
||||
stream_options={"include_usage": True},
|
||||
parallel_tool_calls=True,
|
||||
mock_response=MOCK_LLM_RESPONSE,
|
||||
allowed_openai_params=["tool_choice"],
|
||||
@@ -507,21 +514,20 @@ def test_openai_chat_omits_reasoning_params() -> None:
|
||||
"onyx.llm.multi_llm.is_true_openai_model", return_value=True
|
||||
) as mock_is_openai,
|
||||
):
|
||||
mock_response = litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=litellm.Message(
|
||||
content="Hello",
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
model="gpt-5-chat",
|
||||
)
|
||||
mock_completion.return_value = mock_response
|
||||
mock_stream_chunks = [
|
||||
litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
delta=_create_delta(role="assistant", content="Hello"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model="gpt-5-chat",
|
||||
),
|
||||
]
|
||||
mock_completion.return_value = mock_stream_chunks
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Hi")]
|
||||
llm.invoke(messages)
|
||||
@@ -539,21 +545,20 @@ def test_user_identity_metadata_enabled(default_multi_llm: LitellmLLM) -> None:
|
||||
patch("litellm.completion") as mock_completion,
|
||||
patch("onyx.llm.utils.SEND_USER_METADATA_TO_LLM_PROVIDER", True),
|
||||
):
|
||||
mock_response = litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=litellm.Message(
|
||||
content="Hello",
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
model="gpt-3.5-turbo",
|
||||
)
|
||||
mock_completion.return_value = mock_response
|
||||
mock_stream_chunks = [
|
||||
litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
delta=_create_delta(role="assistant", content="Hello"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model="gpt-3.5-turbo",
|
||||
),
|
||||
]
|
||||
mock_completion.return_value = mock_stream_chunks
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Hi")]
|
||||
identity = LLMUserIdentity(user_id="user_123", session_id="session_abc")
|
||||
@@ -573,21 +578,20 @@ def test_user_identity_user_id_truncated_to_64_chars(
|
||||
patch("litellm.completion") as mock_completion,
|
||||
patch("onyx.llm.utils.SEND_USER_METADATA_TO_LLM_PROVIDER", True),
|
||||
):
|
||||
mock_response = litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=litellm.Message(
|
||||
content="Hello",
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
model="gpt-3.5-turbo",
|
||||
)
|
||||
mock_completion.return_value = mock_response
|
||||
mock_stream_chunks = [
|
||||
litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
delta=_create_delta(role="assistant", content="Hello"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model="gpt-3.5-turbo",
|
||||
),
|
||||
]
|
||||
mock_completion.return_value = mock_stream_chunks
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Hi")]
|
||||
long_user_id = "u" * 82
|
||||
@@ -607,21 +611,20 @@ def test_user_identity_metadata_disabled_omits_identity(
|
||||
patch("litellm.completion") as mock_completion,
|
||||
patch("onyx.llm.utils.SEND_USER_METADATA_TO_LLM_PROVIDER", False),
|
||||
):
|
||||
mock_response = litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=litellm.Message(
|
||||
content="Hello",
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
model="gpt-3.5-turbo",
|
||||
)
|
||||
mock_completion.return_value = mock_response
|
||||
mock_stream_chunks = [
|
||||
litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
delta=_create_delta(role="assistant", content="Hello"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model="gpt-3.5-turbo",
|
||||
),
|
||||
]
|
||||
mock_completion.return_value = mock_stream_chunks
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Hi")]
|
||||
identity = LLMUserIdentity(user_id="user_123", session_id="session_abc")
|
||||
@@ -654,21 +657,20 @@ def test_existing_metadata_pass_through_when_identity_disabled() -> None:
|
||||
patch("litellm.completion") as mock_completion,
|
||||
patch("onyx.llm.utils.SEND_USER_METADATA_TO_LLM_PROVIDER", False),
|
||||
):
|
||||
mock_response = litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=litellm.Message(
|
||||
content="Hello",
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
model="gpt-3.5-turbo",
|
||||
)
|
||||
mock_completion.return_value = mock_response
|
||||
mock_stream_chunks = [
|
||||
litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
delta=_create_delta(role="assistant", content="Hello"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model="gpt-3.5-turbo",
|
||||
),
|
||||
]
|
||||
mock_completion.return_value = mock_stream_chunks
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Hi")]
|
||||
identity = LLMUserIdentity(user_id="user_123", session_id="session_abc")
|
||||
@@ -688,18 +690,20 @@ def test_openai_model_invoke_uses_httphandler_client(
|
||||
from litellm import HTTPHandler
|
||||
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
mock_response = litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=litellm.Message(content="Hello", role="assistant"),
|
||||
)
|
||||
],
|
||||
model="gpt-3.5-turbo",
|
||||
)
|
||||
mock_completion.return_value = mock_response
|
||||
mock_stream_chunks = [
|
||||
litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
delta=_create_delta(role="assistant", content="Hello"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model="gpt-3.5-turbo",
|
||||
),
|
||||
]
|
||||
mock_completion.return_value = mock_stream_chunks
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Hi")]
|
||||
default_multi_llm.invoke(messages)
|
||||
@@ -737,18 +741,20 @@ def test_anthropic_model_passes_no_client() -> None:
|
||||
)
|
||||
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
mock_response = litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=litellm.Message(content="Hello", role="assistant"),
|
||||
)
|
||||
],
|
||||
model="claude-3-opus-20240229",
|
||||
)
|
||||
mock_completion.return_value = mock_response
|
||||
mock_stream_chunks = [
|
||||
litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
delta=_create_delta(role="assistant", content="Hello"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model="claude-3-opus-20240229",
|
||||
),
|
||||
]
|
||||
mock_completion.return_value = mock_stream_chunks
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Hi")]
|
||||
llm.invoke(messages)
|
||||
@@ -769,18 +775,20 @@ def test_bedrock_model_passes_no_client() -> None:
|
||||
)
|
||||
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
mock_response = litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=litellm.Message(content="Hello", role="assistant"),
|
||||
)
|
||||
],
|
||||
model="anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
)
|
||||
mock_completion.return_value = mock_response
|
||||
mock_stream_chunks = [
|
||||
litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
delta=_create_delta(role="assistant", content="Hello"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model="anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
),
|
||||
]
|
||||
mock_completion.return_value = mock_stream_chunks
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Hi")]
|
||||
llm.invoke(messages)
|
||||
@@ -809,18 +817,20 @@ def test_azure_openai_model_uses_httphandler_client() -> None:
|
||||
)
|
||||
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
mock_response = litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=litellm.Message(content="Hello", role="assistant"),
|
||||
)
|
||||
],
|
||||
model="gpt-4o",
|
||||
)
|
||||
mock_completion.return_value = mock_response
|
||||
mock_stream_chunks = [
|
||||
litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
delta=_create_delta(role="assistant", content="Hello"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model="gpt-4o",
|
||||
),
|
||||
]
|
||||
mock_completion.return_value = mock_stream_chunks
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Hi")]
|
||||
llm.invoke(messages)
|
||||
@@ -828,3 +838,372 @@ def test_azure_openai_model_uses_httphandler_client() -> None:
|
||||
mock_completion.assert_called_once()
|
||||
kwargs = mock_completion.call_args.kwargs
|
||||
assert isinstance(kwargs["client"], HTTPHandler)
|
||||
|
||||
|
||||
def test_temporary_env_cleanup(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Assign some environment variables
|
||||
EXPECTED_ENV_VARS = {
|
||||
"TEST_ENV_VAR": "test_value",
|
||||
"ANOTHER_ONE": "1",
|
||||
"THIRD_ONE": "2",
|
||||
}
|
||||
|
||||
CUSTOM_CONFIG = {
|
||||
"TEST_ENV_VAR": "fdsfsdf",
|
||||
"ANOTHER_ONE": "3",
|
||||
"THIS_IS_RANDOM": "123213",
|
||||
}
|
||||
|
||||
for env_var, value in EXPECTED_ENV_VARS.items():
|
||||
monkeypatch.setenv(env_var, value)
|
||||
|
||||
model_provider = LlmProviderNames.OPENAI
|
||||
model_name = "gpt-3.5-turbo"
|
||||
|
||||
llm = LitellmLLM(
|
||||
api_key="test_key",
|
||||
timeout=30,
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
max_input_tokens=get_max_input_tokens(
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
),
|
||||
model_kwargs={"metadata": {"foo": "bar"}},
|
||||
custom_config=CUSTOM_CONFIG,
|
||||
)
|
||||
|
||||
# When custom_config is set, invoke() internally uses stream=True and
|
||||
# reassembles via stream_chunk_builder, so the mock must return stream chunks.
|
||||
mock_stream_chunks = [
|
||||
litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
delta=_create_delta(role="assistant", content="Hello"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model="gpt-3.5-turbo",
|
||||
),
|
||||
]
|
||||
|
||||
def on_litellm_completion(
|
||||
**kwargs: dict[str, Any], # noqa: ARG001
|
||||
) -> list[litellm.ModelResponse]:
|
||||
# Validate that the environment variables are those in custom config
|
||||
for env_var, value in CUSTOM_CONFIG.items():
|
||||
assert env_var in os.environ
|
||||
assert os.environ[env_var] == value
|
||||
|
||||
return mock_stream_chunks
|
||||
|
||||
with (
|
||||
patch("litellm.completion") as mock_completion,
|
||||
patch("onyx.llm.utils.SEND_USER_METADATA_TO_LLM_PROVIDER", False),
|
||||
):
|
||||
mock_completion.side_effect = on_litellm_completion
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Hi")]
|
||||
identity = LLMUserIdentity(user_id="user_123", session_id="session_abc")
|
||||
|
||||
llm.invoke(messages, user_identity=identity)
|
||||
|
||||
mock_completion.assert_called_once()
|
||||
kwargs = mock_completion.call_args.kwargs
|
||||
assert kwargs["stream"] is True
|
||||
assert "user" not in kwargs
|
||||
assert kwargs["metadata"]["foo"] == "bar"
|
||||
|
||||
# Check that the environment variables are back to the original values
|
||||
for env_var, value in EXPECTED_ENV_VARS.items():
|
||||
assert env_var in os.environ
|
||||
assert os.environ[env_var] == value
|
||||
|
||||
# Check that temporary env var from CUSTOM_CONFIG is no longer set
|
||||
assert "THIS_IS_RANDOM" not in os.environ
|
||||
|
||||
|
||||
def test_temporary_env_cleanup_on_exception(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Verify env vars are restored even when an exception occurs during LLM invocation."""
|
||||
# Assign some environment variables
|
||||
EXPECTED_ENV_VARS = {
|
||||
"TEST_ENV_VAR": "test_value",
|
||||
"ANOTHER_ONE": "1",
|
||||
"THIRD_ONE": "2",
|
||||
}
|
||||
|
||||
CUSTOM_CONFIG = {
|
||||
"TEST_ENV_VAR": "fdsfsdf",
|
||||
"ANOTHER_ONE": "3",
|
||||
"THIS_IS_RANDOM": "123213",
|
||||
}
|
||||
|
||||
for env_var, value in EXPECTED_ENV_VARS.items():
|
||||
monkeypatch.setenv(env_var, value)
|
||||
|
||||
model_provider = LlmProviderNames.OPENAI
|
||||
model_name = "gpt-3.5-turbo"
|
||||
|
||||
llm = LitellmLLM(
|
||||
api_key="test_key",
|
||||
timeout=30,
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
max_input_tokens=get_max_input_tokens(
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
),
|
||||
model_kwargs={"metadata": {"foo": "bar"}},
|
||||
custom_config=CUSTOM_CONFIG,
|
||||
)
|
||||
|
||||
def on_litellm_completion_raises(**kwargs: dict[str, Any]) -> None: # noqa: ARG001
|
||||
# Validate that the environment variables are those in custom config
|
||||
for env_var, value in CUSTOM_CONFIG.items():
|
||||
assert env_var in os.environ
|
||||
assert os.environ[env_var] == value
|
||||
|
||||
# Simulate an error during LLM call
|
||||
raise RuntimeError("Simulated LLM API failure")
|
||||
|
||||
with (
|
||||
patch("litellm.completion") as mock_completion,
|
||||
patch("onyx.llm.utils.SEND_USER_METADATA_TO_LLM_PROVIDER", False),
|
||||
):
|
||||
mock_completion.side_effect = on_litellm_completion_raises
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Hi")]
|
||||
identity = LLMUserIdentity(user_id="user_123", session_id="session_abc")
|
||||
|
||||
with pytest.raises(RuntimeError, match="Simulated LLM API failure"):
|
||||
llm.invoke(messages, user_identity=identity)
|
||||
|
||||
mock_completion.assert_called_once()
|
||||
|
||||
# Check that the environment variables are back to the original values
|
||||
for env_var, value in EXPECTED_ENV_VARS.items():
|
||||
assert env_var in os.environ
|
||||
assert os.environ[env_var] == value
|
||||
|
||||
# Check that temporary env var from CUSTOM_CONFIG is no longer set
|
||||
assert "THIS_IS_RANDOM" not in os.environ
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_stream", [False, True], ids=["invoke", "stream"])
|
||||
def test_multithreaded_custom_config_isolation(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
use_stream: bool,
|
||||
) -> None:
|
||||
"""Verify the env lock prevents concurrent LLM calls from seeing each other's custom_config.
|
||||
|
||||
Two LitellmLLM instances with different custom_config dicts call invoke/stream
|
||||
concurrently. The _env_lock in temporary_env_and_lock serializes their access so
|
||||
each call only ever sees its own env vars—never the other's.
|
||||
"""
|
||||
# Ensure these keys start unset
|
||||
monkeypatch.delenv("SHARED_KEY", raising=False)
|
||||
monkeypatch.delenv("LLM_A_ONLY", raising=False)
|
||||
monkeypatch.delenv("LLM_B_ONLY", raising=False)
|
||||
|
||||
CONFIG_A = {
|
||||
"SHARED_KEY": "value_from_A",
|
||||
"LLM_A_ONLY": "a_secret",
|
||||
}
|
||||
CONFIG_B = {
|
||||
"SHARED_KEY": "value_from_B",
|
||||
"LLM_B_ONLY": "b_secret",
|
||||
}
|
||||
|
||||
all_env_keys = list(set(list(CONFIG_A.keys()) + list(CONFIG_B.keys())))
|
||||
|
||||
model_provider = LlmProviderNames.OPENAI
|
||||
model_name = "gpt-3.5-turbo"
|
||||
|
||||
llm_a = LitellmLLM(
|
||||
api_key="key_a",
|
||||
timeout=30,
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
max_input_tokens=get_max_input_tokens(
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
),
|
||||
custom_config=CONFIG_A,
|
||||
)
|
||||
llm_b = LitellmLLM(
|
||||
api_key="key_b",
|
||||
timeout=30,
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
max_input_tokens=get_max_input_tokens(
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
),
|
||||
custom_config=CONFIG_B,
|
||||
)
|
||||
|
||||
# Both invoke (with custom_config) and stream use stream=True at the
|
||||
# litellm level, so the mock must return stream chunks.
|
||||
mock_stream_chunks = [
|
||||
litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
delta=_create_delta(role="assistant", content="Hi"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model=model_name,
|
||||
),
|
||||
]
|
||||
|
||||
# Track what each call observed inside litellm.completion.
|
||||
# Keyed by api_key so we can identify which LLM instance made the call.
|
||||
observed_envs: dict[str, dict[str, str | None]] = {}
|
||||
|
||||
def fake_completion(**kwargs: Any) -> list[litellm.ModelResponse]:
|
||||
time.sleep(0.1) # We expect someone to get caught on the lock
|
||||
api_key = kwargs.get("api_key", "")
|
||||
label = "A" if api_key == "key_a" else "B"
|
||||
|
||||
snapshot: dict[str, str | None] = {}
|
||||
for key in all_env_keys:
|
||||
snapshot[key] = os.environ.get(key)
|
||||
observed_envs[label] = snapshot
|
||||
|
||||
return mock_stream_chunks
|
||||
|
||||
errors: list[Exception] = []
|
||||
|
||||
def run_llm(llm: LitellmLLM) -> None:
|
||||
try:
|
||||
messages: LanguageModelInput = [UserMessage(content="Hi")]
|
||||
if use_stream:
|
||||
list(llm.stream(messages))
|
||||
else:
|
||||
llm.invoke(messages)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
with patch("litellm.completion", side_effect=fake_completion):
|
||||
t_a = threading.Thread(target=run_llm, args=(llm_a,))
|
||||
t_b = threading.Thread(target=run_llm, args=(llm_b,))
|
||||
|
||||
t_a.start()
|
||||
t_b.start()
|
||||
t_a.join(timeout=10)
|
||||
t_b.join(timeout=10)
|
||||
|
||||
assert not errors, f"Thread errors: {errors}"
|
||||
assert "A" in observed_envs and "B" in observed_envs
|
||||
|
||||
# Thread A must have seen its own config for SHARED_KEY, not B's
|
||||
assert observed_envs["A"]["SHARED_KEY"] == "value_from_A"
|
||||
assert observed_envs["A"]["LLM_A_ONLY"] == "a_secret"
|
||||
# A must NOT see B's exclusive key
|
||||
assert observed_envs["A"]["LLM_B_ONLY"] is None
|
||||
|
||||
# Thread B must have seen its own config for SHARED_KEY, not A's
|
||||
assert observed_envs["B"]["SHARED_KEY"] == "value_from_B"
|
||||
assert observed_envs["B"]["LLM_B_ONLY"] == "b_secret"
|
||||
# B must NOT see A's exclusive key
|
||||
assert observed_envs["B"]["LLM_A_ONLY"] is None
|
||||
|
||||
# After both calls, env should be clean
|
||||
assert os.environ.get("SHARED_KEY") is None
|
||||
assert os.environ.get("LLM_A_ONLY") is None
|
||||
assert os.environ.get("LLM_B_ONLY") is None
|
||||
|
||||
|
||||
def test_multithreaded_invoke_without_custom_config_skips_env_lock() -> None:
|
||||
"""Verify that invoke() without custom_config does not acquire the env lock.
|
||||
|
||||
Two LitellmLLM instances without custom_config call invoke concurrently.
|
||||
Both should run with stream=False, never touch the env lock, and complete
|
||||
without blocking each other.
|
||||
"""
|
||||
from onyx.llm import multi_llm as multi_llm_module
|
||||
|
||||
model_provider = LlmProviderNames.OPENAI
|
||||
model_name = "gpt-3.5-turbo"
|
||||
|
||||
llm_a = LitellmLLM(
|
||||
api_key="key_a",
|
||||
timeout=30,
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
max_input_tokens=get_max_input_tokens(
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
),
|
||||
)
|
||||
llm_b = LitellmLLM(
|
||||
api_key="key_b",
|
||||
timeout=30,
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
max_input_tokens=get_max_input_tokens(
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
),
|
||||
)
|
||||
|
||||
mock_stream_chunks = [
|
||||
litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
delta=_create_delta(role="assistant", content="Hi"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model=model_name,
|
||||
),
|
||||
]
|
||||
|
||||
call_kwargs: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def fake_completion(**kwargs: Any) -> list[litellm.ModelResponse]:
|
||||
api_key = kwargs.get("api_key", "")
|
||||
label = "A" if api_key == "key_a" else "B"
|
||||
call_kwargs[label] = kwargs
|
||||
return mock_stream_chunks
|
||||
|
||||
errors: list[Exception] = []
|
||||
|
||||
def run_llm(llm: LitellmLLM) -> None:
|
||||
try:
|
||||
messages: LanguageModelInput = [UserMessage(content="Hi")]
|
||||
llm.invoke(messages)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
with (
|
||||
patch("litellm.completion", side_effect=fake_completion),
|
||||
patch.object(
|
||||
multi_llm_module,
|
||||
"temporary_env_and_lock",
|
||||
wraps=multi_llm_module.temporary_env_and_lock,
|
||||
) as mock_env_lock,
|
||||
):
|
||||
t_a = threading.Thread(target=run_llm, args=(llm_a,))
|
||||
t_b = threading.Thread(target=run_llm, args=(llm_b,))
|
||||
|
||||
t_a.start()
|
||||
t_b.start()
|
||||
t_a.join(timeout=10)
|
||||
t_b.join(timeout=10)
|
||||
|
||||
assert not errors, f"Thread errors: {errors}"
|
||||
assert "A" in call_kwargs and "B" in call_kwargs
|
||||
|
||||
# invoke() always uses stream=True internally (reassembles via stream_chunk_builder)
|
||||
assert call_kwargs["A"]["stream"] is True
|
||||
assert call_kwargs["B"]["stream"] is True
|
||||
|
||||
# The env lock context manager should never have been called
|
||||
mock_env_lock.assert_not_called()
|
||||
|
||||
15
backend/tests/unit/onyx/prompts/test_prompt_utils.py
Normal file
15
backend/tests/unit/onyx/prompts/test_prompt_utils.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from onyx.prompts.constants import REMINDER_TAG_DESCRIPTION
|
||||
from onyx.prompts.prompt_utils import replace_reminder_tag
|
||||
|
||||
|
||||
def test_replace_reminder_tag_pattern() -> None:
|
||||
prompt = "Some text {{REMINDER_TAG_DESCRIPTION}} more text"
|
||||
result = replace_reminder_tag(prompt)
|
||||
assert "{{REMINDER_TAG_DESCRIPTION}}" not in result
|
||||
assert REMINDER_TAG_DESCRIPTION in result
|
||||
|
||||
|
||||
def test_replace_reminder_tag_no_pattern() -> None:
|
||||
prompt = "Some text without any pattern"
|
||||
result = replace_reminder_tag(prompt)
|
||||
assert result == prompt
|
||||
0
backend/tests/unit/onyx/server/__init__.py
Normal file
0
backend/tests/unit/onyx/server/__init__.py
Normal file
0
backend/tests/unit/onyx/server/features/__init__.py
Normal file
0
backend/tests/unit/onyx/server/features/__init__.py
Normal file
@@ -0,0 +1,374 @@
|
||||
"""Unit tests for Zed-style ACP session management in KubernetesSandboxManager.
|
||||
|
||||
These tests verify that the KubernetesSandboxManager correctly:
|
||||
- Maintains one shared ACPExecClient per sandbox
|
||||
- Maps craft sessions to ACP sessions on the shared client
|
||||
- Replaces dead clients and re-creates sessions
|
||||
- Cleans up on terminate/cleanup
|
||||
|
||||
All external dependencies (K8s, WebSockets, packet logging) are mocked.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# The fully-qualified path to the module under test, used for patching
|
||||
_K8S_MODULE = "onyx.server.features.build.sandbox.kubernetes.kubernetes_sandbox_manager"
|
||||
_ACP_CLIENT_CLASS = f"{_K8S_MODULE}.ACPExecClient"
|
||||
_GET_PACKET_LOGGER = f"{_K8S_MODULE}.get_packet_logger"
|
||||
|
||||
|
||||
def _make_mock_event() -> MagicMock:
|
||||
"""Create a mock ACP event."""
|
||||
return MagicMock(name="mock_acp_event")
|
||||
|
||||
|
||||
def _make_mock_client(
|
||||
is_running: bool = True,
|
||||
session_ids: list[str] | None = None,
|
||||
) -> MagicMock:
|
||||
"""Create a mock ACPExecClient with configurable state.
|
||||
|
||||
Args:
|
||||
is_running: Whether the client appears running
|
||||
session_ids: List of ACP session IDs the client tracks
|
||||
"""
|
||||
mock_client = MagicMock()
|
||||
type(mock_client).is_running = property(lambda _self: is_running)
|
||||
type(mock_client).session_ids = property(
|
||||
lambda _self: session_ids if session_ids is not None else []
|
||||
)
|
||||
mock_client.start.return_value = None
|
||||
mock_client.stop.return_value = None
|
||||
|
||||
# get_or_create_session returns a unique ACP session ID
|
||||
mock_client.get_or_create_session.return_value = f"acp-session-{uuid4().hex[:8]}"
|
||||
|
||||
mock_event = _make_mock_event()
|
||||
mock_client.send_message.return_value = iter([mock_event])
|
||||
return mock_client
|
||||
|
||||
|
||||
def _drain_generator(gen: Generator[Any, None, None]) -> list[Any]:
|
||||
"""Consume a generator and return all yielded values as a list."""
|
||||
return list(gen)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixture: fresh KubernetesSandboxManager instance
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def manager() -> Generator[Any, None, None]:
|
||||
"""Create a fresh KubernetesSandboxManager instance with all externals mocked."""
|
||||
with (
|
||||
patch(f"{_K8S_MODULE}.config") as _mock_config,
|
||||
patch(f"{_K8S_MODULE}.client") as _mock_k8s_client,
|
||||
patch(f"{_K8S_MODULE}.k8s_stream"),
|
||||
patch(_GET_PACKET_LOGGER) as mock_get_logger,
|
||||
):
|
||||
mock_packet_logger = MagicMock()
|
||||
mock_get_logger.return_value = mock_packet_logger
|
||||
|
||||
_mock_config.load_incluster_config.return_value = None
|
||||
_mock_config.ConfigException = Exception
|
||||
|
||||
_mock_k8s_client.ApiClient.return_value = MagicMock()
|
||||
_mock_k8s_client.CoreV1Api.return_value = MagicMock()
|
||||
_mock_k8s_client.BatchV1Api.return_value = MagicMock()
|
||||
_mock_k8s_client.NetworkingV1Api.return_value = MagicMock()
|
||||
|
||||
from onyx.server.features.build.sandbox.kubernetes.kubernetes_sandbox_manager import (
|
||||
KubernetesSandboxManager,
|
||||
)
|
||||
|
||||
KubernetesSandboxManager._instance = None
|
||||
mgr = KubernetesSandboxManager()
|
||||
|
||||
yield mgr
|
||||
|
||||
KubernetesSandboxManager._instance = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: Shared client lifecycle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_send_message_creates_shared_client_on_first_call(manager: Any) -> None:
|
||||
"""First call to send_message() should create one shared ACPExecClient
|
||||
for the sandbox, create an ACP session, and yield events."""
|
||||
sandbox_id: UUID = uuid4()
|
||||
session_id: UUID = uuid4()
|
||||
message = "hello world"
|
||||
|
||||
mock_event = _make_mock_event()
|
||||
mock_client = _make_mock_client(is_running=True)
|
||||
acp_session_id = "acp-session-abc"
|
||||
mock_client.get_or_create_session.return_value = acp_session_id
|
||||
# session_ids must include the created session for validation
|
||||
type(mock_client).session_ids = property(lambda _: [acp_session_id])
|
||||
mock_client.send_message.return_value = iter([mock_event])
|
||||
|
||||
with patch(_ACP_CLIENT_CLASS, return_value=mock_client) as MockClass:
|
||||
events = _drain_generator(manager.send_message(sandbox_id, session_id, message))
|
||||
|
||||
# Verify shared client was constructed once
|
||||
MockClass.assert_called_once()
|
||||
|
||||
# Verify start() was called with /workspace (not session-specific path)
|
||||
mock_client.start.assert_called_once_with(cwd="/workspace")
|
||||
|
||||
# Verify get_or_create_session was called with the session path
|
||||
expected_cwd = f"/workspace/sessions/{session_id}"
|
||||
mock_client.get_or_create_session.assert_called_once_with(cwd=expected_cwd)
|
||||
|
||||
# Verify send_message was called with correct args
|
||||
mock_client.send_message.assert_called_once_with(message, session_id=acp_session_id)
|
||||
|
||||
# Verify we got the event
|
||||
assert mock_event in events
|
||||
|
||||
# Verify shared client is cached by sandbox_id
|
||||
assert sandbox_id in manager._acp_clients
|
||||
assert manager._acp_clients[sandbox_id] is mock_client
|
||||
|
||||
# Verify session mapping exists
|
||||
assert (sandbox_id, session_id) in manager._acp_session_ids
|
||||
assert manager._acp_session_ids[(sandbox_id, session_id)] == acp_session_id
|
||||
|
||||
|
||||
def test_send_message_reuses_shared_client_for_same_session(manager: Any) -> None:
|
||||
"""Second call with the same session should reuse the shared client
|
||||
and the same ACP session ID."""
|
||||
sandbox_id: UUID = uuid4()
|
||||
session_id: UUID = uuid4()
|
||||
|
||||
mock_event_1 = _make_mock_event()
|
||||
mock_event_2 = _make_mock_event()
|
||||
mock_client = _make_mock_client(is_running=True)
|
||||
acp_session_id = "acp-session-reuse"
|
||||
mock_client.get_or_create_session.return_value = acp_session_id
|
||||
type(mock_client).session_ids = property(lambda _: [acp_session_id])
|
||||
|
||||
mock_client.send_message.side_effect = [
|
||||
iter([mock_event_1]),
|
||||
iter([mock_event_2]),
|
||||
]
|
||||
|
||||
with patch(_ACP_CLIENT_CLASS, return_value=mock_client) as MockClass:
|
||||
events_1 = _drain_generator(
|
||||
manager.send_message(sandbox_id, session_id, "first")
|
||||
)
|
||||
events_2 = _drain_generator(
|
||||
manager.send_message(sandbox_id, session_id, "second")
|
||||
)
|
||||
|
||||
# Constructor called only ONCE (shared client)
|
||||
MockClass.assert_called_once()
|
||||
|
||||
# start() called only once
|
||||
mock_client.start.assert_called_once()
|
||||
|
||||
# get_or_create_session called only once (second call uses cached mapping)
|
||||
mock_client.get_or_create_session.assert_called_once()
|
||||
|
||||
# send_message called twice with same ACP session ID
|
||||
assert mock_client.send_message.call_count == 2
|
||||
|
||||
assert mock_event_1 in events_1
|
||||
assert mock_event_2 in events_2
|
||||
|
||||
|
||||
def test_send_message_different_sessions_share_client(manager: Any) -> None:
|
||||
"""Two different craft sessions on the same sandbox should share the
|
||||
same ACPExecClient but have different ACP sessions."""
|
||||
sandbox_id: UUID = uuid4()
|
||||
session_id_a: UUID = uuid4()
|
||||
session_id_b: UUID = uuid4()
|
||||
|
||||
mock_client = _make_mock_client(is_running=True)
|
||||
acp_session_a = "acp-session-a"
|
||||
acp_session_b = "acp-session-b"
|
||||
mock_client.get_or_create_session.side_effect = [acp_session_a, acp_session_b]
|
||||
type(mock_client).session_ids = property(lambda _: [acp_session_a, acp_session_b])
|
||||
|
||||
mock_event_a = _make_mock_event()
|
||||
mock_event_b = _make_mock_event()
|
||||
mock_client.send_message.side_effect = [
|
||||
iter([mock_event_a]),
|
||||
iter([mock_event_b]),
|
||||
]
|
||||
|
||||
with patch(_ACP_CLIENT_CLASS, return_value=mock_client) as MockClass:
|
||||
events_a = _drain_generator(
|
||||
manager.send_message(sandbox_id, session_id_a, "msg a")
|
||||
)
|
||||
events_b = _drain_generator(
|
||||
manager.send_message(sandbox_id, session_id_b, "msg b")
|
||||
)
|
||||
|
||||
# Only ONE shared client was created
|
||||
MockClass.assert_called_once()
|
||||
|
||||
# get_or_create_session called twice (once per craft session)
|
||||
assert mock_client.get_or_create_session.call_count == 2
|
||||
|
||||
# send_message called with different ACP session IDs
|
||||
mock_client.send_message.assert_any_call("msg a", session_id=acp_session_a)
|
||||
mock_client.send_message.assert_any_call("msg b", session_id=acp_session_b)
|
||||
|
||||
# Both session mappings exist
|
||||
assert manager._acp_session_ids[(sandbox_id, session_id_a)] == acp_session_a
|
||||
assert manager._acp_session_ids[(sandbox_id, session_id_b)] == acp_session_b
|
||||
|
||||
assert mock_event_a in events_a
|
||||
assert mock_event_b in events_b
|
||||
|
||||
|
||||
def test_send_message_replaces_dead_client(manager: Any) -> None:
|
||||
"""If the shared client has is_running == False, should replace it and
|
||||
re-create sessions."""
|
||||
sandbox_id: UUID = uuid4()
|
||||
session_id: UUID = uuid4()
|
||||
|
||||
# Place a dead client in the cache
|
||||
dead_client = _make_mock_client(is_running=False)
|
||||
manager._acp_clients[sandbox_id] = dead_client
|
||||
manager._acp_session_ids[(sandbox_id, session_id)] = "old-acp-session"
|
||||
|
||||
# Create the replacement client
|
||||
new_event = _make_mock_event()
|
||||
new_client = _make_mock_client(is_running=True)
|
||||
new_acp_session = "new-acp-session"
|
||||
new_client.get_or_create_session.return_value = new_acp_session
|
||||
type(new_client).session_ids = property(lambda _: [new_acp_session])
|
||||
new_client.send_message.return_value = iter([new_event])
|
||||
|
||||
with patch(_ACP_CLIENT_CLASS, return_value=new_client):
|
||||
events = _drain_generator(manager.send_message(sandbox_id, session_id, "test"))
|
||||
|
||||
# Dead client was stopped during replacement
|
||||
dead_client.stop.assert_called_once()
|
||||
|
||||
# New client was started
|
||||
new_client.start.assert_called_once()
|
||||
|
||||
# Old session mapping was cleared, new one created
|
||||
assert manager._acp_session_ids[(sandbox_id, session_id)] == new_acp_session
|
||||
|
||||
# Cache holds the new client
|
||||
assert manager._acp_clients[sandbox_id] is new_client
|
||||
|
||||
assert new_event in events
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: Cleanup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_terminate_stops_shared_client(manager: Any) -> None:
|
||||
"""terminate(sandbox_id) should stop the shared client and clear
|
||||
all session mappings for that sandbox."""
|
||||
sandbox_id: UUID = uuid4()
|
||||
session_id_1: UUID = uuid4()
|
||||
session_id_2: UUID = uuid4()
|
||||
|
||||
mock_client = _make_mock_client(is_running=True)
|
||||
manager._acp_clients[sandbox_id] = mock_client
|
||||
manager._acp_session_ids[(sandbox_id, session_id_1)] = "acp-1"
|
||||
manager._acp_session_ids[(sandbox_id, session_id_2)] = "acp-2"
|
||||
|
||||
with patch.object(manager, "_cleanup_kubernetes_resources"):
|
||||
manager.terminate(sandbox_id)
|
||||
|
||||
# Shared client was stopped
|
||||
mock_client.stop.assert_called_once()
|
||||
|
||||
# Client removed from cache
|
||||
assert sandbox_id not in manager._acp_clients
|
||||
|
||||
# Session mappings removed
|
||||
assert (sandbox_id, session_id_1) not in manager._acp_session_ids
|
||||
assert (sandbox_id, session_id_2) not in manager._acp_session_ids
|
||||
|
||||
|
||||
def test_terminate_leaves_other_sandbox_untouched(manager: Any) -> None:
|
||||
"""terminate(sandbox_A) should NOT affect sandbox_B's client or sessions."""
|
||||
sandbox_a: UUID = uuid4()
|
||||
sandbox_b: UUID = uuid4()
|
||||
session_a: UUID = uuid4()
|
||||
session_b: UUID = uuid4()
|
||||
|
||||
client_a = _make_mock_client(is_running=True)
|
||||
client_b = _make_mock_client(is_running=True)
|
||||
|
||||
manager._acp_clients[sandbox_a] = client_a
|
||||
manager._acp_clients[sandbox_b] = client_b
|
||||
manager._acp_session_ids[(sandbox_a, session_a)] = "acp-a"
|
||||
manager._acp_session_ids[(sandbox_b, session_b)] = "acp-b"
|
||||
|
||||
with patch.object(manager, "_cleanup_kubernetes_resources"):
|
||||
manager.terminate(sandbox_a)
|
||||
|
||||
# sandbox_a cleaned up
|
||||
client_a.stop.assert_called_once()
|
||||
assert sandbox_a not in manager._acp_clients
|
||||
assert (sandbox_a, session_a) not in manager._acp_session_ids
|
||||
|
||||
# sandbox_b untouched
|
||||
client_b.stop.assert_not_called()
|
||||
assert sandbox_b in manager._acp_clients
|
||||
assert manager._acp_session_ids[(sandbox_b, session_b)] == "acp-b"
|
||||
|
||||
|
||||
def test_cleanup_session_removes_session_mapping(manager: Any) -> None:
|
||||
"""cleanup_session_workspace() should remove the session mapping but
|
||||
leave the shared client alive for other sessions."""
|
||||
sandbox_id: UUID = uuid4()
|
||||
session_id: UUID = uuid4()
|
||||
|
||||
mock_client = _make_mock_client(is_running=True)
|
||||
manager._acp_clients[sandbox_id] = mock_client
|
||||
manager._acp_session_ids[(sandbox_id, session_id)] = "acp-session-xyz"
|
||||
|
||||
with patch.object(manager, "_stream_core_api") as mock_stream_api:
|
||||
mock_stream_api.connect_get_namespaced_pod_exec = MagicMock()
|
||||
with patch(f"{_K8S_MODULE}.k8s_stream", return_value="cleanup ok"):
|
||||
manager.cleanup_session_workspace(sandbox_id, session_id)
|
||||
|
||||
# Session mapping removed
|
||||
assert (sandbox_id, session_id) not in manager._acp_session_ids
|
||||
|
||||
# Shared client is NOT stopped (other sessions may use it)
|
||||
mock_client.stop.assert_not_called()
|
||||
assert sandbox_id in manager._acp_clients
|
||||
|
||||
|
||||
def test_cleanup_session_handles_no_mapping(manager: Any) -> None:
|
||||
"""cleanup_session_workspace() should not error when there's no
|
||||
session mapping."""
|
||||
sandbox_id: UUID = uuid4()
|
||||
session_id: UUID = uuid4()
|
||||
|
||||
assert (sandbox_id, session_id) not in manager._acp_session_ids
|
||||
|
||||
with patch.object(manager, "_stream_core_api") as mock_stream_api:
|
||||
mock_stream_api.connect_get_namespaced_pod_exec = MagicMock()
|
||||
with patch(f"{_K8S_MODULE}.k8s_stream", return_value="cleanup ok"):
|
||||
manager.cleanup_session_workspace(sandbox_id, session_id)
|
||||
|
||||
assert (sandbox_id, session_id) not in manager._acp_session_ids
|
||||
0
backend/tests/unit/onyx/server/scim/__init__.py
Normal file
0
backend/tests/unit/onyx/server/scim/__init__.py
Normal file
93
backend/tests/unit/onyx/server/scim/test_filtering.py
Normal file
93
backend/tests/unit/onyx/server/scim/test_filtering.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import pytest
|
||||
|
||||
from ee.onyx.server.scim.filtering import parse_scim_filter
|
||||
from ee.onyx.server.scim.filtering import ScimFilter
|
||||
from ee.onyx.server.scim.filtering import ScimFilterOperator
|
||||
|
||||
|
||||
class TestParseScimFilter:
|
||||
"""Tests for SCIM filter expression parsing."""
|
||||
|
||||
def test_eq_filter_double_quoted(self) -> None:
|
||||
result = parse_scim_filter('userName eq "john@example.com"')
|
||||
assert result == ScimFilter(
|
||||
attribute="userName",
|
||||
operator=ScimFilterOperator.EQUAL,
|
||||
value="john@example.com",
|
||||
)
|
||||
|
||||
def test_eq_filter_single_quoted(self) -> None:
|
||||
result = parse_scim_filter("userName eq 'john@example.com'")
|
||||
assert result == ScimFilter(
|
||||
attribute="userName",
|
||||
operator=ScimFilterOperator.EQUAL,
|
||||
value="john@example.com",
|
||||
)
|
||||
|
||||
def test_co_filter(self) -> None:
|
||||
result = parse_scim_filter('displayName co "Engineering"')
|
||||
assert result == ScimFilter(
|
||||
attribute="displayName",
|
||||
operator=ScimFilterOperator.CONTAINS,
|
||||
value="Engineering",
|
||||
)
|
||||
|
||||
def test_sw_filter(self) -> None:
|
||||
result = parse_scim_filter('userName sw "admin"')
|
||||
assert result == ScimFilter(
|
||||
attribute="userName",
|
||||
operator=ScimFilterOperator.STARTS_WITH,
|
||||
value="admin",
|
||||
)
|
||||
|
||||
def test_case_insensitive_operator(self) -> None:
|
||||
result = parse_scim_filter('userName EQ "test@example.com"')
|
||||
assert result is not None
|
||||
assert result.operator == ScimFilterOperator.EQUAL
|
||||
|
||||
def test_external_id_filter(self) -> None:
|
||||
result = parse_scim_filter('externalId eq "abc-123"')
|
||||
assert result == ScimFilter(
|
||||
attribute="externalId",
|
||||
operator=ScimFilterOperator.EQUAL,
|
||||
value="abc-123",
|
||||
)
|
||||
|
||||
def test_empty_value(self) -> None:
|
||||
result = parse_scim_filter('userName eq ""')
|
||||
assert result == ScimFilter(
|
||||
attribute="userName",
|
||||
operator=ScimFilterOperator.EQUAL,
|
||||
value="",
|
||||
)
|
||||
|
||||
def test_whitespace_trimming(self) -> None:
|
||||
result = parse_scim_filter(' userName eq "test" ')
|
||||
assert result is not None
|
||||
assert result.value == "test"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filter_string",
|
||||
[
|
||||
None,
|
||||
"",
|
||||
" ",
|
||||
],
|
||||
)
|
||||
def test_empty_input_returns_none(self, filter_string: str | None) -> None:
|
||||
assert parse_scim_filter(filter_string) is None
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filter_string",
|
||||
[
|
||||
"userName", # missing operator and value
|
||||
"userName eq", # missing value
|
||||
'userName gt "5"', # unsupported operator
|
||||
'userName ne "test"', # unsupported operator
|
||||
"userName eq unquoted", # unquoted value
|
||||
'a eq "x" and b eq "y"', # compound filter not supported
|
||||
],
|
||||
)
|
||||
def test_malformed_input_raises_value_error(self, filter_string: str) -> None:
|
||||
with pytest.raises(ValueError, match="Unsupported or malformed"):
|
||||
parse_scim_filter(filter_string)
|
||||
258
backend/tests/unit/onyx/server/scim/test_patch.py
Normal file
258
backend/tests/unit/onyx/server/scim/test_patch.py
Normal file
@@ -0,0 +1,258 @@
|
||||
import pytest
|
||||
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimPatchOperation
|
||||
from ee.onyx.server.scim.models import ScimPatchOperationType
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.patch import apply_group_patch
|
||||
from ee.onyx.server.scim.patch import apply_user_patch
|
||||
from ee.onyx.server.scim.patch import ScimPatchError
|
||||
|
||||
|
||||
def _make_user(**kwargs: object) -> ScimUserResource:
|
||||
defaults: dict = {
|
||||
"userName": "test@example.com",
|
||||
"active": True,
|
||||
"name": ScimName(givenName="Test", familyName="User"),
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return ScimUserResource(**defaults)
|
||||
|
||||
|
||||
def _make_group(**kwargs: object) -> ScimGroupResource:
|
||||
defaults: dict = {"displayName": "Engineering"}
|
||||
defaults.update(kwargs)
|
||||
return ScimGroupResource(**defaults)
|
||||
|
||||
|
||||
def _replace_op(
|
||||
path: str | None = None,
|
||||
value: str | bool | dict | list | None = None,
|
||||
) -> ScimPatchOperation:
|
||||
return ScimPatchOperation(op=ScimPatchOperationType.REPLACE, path=path, value=value)
|
||||
|
||||
|
||||
def _add_op(
|
||||
path: str | None = None,
|
||||
value: str | bool | dict | list | None = None,
|
||||
) -> ScimPatchOperation:
|
||||
return ScimPatchOperation(op=ScimPatchOperationType.ADD, path=path, value=value)
|
||||
|
||||
|
||||
def _remove_op(path: str) -> ScimPatchOperation:
|
||||
return ScimPatchOperation(op=ScimPatchOperationType.REMOVE, path=path)
|
||||
|
||||
|
||||
class TestApplyUserPatch:
|
||||
"""Tests for SCIM user PATCH operations."""
|
||||
|
||||
def test_deactivate_user(self) -> None:
|
||||
user = _make_user()
|
||||
result = apply_user_patch([_replace_op("active", False)], user)
|
||||
assert result.active is False
|
||||
assert result.userName == "test@example.com"
|
||||
|
||||
def test_activate_user(self) -> None:
|
||||
user = _make_user(active=False)
|
||||
result = apply_user_patch([_replace_op("active", True)], user)
|
||||
assert result.active is True
|
||||
|
||||
def test_replace_given_name(self) -> None:
|
||||
user = _make_user()
|
||||
result = apply_user_patch([_replace_op("name.givenName", "NewFirst")], user)
|
||||
assert result.name is not None
|
||||
assert result.name.givenName == "NewFirst"
|
||||
assert result.name.familyName == "User"
|
||||
|
||||
def test_replace_family_name(self) -> None:
|
||||
user = _make_user()
|
||||
result = apply_user_patch([_replace_op("name.familyName", "NewLast")], user)
|
||||
assert result.name is not None
|
||||
assert result.name.familyName == "NewLast"
|
||||
|
||||
def test_replace_username(self) -> None:
|
||||
user = _make_user()
|
||||
result = apply_user_patch([_replace_op("userName", "new@example.com")], user)
|
||||
assert result.userName == "new@example.com"
|
||||
|
||||
def test_replace_without_path_uses_dict(self) -> None:
|
||||
user = _make_user()
|
||||
result = apply_user_patch(
|
||||
[_replace_op(None, {"active": False, "userName": "new@example.com"})],
|
||||
user,
|
||||
)
|
||||
assert result.active is False
|
||||
assert result.userName == "new@example.com"
|
||||
|
||||
def test_multiple_operations(self) -> None:
|
||||
user = _make_user()
|
||||
result = apply_user_patch(
|
||||
[
|
||||
_replace_op("active", False),
|
||||
_replace_op("name.givenName", "Updated"),
|
||||
],
|
||||
user,
|
||||
)
|
||||
assert result.active is False
|
||||
assert result.name is not None
|
||||
assert result.name.givenName == "Updated"
|
||||
|
||||
def test_case_insensitive_path(self) -> None:
|
||||
user = _make_user()
|
||||
result = apply_user_patch([_replace_op("Active", False)], user)
|
||||
assert result.active is False
|
||||
|
||||
def test_original_not_mutated(self) -> None:
|
||||
user = _make_user()
|
||||
apply_user_patch([_replace_op("active", False)], user)
|
||||
assert user.active is True
|
||||
|
||||
def test_unsupported_path_raises(self) -> None:
|
||||
user = _make_user()
|
||||
with pytest.raises(ScimPatchError, match="Unsupported path"):
|
||||
apply_user_patch([_replace_op("unknownField", "value")], user)
|
||||
|
||||
def test_remove_op_on_user_raises(self) -> None:
|
||||
user = _make_user()
|
||||
with pytest.raises(ScimPatchError, match="Unsupported operation"):
|
||||
apply_user_patch([_remove_op("active")], user)
|
||||
|
||||
|
||||
class TestApplyGroupPatch:
|
||||
"""Tests for SCIM group PATCH operations."""
|
||||
|
||||
def test_replace_display_name(self) -> None:
|
||||
group = _make_group()
|
||||
result, added, removed = apply_group_patch(
|
||||
[_replace_op("displayName", "New Name")], group
|
||||
)
|
||||
assert result.displayName == "New Name"
|
||||
assert added == []
|
||||
assert removed == []
|
||||
|
||||
def test_add_members(self) -> None:
|
||||
group = _make_group()
|
||||
result, added, removed = apply_group_patch(
|
||||
[_add_op("members", [{"value": "user-1"}, {"value": "user-2"}])],
|
||||
group,
|
||||
)
|
||||
assert len(result.members) == 2
|
||||
assert added == ["user-1", "user-2"]
|
||||
assert removed == []
|
||||
|
||||
def test_add_members_without_path(self) -> None:
|
||||
group = _make_group()
|
||||
result, added, _ = apply_group_patch(
|
||||
[_add_op(None, [{"value": "user-1"}])],
|
||||
group,
|
||||
)
|
||||
assert len(result.members) == 1
|
||||
assert added == ["user-1"]
|
||||
|
||||
def test_add_duplicate_member_skipped(self) -> None:
|
||||
group = _make_group(members=[ScimGroupMember(value="user-1")])
|
||||
result, added, _ = apply_group_patch(
|
||||
[_add_op("members", [{"value": "user-1"}, {"value": "user-2"}])],
|
||||
group,
|
||||
)
|
||||
assert len(result.members) == 2
|
||||
assert added == ["user-2"]
|
||||
|
||||
def test_remove_member(self) -> None:
|
||||
group = _make_group(
|
||||
members=[
|
||||
ScimGroupMember(value="user-1"),
|
||||
ScimGroupMember(value="user-2"),
|
||||
]
|
||||
)
|
||||
result, added, removed = apply_group_patch(
|
||||
[_remove_op('members[value eq "user-1"]')],
|
||||
group,
|
||||
)
|
||||
assert len(result.members) == 1
|
||||
assert result.members[0].value == "user-2"
|
||||
assert removed == ["user-1"]
|
||||
assert added == []
|
||||
|
||||
def test_remove_nonexistent_member(self) -> None:
|
||||
group = _make_group(members=[ScimGroupMember(value="user-1")])
|
||||
result, _, removed = apply_group_patch(
|
||||
[_remove_op('members[value eq "user-999"]')],
|
||||
group,
|
||||
)
|
||||
assert len(result.members) == 1
|
||||
assert removed == []
|
||||
|
||||
def test_mixed_operations(self) -> None:
|
||||
group = _make_group(members=[ScimGroupMember(value="user-1")])
|
||||
result, added, removed = apply_group_patch(
|
||||
[
|
||||
_replace_op("displayName", "Renamed"),
|
||||
_add_op("members", [{"value": "user-2"}]),
|
||||
_remove_op('members[value eq "user-1"]'),
|
||||
],
|
||||
group,
|
||||
)
|
||||
assert result.displayName == "Renamed"
|
||||
assert added == ["user-2"]
|
||||
assert removed == ["user-1"]
|
||||
assert len(result.members) == 1
|
||||
|
||||
def test_remove_without_path_raises(self) -> None:
|
||||
group = _make_group()
|
||||
with pytest.raises(ScimPatchError, match="requires a path"):
|
||||
apply_group_patch(
|
||||
[ScimPatchOperation(op=ScimPatchOperationType.REMOVE, path=None)],
|
||||
group,
|
||||
)
|
||||
|
||||
def test_remove_invalid_path_raises(self) -> None:
|
||||
group = _make_group()
|
||||
with pytest.raises(ScimPatchError, match="Unsupported remove path"):
|
||||
apply_group_patch([_remove_op("displayName")], group)
|
||||
|
||||
def test_replace_members_with_path(self) -> None:
|
||||
group = _make_group(
|
||||
members=[
|
||||
ScimGroupMember(value="user-1"),
|
||||
ScimGroupMember(value="user-2"),
|
||||
]
|
||||
)
|
||||
result, added, removed = apply_group_patch(
|
||||
[_replace_op("members", [{"value": "user-2"}, {"value": "user-3"}])],
|
||||
group,
|
||||
)
|
||||
assert len(result.members) == 2
|
||||
member_ids = {m.value for m in result.members}
|
||||
assert member_ids == {"user-2", "user-3"}
|
||||
assert "user-3" in added
|
||||
assert "user-1" in removed
|
||||
assert "user-2" not in added
|
||||
assert "user-2" not in removed
|
||||
|
||||
def test_replace_members_empty_list_clears(self) -> None:
|
||||
group = _make_group(
|
||||
members=[
|
||||
ScimGroupMember(value="user-1"),
|
||||
ScimGroupMember(value="user-2"),
|
||||
]
|
||||
)
|
||||
result, added, removed = apply_group_patch(
|
||||
[_replace_op("members", [])],
|
||||
group,
|
||||
)
|
||||
assert len(result.members) == 0
|
||||
assert added == []
|
||||
assert set(removed) == {"user-1", "user-2"}
|
||||
|
||||
def test_unsupported_replace_path_raises(self) -> None:
|
||||
group = _make_group()
|
||||
with pytest.raises(ScimPatchError, match="Unsupported path"):
|
||||
apply_group_patch([_replace_op("unknownField", "val")], group)
|
||||
|
||||
def test_original_not_mutated(self) -> None:
|
||||
group = _make_group()
|
||||
apply_group_patch([_replace_op("displayName", "Changed")], group)
|
||||
assert group.displayName == "Engineering"
|
||||
@@ -367,7 +367,6 @@ webserver:
|
||||
# KEDA specific configurations (only used when autoscaling.engine is set to 'keda')
|
||||
pollingInterval: 30 # seconds
|
||||
cooldownPeriod: 300 # seconds
|
||||
idleReplicaCount: 1 # minimum replicas when idle
|
||||
failureThreshold: 3 # number of failures before fallback
|
||||
fallbackReplicas: 1 # replicas to maintain on failure
|
||||
# Custom triggers for advanced KEDA configurations
|
||||
@@ -457,7 +456,6 @@ api:
|
||||
# KEDA specific configurations (only used when autoscaling.engine is set to 'keda')
|
||||
pollingInterval: 30 # seconds
|
||||
cooldownPeriod: 300 # seconds
|
||||
idleReplicaCount: 1 # minimum replicas when idle
|
||||
failureThreshold: 3 # number of failures before fallback
|
||||
fallbackReplicas: 1 # replicas to maintain on failure
|
||||
# Custom triggers for advanced KEDA configurations
|
||||
@@ -553,7 +551,6 @@ celery_worker_heavy:
|
||||
# KEDA specific configurations (only used when autoscaling.engine is set to 'keda')
|
||||
pollingInterval: 30 # seconds
|
||||
cooldownPeriod: 300 # seconds
|
||||
idleReplicaCount: 1 # minimum replicas when idle
|
||||
failureThreshold: 3 # number of failures before fallback
|
||||
fallbackReplicas: 1 # replicas to maintain on failure
|
||||
# Custom triggers for advanced KEDA configurations
|
||||
@@ -587,7 +584,6 @@ celery_worker_docprocessing:
|
||||
# KEDA specific configurations (only used when autoscaling.engine is set to 'keda')
|
||||
pollingInterval: 30 # seconds
|
||||
cooldownPeriod: 300 # seconds
|
||||
idleReplicaCount: 1 # minimum replicas when idle
|
||||
failureThreshold: 3 # number of failures before fallback
|
||||
fallbackReplicas: 1 # replicas to maintain on failure
|
||||
# Custom triggers for advanced KEDA configurations
|
||||
@@ -621,7 +617,6 @@ celery_worker_light:
|
||||
# KEDA specific configurations (only used when autoscaling.engine is set to 'keda')
|
||||
pollingInterval: 30 # seconds
|
||||
cooldownPeriod: 300 # seconds
|
||||
idleReplicaCount: 1 # minimum replicas when idle
|
||||
failureThreshold: 3 # number of failures before fallback
|
||||
fallbackReplicas: 1 # replicas to maintain on failure
|
||||
# Custom triggers for advanced KEDA configurations
|
||||
@@ -655,7 +650,6 @@ celery_worker_monitoring:
|
||||
# KEDA specific configurations (only used when autoscaling.engine is set to 'keda')
|
||||
pollingInterval: 30 # seconds
|
||||
cooldownPeriod: 300 # seconds
|
||||
idleReplicaCount: 1 # minimum replicas when idle
|
||||
failureThreshold: 3 # number of failures before fallback
|
||||
fallbackReplicas: 1 # replicas to maintain on failure
|
||||
# Custom triggers for advanced KEDA configurations
|
||||
@@ -689,7 +683,6 @@ celery_worker_primary:
|
||||
# KEDA specific configurations (only used when autoscaling.engine is set to 'keda')
|
||||
pollingInterval: 30 # seconds
|
||||
cooldownPeriod: 300 # seconds
|
||||
idleReplicaCount: 1 # minimum replicas when idle
|
||||
failureThreshold: 3 # number of failures before fallback
|
||||
fallbackReplicas: 1 # replicas to maintain on failure
|
||||
# Custom triggers for advanced KEDA configurations
|
||||
@@ -723,7 +716,6 @@ celery_worker_user_file_processing:
|
||||
# KEDA specific configurations
|
||||
pollingInterval: 30 # seconds
|
||||
cooldownPeriod: 300 # seconds
|
||||
idleReplicaCount: 1 # minimum replicas when idle
|
||||
failureThreshold: 3 # number of failures before fallback
|
||||
fallbackReplicas: 1 # replicas to maintain on failure
|
||||
# Custom triggers for advanced KEDA configurations
|
||||
@@ -868,7 +860,6 @@ celery_worker_docfetching:
|
||||
# KEDA specific configurations (only used when autoscaling.engine is set to 'keda')
|
||||
pollingInterval: 30 # seconds
|
||||
cooldownPeriod: 300 # seconds
|
||||
idleReplicaCount: 1 # minimum replicas when idle
|
||||
failureThreshold: 3 # number of failures before fallback
|
||||
fallbackReplicas: 1 # replicas to maintain on failure
|
||||
# Custom triggers for advanced KEDA configurations
|
||||
|
||||
@@ -196,7 +196,7 @@ members = ["backend", "tools/ods"]
|
||||
|
||||
[tool.basedpyright]
|
||||
include = ["backend"]
|
||||
exclude = ["backend/generated"]
|
||||
exclude = ["backend/generated", "backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx", "backend/onyx/server/features/build/sandbox/kubernetes/docker/templates/venv"]
|
||||
typeCheckingMode = "off"
|
||||
|
||||
[tool.ruff]
|
||||
|
||||
@@ -70,6 +70,9 @@ ENV NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=${NEXT_PUBLIC_INCLUDE_ERROR_POP
|
||||
ARG NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY
|
||||
ENV NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY}
|
||||
|
||||
ARG NEXT_PUBLIC_RECAPTCHA_SITE_KEY
|
||||
ENV NEXT_PUBLIC_RECAPTCHA_SITE_KEY=${NEXT_PUBLIC_RECAPTCHA_SITE_KEY}
|
||||
|
||||
# Add NODE_OPTIONS argument
|
||||
ARG NODE_OPTIONS
|
||||
|
||||
@@ -144,6 +147,9 @@ ENV NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=${NEXT_PUBLIC_INCLUDE_ERROR_POP
|
||||
ARG NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY
|
||||
ENV NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY}
|
||||
|
||||
ARG NEXT_PUBLIC_RECAPTCHA_SITE_KEY
|
||||
ENV NEXT_PUBLIC_RECAPTCHA_SITE_KEY=${NEXT_PUBLIC_RECAPTCHA_SITE_KEY}
|
||||
|
||||
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG ONYX_VERSION=0.0.0-dev
|
||||
ENV ONYX_VERSION=${ONYX_VERSION}
|
||||
|
||||
@@ -1,22 +1,20 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgMaximize2 = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 14 14"
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
strokeWidth={2.5}
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M9 1H13M13 1V5M13 1L8.33333 5.66667M5 13H1M1 13V9M1 13L5.66667 8.33333"
|
||||
d="M10 2H14M14 2V6M14 2L9.33333 6.66667M6 14H2M2 14V10M2 14L6.66667 9.33333"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
|
||||
export default SvgMaximize2;
|
||||
|
||||
6
web/package-lock.json
generated
6
web/package-lock.json
generated
@@ -13748,9 +13748,9 @@
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/qs": {
|
||||
"version": "6.14.1",
|
||||
"resolved": "https://registry.npmjs.org/qs/-/qs-6.14.1.tgz",
|
||||
"integrity": "sha512-4EK3+xJl8Ts67nLYNwqw/dsFVnCf+qR7RgXSK9jEEm9unao3njwMDdmsdvoKBKHzxd7tCYz5e5M+SnMjdtXGQQ==",
|
||||
"version": "6.14.2",
|
||||
"resolved": "https://registry.npmjs.org/qs/-/qs-6.14.2.tgz",
|
||||
"integrity": "sha512-V/yCWTTF7VJ9hIh18Ugr2zhJMP01MY7c5kh4J870L7imm6/DIzBsNLTXzMwUA3yZ5b/KBqLx8Kp3uRvd7xSe3Q==",
|
||||
"license": "BSD-3-Clause",
|
||||
"dependencies": {
|
||||
"side-channel": "^1.1.0"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { Form, Formik } from "formik";
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { SelectorFormField, TextFormField } from "@/components/Field";
|
||||
import { createApiKey, updateApiKey } from "./lib";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
@@ -10,14 +10,12 @@ import { APIKey } from "./types";
|
||||
import { SvgKey } from "@opal/icons";
|
||||
export interface OnyxApiKeyFormProps {
|
||||
onClose: () => void;
|
||||
setPopup: (popupSpec: PopupSpec | null) => void;
|
||||
onCreateApiKey: (apiKey: APIKey) => void;
|
||||
apiKey?: APIKey;
|
||||
}
|
||||
|
||||
export default function OnyxApiKeyForm({
|
||||
onClose,
|
||||
setPopup,
|
||||
onCreateApiKey,
|
||||
apiKey,
|
||||
}: OnyxApiKeyFormProps) {
|
||||
@@ -54,12 +52,11 @@ export default function OnyxApiKeyForm({
|
||||
}
|
||||
formikHelpers.setSubmitting(false);
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: isUpdate
|
||||
toast.success(
|
||||
isUpdate
|
||||
? "Successfully updated API key!"
|
||||
: "Successfully created API key!",
|
||||
type: "success",
|
||||
});
|
||||
: "Successfully created API key!"
|
||||
);
|
||||
if (!isUpdate) {
|
||||
onCreateApiKey(await response.json());
|
||||
}
|
||||
@@ -67,12 +64,11 @@ export default function OnyxApiKeyForm({
|
||||
} else {
|
||||
const responseJson = await response.json();
|
||||
const errorMsg = responseJson.detail || responseJson.message;
|
||||
setPopup({
|
||||
message: isUpdate
|
||||
toast.error(
|
||||
isUpdate
|
||||
? `Error updating API key - ${errorMsg}`
|
||||
: `Error creating API key - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
: `Error creating API key - ${errorMsg}`
|
||||
);
|
||||
}
|
||||
}}
|
||||
>
|
||||
|
||||
@@ -15,7 +15,7 @@ import {
|
||||
Table,
|
||||
} from "@/components/ui/table";
|
||||
import Title from "@/components/ui/title";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { useState } from "react";
|
||||
import { DeleteButton } from "@/components/DeleteButton";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
@@ -33,8 +33,6 @@ import Text from "@/refresh-components/texts/Text";
|
||||
import { SvgEdit, SvgKey, SvgRefreshCw } from "@opal/icons";
|
||||
|
||||
function Main() {
|
||||
const { popup, setPopup } = usePopup();
|
||||
|
||||
const {
|
||||
data: apiKeys,
|
||||
isLoading,
|
||||
@@ -84,7 +82,6 @@ function Main() {
|
||||
if (filteredApiKeys.length === 0) {
|
||||
return (
|
||||
<div>
|
||||
{popup}
|
||||
{introSection}
|
||||
|
||||
{showCreateUpdateForm && (
|
||||
@@ -97,7 +94,6 @@ function Main() {
|
||||
setSelectedApiKey(undefined);
|
||||
mutate("/api/admin/api-key");
|
||||
}}
|
||||
setPopup={setPopup}
|
||||
apiKey={selectedApiKey}
|
||||
/>
|
||||
)}
|
||||
@@ -107,8 +103,6 @@ function Main() {
|
||||
|
||||
return (
|
||||
<>
|
||||
{popup}
|
||||
|
||||
<Modal open={!!fullApiKey}>
|
||||
<Modal.Content width="sm" height="sm">
|
||||
<Modal.Header
|
||||
@@ -171,10 +165,7 @@ function Main() {
|
||||
setKeyIsGenerating(false);
|
||||
if (!response.ok) {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
type: "error",
|
||||
message: `Failed to regenerate API Key: ${errorMsg}`,
|
||||
});
|
||||
toast.error(`Failed to regenerate API Key: ${errorMsg}`);
|
||||
return;
|
||||
}
|
||||
const newKey = (await response.json()) as APIKey;
|
||||
@@ -191,10 +182,7 @@ function Main() {
|
||||
const response = await deleteApiKey(apiKey.api_key_id);
|
||||
if (!response.ok) {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
type: "error",
|
||||
message: `Failed to delete API Key: ${errorMsg}`,
|
||||
});
|
||||
toast.error(`Failed to delete API Key: ${errorMsg}`);
|
||||
return;
|
||||
}
|
||||
mutate("/api/admin/api-key");
|
||||
@@ -216,7 +204,6 @@ function Main() {
|
||||
setSelectedApiKey(undefined);
|
||||
mutate("/api/admin/api-key");
|
||||
}}
|
||||
setPopup={setPopup}
|
||||
apiKey={selectedApiKey}
|
||||
/>
|
||||
)}
|
||||
|
||||
@@ -4,7 +4,7 @@ import Text from "@/refresh-components/texts/Text";
|
||||
import { Persona } from "./interfaces";
|
||||
import { useRouter } from "next/navigation";
|
||||
import Checkbox from "@/refresh-components/inputs/Checkbox";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { useState, useMemo, useEffect } from "react";
|
||||
import { UniqueIdentifier } from "@dnd-kit/core";
|
||||
import { DraggableTable } from "@/components/table/DraggableTable";
|
||||
@@ -56,7 +56,6 @@ export function PersonasTable({
|
||||
pageSize: number;
|
||||
}) {
|
||||
const router = useRouter();
|
||||
const { popup, setPopup } = usePopup();
|
||||
const { refreshUser, isAdmin } = useUser();
|
||||
|
||||
const editablePersonas = useMemo(() => {
|
||||
@@ -109,10 +108,7 @@ export function PersonasTable({
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
setPopup({
|
||||
type: "error",
|
||||
message: `Failed to update persona order - ${await response.text()}`,
|
||||
});
|
||||
toast.error(`Failed to update persona order - ${await response.text()}`);
|
||||
setFinalPersonas(personas);
|
||||
await refreshPersonas();
|
||||
return;
|
||||
@@ -139,10 +135,7 @@ export function PersonasTable({
|
||||
refreshPersonas();
|
||||
closeDeleteModal();
|
||||
} else {
|
||||
setPopup({
|
||||
type: "error",
|
||||
message: `Failed to delete persona - ${await response.text()}`,
|
||||
});
|
||||
toast.error(`Failed to delete persona - ${await response.text()}`);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -167,17 +160,13 @@ export function PersonasTable({
|
||||
refreshPersonas();
|
||||
closeDefaultModal();
|
||||
} else {
|
||||
setPopup({
|
||||
type: "error",
|
||||
message: `Failed to update persona - ${await response.text()}`,
|
||||
});
|
||||
toast.error(`Failed to update persona - ${await response.text()}`);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div>
|
||||
{popup}
|
||||
{deleteModalOpen && personaToDelete && (
|
||||
<ConfirmationModalLayout
|
||||
icon={SvgAlertCircle}
|
||||
@@ -290,10 +279,9 @@ export function PersonasTable({
|
||||
if (response.ok) {
|
||||
refreshPersonas();
|
||||
} else {
|
||||
setPopup({
|
||||
type: "error",
|
||||
message: `Failed to update persona - ${await response.text()}`,
|
||||
});
|
||||
toast.error(
|
||||
`Failed to update persona - ${await response.text()}`
|
||||
);
|
||||
}
|
||||
}}
|
||||
className={`
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import CardSection from "@/components/admin/CardSection";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useState } from "react";
|
||||
import { SlackTokensForm } from "./SlackTokensForm";
|
||||
@@ -17,7 +16,6 @@ export const NewSlackBotForm = () => {
|
||||
app_token: "",
|
||||
user_token: "",
|
||||
});
|
||||
const { popup, setPopup } = usePopup();
|
||||
const router = useRouter();
|
||||
|
||||
return (
|
||||
@@ -27,12 +25,10 @@ export const NewSlackBotForm = () => {
|
||||
title="New Slack Bot"
|
||||
/>
|
||||
<CardSection>
|
||||
{popup}
|
||||
<div className="p-4">
|
||||
<SlackTokensForm
|
||||
isUpdate={false}
|
||||
initialValues={formValues}
|
||||
setPopup={setPopup}
|
||||
router={router}
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { SlackBot, ValidSources } from "@/lib/types";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useState, useEffect, useRef } from "react";
|
||||
@@ -24,7 +24,6 @@ export const ExistingSlackBotForm = ({
|
||||
}) => {
|
||||
const [isExpanded, setIsExpanded] = useState(false);
|
||||
const [formValues, setFormValues] = useState(existingSlackBot);
|
||||
const { popup, setPopup } = usePopup();
|
||||
const router = useRouter();
|
||||
const dropdownRef = useRef<HTMLDivElement>(null);
|
||||
const [showDeleteModal, setShowDeleteModal] = useState(false);
|
||||
@@ -42,15 +41,9 @@ export const ExistingSlackBotForm = ({
|
||||
if (!response.ok) {
|
||||
throw new Error(await response.text());
|
||||
}
|
||||
setPopup({
|
||||
message: `Connector ${field} updated successfully`,
|
||||
type: "success",
|
||||
});
|
||||
toast.success(`Connector ${field} updated successfully`);
|
||||
} catch (error) {
|
||||
setPopup({
|
||||
message: `Failed to update connector ${field}`,
|
||||
type: "error",
|
||||
});
|
||||
toast.error(`Failed to update connector ${field}`);
|
||||
}
|
||||
setFormValues((prev) => ({ ...prev, [field]: value }));
|
||||
};
|
||||
@@ -74,7 +67,6 @@ export const ExistingSlackBotForm = ({
|
||||
|
||||
return (
|
||||
<div>
|
||||
{popup}
|
||||
<div className="flex items-center justify-between h-14">
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="my-auto">
|
||||
@@ -120,7 +112,6 @@ export const ExistingSlackBotForm = ({
|
||||
initialValues={formValues}
|
||||
existingSlackBotId={existingSlackBot.id}
|
||||
refreshSlackBot={refreshSlackBot}
|
||||
setPopup={setPopup}
|
||||
router={router}
|
||||
onValuesChange={(values) => setFormValues(values)}
|
||||
/>
|
||||
@@ -149,16 +140,10 @@ export const ExistingSlackBotForm = ({
|
||||
if (!response.ok) {
|
||||
throw new Error(await response.text());
|
||||
}
|
||||
setPopup({
|
||||
message: "Slack bot deleted successfully",
|
||||
type: "success",
|
||||
});
|
||||
toast.success("Slack bot deleted successfully");
|
||||
router.push("/admin/bots");
|
||||
} catch (error) {
|
||||
setPopup({
|
||||
message: "Failed to delete Slack bot",
|
||||
type: "error",
|
||||
});
|
||||
toast.error("Failed to delete Slack bot");
|
||||
}
|
||||
setShowDeleteModal(false);
|
||||
}}
|
||||
|
||||
@@ -8,13 +8,13 @@ import Button from "@/refresh-components/buttons/Button";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import { useEffect } from "react";
|
||||
import { DOCS_ADMINS_PATH } from "@/lib/constants";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
|
||||
export const SlackTokensForm = ({
|
||||
isUpdate,
|
||||
initialValues,
|
||||
existingSlackBotId,
|
||||
refreshSlackBot,
|
||||
setPopup,
|
||||
router,
|
||||
onValuesChange,
|
||||
}: {
|
||||
@@ -22,7 +22,6 @@ export const SlackTokensForm = ({
|
||||
initialValues: any;
|
||||
existingSlackBotId?: number;
|
||||
refreshSlackBot?: () => void;
|
||||
setPopup: (popup: { message: string; type: "error" | "success" }) => void;
|
||||
router: any;
|
||||
onValuesChange?: (values: any) => void;
|
||||
}) => {
|
||||
@@ -59,12 +58,11 @@ export const SlackTokensForm = ({
|
||||
}
|
||||
const responseJson = await response.json();
|
||||
const botId = isUpdate ? existingSlackBotId : responseJson.id;
|
||||
setPopup({
|
||||
message: isUpdate
|
||||
toast.success(
|
||||
isUpdate
|
||||
? "Successfully updated Slack Bot!"
|
||||
: "Successfully created Slack Bot!",
|
||||
type: "success",
|
||||
});
|
||||
: "Successfully created Slack Bot!"
|
||||
);
|
||||
router.push(`/admin/bots/${encodeURIComponent(botId)}`);
|
||||
} else {
|
||||
const responseJson = await response.json();
|
||||
@@ -75,12 +73,11 @@ export const SlackTokensForm = ({
|
||||
} else if (errorMsg.includes("Invalid app token:")) {
|
||||
errorMsg = "Slack App Token is invalid";
|
||||
}
|
||||
setPopup({
|
||||
message: isUpdate
|
||||
toast.error(
|
||||
isUpdate
|
||||
? `Error updating Slack Bot - ${errorMsg}`
|
||||
: `Error creating Slack Bot - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
: `Error creating Slack Bot - ${errorMsg}`
|
||||
);
|
||||
}
|
||||
}}
|
||||
enableReinitialize={true}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import { PageSelector } from "@/components/PageSelector";
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { EditIcon } from "@/components/icons/icons";
|
||||
import { SlackChannelConfig } from "@/lib/types";
|
||||
import {
|
||||
@@ -27,14 +27,12 @@ export interface SlackChannelConfigsTableProps {
|
||||
slackBotId: number;
|
||||
slackChannelConfigs: SlackChannelConfig[];
|
||||
refresh: () => void;
|
||||
setPopup: (popupSpec: PopupSpec | null) => void;
|
||||
}
|
||||
|
||||
export default function SlackChannelConfigsTable({
|
||||
slackBotId,
|
||||
slackChannelConfigs,
|
||||
refresh,
|
||||
setPopup,
|
||||
}: SlackChannelConfigsTableProps) {
|
||||
const [page, setPage] = useState(1);
|
||||
|
||||
@@ -130,16 +128,14 @@ export default function SlackChannelConfigsTable({
|
||||
slackChannelConfig.id
|
||||
);
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: `Slack bot config "${slackChannelConfig.id}" deleted`,
|
||||
type: "success",
|
||||
});
|
||||
toast.success(
|
||||
`Slack bot config "${slackChannelConfig.id}" deleted`
|
||||
);
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to delete Slack bot config - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
toast.error(
|
||||
`Failed to delete Slack bot config - ${errorMsg}`
|
||||
);
|
||||
}
|
||||
refresh();
|
||||
}}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import React, { useMemo } from "react";
|
||||
import { Formik, Form } from "formik";
|
||||
import * as Yup from "yup";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import {
|
||||
DocumentSetSummary,
|
||||
SlackChannelConfig,
|
||||
@@ -34,7 +34,6 @@ export const SlackChannelConfigCreationForm = ({
|
||||
standardAnswerCategoryResponse: StandardAnswerCategoryResponse;
|
||||
existingSlackChannelConfig?: SlackChannelConfig;
|
||||
}) => {
|
||||
const { popup, setPopup } = usePopup();
|
||||
const router = useRouter();
|
||||
const isUpdate = Boolean(existingSlackChannelConfig);
|
||||
const isDefault = existingSlackChannelConfig?.is_default || false;
|
||||
@@ -65,8 +64,6 @@ export const SlackChannelConfigCreationForm = ({
|
||||
|
||||
return (
|
||||
<CardSection className="!px-12 max-w-4xl">
|
||||
{popup}
|
||||
|
||||
<Formik
|
||||
initialValues={{
|
||||
slack_bot_id: slack_bot_id,
|
||||
@@ -221,12 +218,11 @@ export const SlackChannelConfigCreationForm = ({
|
||||
} else {
|
||||
const responseJson = await response.json();
|
||||
const errorMsg = responseJson.detail || responseJson.message;
|
||||
setPopup({
|
||||
message: `Error ${
|
||||
toast.error(
|
||||
`Error ${
|
||||
isUpdate ? "updating" : "creating"
|
||||
} OnyxBot config - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
} OnyxBot config - ${errorMsg}`
|
||||
);
|
||||
}
|
||||
}}
|
||||
>
|
||||
@@ -241,7 +237,6 @@ export const SlackChannelConfigCreationForm = ({
|
||||
searchEnabledAssistants={searchEnabledAssistants}
|
||||
nonSearchAssistants={nonSearchAssistants}
|
||||
standardAnswerCategoryResponse={standardAnswerCategoryResponse}
|
||||
setPopup={setPopup}
|
||||
slack_bot_id={slack_bot_id}
|
||||
formikProps={formikProps}
|
||||
/>
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import { useState, useEffect, useMemo } from "react";
|
||||
import { FieldArray, useFormikContext, ErrorMessage } from "formik";
|
||||
import { DocumentSetSummary } from "@/lib/types";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import {
|
||||
Label,
|
||||
SelectorFormField,
|
||||
@@ -47,10 +48,6 @@ export interface SlackChannelConfigFormFieldsProps {
|
||||
searchEnabledAssistants: MinimalPersonaSnapshot[];
|
||||
nonSearchAssistants: MinimalPersonaSnapshot[];
|
||||
standardAnswerCategoryResponse: StandardAnswerCategoryResponse;
|
||||
setPopup: (popup: {
|
||||
message: string;
|
||||
type: "error" | "success" | "warning";
|
||||
}) => void;
|
||||
slack_bot_id: number;
|
||||
formikProps: any;
|
||||
}
|
||||
@@ -62,7 +59,6 @@ export function SlackChannelConfigFormFields({
|
||||
searchEnabledAssistants,
|
||||
nonSearchAssistants,
|
||||
standardAnswerCategoryResponse,
|
||||
setPopup,
|
||||
slack_bot_id,
|
||||
formikProps,
|
||||
}: SlackChannelConfigFormFieldsProps) {
|
||||
@@ -142,13 +138,11 @@ export function SlackChannelConfigFormFields({
|
||||
(dsId: number) => !invalidSelected.includes(dsId)
|
||||
)
|
||||
);
|
||||
setPopup({
|
||||
message:
|
||||
"We removed one or more document sets from your selection because they are no longer valid. Please review and update your configuration.",
|
||||
type: "warning",
|
||||
});
|
||||
toast.warning(
|
||||
"We removed one or more document sets from your selection because they are no longer valid. Please review and update your configuration."
|
||||
);
|
||||
}
|
||||
}, [unselectableSets, values.document_sets, setFieldValue, setPopup]);
|
||||
}, [unselectableSets, values.document_sets, setFieldValue]);
|
||||
|
||||
const shouldShowPrivacyAlert = useMemo(() => {
|
||||
if (values.knowledge_source === "document_sets") {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user