Compare commits

...

43 Commits

Author SHA1 Message Date
rohoswagger
bf3c98142d refactor(craft): Zed-style ACP session management — one process per sandbox
Instead of creating a fresh `opencode acp` process per prompt (or caching
one per session), use one long-lived process per sandbox with multiple ACP
sessions. This mirrors how Zed editor manages ACP connections.

Key changes:
- ACPExecClient: start() only does initialize (no session creation),
  new public create_session/resume_session/get_or_create_session methods,
  send_message accepts explicit session_id parameter, tracks multiple
  sessions in a dict
- KubernetesSandboxManager: _acp_clients keyed by sandbox_id (not
  session), _acp_session_ids maps (sandbox_id, session_id) → ACP session
  ID, switching sessions is implicit via sessionId in each prompt
- Tests rewritten for new architecture (8 passing)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 12:04:56 -08:00
rohoswagger
f023599618 fix(craft): add diagnostic logging for hanging prompt debug + silence usage_update
Adds targeted logging to identify why Prompt #2 hangs after usage_update:
- Reader thread: logs buffer state when unterminated data detected
- Reader thread: periodic idle heartbeat every ~5s
- send_message: logs wait state every 3rd keepalive
- Silences usage_update (token stats) in _process_session_update

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 10:49:55 -08:00
rohoswagger
c55cb899f7 fix(craft): handle unterminated ACP messages + clean up diagnostic logging
The reader thread only processed JSON messages terminated with \n. If the
agent's final response lacked a trailing newline, it sat in the buffer
forever causing send_message to hang with keepalives. Added stale buffer
detection that parses unterminated content after ~0.3s of no new data.

Also cleaned up verbose diagnostic logging ([ACP-LIFECYCLE], [ACP-READER],
[ACP-SEND] prefixes) added during debugging — moved per-message noise to
debug level, kept important lifecycle events at info.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 21:46:00 -08:00
rohoswagger
9b8a6e60b7 fix(craft): resume ACP sessions across API replicas for follow-up messages
Multiple API server replicas each maintain independent in-memory ACP client
caches. When a follow-up message is routed to a different replica, it creates
a new opencode session with no conversation context.

Fix: After initializing a new opencode ACP process, try session/list (filtered
by workspace cwd) to discover existing sessions from previous processes, then
session/resume to restore conversation context. Falls back to session/new if
the agent doesn't support these methods or no sessions exist.

Also adds api_pod hostname to SANDBOX-ACP log lines for replica identification.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 17:20:43 -08:00
rohoswagger
dd9d201b51 fix(craft): add extensive diagnostic logging for ACP follow-up messages
- Add [ACP-LIFECYCLE] logs for client start/stop with session IDs
- Add [ACP-READER] logs for every message read from WebSocket with
  update_type, queue size, and ACP session ID
- Add [ACP-SEND] logs for every dequeued message with prompt number,
  completion reason tracking, and queue state
- Add [SANDBOX-ACP] logs for cache hit/miss decisions and PromptResponse
  tracking in the sandbox manager
- Add stderr reading in reader thread to catch opencode errors
- Add queue drain at start of each send_message() to clear stale messages
- Track prompt_count per client to identify 1st vs 2nd+ prompts
- Log completion_reason (jsonrpc_response, notification, timeout, etc.)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 16:49:57 -08:00
rohoswagger
c545819aa6 fix(craft): harden ACP response matching for cached sessions
The cached ACPExecClient could emit keepalives forever on follow-up
messages because the PromptResponse was never matched. This adds:

- ID type-mismatch tolerance (str fallback for int/str id comparison)
- Guard against agent request ID collision (require no "method" field)
- PromptResponse via session/update notification handler
- Reader thread health check (detect dead WebSocket, stop keepalives)
- Buffer flush on reader thread exit (catch trailing PromptResponse)
- Diagnostic logging for every dequeued message and dropped messages

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 15:45:39 -08:00
rohoswagger
960ee228bf fix(craft): cache ACPExecClient in K8s sandbox to fix follow-up messages 2026-02-15 11:14:10 -08:00
Yuhong Sun
dea5be2185 chore: License update (No change, just touchup) (#8460) 2026-02-14 02:44:38 +00:00
Wenxi
d083973d4f chore: disable auto craft animation with feature flag (#8459) 2026-02-14 02:29:37 +00:00
Wenxi
df956888bf fix: bake public recaptcha key in cloud image (#8458) 2026-02-14 02:12:43 +00:00
dependabot[bot]
7c6062e7d5 chore(deps): bump qs from 6.14.1 to 6.14.2 in /web (#8451)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-02-14 02:04:30 +00:00
Yuhong Sun
89d2759021 chore: Remove end of lived backend routes (#8453) 2026-02-14 01:57:06 +00:00
Justin Tahara
d9feaf43a7 chore(playwright): Adding new LLM Runtime tests (#8447) 2026-02-14 01:38:23 +00:00
Nikolas Garza
5bfffefa2f feat(scim): add SCIM filter expression parser with unit tests (#8421) 2026-02-14 01:17:48 +00:00
Nikolas Garza
4d0b7e14d4 feat(scim): add SCIM PATCH operation handler with unit tests (#8422) 2026-02-14 01:12:46 +00:00
Jamison Lahman
36c55d9e59 chore(gha): de-duplicate integration test logic (#8450) 2026-02-14 00:31:31 +00:00
Wenxi
9f652108f9 fix: don't pass captcha token to db (#8449) 2026-02-14 00:20:36 +00:00
victoria reese
d4e4c6b40e feat: add setting to configure mcp host (#8439) 2026-02-13 23:49:18 +00:00
Jamison Lahman
9c8deb5d0c chore(playwright): mask non-deterministic email element (#8448) 2026-02-13 23:37:24 +00:00
Danelegend
58f57c43aa feat(contextual-llm): Populate and set w/ llm flow (#8398) 2026-02-13 23:32:26 +00:00
Evan Lohn
62106df753 fix: sharepoint cred refresh2 (#8445) 2026-02-13 23:05:15 +00:00
Jamison Lahman
45b3a5e945 chore(playwright): include option to hide element in screenshots (#8446) 2026-02-13 22:45:46 +00:00
Jamison Lahman
e19a6b6789 chore(playwright): create new user tests (#8429) 2026-02-13 22:17:18 +00:00
Jamison Lahman
2de7df4839 chore(playwright): login page screenshots (#8427) 2026-02-13 22:01:32 +00:00
victoria reese
bd054bbad9 fix: remove default idleReplicaCount (#8434) 2026-02-13 13:37:19 -08:00
Justin Tahara
313e709d41 fix(celery): Respecting Limits for Celery Heavy Tasks (#8407) 2026-02-13 21:27:04 +00:00
Nikolas Garza
aeb1d6edac feat(scim): add SCIM 2.0 Pydantic schemas (#8420) 2026-02-13 21:21:05 +00:00
Wenxi
49a35f8aaa fix: remove user file indexing from launch, add init imports for all celery tasks, bump sandbox memory limits (#8443) 2026-02-13 21:15:30 +00:00
Danelegend
049e8ef0e2 feat(llm): Populate env w/ custom config (#8328) 2026-02-13 21:11:49 +00:00
Jamison Lahman
3b61b495a3 chore(playwright): tag appearance_theme tests exclusive (#8441) 2026-02-13 21:07:57 +00:00
Wenxi
5c5c9f0e1d feat(airtable): index all and heirarchy for craft (#8414) 2026-02-13 21:03:53 +00:00
Nikolas Garza
f20d5c33b7 feat(scim): add SCIM database models and migration (#8419) 2026-02-13 20:54:56 +00:00
Jamison Lahman
e898407f7b chore(tests): skip yet another test_web_search_api test (#8442) 2026-02-13 12:50:04 -08:00
Jamison Lahman
f802ff09a7 chore(tests): skip additional web_search test (#8440) 2026-02-13 12:29:36 -08:00
Jamison Lahman
69ad712e09 chore(tests): temporarily disable exa tests (#8431) 2026-02-13 11:06:25 -08:00
Jamison Lahman
98b69c0f2c chore(playwright): welcome_page tests & per-element screenshots (#8426) 2026-02-13 10:07:27 -08:00
Raunak Bhagat
1e5c87896f refactor(web): migrate from usePopup/setPopup to global toast system (#8411)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 17:21:14 +00:00
Raunak Bhagat
b6cc97a8c3 fix(web): icon button and timeline header UI fixes (#8416) 2026-02-13 17:20:37 +00:00
Yuhong Sun
032fbf1058 chore: reminder prompt to be moveable (#8417) 2026-02-13 07:39:12 +00:00
SubashMohan
fc32a9f92a fix(memory): memory tool UI and prompt injection issues (#8377) 2026-02-13 04:29:51 +00:00
Jamison Lahman
9be13bbf63 chore(playwright): make screenshots deterministic (#8412) 2026-02-12 19:53:11 -08:00
Yuhong Sun
9e7176eb82 chore: Tiny intro message change (#8415) 2026-02-12 19:44:34 -08:00
Yuhong Sun
c7faf8ce52 chore: Project instructions would get ignored (#8409) 2026-02-13 02:51:13 +00:00
261 changed files with 8313 additions and 5412 deletions

View File

@@ -640,6 +640,7 @@ jobs:
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }}
NEXT_PUBLIC_RECAPTCHA_SITE_KEY=${{ vars.NEXT_PUBLIC_RECAPTCHA_SITE_KEY }}
NEXT_PUBLIC_GTM_ENABLED=true
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
@@ -721,6 +722,7 @@ jobs:
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }}
NEXT_PUBLIC_RECAPTCHA_SITE_KEY=${{ vars.NEXT_PUBLIC_RECAPTCHA_SITE_KEY }}
NEXT_PUBLIC_GTM_ENABLED=true
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true

View File

@@ -46,6 +46,7 @@ jobs:
timeout-minutes: 45
outputs:
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
editions: ${{ steps.set-editions.outputs.editions }}
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
@@ -72,6 +73,16 @@ jobs:
all_dirs="[${all_dirs%,}]"
echo "test-dirs=$all_dirs" >> $GITHUB_OUTPUT
- name: Determine editions to test
id: set-editions
run: |
# On PRs, only run EE tests. On merge_group and tags, run both EE and MIT.
if [ "${{ github.event_name }}" = "pull_request" ]; then
echo 'editions=["ee"]' >> $GITHUB_OUTPUT
else
echo 'editions=["ee","mit"]' >> $GITHUB_OUTPUT
fi
build-backend-image:
runs-on:
[
@@ -267,7 +278,7 @@ jobs:
runs-on:
- runs-on
- runner=4cpu-linux-arm64
- ${{ format('run-id={0}-integration-tests-job-{1}', github.run_id, strategy['job-index']) }}
- ${{ format('run-id={0}-integration-tests-{1}-job-{2}', github.run_id, matrix.edition, strategy['job-index']) }}
- extras=ecr-cache
timeout-minutes: 45
@@ -275,6 +286,7 @@ jobs:
fail-fast: false
matrix:
test-dir: ${{ fromJson(needs.discover-test-dirs.outputs.test-dirs) }}
edition: ${{ fromJson(needs.discover-test-dirs.outputs.editions) }}
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
@@ -298,12 +310,11 @@ jobs:
env:
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
RUN_ID: ${{ github.run_id }}
EDITION: ${{ matrix.edition }}
run: |
# Base config shared by both editions
cat <<EOF > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
LICENSE_ENFORCEMENT_ENABLED=false
AUTH_TYPE=basic
POSTGRES_POOL_PRE_PING=true
POSTGRES_USE_NULL_POOL=true
@@ -312,11 +323,20 @@ jobs:
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
INTEGRATION_TESTS_MODE=true
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
MCP_SERVER_ENABLED=true
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
EOF
# EE-only config
if [ "$EDITION" = "ee" ]; then
cat <<EOF >> deployment/docker_compose/.env
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
LICENSE_ENFORCEMENT_ENABLED=false
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
EOF
fi
- name: Start Docker containers
run: |
@@ -379,14 +399,14 @@ jobs:
docker compose -f docker-compose.mock-it-services.yml \
-p mock-it-services-stack up -d
- name: Run Integration Tests for ${{ matrix.test-dir.name }}
- name: Run Integration Tests (${{ matrix.edition }}) for ${{ matrix.test-dir.name }}
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
with:
timeout_minutes: 20
max_attempts: 3
retry_wait_seconds: 10
command: |
echo "Running integration tests for ${{ matrix.test-dir.path }}..."
echo "Running ${{ matrix.edition }} integration tests for ${{ matrix.test-dir.path }}..."
docker run --rm --network onyx_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
@@ -444,7 +464,7 @@ jobs:
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: docker-all-logs-${{ matrix.test-dir.name }}
name: docker-all-logs-${{ matrix.edition }}-${{ matrix.test-dir.name }}
path: ${{ github.workspace }}/docker-compose.log
# ------------------------------------------------------------

View File

@@ -1,443 +0,0 @@
name: Run MIT Integration Tests v2
concurrency:
group: Run-MIT-Integration-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
merge_group:
types: [checks_requested]
push:
tags:
- "v*.*.*"
permissions:
contents: read
env:
# Test Environment Variables
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
jobs:
discover-test-dirs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 45
outputs:
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Discover test directories
id: set-matrix
run: |
# Find all leaf-level directories in both test directories
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" ! -name "mcp" -exec basename {} \; | sort)
connector_dirs=$(find backend/tests/integration/connector_job_tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
# Create JSON array with directory info
all_dirs=""
for dir in $tests_dirs; do
all_dirs="$all_dirs{\"path\":\"tests/$dir\",\"name\":\"tests-$dir\"},"
done
for dir in $connector_dirs; do
all_dirs="$all_dirs{\"path\":\"connector_job_tests/$dir\",\"name\":\"connector-$dir\"},"
done
# Remove trailing comma and wrap in array
all_dirs="[${all_dirs%,}]"
echo "test-dirs=$all_dirs" >> $GITHUB_OUTPUT
build-backend-image:
runs-on:
[
runs-on,
runner=1cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-backend-image",
"extras=ecr-cache",
]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Format branch name for cache
id: format-branch
env:
PR_NUMBER: ${{ github.event.pull_request.number }}
REF_NAME: ${{ github.ref_name }}
run: |
if [ -n "${PR_NUMBER}" ]; then
CACHE_SUFFIX="${PR_NUMBER}"
else
# shellcheck disable=SC2001
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
fi
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push Backend Docker image
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile
push: true
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ github.event.pull_request.head.sha || github.sha }}
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }}
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache
type=registry,ref=onyxdotapp/onyx-backend:latest
cache-to: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache,mode=max
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
build-model-server-image:
runs-on:
[
runs-on,
runner=1cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-model-server-image",
"extras=ecr-cache",
]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Format branch name for cache
id: format-branch
env:
PR_NUMBER: ${{ github.event.pull_request.number }}
REF_NAME: ${{ github.ref_name }}
run: |
if [ -n "${PR_NUMBER}" ]; then
CACHE_SUFFIX="${PR_NUMBER}"
else
# shellcheck disable=SC2001
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
fi
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push Model Server Docker image
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile.model_server
push: true
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }}
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }}
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache
type=registry,ref=onyxdotapp/onyx-model-server:latest
cache-to: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
build-integration-image:
runs-on:
[
runs-on,
runner=2cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-integration-image",
"extras=ecr-cache",
]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Format branch name for cache
id: format-branch
env:
PR_NUMBER: ${{ github.event.pull_request.number }}
REF_NAME: ${{ github.ref_name }}
run: |
if [ -n "${PR_NUMBER}" ]; then
CACHE_SUFFIX="${PR_NUMBER}"
else
# shellcheck disable=SC2001
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
fi
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling openapitools/openapi-generator-cli
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push integration test image with Docker Bake
env:
INTEGRATION_REPOSITORY: ${{ env.RUNS_ON_ECR_CACHE }}
TAG: integration-test-${{ github.run_id }}
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
HEAD_SHA: ${{ github.event.pull_request.head.sha || github.sha }}
run: |
docker buildx bake --push \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA} \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX} \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
--set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA},mode=max \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX},mode=max \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA} \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX} \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA},mode=max \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX},mode=max \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \
integration
integration-tests-mit:
needs:
[
discover-test-dirs,
build-backend-image,
build-model-server-image,
build-integration-image,
]
runs-on:
- runs-on
- runner=4cpu-linux-arm64
- ${{ format('run-id={0}-integration-tests-mit-job-{1}', github.run_id, strategy['job-index']) }}
- extras=ecr-cache
timeout-minutes: 45
strategy:
fail-fast: false
matrix:
test-dir: ${{ fromJson(needs.discover-test-dirs.outputs.test-dirs) }}
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
# NOTE: don't need web server for integration tests
- name: Create .env file for Docker Compose
env:
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
RUN_ID: ${{ github.run_id }}
run: |
cat <<EOF > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore
AUTH_TYPE=basic
POSTGRES_POOL_PRE_PING=true
POSTGRES_USE_NULL_POOL=true
REQUIRE_EMAIL_VERIFICATION=false
DISABLE_TELEMETRY=true
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
INTEGRATION_TESTS_MODE=true
MCP_SERVER_ENABLED=true
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
EOF
- name: Start Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
relational_db \
index \
cache \
minio \
api_server \
inference_model_server \
indexing_model_server \
background \
-d
id: start_docker
- name: Wait for services to be ready
run: |
echo "Starting wait-for-service script..."
wait_for_service() {
local url=$1
local label=$2
local timeout=${3:-300} # default 5 minutes
local start_time
start_time=$(date +%s)
while true; do
local current_time
current_time=$(date +%s)
local elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. ${label} did not become ready in $timeout seconds."
exit 1
fi
local response
response=$(curl -s -o /dev/null -w "%{http_code}" "$url" || echo "curl_error")
if [ "$response" = "200" ]; then
echo "${label} is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error while checking ${label}. Retrying in 5 seconds..."
else
echo "${label} not ready yet (HTTP status $response). Retrying in 5 seconds..."
fi
sleep 5
done
}
wait_for_service "http://localhost:8080/health" "API server"
echo "Finished waiting for services."
- name: Start Mock Services
run: |
cd backend/tests/integration/mock_services
docker compose -f docker-compose.mock-it-services.yml \
-p mock-it-services-stack up -d
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
- name: Run Integration Tests for ${{ matrix.test-dir.name }}
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
with:
timeout_minutes: 20
max_attempts: 3
retry_wait_seconds: 10
command: |
echo "Running integration tests for ${{ matrix.test-dir.path }}..."
docker run --rm --network onyx_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e DB_READONLY_USER=db_readonly_user \
-e DB_READONLY_PASSWORD=password \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e EXA_API_KEY=${EXA_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e CONFLUENCE_ACCESS_TOKEN_SCOPED=${CONFLUENCE_ACCESS_TOKEN_SCOPED} \
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
-e JIRA_API_TOKEN_SCOPED=${JIRA_API_TOKEN_SCOPED} \
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
/app/tests/integration/${{ matrix.test-dir.path }}
# ------------------------------------------------------------
# Always gather logs BEFORE "down":
- name: Dump API server logs
if: always()
run: |
cd deployment/docker_compose
docker compose logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
- name: Dump all-container logs (optional)
if: always()
run: |
cd deployment/docker_compose
docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: docker-all-logs-${{ matrix.test-dir.name }}
path: ${{ github.workspace }}/docker-compose.log
# ------------------------------------------------------------
required:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 45
needs: [integration-tests-mit]
if: ${{ always() }}
steps:
- name: Check job status
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
run: exit 1

4
.vscode/launch.json vendored
View File

@@ -275,7 +275,7 @@
"--loglevel=INFO",
"--hostname=background@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,user_files_indexing,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration"
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration"
],
"presentation": {
"group": "2"
@@ -419,7 +419,7 @@
"--loglevel=INFO",
"--hostname=docfetching@%n",
"-Q",
"connector_doc_fetching,user_files_indexing"
"connector_doc_fetching"
],
"presentation": {
"group": "2"

View File

@@ -2,7 +2,10 @@ Copyright (c) 2023-present DanswerAI, Inc.
Portions of this software are licensed as follows:
- All content that resides under "ee" directories of this repository, if that directory exists, is licensed under the license defined in "backend/ee/LICENSE". Specifically all content under "backend/ee" and "web/src/app/ee" is licensed under the license defined in "backend/ee/LICENSE".
- All content that resides under "ee" directories of this repository is licensed under the Onyx Enterprise License. Each ee directory contains an identical copy of this license at its root:
- backend/ee/LICENSE
- web/src/app/ee/LICENSE
- web/src/ee/LICENSE
- All third party components incorporated into the Onyx Software are licensed under the original license provided by the owner of the applicable component.
- Content outside of the above mentioned directories or restrictions above is available under the "MIT Expat" license as defined below.

View File

@@ -0,0 +1,71 @@
"""Migrate to contextual rag model
Revision ID: 19c0ccb01687
Revises: 9c54986124c6
Create Date: 2026-02-12 11:21:41.798037
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "19c0ccb01687"
down_revision = "9c54986124c6"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Widen the column to fit 'CONTEXTUAL_RAG' (15 chars); was varchar(10)
# when the table was created with only CHAT/VISION values.
op.alter_column(
"llm_model_flow",
"llm_model_flow_type",
type_=sa.String(length=20),
existing_type=sa.String(length=10),
existing_nullable=False,
)
# For every search_settings row that has contextual rag configured,
# create an llm_model_flow entry. is_default is TRUE if the row
# belongs to the PRESENT search settings, FALSE otherwise.
op.execute(
"""
INSERT INTO llm_model_flow (llm_model_flow_type, model_configuration_id, is_default)
SELECT DISTINCT
'CONTEXTUAL_RAG',
mc.id,
(ss.status = 'PRESENT')
FROM search_settings ss
JOIN llm_provider lp
ON lp.name = ss.contextual_rag_llm_provider
JOIN model_configuration mc
ON mc.llm_provider_id = lp.id
AND mc.name = ss.contextual_rag_llm_name
WHERE ss.enable_contextual_rag = TRUE
AND ss.contextual_rag_llm_name IS NOT NULL
AND ss.contextual_rag_llm_provider IS NOT NULL
ON CONFLICT (llm_model_flow_type, model_configuration_id)
DO UPDATE SET is_default = EXCLUDED.is_default
WHERE EXCLUDED.is_default = TRUE
"""
)
def downgrade() -> None:
op.execute(
"""
DELETE FROM llm_model_flow
WHERE llm_model_flow_type = 'CONTEXTUAL_RAG'
"""
)
op.alter_column(
"llm_model_flow",
"llm_model_flow_type",
type_=sa.String(length=10),
existing_type=sa.String(length=20),
existing_nullable=False,
)

View File

@@ -0,0 +1,124 @@
"""add_scim_tables
Revision ID: 9c54986124c6
Revises: b51c6844d1df
Create Date: 2026-02-12 20:29:47.448614
"""
from alembic import op
import fastapi_users_db_sqlalchemy
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "9c54986124c6"
down_revision = "b51c6844d1df"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"scim_token",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("hashed_token", sa.String(length=64), nullable=False),
sa.Column("token_display", sa.String(), nullable=False),
sa.Column(
"created_by_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=False,
),
sa.Column(
"is_active",
sa.Boolean(),
server_default=sa.text("true"),
nullable=False,
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(["created_by_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("hashed_token"),
)
op.create_table(
"scim_group_mapping",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("external_id", sa.String(), nullable=False),
sa.Column("user_group_id", sa.Integer(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["user_group_id"], ["user_group.id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("user_group_id"),
)
op.create_index(
op.f("ix_scim_group_mapping_external_id"),
"scim_group_mapping",
["external_id"],
unique=True,
)
op.create_table(
"scim_user_mapping",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("external_id", sa.String(), nullable=False),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=False,
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("user_id"),
)
op.create_index(
op.f("ix_scim_user_mapping_external_id"),
"scim_user_mapping",
["external_id"],
unique=True,
)
def downgrade() -> None:
op.drop_index(
op.f("ix_scim_user_mapping_external_id"),
table_name="scim_user_mapping",
)
op.drop_table("scim_user_mapping")
op.drop_index(
op.f("ix_scim_group_mapping_external_id"),
table_name="scim_group_mapping",
)
op.drop_table("scim_group_mapping")
op.drop_table("scim_token")

View File

@@ -1,20 +1,20 @@
The DanswerAI Enterprise license (the Enterprise License)
The Onyx Enterprise License (the "Enterprise License")
Copyright (c) 2023-present DanswerAI, Inc.
With regard to the Onyx Software:
This software and associated documentation files (the "Software") may only be
used in production, if you (and any entity that you represent) have agreed to,
and are in compliance with, the DanswerAI Subscription Terms of Service, available
at https://onyx.app/terms (the Enterprise Terms), or other
and are in compliance with, the Onyx Subscription Terms of Service, available
at https://www.onyx.app/legal/self-host (the "Enterprise Terms"), or other
agreement governing the use of the Software, as agreed by you and DanswerAI,
and otherwise have a valid Onyx Enterprise license for the
and otherwise have a valid Onyx Enterprise License for the
correct number of user seats. Subject to the foregoing sentence, you are free to
modify this Software and publish patches to the Software. You agree that DanswerAI
and/or its licensors (as applicable) retain all right, title and interest in and
to all such modifications and/or patches, and all such modifications and/or
patches may only be used, copied, modified, displayed, distributed, or otherwise
exploited with a valid Onyx Enterprise license for the correct
exploited with a valid Onyx Enterprise License for the correct
number of user seats. Notwithstanding the foregoing, you may copy and modify
the Software for development and testing purposes, without requiring a
subscription. You agree that DanswerAI and/or its licensors (as applicable) retain

View File

@@ -536,7 +536,9 @@ def connector_permission_sync_generator_task(
)
redis_connector.permissions.set_fence(new_payload)
callback = PermissionSyncCallback(redis_connector, lock, r)
callback = PermissionSyncCallback(
redis_connector, lock, r, timeout_seconds=JOB_TIMEOUT
)
# pass in the capability to fetch all existing docs for the cc_pair
# this is can be used to determine documents that are "missing" and thus
@@ -576,6 +578,13 @@ def connector_permission_sync_generator_task(
tasks_generated = 0
docs_with_errors = 0
for doc_external_access in document_external_accesses:
if callback.should_stop():
raise RuntimeError(
f"Permission sync task timed out or stop signal detected: "
f"cc_pair={cc_pair_id} "
f"tasks_generated={tasks_generated}"
)
result = redis_connector.permissions.update_db(
lock=lock,
new_permissions=[doc_external_access],
@@ -932,6 +941,7 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
redis_connector: RedisConnector,
redis_lock: RedisLock,
redis_client: Redis,
timeout_seconds: int | None = None,
):
super().__init__()
self.redis_connector: RedisConnector = redis_connector
@@ -944,11 +954,26 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
self.last_tag: str = "PermissionSyncCallback.__init__"
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
self.last_lock_monotonic = time.monotonic()
self.start_monotonic = time.monotonic()
self.timeout_seconds = timeout_seconds
def should_stop(self) -> bool:
if self.redis_connector.stop.fenced:
return True
# Check if the task has exceeded its timeout
# NOTE: Celery's soft_time_limit does not work with thread pools,
# so we must enforce timeouts internally.
if self.timeout_seconds is not None:
elapsed = time.monotonic() - self.start_monotonic
if elapsed > self.timeout_seconds:
logger.warning(
f"PermissionSyncCallback - task timeout exceeded: "
f"elapsed={elapsed:.0f}s timeout={self.timeout_seconds}s "
f"cc_pair={self.redis_connector.cc_pair_id}"
)
return True
return False
def progress(self, tag: str, amount: int) -> None: # noqa: ARG002

View File

@@ -466,6 +466,7 @@ def connector_external_group_sync_generator_task(
def _perform_external_group_sync(
cc_pair_id: int,
tenant_id: str,
timeout_seconds: int = JOB_TIMEOUT,
) -> None:
# Create attempt record at the start
with get_session_with_current_tenant() as db_session:
@@ -518,9 +519,23 @@ def _perform_external_group_sync(
seen_users: set[str] = set() # Track unique users across all groups
total_groups_processed = 0
total_group_memberships_synced = 0
start_time = time.monotonic()
try:
external_user_group_generator = ext_group_sync_func(tenant_id, cc_pair)
for external_user_group in external_user_group_generator:
# Check if the task has exceeded its timeout
# NOTE: Celery's soft_time_limit does not work with thread pools,
# so we must enforce timeouts internally.
elapsed = time.monotonic() - start_time
if elapsed > timeout_seconds:
raise RuntimeError(
f"External group sync task timed out: "
f"cc_pair={cc_pair_id} "
f"elapsed={elapsed:.0f}s "
f"timeout={timeout_seconds}s "
f"groups_processed={total_groups_processed}"
)
external_user_group_batch.append(external_user_group)
# Track progress

View File

@@ -1,12 +1,9 @@
from collections.abc import Generator
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.sharepoint.permission_utils import (
get_sharepoint_external_groups,
)
from onyx.connectors.sharepoint.connector import acquire_token_for_rest
from onyx.connectors.sharepoint.connector import SharepointConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.utils.logger import setup_logger
@@ -46,16 +43,11 @@ def sharepoint_group_sync(
logger.info(f"Processing {len(site_descriptors)} sites for group sync")
msal_app = connector.msal_app
sp_tenant_domain = connector.sp_tenant_domain
# Process each site
for site_descriptor in site_descriptors:
logger.debug(f"Processing site: {site_descriptor.url}")
# Create client context for the site using connector's MSAL app
ctx = ClientContext(site_descriptor.url).with_access_token(
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
)
ctx = connector._create_rest_client_context(site_descriptor.url)
# Get external groups for this site
external_groups = get_sharepoint_external_groups(ctx, connector.graph_client)

View File

@@ -27,6 +27,8 @@ class SearchFlowClassificationResponse(BaseModel):
is_search_flow: bool
# NOTE: This model is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
# experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
class SendSearchQueryRequest(BaseModel):
search_query: str
filters: BaseFilters | None = None

View File

@@ -67,6 +67,8 @@ def search_flow_classification(
return SearchFlowClassificationResponse(is_search_flow=is_search_flow)
# NOTE: This endpoint is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
# experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
@router.post(
"/send-search-message",
response_model=None,

View File

View File

@@ -0,0 +1,96 @@
"""SCIM filter expression parser (RFC 7644 §3.4.2.2).
Identity providers (Okta, Azure AD, OneLogin, etc.) use filters to look up
resources before deciding whether to create or update them. For example, when
an admin assigns a user to the Onyx app, the IdP first checks whether that
user already exists::
GET /scim/v2/Users?filter=userName eq "john@example.com"
If zero results come back the IdP creates the user (``POST``); if a match is
found it links to the existing record and uses ``PUT``/``PATCH`` going forward.
The same pattern applies to groups (``displayName eq "Engineering"``).
This module parses the subset of the SCIM filter grammar that identity
providers actually send in practice:
attribute SP operator SP value
Supported operators: ``eq``, ``co`` (contains), ``sw`` (starts with).
Compound filters (``and`` / ``or``) are not supported; if an IdP sends one
the parser returns ``None`` and the caller falls back to an unfiltered list.
"""
from __future__ import annotations
import re
from dataclasses import dataclass
from enum import Enum
class ScimFilterOperator(str, Enum):
"""Supported SCIM filter operators."""
EQUAL = "eq"
CONTAINS = "co"
STARTS_WITH = "sw"
@dataclass(frozen=True, slots=True)
class ScimFilter:
"""Parsed SCIM filter expression."""
attribute: str
operator: ScimFilterOperator
value: str
# Matches: attribute operator "value" (with or without quotes around value)
# Groups: (attribute) (operator) ("quoted value" | unquoted_value)
_FILTER_RE = re.compile(
r"^(\S+)\s+(eq|co|sw)\s+" # attribute + operator
r'(?:"([^"]*)"' # quoted value
r"|'([^']*)')" # or single-quoted value
r"$",
re.IGNORECASE,
)
def parse_scim_filter(filter_string: str | None) -> ScimFilter | None:
"""Parse a simple SCIM filter expression.
Args:
filter_string: Raw filter query parameter value, e.g.
``'userName eq "john@example.com"'``
Returns:
A ``ScimFilter`` if the expression is valid and uses a supported
operator, or ``None`` if the input is empty / missing.
Raises:
ValueError: If the filter string is present but malformed or uses
an unsupported operator.
"""
if not filter_string or not filter_string.strip():
return None
match = _FILTER_RE.match(filter_string.strip())
if not match:
raise ValueError(f"Unsupported or malformed SCIM filter: {filter_string}")
return _build_filter(match, filter_string)
def _build_filter(match: re.Match[str], raw: str) -> ScimFilter:
"""Extract fields from a regex match and construct a ScimFilter."""
attribute = match.group(1)
op_str = match.group(2).lower()
# Value is in group 3 (double-quoted) or group 4 (single-quoted)
value = match.group(3) if match.group(3) is not None else match.group(4)
if value is None:
raise ValueError(f"Unsupported or malformed SCIM filter: {raw}")
operator = ScimFilterOperator(op_str)
return ScimFilter(attribute=attribute, operator=operator, value=value)

View File

@@ -0,0 +1,255 @@
"""Pydantic schemas for SCIM 2.0 provisioning (RFC 7643 / RFC 7644).
SCIM protocol schemas follow the wire format defined in:
- Core Schema: https://datatracker.ietf.org/doc/html/rfc7643
- Protocol: https://datatracker.ietf.org/doc/html/rfc7644
Admin API schemas are internal to Onyx and used for SCIM token management.
"""
from datetime import datetime
from enum import Enum
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
# ---------------------------------------------------------------------------
# SCIM Schema URIs (RFC 7643 §8)
# Every SCIM JSON payload includes a "schemas" array identifying its type.
# IdPs like Okta/Azure AD use these URIs to determine how to parse responses.
# ---------------------------------------------------------------------------
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group"
SCIM_LIST_RESPONSE_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:ListResponse"
SCIM_PATCH_OP_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:PatchOp"
SCIM_ERROR_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:Error"
SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA = (
"urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"
)
SCIM_RESOURCE_TYPE_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:ResourceType"
# ---------------------------------------------------------------------------
# SCIM Protocol Schemas
# ---------------------------------------------------------------------------
class ScimName(BaseModel):
"""User name components (RFC 7643 §4.1.1)."""
givenName: str | None = None
familyName: str | None = None
formatted: str | None = None
class ScimEmail(BaseModel):
"""Email sub-attribute (RFC 7643 §4.1.2)."""
value: str
type: str | None = None
primary: bool = False
class ScimMeta(BaseModel):
"""Resource metadata (RFC 7643 §3.1)."""
resourceType: str | None = None
created: datetime | None = None
lastModified: datetime | None = None
location: str | None = None
class ScimUserResource(BaseModel):
"""SCIM User resource representation (RFC 7643 §4.1).
This is the JSON shape that IdPs send when creating/updating a user via
SCIM, and the shape we return in GET responses. Field names use camelCase
to match the SCIM wire format (not Python convention).
"""
schemas: list[str] = Field(default_factory=lambda: [SCIM_USER_SCHEMA])
id: str | None = None # Onyx's internal user ID, set on responses
externalId: str | None = None # IdP's identifier for this user
userName: str # Typically the user's email address
name: ScimName | None = None
emails: list[ScimEmail] = Field(default_factory=list)
active: bool = True
meta: ScimMeta | None = None
class ScimGroupMember(BaseModel):
"""Group member reference (RFC 7643 §4.2).
Represents a user within a SCIM group. The IdP sends these when adding
or removing users from groups. ``value`` is the Onyx user ID.
"""
value: str # User ID of the group member
display: str | None = None
class ScimGroupResource(BaseModel):
"""SCIM Group resource representation (RFC 7643 §4.2)."""
schemas: list[str] = Field(default_factory=lambda: [SCIM_GROUP_SCHEMA])
id: str | None = None
externalId: str | None = None
displayName: str
members: list[ScimGroupMember] = Field(default_factory=list)
meta: ScimMeta | None = None
class ScimListResponse(BaseModel):
"""Paginated list response (RFC 7644 §3.4.2)."""
schemas: list[str] = Field(default_factory=lambda: [SCIM_LIST_RESPONSE_SCHEMA])
totalResults: int
startIndex: int = 1
itemsPerPage: int = 100
Resources: list[ScimUserResource | ScimGroupResource] = Field(default_factory=list)
class ScimPatchOperationType(str, Enum):
"""Supported PATCH operations (RFC 7644 §3.5.2)."""
ADD = "add"
REPLACE = "replace"
REMOVE = "remove"
class ScimPatchOperation(BaseModel):
"""Single PATCH operation (RFC 7644 §3.5.2)."""
op: ScimPatchOperationType
path: str | None = None
value: str | list[dict[str, str]] | dict[str, str | bool] | bool | None = None
class ScimPatchRequest(BaseModel):
"""PATCH request body (RFC 7644 §3.5.2).
IdPs use PATCH to make incremental changes — e.g. deactivating a user
(replace active=false) or adding/removing group members — instead of
replacing the entire resource with PUT.
"""
schemas: list[str] = Field(default_factory=lambda: [SCIM_PATCH_OP_SCHEMA])
Operations: list[ScimPatchOperation]
class ScimError(BaseModel):
"""SCIM error response (RFC 7644 §3.12)."""
schemas: list[str] = Field(default_factory=lambda: [SCIM_ERROR_SCHEMA])
status: str
detail: str | None = None
scimType: str | None = None
# ---------------------------------------------------------------------------
# Service Provider Configuration (RFC 7643 §5)
# ---------------------------------------------------------------------------
class ScimSupported(BaseModel):
"""Generic supported/not-supported flag used in ServiceProviderConfig."""
supported: bool
class ScimFilterConfig(BaseModel):
"""Filter configuration within ServiceProviderConfig (RFC 7643 §5)."""
supported: bool
maxResults: int = 100
class ScimServiceProviderConfig(BaseModel):
"""SCIM ServiceProviderConfig resource (RFC 7643 §5).
Served at GET /scim/v2/ServiceProviderConfig. IdPs fetch this during
initial setup to discover which SCIM features our server supports
(e.g. PATCH yes, bulk no, filtering yes).
"""
schemas: list[str] = Field(
default_factory=lambda: [SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA]
)
patch: ScimSupported = ScimSupported(supported=True)
bulk: ScimSupported = ScimSupported(supported=False)
filter: ScimFilterConfig = ScimFilterConfig(supported=True)
changePassword: ScimSupported = ScimSupported(supported=False)
sort: ScimSupported = ScimSupported(supported=False)
etag: ScimSupported = ScimSupported(supported=False)
authenticationSchemes: list[dict[str, str]] = Field(
default_factory=lambda: [
{
"type": "oauthbearertoken",
"name": "OAuth Bearer Token",
"description": "Authentication scheme using a SCIM bearer token",
}
]
)
class ScimSchemaExtension(BaseModel):
"""Schema extension reference within ResourceType (RFC 7643 §6)."""
model_config = ConfigDict(populate_by_name=True)
schema_: str = Field(alias="schema")
required: bool
class ScimResourceType(BaseModel):
"""SCIM ResourceType resource (RFC 7643 §6).
Served at GET /scim/v2/ResourceTypes. Tells the IdP which resource
types are available (Users, Groups) and their respective endpoints.
"""
model_config = ConfigDict(populate_by_name=True)
schemas: list[str] = Field(default_factory=lambda: [SCIM_RESOURCE_TYPE_SCHEMA])
id: str
name: str
endpoint: str
description: str | None = None
schema_: str = Field(alias="schema")
schemaExtensions: list[ScimSchemaExtension] = Field(default_factory=list)
# ---------------------------------------------------------------------------
# Admin API Schemas (Onyx-internal, for SCIM token management)
# These are NOT part of the SCIM protocol. They power the Onyx admin UI
# where admins create/revoke the bearer tokens that IdPs use to authenticate.
# ---------------------------------------------------------------------------
class ScimTokenCreate(BaseModel):
"""Request to create a new SCIM bearer token."""
name: str
class ScimTokenResponse(BaseModel):
"""SCIM token metadata returned in list/get responses."""
id: int
name: str
token_display: str
is_active: bool
created_at: datetime
last_used_at: datetime | None = None
class ScimTokenCreatedResponse(ScimTokenResponse):
"""Response returned when a new SCIM token is created.
Includes the raw token value which is only available at creation time.
"""
raw_token: str

View File

@@ -0,0 +1,256 @@
"""SCIM PATCH operation handler (RFC 7644 §3.5.2).
Identity providers use PATCH to make incremental changes to SCIM resources
instead of replacing the entire resource with PUT. Common operations include:
- Deactivating a user: ``replace`` ``active`` with ``false``
- Adding group members: ``add`` to ``members``
- Removing group members: ``remove`` from ``members[value eq "..."]``
This module applies PATCH operations to Pydantic SCIM resource objects and
returns the modified result. It does NOT touch the database — the caller is
responsible for persisting changes.
"""
from __future__ import annotations
import re
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimPatchOperation
from ee.onyx.server.scim.models import ScimPatchOperationType
from ee.onyx.server.scim.models import ScimUserResource
class ScimPatchError(Exception):
"""Raised when a PATCH operation cannot be applied."""
def __init__(self, detail: str, status: int = 400) -> None:
self.detail = detail
self.status = status
super().__init__(detail)
# Pattern for member removal path: members[value eq "user-id"]
_MEMBER_FILTER_RE = re.compile(
r'^members\[value\s+eq\s+"([^"]+)"\]$',
re.IGNORECASE,
)
def apply_user_patch(
operations: list[ScimPatchOperation],
current: ScimUserResource,
) -> ScimUserResource:
"""Apply SCIM PATCH operations to a user resource.
Returns a new ``ScimUserResource`` with the modifications applied.
The original object is not mutated.
Raises:
ScimPatchError: If an operation targets an unsupported path.
"""
data = current.model_dump()
name_data = data.get("name") or {}
for op in operations:
if op.op == ScimPatchOperationType.REPLACE:
_apply_user_replace(op, data, name_data)
elif op.op == ScimPatchOperationType.ADD:
_apply_user_replace(op, data, name_data)
else:
raise ScimPatchError(
f"Unsupported operation '{op.op.value}' on User resource"
)
data["name"] = name_data
return ScimUserResource.model_validate(data)
def _apply_user_replace(
op: ScimPatchOperation,
data: dict,
name_data: dict,
) -> None:
"""Apply a replace/add operation to user data."""
path = (op.path or "").lower()
if not path:
# No path — value is a dict of top-level attributes to set
if isinstance(op.value, dict):
for key, val in op.value.items():
_set_user_field(key.lower(), val, data, name_data)
else:
raise ScimPatchError("Replace without path requires a dict value")
return
_set_user_field(path, op.value, data, name_data)
def _set_user_field(
path: str,
value: str | bool | dict | list | None,
data: dict,
name_data: dict,
) -> None:
"""Set a single field on user data by SCIM path."""
if path == "active":
data["active"] = value
elif path == "username":
data["userName"] = value
elif path == "externalid":
data["externalId"] = value
elif path == "name.givenname":
name_data["givenName"] = value
elif path == "name.familyname":
name_data["familyName"] = value
elif path == "name.formatted":
name_data["formatted"] = value
elif path == "displayname":
# Some IdPs send displayName on users; map to formatted name
name_data["formatted"] = value
else:
raise ScimPatchError(f"Unsupported path '{path}' for User PATCH")
def apply_group_patch(
operations: list[ScimPatchOperation],
current: ScimGroupResource,
) -> tuple[ScimGroupResource, list[str], list[str]]:
"""Apply SCIM PATCH operations to a group resource.
Returns:
A tuple of (modified group, added member IDs, removed member IDs).
The caller uses the member ID lists to update the database.
Raises:
ScimPatchError: If an operation targets an unsupported path.
"""
data = current.model_dump()
current_members: list[dict] = list(data.get("members") or [])
added_ids: list[str] = []
removed_ids: list[str] = []
for op in operations:
if op.op == ScimPatchOperationType.REPLACE:
_apply_group_replace(op, data, current_members, added_ids, removed_ids)
elif op.op == ScimPatchOperationType.ADD:
_apply_group_add(op, current_members, added_ids)
elif op.op == ScimPatchOperationType.REMOVE:
_apply_group_remove(op, current_members, removed_ids)
else:
raise ScimPatchError(
f"Unsupported operation '{op.op.value}' on Group resource"
)
data["members"] = current_members
group = ScimGroupResource.model_validate(data)
return group, added_ids, removed_ids
def _apply_group_replace(
op: ScimPatchOperation,
data: dict,
current_members: list[dict],
added_ids: list[str],
removed_ids: list[str],
) -> None:
"""Apply a replace operation to group data."""
path = (op.path or "").lower()
if not path:
if isinstance(op.value, dict):
for key, val in op.value.items():
if key.lower() == "members":
_replace_members(val, current_members, added_ids, removed_ids)
else:
_set_group_field(key.lower(), val, data)
else:
raise ScimPatchError("Replace without path requires a dict value")
return
if path == "members":
_replace_members(op.value, current_members, added_ids, removed_ids)
return
_set_group_field(path, op.value, data)
def _replace_members(
value: str | list | dict | bool | None,
current_members: list[dict],
added_ids: list[str],
removed_ids: list[str],
) -> None:
"""Replace the entire group member list."""
if not isinstance(value, list):
raise ScimPatchError("Replace members requires a list value")
old_ids = {m["value"] for m in current_members}
new_ids = {m.get("value", "") for m in value}
removed_ids.extend(old_ids - new_ids)
added_ids.extend(new_ids - old_ids)
current_members[:] = value
def _set_group_field(
path: str,
value: str | bool | dict | list | None,
data: dict,
) -> None:
"""Set a single field on group data by SCIM path."""
if path == "displayname":
data["displayName"] = value
elif path == "externalid":
data["externalId"] = value
else:
raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH")
def _apply_group_add(
op: ScimPatchOperation,
members: list[dict],
added_ids: list[str],
) -> None:
"""Add members to a group."""
path = (op.path or "").lower()
if path and path != "members":
raise ScimPatchError(f"Unsupported add path '{op.path}' for Group")
if not isinstance(op.value, list):
raise ScimPatchError("Add members requires a list value")
existing_ids = {m["value"] for m in members}
for member_data in op.value:
member_id = member_data.get("value", "")
if member_id and member_id not in existing_ids:
members.append(member_data)
added_ids.append(member_id)
existing_ids.add(member_id)
def _apply_group_remove(
op: ScimPatchOperation,
members: list[dict],
removed_ids: list[str],
) -> None:
"""Remove members from a group."""
if not op.path:
raise ScimPatchError("Remove operation requires a path")
match = _MEMBER_FILTER_RE.match(op.path)
if not match:
raise ScimPatchError(
f"Unsupported remove path '{op.path}'. "
'Expected: members[value eq "user-id"]'
)
target_id = match.group(1)
original_len = len(members)
members[:] = [m for m in members if m.get("value") != target_id]
if len(members) < original_len:
removed_ids.append(target_id)

View File

@@ -1,7 +1,9 @@
import uuid
from enum import Enum
from typing import Any
from fastapi_users import schemas
from typing_extensions import override
class UserRole(str, Enum):
@@ -41,8 +43,21 @@ class UserCreate(schemas.BaseUserCreate):
role: UserRole = UserRole.BASIC
tenant_id: str | None = None
# Captcha token for cloud signup protection (optional, only used when captcha is enabled)
# Excluded from create_update_dict so it never reaches the DB layer
captcha_token: str | None = None
@override
def create_update_dict(self) -> dict[str, Any]:
d = super().create_update_dict()
d.pop("captcha_token", None)
return d
@override
def create_update_dict_superuser(self) -> dict[str, Any]:
d = super().create_update_dict_superuser()
d.pop("captcha_token", None)
return d
class UserUpdateWithRole(schemas.BaseUserUpdate):
role: UserRole

View File

@@ -37,6 +37,7 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
redis_connector: RedisConnector,
redis_lock: RedisLock,
redis_client: Redis,
timeout_seconds: int | None = None,
):
super().__init__()
self.parent_pid = parent_pid
@@ -51,11 +52,29 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
self.last_lock_monotonic = time.monotonic()
self.last_parent_check = time.monotonic()
self.start_monotonic = time.monotonic()
self.timeout_seconds = timeout_seconds
def should_stop(self) -> bool:
# Check if the associated indexing attempt has been cancelled
# TODO: Pass index_attempt_id to the callback and check cancellation using the db
return bool(self.redis_connector.stop.fenced)
if bool(self.redis_connector.stop.fenced):
return True
# Check if the task has exceeded its timeout
# NOTE: Celery's soft_time_limit does not work with thread pools,
# so we must enforce timeouts internally.
if self.timeout_seconds is not None:
elapsed = time.monotonic() - self.start_monotonic
if elapsed > self.timeout_seconds:
logger.warning(
f"IndexingCallback Docprocessing - task timeout exceeded: "
f"elapsed={elapsed:.0f}s timeout={self.timeout_seconds}s "
f"cc_pair={self.redis_connector.cc_pair_id}"
)
return True
return False
def progress(self, tag: str, amount: int) -> None: # noqa: ARG002
"""Amount isn't used yet."""

View File

@@ -0,0 +1,10 @@
"""Celery tasks for hierarchy fetching."""
from onyx.background.celery.tasks.hierarchyfetching.tasks import ( # noqa: F401
check_for_hierarchy_fetching,
)
from onyx.background.celery.tasks.hierarchyfetching.tasks import ( # noqa: F401
connector_hierarchy_fetching_task,
)
__all__ = ["check_for_hierarchy_fetching", "connector_hierarchy_fetching_task"]

View File

@@ -146,14 +146,26 @@ def _collect_queue_metrics(redis_celery: Redis) -> list[Metric]:
"""Collect metrics about queue lengths for different Celery queues"""
metrics = []
queue_mappings = {
"celery_queue_length": "celery",
"docprocessing_queue_length": "docprocessing",
"sync_queue_length": "sync",
"deletion_queue_length": "deletion",
"pruning_queue_length": "pruning",
"celery_queue_length": OnyxCeleryQueues.PRIMARY,
"docprocessing_queue_length": OnyxCeleryQueues.DOCPROCESSING,
"docfetching_queue_length": OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,
"sync_queue_length": OnyxCeleryQueues.VESPA_METADATA_SYNC,
"deletion_queue_length": OnyxCeleryQueues.CONNECTOR_DELETION,
"pruning_queue_length": OnyxCeleryQueues.CONNECTOR_PRUNING,
"permissions_sync_queue_length": OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
"external_group_sync_queue_length": OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC,
"permissions_upsert_queue_length": OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT,
"hierarchy_fetching_queue_length": OnyxCeleryQueues.CONNECTOR_HIERARCHY_FETCHING,
"llm_model_update_queue_length": OnyxCeleryQueues.LLM_MODEL_UPDATE,
"checkpoint_cleanup_queue_length": OnyxCeleryQueues.CHECKPOINT_CLEANUP,
"index_attempt_cleanup_queue_length": OnyxCeleryQueues.INDEX_ATTEMPT_CLEANUP,
"csv_generation_queue_length": OnyxCeleryQueues.CSV_GENERATION,
"user_file_processing_queue_length": OnyxCeleryQueues.USER_FILE_PROCESSING,
"user_file_project_sync_queue_length": OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
"user_file_delete_queue_length": OnyxCeleryQueues.USER_FILE_DELETE,
"monitoring_queue_length": OnyxCeleryQueues.MONITORING,
"sandbox_queue_length": OnyxCeleryQueues.SANDBOX,
"opensearch_migration_queue_length": OnyxCeleryQueues.OPENSEARCH_MIGRATION,
}
for name, queue in queue_mappings.items():
@@ -881,7 +893,7 @@ def monitor_celery_queues_helper(
"""A task to monitor all celery queue lengths."""
r_celery = task.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length("celery", r_celery)
n_celery = celery_get_queue_length(OnyxCeleryQueues.PRIMARY, r_celery)
n_docfetching = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
)
@@ -908,6 +920,26 @@ def monitor_celery_queues_helper(
n_permissions_upsert = celery_get_queue_length(
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
)
n_hierarchy_fetching = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_HIERARCHY_FETCHING, r_celery
)
n_llm_model_update = celery_get_queue_length(
OnyxCeleryQueues.LLM_MODEL_UPDATE, r_celery
)
n_checkpoint_cleanup = celery_get_queue_length(
OnyxCeleryQueues.CHECKPOINT_CLEANUP, r_celery
)
n_index_attempt_cleanup = celery_get_queue_length(
OnyxCeleryQueues.INDEX_ATTEMPT_CLEANUP, r_celery
)
n_csv_generation = celery_get_queue_length(
OnyxCeleryQueues.CSV_GENERATION, r_celery
)
n_monitoring = celery_get_queue_length(OnyxCeleryQueues.MONITORING, r_celery)
n_sandbox = celery_get_queue_length(OnyxCeleryQueues.SANDBOX, r_celery)
n_opensearch_migration = celery_get_queue_length(
OnyxCeleryQueues.OPENSEARCH_MIGRATION, r_celery
)
n_docfetching_prefetched = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
@@ -931,6 +963,14 @@ def monitor_celery_queues_helper(
f"permissions_sync={n_permissions_sync} "
f"external_group_sync={n_external_group_sync} "
f"permissions_upsert={n_permissions_upsert} "
f"hierarchy_fetching={n_hierarchy_fetching} "
f"llm_model_update={n_llm_model_update} "
f"checkpoint_cleanup={n_checkpoint_cleanup} "
f"index_attempt_cleanup={n_index_attempt_cleanup} "
f"csv_generation={n_csv_generation} "
f"monitoring={n_monitoring} "
f"sandbox={n_sandbox} "
f"opensearch_migration={n_opensearch_migration} "
)

View File

@@ -0,0 +1,8 @@
"""Celery tasks for connector pruning."""
from onyx.background.celery.tasks.pruning.tasks import check_for_pruning # noqa: F401
from onyx.background.celery.tasks.pruning.tasks import ( # noqa: F401
connector_pruning_generator_task,
)
__all__ = ["check_for_pruning", "connector_pruning_generator_task"]

View File

@@ -523,6 +523,7 @@ def connector_pruning_generator_task(
redis_connector,
lock,
r,
timeout_seconds=JOB_TIMEOUT,
)
# a list of docs in the source

View File

@@ -3,34 +3,26 @@ from collections.abc import Callable
from typing import cast
from uuid import UUID
from fastapi import HTTPException
from fastapi.datastructures import Headers
from sqlalchemy.orm import Session
from onyx.auth.users import is_user_admin
from onyx.chat.models import ChatHistoryResult
from onyx.chat.models import ChatLoadedFile
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import FileToolMetadata
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import ToolCallSimple
from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.configs.constants import MessageType
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.chat import create_chat_session
from onyx.db.chat import get_chat_messages_by_session
from onyx.db.chat import get_or_create_root_message
from onyx.db.kg_config import get_kg_config_settings
from onyx.db.kg_config import is_kg_config_settings_enabled_valid
from onyx.db.llm import fetch_existing_doc_sets
from onyx.db.llm import fetch_existing_tools
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
from onyx.db.models import Persona
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.db.models import Tool
from onyx.db.models import User
from onyx.db.models import UserFile
from onyx.db.projects import check_project_ownership
from onyx.file_processing.extract_file_text import extract_file_text
@@ -47,9 +39,6 @@ from onyx.prompts.tool_prompts import TOOL_CALL_FAILURE_PROMPT
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.tools.models import ToolCallKickoff
from onyx.tools.tool_implementations.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.timing import log_function_time
@@ -278,70 +267,6 @@ def extract_headers(
return extracted_headers
def create_temporary_persona(
persona_config: PersonaOverrideConfig, db_session: Session, user: User
) -> Persona:
if not is_user_admin(user):
raise HTTPException(
status_code=403,
detail="User is not authorized to create a persona in one shot queries",
)
"""Create a temporary Persona object from the provided configuration."""
persona = Persona(
name=persona_config.name,
description=persona_config.description,
num_chunks=persona_config.num_chunks,
llm_relevance_filter=persona_config.llm_relevance_filter,
llm_filter_extraction=persona_config.llm_filter_extraction,
recency_bias=RecencyBiasSetting.BASE_DECAY,
llm_model_provider_override=persona_config.llm_model_provider_override,
llm_model_version_override=persona_config.llm_model_version_override,
)
if persona_config.prompts:
# Use the first prompt from the override config for embedded prompt fields
first_prompt = persona_config.prompts[0]
persona.system_prompt = first_prompt.system_prompt
persona.task_prompt = first_prompt.task_prompt
persona.datetime_aware = first_prompt.datetime_aware
persona.tools = []
if persona_config.custom_tools_openapi:
from onyx.chat.emitter import get_default_emitter
for schema in persona_config.custom_tools_openapi:
tools = cast(
list[Tool],
build_custom_tools_from_openapi_schema_and_headers(
tool_id=0, # dummy tool id
openapi_schema=schema,
emitter=get_default_emitter(),
),
)
persona.tools.extend(tools)
if persona_config.tools:
tool_ids = [tool.id for tool in persona_config.tools]
persona.tools.extend(
fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
)
if persona_config.tool_ids:
persona.tools.extend(
fetch_existing_tools(
db_session=db_session, tool_ids=persona_config.tool_ids
)
)
fetched_docs = fetch_existing_doc_sets(
db_session=db_session, doc_ids=persona_config.document_set_ids
)
persona.document_sets = fetched_docs
return persona
def process_kg_commands(
message: str, persona_name: str, tenant_id: str, db_session: Session # noqa: ARG001
) -> None:
@@ -688,28 +613,34 @@ def convert_chat_history(
def get_custom_agent_prompt(persona: Persona, chat_session: ChatSession) -> str | None:
"""Get the custom agent prompt from persona or project instructions.
"""Get the custom agent prompt from persona or project instructions. If it's replacing the base system prompt,
it does not count as a custom agent prompt (logic exists later also to drop it in this case).
Chat Sessions in Projects that are using a custom agent will retain the custom agent prompt.
Priority: persona.system_prompt > chat_session.project.instructions > None
Priority: persona.system_prompt (if not default Agent) > chat_session.project.instructions
# NOTE: Logic elsewhere allows saving empty strings for potentially other purposes but for constructing the prompts
# we never want to return an empty string for a prompt so it's translated into an explicit None.
Args:
persona: The Persona object
chat_session: The ChatSession object
Returns:
The custom agent prompt string, or None if neither persona nor project has one
The prompt to use for the custom Agent part of the prompt.
"""
# Not considered a custom agent if it's the default behavior persona
if persona.id == DEFAULT_PERSONA_ID:
return None
# If using a custom Agent, always respect its prompt, even if in a Project, and even if it's an empty custom prompt.
if persona.id != DEFAULT_PERSONA_ID:
# Logic exists later also to drop it in this case but this is strictly correct anyhow.
if persona.replace_base_system_prompt:
return None
return persona.system_prompt or None
if persona.system_prompt:
return persona.system_prompt
elif chat_session.project and chat_session.project.instructions:
# If in a project and using the default Agent, respect the project instructions.
if chat_session.project and chat_session.project.instructions:
return chat_session.project.instructions
else:
return None
return None
def is_last_assistant_message_clarification(chat_history: list[ChatMessage]) -> bool:

View File

@@ -38,7 +38,6 @@ from onyx.llm.constants import LlmProviderNames
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMUserIdentity
from onyx.llm.interfaces import ToolChoiceOptions
from onyx.llm.utils import model_needs_formatting_reenabled
from onyx.prompts.chat_prompts import IMAGE_GEN_REMINDER
from onyx.prompts.chat_prompts import OPEN_URL_REMINDER
from onyx.server.query_and_chat.placement import Placement
@@ -594,6 +593,7 @@ def run_llm_loop(
reasoning_cycles = 0
for llm_cycle_count in range(MAX_LLM_CYCLES):
# Handling tool calls based on cycle count and past cycle conditions
out_of_cycles = llm_cycle_count == MAX_LLM_CYCLES - 1
if forced_tool_id:
# Needs to be just the single one because the "required" currently doesn't have a specified tool, just a binary
@@ -615,6 +615,7 @@ def run_llm_loop(
tool_choice = ToolChoiceOptions.AUTO
final_tools = tools
# Handling the system prompt and custom agent prompt
# The section below calculates the available tokens for history a bit more accurately
# now that project files are loaded in.
if persona and persona.replace_base_system_prompt:
@@ -632,12 +633,14 @@ def run_llm_loop(
else:
# If it's an empty string, we assume the user does not want to include it as an empty System message
if default_base_system_prompt:
open_ai_formatting_enabled = model_needs_formatting_reenabled(
llm.config.model_name
)
prompt_memory_context = (
user_memory_context if inject_memories_in_prompt else None
user_memory_context
if inject_memories_in_prompt
else (
user_memory_context.without_memories()
if user_memory_context
else None
)
)
system_prompt_str = build_system_prompt(
base_system_prompt=default_base_system_prompt,
@@ -646,7 +649,6 @@ def run_llm_loop(
tools=tools,
should_cite_documents=should_cite_documents
or always_cite_documents,
open_ai_formatting_enabled=open_ai_formatting_enabled,
)
system_prompt = ChatMessageSimple(
message=system_prompt_str,

View File

@@ -36,6 +36,8 @@ from onyx.llm.models import ToolCall
from onyx.llm.models import ToolMessage
from onyx.llm.models import UserMessage
from onyx.llm.prompt_cache.processor import process_with_prompt_cache
from onyx.llm.utils import model_needs_formatting_reenabled
from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
from onyx.prompts.constants import SYSTEM_REMINDER_TAG_CLOSE
from onyx.prompts.constants import SYSTEM_REMINDER_TAG_OPEN
from onyx.server.query_and_chat.placement import Placement
@@ -623,6 +625,17 @@ def translate_history_to_llm_format(
f"Unknown message type {msg.message_type} in history. Skipping message."
)
# Apply model-specific formatting when translating to LLM format (e.g. OpenAI
# reasoning models need CODE_BLOCK_MARKDOWN prefix for correct markdown generation)
if model_needs_formatting_reenabled(llm_config.model_name):
for i, m in enumerate(messages):
if isinstance(m, SystemMessage):
messages[i] = SystemMessage(
role="system",
content=CODE_BLOCK_MARKDOWN + m.content,
)
break
# prompt caching: rely on should_cache in ChatMessageSimple to
# pick the split point for the cacheable prefix and suffix
if last_cacheable_msg_idx != -1:

View File

@@ -1,17 +1,13 @@
from collections.abc import Callable
from collections.abc import Iterator
from enum import Enum
from typing import Any
from uuid import UUID
from pydantic import BaseModel
from pydantic import Field
from onyx.configs.constants import MessageType
from onyx.context.search.enums import SearchType
from onyx.context.search.models import SearchDoc
from onyx.file_store.models import FileDescriptor
from onyx.file_store.models import InMemoryChatFile
from onyx.server.query_and_chat.models import MessageResponseIDInfo
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import GeneratedImage
from onyx.server.query_and_chat.streaming_models import Packet
@@ -20,54 +16,6 @@ from onyx.tools.models import ToolCallKickoff
from onyx.tools.tool_implementations.custom.base_tool_types import ToolResultType
class StreamStopReason(Enum):
CONTEXT_LENGTH = "context_length"
CANCELLED = "cancelled"
FINISHED = "finished"
class StreamType(Enum):
SUB_QUESTIONS = "sub_questions"
SUB_ANSWER = "sub_answer"
MAIN_ANSWER = "main_answer"
class StreamStopInfo(BaseModel):
stop_reason: StreamStopReason
stream_type: StreamType = StreamType.MAIN_ANSWER
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
data = super().model_dump(mode="json", *args, **kwargs) # type: ignore
data["stop_reason"] = self.stop_reason.name
return data
class UserKnowledgeFilePacket(BaseModel):
user_files: list[FileDescriptor]
class RelevanceAnalysis(BaseModel):
relevant: bool
content: str | None = None
class DocumentRelevance(BaseModel):
"""Contains all relevance information for a given search"""
relevance_summaries: dict[str, RelevanceAnalysis]
class OnyxAnswerPiece(BaseModel):
# A small piece of a complete answer. Used for streaming back answers.
answer_piece: str | None # if None, specifies the end of an Answer
class MessageResponseIDInfo(BaseModel):
user_message_id: int | None
reserved_assistant_message_id: int
class StreamingError(BaseModel):
error: str
stack_trace: str | None = None
@@ -78,23 +26,11 @@ class StreamingError(BaseModel):
details: dict | None = None # Additional context (tool name, model name, etc.)
class OnyxAnswer(BaseModel):
answer: str | None
class FileChatDisplay(BaseModel):
file_ids: list[str]
class CustomToolResponse(BaseModel):
response: ToolResultType
tool_name: str
class ToolConfig(BaseModel):
id: int
class ProjectSearchConfig(BaseModel):
"""Configuration for search tool availability in project context."""
@@ -102,83 +38,15 @@ class ProjectSearchConfig(BaseModel):
disable_forced_tool: bool
class PromptOverrideConfig(BaseModel):
name: str
description: str = ""
system_prompt: str
task_prompt: str = ""
datetime_aware: bool = True
include_citations: bool = True
class PersonaOverrideConfig(BaseModel):
name: str
description: str
search_type: SearchType = SearchType.SEMANTIC
num_chunks: float | None = None
llm_relevance_filter: bool = False
llm_filter_extraction: bool = False
llm_model_provider_override: str | None = None
llm_model_version_override: str | None = None
prompts: list[PromptOverrideConfig] = Field(default_factory=list)
# Note: prompt_ids removed - prompts are now embedded in personas
document_set_ids: list[int] = Field(default_factory=list)
tools: list[ToolConfig] = Field(default_factory=list)
tool_ids: list[int] = Field(default_factory=list)
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
AnswerQuestionPossibleReturn = (
OnyxAnswerPiece
| CitationInfo
| FileChatDisplay
| CustomToolResponse
| StreamingError
| StreamStopInfo
)
class CreateChatSessionID(BaseModel):
chat_session_id: UUID
AnswerQuestionStreamReturn = Iterator[AnswerQuestionPossibleReturn]
class LLMMetricsContainer(BaseModel):
prompt_tokens: int
response_tokens: int
StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn]
AnswerStreamPart = (
Packet
| StreamStopInfo
| MessageResponseIDInfo
| StreamingError
| UserKnowledgeFilePacket
| CreateChatSessionID
)
AnswerStreamPart = Packet | MessageResponseIDInfo | StreamingError | CreateChatSessionID
AnswerStream = Iterator[AnswerStreamPart]
class ChatBasicResponse(BaseModel):
# This is built piece by piece, any of these can be None as the flow could break
answer: str
answer_citationless: str
top_documents: list[SearchDoc]
error_msg: str | None
message_id: int
citation_info: list[CitationInfo]
class ToolCallResponse(BaseModel):
"""Tool call with full details for non-streaming response."""
@@ -191,8 +59,23 @@ class ToolCallResponse(BaseModel):
pre_reasoning: str | None = None
class ChatBasicResponse(BaseModel):
# This is built piece by piece, any of these can be None as the flow could break
answer: str
answer_citationless: str
top_documents: list[SearchDoc]
error_msg: str | None
message_id: int
citation_info: list[CitationInfo]
class ChatFullResponse(BaseModel):
"""Complete non-streaming response with all available data."""
"""Complete non-streaming response with all available data.
NOTE: This model is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
"""
# Core response fields
answer: str

View File

@@ -37,7 +37,6 @@ from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import CreateChatSessionID
from onyx.chat.models import ExtractedProjectFiles
from onyx.chat.models import FileToolMetadata
from onyx.chat.models import MessageResponseIDInfo
from onyx.chat.models import ProjectFileMetadata
from onyx.chat.models import ProjectSearchConfig
from onyx.chat.models import StreamingError
@@ -81,8 +80,7 @@ from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.onyxbot.slack.models import SlackContext
from onyx.redis.redis_pool import get_redis_client
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.server.query_and_chat.models import OptionalSearchSetting
from onyx.server.query_and_chat.models import MessageResponseIDInfo
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
@@ -615,16 +613,27 @@ def handle_stream_message_objects(
user_memory_context = get_memories(user, db_session)
# This is the custom prompt which may come from the Agent or Project. We fetch it earlier because the inner loop
# (run_llm_loop and run_deep_research_llm_loop) should not need to be aware of the Chat History in the DB form processed
# here, however we need this early for token reservation.
custom_agent_prompt = get_custom_agent_prompt(persona, chat_session)
# When use_memories is disabled, don't inject memories into the prompt
# or count them in token reservation, but still pass the full context
# When use_memories is disabled, strip memories from the prompt context
# but keep user info/preferences. The full context is still passed
# to the LLM loop for memory tool persistence.
prompt_memory_context = user_memory_context if user.use_memories else None
prompt_memory_context = (
user_memory_context
if user.use_memories
else user_memory_context.without_memories()
)
max_reserved_system_prompt_tokens_str = (persona.system_prompt or "") + (
custom_agent_prompt or ""
)
reserved_token_count = calculate_reserved_tokens(
db_session=db_session,
persona_system_prompt=custom_agent_prompt or "",
persona_system_prompt=max_reserved_system_prompt_tokens_str,
token_counter=token_counter,
files=new_msg_req.file_descriptors,
user_memory_context=prompt_memory_context,
@@ -1016,68 +1025,6 @@ def llm_loop_completion_handle(
)
def stream_chat_message_objects(
new_msg_req: CreateChatMessageRequest,
user: User,
db_session: Session,
# if specified, uses the last user message and does not create a new user message based
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
bypass_acl: bool = False,
# Additional context that should be included in the chat history, for example:
# Slack threads where the conversation cannot be represented by a chain of User/Assistant
# messages. Both of the below are used for Slack
# NOTE: is not stored in the database, only passed in to the LLM as context
additional_context: str | None = None,
# Slack context for federated Slack search
slack_context: SlackContext | None = None,
) -> AnswerStream:
forced_tool_id = (
new_msg_req.forced_tool_ids[0] if new_msg_req.forced_tool_ids else None
)
if (
new_msg_req.retrieval_options
and new_msg_req.retrieval_options.run_search == OptionalSearchSetting.ALWAYS
):
all_tools = get_tools(db_session)
search_tool_id = next(
(tool.id for tool in all_tools if tool.in_code_tool_id == SEARCH_TOOL_ID),
None,
)
forced_tool_id = search_tool_id
translated_new_msg_req = SendMessageRequest(
message=new_msg_req.message,
llm_override=new_msg_req.llm_override,
mock_llm_response=new_msg_req.mock_llm_response,
allowed_tool_ids=new_msg_req.allowed_tool_ids,
forced_tool_id=forced_tool_id,
file_descriptors=new_msg_req.file_descriptors,
internal_search_filters=(
new_msg_req.retrieval_options.filters
if new_msg_req.retrieval_options
else None
),
deep_research=new_msg_req.deep_research,
parent_message_id=new_msg_req.parent_message_id,
chat_session_id=new_msg_req.chat_session_id,
origin=new_msg_req.origin,
include_citations=new_msg_req.include_citations,
)
return handle_stream_message_objects(
new_msg_req=translated_new_msg_req,
user=user,
db_session=db_session,
litellm_additional_headers=litellm_additional_headers,
custom_tool_additional_headers=custom_tool_additional_headers,
bypass_acl=bypass_acl,
additional_context=additional_context,
slack_context=slack_context,
)
def remove_answer_citations(answer: str) -> str:
pattern = r"\s*\[\[\d+\]\]\(http[s]?://[^\s]+\)"

View File

@@ -9,13 +9,13 @@ from onyx.db.persona import get_default_behavior_persona
from onyx.db.user_file import calculate_user_files_token_count
from onyx.file_store.models import FileDescriptor
from onyx.prompts.chat_prompts import CITATION_REMINDER
from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
from onyx.prompts.chat_prompts import DEFAULT_SYSTEM_PROMPT
from onyx.prompts.chat_prompts import LAST_CYCLE_CITATION_REMINDER
from onyx.prompts.chat_prompts import REQUIRE_CITATION_GUIDANCE
from onyx.prompts.prompt_utils import get_company_context
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
from onyx.prompts.prompt_utils import replace_citation_guidance_tag
from onyx.prompts.prompt_utils import replace_reminder_tag
from onyx.prompts.tool_prompts import GENERATE_IMAGE_GUIDANCE
from onyx.prompts.tool_prompts import INTERNAL_SEARCH_GUIDANCE
from onyx.prompts.tool_prompts import MEMORY_GUIDANCE
@@ -25,7 +25,12 @@ from onyx.prompts.tool_prompts import TOOL_DESCRIPTION_SEARCH_GUIDANCE
from onyx.prompts.tool_prompts import TOOL_SECTION_HEADER
from onyx.prompts.tool_prompts import WEB_SEARCH_GUIDANCE
from onyx.prompts.tool_prompts import WEB_SEARCH_SITE_DISABLED_GUIDANCE
from onyx.prompts.user_info import BASIC_INFORMATION_PROMPT
from onyx.prompts.user_info import TEAM_INFORMATION_PROMPT
from onyx.prompts.user_info import USER_INFORMATION_HEADER
from onyx.prompts.user_info import USER_MEMORIES_PROMPT
from onyx.prompts.user_info import USER_PREFERENCES_PROMPT
from onyx.prompts.user_info import USER_ROLE_PROMPT
from onyx.tools.interface import Tool
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
@@ -131,6 +136,59 @@ def build_reminder_message(
return reminder if reminder else None
def _build_user_information_section(
user_memory_context: UserMemoryContext | None,
company_context: str | None,
) -> str:
"""Build the complete '# User Information' section with all sub-sections
in the correct order: Basic Info → Team Info → Preferences → Memories."""
sections: list[str] = []
if user_memory_context:
ctx = user_memory_context
has_basic_info = ctx.user_info.name or ctx.user_info.email or ctx.user_info.role
if has_basic_info:
role_line = (
USER_ROLE_PROMPT.format(user_role=ctx.user_info.role).strip()
if ctx.user_info.role
else ""
)
if role_line:
role_line = "\n" + role_line
sections.append(
BASIC_INFORMATION_PROMPT.format(
user_name=ctx.user_info.name or "",
user_email=ctx.user_info.email or "",
user_role=role_line,
)
)
if company_context:
sections.append(
TEAM_INFORMATION_PROMPT.format(team_information=company_context.strip())
)
if user_memory_context:
ctx = user_memory_context
if ctx.user_preferences:
sections.append(
USER_PREFERENCES_PROMPT.format(user_preferences=ctx.user_preferences)
)
if ctx.memories:
formatted_memories = "\n".join(f"- {memory}" for memory in ctx.memories)
sections.append(
USER_MEMORIES_PROMPT.format(user_memories=formatted_memories)
)
if not sections:
return ""
return USER_INFORMATION_HEADER + "".join(sections)
def build_system_prompt(
base_system_prompt: str,
datetime_aware: bool = False,
@@ -138,18 +196,12 @@ def build_system_prompt(
tools: Sequence[Tool] | None = None,
should_cite_documents: bool = False,
include_all_guidance: bool = False,
open_ai_formatting_enabled: bool = False,
) -> str:
"""Should only be called with the default behavior system prompt.
If the user has replaced the default behavior prompt with their custom agent prompt, do not call this function.
"""
system_prompt = handle_onyx_date_awareness(base_system_prompt, datetime_aware)
# See https://simonwillison.net/tags/markdown/ for context on why this is needed
# for OpenAI reasoning models to have correct markdown generation
if open_ai_formatting_enabled:
system_prompt = CODE_BLOCK_MARKDOWN + system_prompt
# Replace citation guidance placeholder if present
system_prompt, should_append_citation_guidance = replace_citation_guidance_tag(
system_prompt,
@@ -157,16 +209,14 @@ def build_system_prompt(
include_all_guidance=include_all_guidance,
)
# Replace reminder tag placeholder if present
system_prompt = replace_reminder_tag(system_prompt)
company_context = get_company_context()
formatted_user_context = (
user_memory_context.as_formatted_prompt() if user_memory_context else ""
user_info_section = _build_user_information_section(
user_memory_context, company_context
)
if company_context or formatted_user_context:
system_prompt += USER_INFORMATION_HEADER
if company_context:
system_prompt += company_context
if formatted_user_context:
system_prompt += formatted_user_context
system_prompt += user_info_section
# Append citation guidance after company context if placeholder was not present
# This maintains backward compatibility and ensures citations are always enforced when needed

View File

@@ -977,6 +977,7 @@ API_KEY_HASH_ROUNDS = (
# MCP Server Configs
#####
MCP_SERVER_ENABLED = os.environ.get("MCP_SERVER_ENABLED", "").lower() == "true"
MCP_SERVER_HOST = os.environ.get("MCP_SERVER_HOST", "0.0.0.0")
MCP_SERVER_PORT = int(os.environ.get("MCP_SERVER_PORT") or 8090)
# CORS origins for MCP clients (comma-separated)

View File

@@ -1,4 +1,5 @@
import contextvars
import re
from concurrent.futures import as_completed
from concurrent.futures import Future
from concurrent.futures import ThreadPoolExecutor
@@ -14,6 +15,7 @@ from retry import retry
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
@@ -62,11 +64,44 @@ class AirtableClientNotSetUpError(PermissionError):
super().__init__("Airtable Client is not set up, was load_credentials called?")
# Matches URLs like https://airtable.com/appXXX/tblYYY/viwZZZ?blocks=hide
# Captures: base_id (appXXX), table_id (tblYYY), and optionally view_id (viwZZZ)
_AIRTABLE_URL_PATTERN = re.compile(
r"https?://airtable\.com/(app[A-Za-z0-9]+)/(tbl[A-Za-z0-9]+)(?:/(viw[A-Za-z0-9]+))?",
)
def parse_airtable_url(
url: str,
) -> tuple[str, str, str | None]:
"""Parse an Airtable URL into (base_id, table_id, view_id).
Accepts URLs like:
https://airtable.com/appXXX/tblYYY
https://airtable.com/appXXX/tblYYY/viwZZZ
https://airtable.com/appXXX/tblYYY/viwZZZ?blocks=hide
Returns:
(base_id, table_id, view_id or None)
Raises:
ValueError if the URL doesn't match the expected format.
"""
match = _AIRTABLE_URL_PATTERN.search(url.strip())
if not match:
raise ValueError(
f"Could not parse Airtable URL: '{url}'. "
"Expected format: https://airtable.com/appXXX/tblYYY[/viwZZZ]"
)
return match.group(1), match.group(2), match.group(3)
class AirtableConnector(LoadConnector):
def __init__(
self,
base_id: str,
table_name_or_id: str,
base_id: str = "",
table_name_or_id: str = "",
airtable_url: str = "",
treat_all_non_attachment_fields_as_metadata: bool = False,
view_id: str | None = None,
share_id: str | None = None,
@@ -75,16 +110,33 @@ class AirtableConnector(LoadConnector):
"""Initialize an AirtableConnector.
Args:
base_id: The ID of the Airtable base to connect to
table_name_or_id: The name or ID of the table to index
base_id: The ID of the Airtable base (not required when airtable_url is set)
table_name_or_id: The name or ID of the table (not required when airtable_url is set)
airtable_url: An Airtable URL to parse base_id, table_id, and view_id from.
Overrides base_id, table_name_or_id, and view_id if provided.
treat_all_non_attachment_fields_as_metadata: If True, all fields except attachments will be treated as metadata.
If False, only fields with types in DEFAULT_METADATA_FIELD_TYPES will be treated as metadata.
view_id: Optional ID of a specific view to use
share_id: Optional ID of a "share" to use for generating record URLs (https://airtable.com/developers/web/api/list-shares)
share_id: Optional ID of a "share" to use for generating record URLs
batch_size: Number of records to process in each batch
Mode is auto-detected: if a specific table is identified (via URL or
base_id + table_name_or_id), the connector indexes that single table.
Otherwise, it discovers and indexes all accessible bases and tables.
"""
# If a URL is provided, parse it to extract base_id, table_id, and view_id
if airtable_url:
parsed_base_id, parsed_table_id, parsed_view_id = parse_airtable_url(
airtable_url
)
base_id = parsed_base_id
table_name_or_id = parsed_table_id
if parsed_view_id:
view_id = parsed_view_id
self.base_id = base_id
self.table_name_or_id = table_name_or_id
self.index_all = not (base_id and table_name_or_id)
self.view_id = view_id
self.share_id = share_id
self.batch_size = batch_size
@@ -103,6 +155,33 @@ class AirtableConnector(LoadConnector):
raise AirtableClientNotSetUpError()
return self._airtable_client
def validate_connector_settings(self) -> None:
if self.index_all:
try:
bases = self.airtable_client.bases()
if not bases:
raise ConnectorValidationError(
"No bases found. Ensure your API token has access to at least one base."
)
except ConnectorValidationError:
raise
except Exception as e:
raise ConnectorValidationError(f"Failed to list Airtable bases: {e}")
else:
if not self.base_id or not self.table_name_or_id:
raise ConnectorValidationError(
"A valid Airtable URL or base_id and table_name_or_id are required "
"when not using index_all mode."
)
try:
table = self.airtable_client.table(self.base_id, self.table_name_or_id)
table.schema()
except Exception as e:
raise ConnectorValidationError(
f"Failed to access table '{self.table_name_or_id}' "
f"in base '{self.base_id}': {e}"
)
@classmethod
def _get_record_url(
cls,
@@ -267,6 +346,7 @@ class AirtableConnector(LoadConnector):
field_name: str,
field_info: Any,
field_type: str,
base_id: str,
table_id: str,
view_id: str | None,
record_id: str,
@@ -291,7 +371,7 @@ class AirtableConnector(LoadConnector):
field_name=field_name,
field_info=field_info,
field_type=field_type,
base_id=self.base_id,
base_id=base_id,
table_id=table_id,
view_id=view_id,
record_id=record_id,
@@ -326,15 +406,17 @@ class AirtableConnector(LoadConnector):
record: RecordDict,
table_schema: TableSchema,
primary_field_name: str | None,
base_id: str,
base_name: str | None = None,
) -> Document | None:
"""Process a single Airtable record into a Document.
Args:
record: The Airtable record to process
table_schema: Schema information for the table
table_name: Name of the table
table_id: ID of the table
primary_field_name: Name of the primary field, if any
base_id: The ID of the base this record belongs to
base_name: The name of the base (used in semantic ID for index_all mode)
Returns:
Document object representing the record
@@ -367,6 +449,7 @@ class AirtableConnector(LoadConnector):
field_name=field_name,
field_info=field_val,
field_type=field_type,
base_id=base_id,
table_id=table_id,
view_id=view_id,
record_id=record_id,
@@ -379,11 +462,26 @@ class AirtableConnector(LoadConnector):
logger.warning(f"No sections found for record {record_id}")
return None
semantic_id = (
f"{table_name}: {primary_field_value}"
if primary_field_value
else table_name
)
# Include base name in semantic ID only in index_all mode
if self.index_all and base_name:
semantic_id = (
f"{base_name} > {table_name}: {primary_field_value}"
if primary_field_value
else f"{base_name} > {table_name}"
)
else:
semantic_id = (
f"{table_name}: {primary_field_value}"
if primary_field_value
else table_name
)
# Build hierarchy source_path for Craft file system subdirectory structure.
# This creates: airtable/{base_name}/{table_name}/record.json
source_path: list[str] = []
if base_name:
source_path.append(base_name)
source_path.append(table_name)
return Document(
id=f"airtable__{record_id}",
@@ -391,19 +489,39 @@ class AirtableConnector(LoadConnector):
source=DocumentSource.AIRTABLE,
semantic_identifier=semantic_id,
metadata=metadata,
doc_metadata={
"hierarchy": {
"source_path": source_path,
"base_id": base_id,
"table_id": table_id,
"table_name": table_name,
**({"base_name": base_name} if base_name else {}),
}
},
)
def load_from_state(self) -> GenerateDocumentsOutput:
"""
Fetch all records from the table.
def _resolve_base_name(self, base_id: str) -> str | None:
"""Try to resolve a human-readable base name from the API."""
try:
for base_info in self.airtable_client.bases():
if base_info.id == base_id:
return base_info.name
except Exception:
logger.debug(f"Could not resolve base name for {base_id}")
return None
NOTE: Airtable does not support filtering by time updated, so
we have to fetch all records every time.
"""
if not self.airtable_client:
raise AirtableClientNotSetUpError()
def _index_table(
self,
base_id: str,
table_name_or_id: str,
base_name: str | None = None,
) -> GenerateDocumentsOutput:
"""Index all records from a single table. Yields batches of Documents."""
# Resolve base name for hierarchy if not provided
if base_name is None:
base_name = self._resolve_base_name(base_id)
table = self.airtable_client.table(self.base_id, self.table_name_or_id)
table = self.airtable_client.table(base_id, table_name_or_id)
records = table.all()
table_schema = table.schema()
@@ -415,21 +533,25 @@ class AirtableConnector(LoadConnector):
primary_field_name = field.name
break
logger.info(f"Starting to process Airtable records for {table.name}.")
logger.info(
f"Processing {len(records)} records from table "
f"'{table_schema.name}' in base '{base_name or base_id}'."
)
if not records:
return
# Process records in parallel batches using ThreadPoolExecutor
PARALLEL_BATCH_SIZE = 8
max_workers = min(PARALLEL_BATCH_SIZE, len(records))
record_documents: list[Document | HierarchyNode] = []
# Process records in batches
for i in range(0, len(records), PARALLEL_BATCH_SIZE):
batch_records = records[i : i + PARALLEL_BATCH_SIZE]
record_documents = []
record_documents: list[Document | HierarchyNode] = []
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit batch tasks
future_to_record: dict[Future, RecordDict] = {}
future_to_record: dict[Future[Document | None], RecordDict] = {}
for record in batch_records:
# Capture the current context so that the thread gets the current tenant ID
current_context = contextvars.copy_context()
@@ -440,6 +562,8 @@ class AirtableConnector(LoadConnector):
record=record,
table_schema=table_schema,
primary_field_name=primary_field_name,
base_id=base_id,
base_name=base_name,
)
] = record
@@ -454,9 +578,58 @@ class AirtableConnector(LoadConnector):
logger.exception(f"Failed to process record {record['id']}")
raise e
yield record_documents
record_documents = []
if record_documents:
yield record_documents
# Yield any remaining records
if record_documents:
yield record_documents
def load_from_state(self) -> GenerateDocumentsOutput:
"""
Fetch all records from one or all tables.
NOTE: Airtable does not support filtering by time updated, so
we have to fetch all records every time.
"""
if not self.airtable_client:
raise AirtableClientNotSetUpError()
if self.index_all:
yield from self._load_all()
else:
yield from self._index_table(
base_id=self.base_id,
table_name_or_id=self.table_name_or_id,
)
def _load_all(self) -> GenerateDocumentsOutput:
"""Discover all bases and tables, then index everything."""
bases = self.airtable_client.bases()
logger.info(f"Discovered {len(bases)} Airtable base(s).")
for base_info in bases:
base_id = base_info.id
base_name = base_info.name
logger.info(f"Listing tables for base '{base_name}' ({base_id}).")
try:
base = self.airtable_client.base(base_id)
tables = base.tables()
except Exception:
logger.exception(
f"Failed to list tables for base '{base_name}' ({base_id}), skipping."
)
continue
logger.info(f"Found {len(tables)} table(s) in base '{base_name}'.")
for table in tables:
try:
yield from self._index_table(
base_id=base_id,
table_name_or_id=table.id,
base_name=base_name,
)
except Exception:
logger.exception(
f"Failed to index table '{table.name}' ({table.id}) "
f"in base '{base_name}' ({base_id}), skipping."
)
continue

View File

@@ -79,6 +79,13 @@ SHARED_DOCUMENTS_MAP_REVERSE = {v: k for k, v in SHARED_DOCUMENTS_MAP.items()}
ASPX_EXTENSION = ".aspx"
# The office365 library's ClientContext caches the access token from
# The office365 library's ClientContext caches the access token from its
# first request and never re-invokes the token callback. Microsoft access
# tokens live ~60-75 minutes, so we recreate the cached ClientContext every
# 30 minutes to let MSAL transparently handle token refresh.
_REST_CTX_MAX_AGE_S = 30 * 60
class SiteDescriptor(BaseModel):
"""Data class for storing SharePoint site information.
@@ -104,30 +111,11 @@ class CertificateData(BaseModel):
thumbprint: str
# TODO(Evan): Remove this once we have a proper token refresh mechanism.
def _clear_cached_token(query_obj: ClientQuery) -> bool:
"""Clear the cached access token on the query object's ClientContext so
the next request re-invokes the token callback and gets a fresh token.
The office365 library's AuthenticationContext.with_access_token() caches
the token in ``_cached_token`` and never refreshes it. Setting it to
``None`` forces re-acquisition on the next request.
Returns True if the token was successfully cleared."""
ctx = getattr(query_obj, "context", query_obj)
auth_ctx = getattr(ctx, "authentication_context", None)
if auth_ctx is not None and hasattr(auth_ctx, "_cached_token"):
auth_ctx._cached_token = None
return True
return False
def sleep_and_retry(
query_obj: ClientQuery, method_name: str, max_retries: int = 3
) -> Any:
"""
Execute a SharePoint query with retry logic for rate limiting
and automatic token refresh on 401 Unauthorized.
Execute a SharePoint query with retry logic for rate limiting.
"""
for attempt in range(max_retries + 1):
try:
@@ -135,15 +123,6 @@ def sleep_and_retry(
except ClientRequestException as e:
status = e.response.status_code if e.response is not None else None
# 401 — token expired. Clear the cached token and retry immediately.
if status == 401 and attempt < max_retries:
cleared = _clear_cached_token(query_obj)
logger.warning(
f"Token expired on {method_name}, attempt {attempt + 1}/{max_retries + 1}, "
f"cleared cached token={cleared}, retrying"
)
continue
# 429 / 503 — rate limit or transient error. Back off and retry.
if status in (429, 503) and attempt < max_retries:
logger.warning(
@@ -742,6 +721,10 @@ class SharepointConnector(
self.include_site_pages = include_site_pages
self.include_site_documents = include_site_documents
self.sp_tenant_domain: str | None = None
self._credential_json: dict[str, Any] | None = None
self._cached_rest_ctx: ClientContext | None = None
self._cached_rest_ctx_url: str | None = None
self._cached_rest_ctx_created_at: float = 0.0
def validate_connector_settings(self) -> None:
# Validate that at least one content type is enabled
@@ -767,6 +750,44 @@ class SharepointConnector(
return self._graph_client
def _create_rest_client_context(self, site_url: str) -> ClientContext:
"""Return a ClientContext for SharePoint REST API calls, with caching.
The office365 library's ClientContext caches the access token from its
first request and never re-invokes the token callback. We cache the
context and recreate it when the site URL changes or after
``_REST_CTX_MAX_AGE_S``. On recreation we also call
``load_credentials`` to build a fresh MSAL app with an empty token
cache, guaranteeing a brand-new token from Azure AD."""
elapsed = time.monotonic() - self._cached_rest_ctx_created_at
if (
self._cached_rest_ctx is not None
and self._cached_rest_ctx_url == site_url
and elapsed <= _REST_CTX_MAX_AGE_S
):
return self._cached_rest_ctx
if self._credential_json:
logger.info(
"Rebuilding SharePoint REST client context "
"(elapsed=%.0fs, site_changed=%s)",
elapsed,
self._cached_rest_ctx_url != site_url,
)
self.load_credentials(self._credential_json)
if not self.msal_app or not self.sp_tenant_domain:
raise RuntimeError("MSAL app or tenant domain is not set")
msal_app = self.msal_app
sp_tenant_domain = self.sp_tenant_domain
self._cached_rest_ctx = ClientContext(site_url).with_access_token(
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
)
self._cached_rest_ctx_url = site_url
self._cached_rest_ctx_created_at = time.monotonic()
return self._cached_rest_ctx
@staticmethod
def _strip_share_link_tokens(path: str) -> list[str]:
# Share links often include a token prefix like /:f:/r/ or /:x:/r/.
@@ -1206,21 +1227,6 @@ class SharepointConnector(
# goes over all urls, converts them into SlimDocument objects and then yields them in batches
doc_batch: list[SlimDocument | HierarchyNode] = []
for site_descriptor in site_descriptors:
ctx: ClientContext | None = None
if self.msal_app and self.sp_tenant_domain:
msal_app = self.msal_app
sp_tenant_domain = self.sp_tenant_domain
ctx = ClientContext(site_descriptor.url).with_access_token(
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
)
else:
raise RuntimeError("MSAL app or tenant domain is not set")
if ctx is None:
logger.warning("ClientContext is not set, skipping permissions")
continue
site_url = site_descriptor.url
# Yield site hierarchy node using helper
@@ -1259,6 +1265,7 @@ class SharepointConnector(
try:
logger.debug(f"Processing: {driveitem.web_url}")
ctx = self._create_rest_client_context(site_descriptor.url)
doc_batch.append(
_convert_driveitem_to_slim_document(
driveitem, drive_name, ctx, self.graph_client
@@ -1278,6 +1285,7 @@ class SharepointConnector(
logger.debug(
f"Processing site page: {site_page.get('webUrl', site_page.get('name', 'Unknown'))}"
)
ctx = self._create_rest_client_context(site_descriptor.url)
doc_batch.append(
_convert_sitepage_to_slim_document(
site_page, ctx, self.graph_client
@@ -1289,6 +1297,7 @@ class SharepointConnector(
yield doc_batch
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self._credential_json = credentials
auth_method = credentials.get(
"authentication_method", SharepointAuthMethod.CLIENT_SECRET.value
)
@@ -1705,17 +1714,6 @@ class SharepointConnector(
)
logger.debug(f"Time range: {start_dt} to {end_dt}")
ctx: ClientContext | None = None
if include_permissions:
if self.msal_app and self.sp_tenant_domain:
msal_app = self.msal_app
sp_tenant_domain = self.sp_tenant_domain
ctx = ClientContext(site_descriptor.url).with_access_token(
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
)
else:
raise RuntimeError("MSAL app or tenant domain is not set")
# At this point current_drive_name should be set from popleft()
current_drive_name = checkpoint.current_drive_name
if current_drive_name is None:
@@ -1810,6 +1808,10 @@ class SharepointConnector(
)
try:
ctx: ClientContext | None = None
if include_permissions:
ctx = self._create_rest_client_context(site_descriptor.url)
doc = _convert_driveitem_to_document_with_permissions(
driveitem,
current_drive_name,
@@ -1875,20 +1877,13 @@ class SharepointConnector(
site_pages = self._fetch_site_pages(
site_descriptor, start=start_dt, end=end_dt
)
client_ctx: ClientContext | None = None
if include_permissions:
if self.msal_app and self.sp_tenant_domain:
msal_app = self.msal_app
sp_tenant_domain = self.sp_tenant_domain
client_ctx = ClientContext(site_descriptor.url).with_access_token(
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
)
else:
raise RuntimeError("MSAL app or tenant domain is not set")
for site_page in site_pages:
logger.debug(
f"Processing site page: {site_page.get('webUrl', site_page.get('name', 'Unknown'))}"
)
client_ctx: ClientContext | None = None
if include_permissions:
client_ctx = self._create_rest_client_context(site_descriptor.url)
yield (
_convert_sitepage_to_document(
site_page,

View File

@@ -6,7 +6,6 @@ from uuid import UUID
from pydantic import BaseModel
from pydantic import Field
from pydantic import field_validator
from onyx.configs.constants import DocumentSource
from onyx.db.models import SearchSettings
@@ -97,21 +96,6 @@ class IndexFilters(BaseFilters, UserFileFilters, AssistantKnowledgeFilters):
tenant_id: str | None = None
class ChunkContext(BaseModel):
# If not specified (None), picked up from Persona settings if there is space
# if specified (even if 0), it always uses the specified number of chunks above and below
chunks_above: int | None = None
chunks_below: int | None = None
full_doc: bool = False
@field_validator("chunks_above", "chunks_below")
@classmethod
def check_non_negative(cls, value: int, field: Any) -> int:
if value is not None and value < 0:
raise ValueError(f"{field.name} must be non-negative")
return value
class BasicChunkRequest(BaseModel):
query: str

View File

@@ -19,7 +19,6 @@ from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.chat.models import DocumentRelevance
from onyx.configs.chat_configs import HARD_DELETE_CHATS
from onyx.configs.constants import MessageType
from onyx.context.search.models import InferenceSection
@@ -672,27 +671,6 @@ def set_as_latest_chat_message(
db_session.commit()
def update_search_docs_table_with_relevance(
db_session: Session,
reference_db_search_docs: list[DBSearchDoc],
relevance_summary: DocumentRelevance,
) -> None:
for search_doc in reference_db_search_docs:
relevance_data = relevance_summary.relevance_summaries.get(
search_doc.document_id
)
if relevance_data is not None:
db_session.execute(
update(DBSearchDoc)
.where(DBSearchDoc.id == search_doc.id)
.values(
is_relevant=relevance_data.relevant,
relevance_explanation=relevance_data.content,
)
)
db_session.commit()
def _sanitize_for_postgres(value: str) -> str:
"""Remove NUL (0x00) characters from strings as PostgreSQL doesn't allow them."""
sanitized = value.replace("\x00", "")

View File

@@ -296,4 +296,4 @@ class HierarchyNodeType(str, PyEnum):
class LLMModelFlowType(str, PyEnum):
CHAT = "chat"
VISION = "vision"
EMBEDDINGS = "embeddings"
CONTEXTUAL_RAG = "contextual_rag"

View File

@@ -509,6 +509,12 @@ def fetch_default_vision_model(db_session: Session) -> ModelConfiguration | None
return fetch_default_model(db_session, LLMModelFlowType.VISION)
def fetch_default_contextual_rag_model(
db_session: Session,
) -> ModelConfiguration | None:
return fetch_default_model(db_session, LLMModelFlowType.CONTEXTUAL_RAG)
def fetch_default_model(
db_session: Session,
flow_type: LLMModelFlowType,
@@ -646,6 +652,73 @@ def update_default_vision_provider(
)
def update_no_default_contextual_rag_provider(
db_session: Session,
) -> None:
db_session.execute(
update(LLMModelFlow)
.where(
LLMModelFlow.llm_model_flow_type == LLMModelFlowType.CONTEXTUAL_RAG,
LLMModelFlow.is_default == True, # noqa: E712
)
.values(is_default=False)
)
db_session.commit()
def update_default_contextual_model(
db_session: Session,
enable_contextual_rag: bool,
contextual_rag_llm_provider: str | None,
contextual_rag_llm_name: str | None,
) -> None:
"""Sets or clears the default contextual RAG model.
Should be called whenever the PRESENT search settings change
(e.g. inline update or FUTURE → PRESENT swap).
"""
if (
not enable_contextual_rag
or not contextual_rag_llm_name
or not contextual_rag_llm_provider
):
update_no_default_contextual_rag_provider(db_session=db_session)
return
provider = fetch_existing_llm_provider(
name=contextual_rag_llm_provider, db_session=db_session
)
if not provider:
raise ValueError(f"Provider '{contextual_rag_llm_provider}' not found")
model_config = next(
(
mc
for mc in provider.model_configurations
if mc.name == contextual_rag_llm_name
),
None,
)
if not model_config:
raise ValueError(
f"Model '{contextual_rag_llm_name}' not found for provider '{contextual_rag_llm_provider}'"
)
add_model_to_flow(
db_session=db_session,
model_configuration_id=model_config.id,
flow_type=LLMModelFlowType.CONTEXTUAL_RAG,
)
_update_default_model(
db_session=db_session,
provider_id=provider.id,
model=contextual_rag_llm_name,
flow_type=LLMModelFlowType.CONTEXTUAL_RAG,
)
return
def fetch_auto_mode_providers(db_session: Session) -> list[LLMProviderModel]:
"""Fetch all LLM providers that are in Auto mode."""
query = (
@@ -760,9 +833,18 @@ def create_new_flow_mapping__no_commit(
)
flow = result.scalar()
if not flow:
# Row already exists — fetch it
flow = db_session.scalar(
select(LLMModelFlow).where(
LLMModelFlow.model_configuration_id == model_configuration_id,
LLMModelFlow.llm_model_flow_type == flow_type,
)
)
if not flow:
raise ValueError(
f"Failed to create new flow mapping for model_configuration_id={model_configuration_id} and flow_type={flow_type}"
f"Failed to create or find flow mapping for "
f"model_configuration_id={model_configuration_id} and flow_type={flow_type}"
)
return flow
@@ -900,3 +982,18 @@ def _update_default_model(
model_config.is_visible = True
db_session.commit()
def add_model_to_flow(
db_session: Session,
model_configuration_id: int,
flow_type: LLMModelFlowType,
) -> None:
# Function does nothing on conflict
create_new_flow_mapping__no_commit(
db_session=db_session,
model_configuration_id=model_configuration_id,
flow_type=flow_type,
)
db_session.commit()

View File

@@ -7,10 +7,6 @@ from sqlalchemy.orm import Session
from onyx.db.models import Memory
from onyx.db.models import User
from onyx.prompts.user_info import BASIC_INFORMATION_PROMPT
from onyx.prompts.user_info import USER_MEMORIES_PROMPT
from onyx.prompts.user_info import USER_PREFERENCES_PROMPT
from onyx.prompts.user_info import USER_ROLE_PROMPT
MAX_MEMORIES_PER_USER = 10
@@ -36,6 +32,15 @@ class UserMemoryContext(BaseModel):
user_preferences: str | None = None
memories: tuple[str, ...] = ()
def without_memories(self) -> "UserMemoryContext":
"""Return a copy with memories cleared but user info/preferences intact."""
return UserMemoryContext(
user_id=self.user_id,
user_info=self.user_info,
user_preferences=self.user_preferences,
memories=(),
)
def as_formatted_list(self) -> list[str]:
"""Returns combined list of user info, preferences, and memories."""
result = []
@@ -50,45 +55,6 @@ class UserMemoryContext(BaseModel):
result.extend(self.memories)
return result
def as_formatted_prompt(self) -> str:
"""Returns structured prompt sections for the system prompt."""
has_basic_info = (
self.user_info.name or self.user_info.email or self.user_info.role
)
if not has_basic_info and not self.user_preferences and not self.memories:
return ""
sections: list[str] = []
if has_basic_info:
role_line = (
USER_ROLE_PROMPT.format(user_role=self.user_info.role).strip()
if self.user_info.role
else ""
)
if role_line:
role_line = "\n" + role_line
sections.append(
BASIC_INFORMATION_PROMPT.format(
user_name=self.user_info.name or "",
user_email=self.user_info.email or "",
user_role=role_line,
)
)
if self.user_preferences:
sections.append(
USER_PREFERENCES_PROMPT.format(user_preferences=self.user_preferences)
)
if self.memories:
formatted_memories = "\n".join(f"- {memory}" for memory in self.memories)
sections.append(
USER_MEMORIES_PROMPT.format(user_memories=formatted_memories)
)
return "".join(sections)
def get_memories(user: User, db_session: Session) -> UserMemoryContext:
user_info = UserInfo(

View File

@@ -4877,3 +4877,90 @@ class BuildMessage(Base):
"ix_build_message_session_turn", "session_id", "turn_index", "created_at"
),
)
"""
SCIM 2.0 Provisioning Models (Enterprise Edition only)
Used for automated user/group provisioning from identity providers (Okta, Azure AD).
"""
class ScimToken(Base):
"""Bearer tokens for IdP SCIM authentication."""
__tablename__ = "scim_token"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String, nullable=False)
hashed_token: Mapped[str] = mapped_column(
String(64), unique=True, nullable=False
) # SHA256 = 64 hex chars
token_display: Mapped[str] = mapped_column(
String, nullable=False
) # Last 4 chars for UI identification
created_by_id: Mapped[UUID] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=False
)
is_active: Mapped[bool] = mapped_column(
Boolean, server_default=text("true"), nullable=False
)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
last_used_at: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
created_by: Mapped[User] = relationship("User", foreign_keys=[created_by_id])
class ScimUserMapping(Base):
"""Maps SCIM externalId from the IdP to an Onyx User."""
__tablename__ = "scim_user_mapping"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
external_id: Mapped[str] = mapped_column(String, unique=True, index=True)
user_id: Mapped[UUID] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), unique=True, nullable=False
)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
updated_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
user: Mapped[User] = relationship("User", foreign_keys=[user_id])
class ScimGroupMapping(Base):
"""Maps SCIM externalId from the IdP to an Onyx UserGroup."""
__tablename__ = "scim_group_mapping"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
external_id: Mapped[str] = mapped_column(String, unique=True, index=True)
user_group_id: Mapped[int] = mapped_column(
ForeignKey("user_group.id", ondelete="CASCADE"), unique=True, nullable=False
)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
updated_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
user_group: Mapped[UserGroup] = relationship(
"UserGroup", foreign_keys=[user_group_id]
)

View File

@@ -15,6 +15,8 @@ from onyx.db.index_attempt import (
count_unique_active_cc_pairs_with_successful_index_attempts,
)
from onyx.db.index_attempt import count_unique_cc_pairs_with_successful_index_attempts
from onyx.db.llm import update_default_contextual_model
from onyx.db.llm import update_no_default_contextual_rag_provider
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import SearchSettings
from onyx.db.search_settings import get_current_search_settings
@@ -80,6 +82,24 @@ def _perform_index_swap(
db_session=db_session,
)
# Update the default contextual model to match the newly promoted settings
try:
update_default_contextual_model(
db_session=db_session,
enable_contextual_rag=new_search_settings.enable_contextual_rag,
contextual_rag_llm_provider=new_search_settings.contextual_rag_llm_provider,
contextual_rag_llm_name=new_search_settings.contextual_rag_llm_name,
)
except ValueError as e:
logger.error(f"Model not found, defaulting to no contextual model: {e}")
update_no_default_contextual_rag_provider(
db_session=db_session,
)
new_search_settings.enable_contextual_rag = False
new_search_settings.contextual_rag_llm_provider = None
new_search_settings.contextual_rag_llm_name = None
db_session.commit()
# This flow is for checking and possibly creating an index so we get all
# indices.
document_indices = get_all_document_indices(new_search_settings, None, None)

View File

@@ -4,23 +4,21 @@ from collections.abc import Generator
from contextlib import contextmanager
from typing import Any
from pydantic import BaseModel
from sqlalchemy import Engine
from sqlalchemy import event
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.session import SessionTransaction
from onyx.chat.models import MessageResponseIDInfo
from onyx.chat.models import StreamingError
from onyx.chat.process_message import AnswerStream
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.models import ChatFullResponse
from onyx.chat.process_message import gather_stream_full
from onyx.chat.process_message import handle_stream_message_objects
from onyx.chat.process_message import remove_answer_citations
from onyx.chat.process_message import stream_chat_message_objects
from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.db.chat import create_chat_session
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
from onyx.db.users import get_user_by_email
from onyx.evals.models import ChatFullEvalResult
from onyx.evals.models import EvalationAck
from onyx.evals.models import EvalConfigurationOptions
from onyx.evals.models import EvalMessage
@@ -33,18 +31,7 @@ from onyx.evals.provider import get_provider
from onyx.llm.override_models import LLMOverride
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.server.query_and_chat.models import RetrievalDetails
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import CustomToolStart
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
from onyx.server.query_and_chat.streaming_models import OpenUrlStart
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PythonToolStart
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
@@ -87,193 +74,29 @@ def isolated_ephemeral_session_factory(
conn.close()
class GatherStreamResult(BaseModel):
"""Result of gathering a stream with tool call information."""
answer: str
answer_citationless: str
tools_called: list[str]
tool_call_details: list[dict[str, Any]]
message_id: int
error_msg: str | None = None
citations: list[CitationInfo] = []
timings: EvalTimings | None = None
def gather_stream_with_tools(packets: AnswerStream) -> GatherStreamResult:
"""
Gather streaming packets and extract both answer content and tool call information.
Returns a GatherStreamResult containing the answer and all tools that were called.
"""
stream_start_time = time.time()
answer: str | None = None
citations: list[CitationInfo] = []
error_msg: str | None = None
message_id: int | None = None
tools_called: list[str] = []
tool_call_details: list[dict[str, Any]] = []
# Timing tracking
first_token_time: float | None = None
tool_start_times: dict[str, float] = {} # tool_name -> start time
tool_execution_ms: dict[str, float] = {} # tool_name -> duration in ms
current_tool: str | None = None
def _finalize_tool_timing(tool_name: str) -> None:
"""Record the duration for a tool that just finished."""
if tool_name in tool_start_times:
duration_ms = (time.time() - tool_start_times[tool_name]) * 1000
tool_execution_ms[tool_name] = duration_ms
for packet in packets:
if isinstance(packet, Packet):
obj = packet.obj
# Handle answer content
if isinstance(obj, AgentResponseStart):
# When answer starts, finalize any in-progress tool
if current_tool:
_finalize_tool_timing(current_tool)
current_tool = None
elif isinstance(obj, AgentResponseDelta):
if answer is None:
answer = ""
first_token_time = time.time()
if obj.content:
answer += obj.content
elif isinstance(obj, CitationInfo):
citations.append(obj)
# Track tool calls with timing
elif isinstance(obj, SearchToolStart):
# Finalize any previous tool
if current_tool:
_finalize_tool_timing(current_tool)
tool_name = "WebSearchTool" if obj.is_internet_search else "SearchTool"
current_tool = tool_name
tool_start_times[tool_name] = time.time()
tools_called.append(tool_name)
tool_call_details.append(
{
"tool_name": tool_name,
"tool_type": "search",
"is_internet_search": obj.is_internet_search,
}
)
elif isinstance(obj, ImageGenerationToolStart):
if current_tool:
_finalize_tool_timing(current_tool)
tool_name = "ImageGenerationTool"
current_tool = tool_name
tool_start_times[tool_name] = time.time()
tools_called.append(tool_name)
tool_call_details.append(
{
"tool_name": tool_name,
"tool_type": "image_generation",
}
)
elif isinstance(obj, PythonToolStart):
if current_tool:
_finalize_tool_timing(current_tool)
tool_name = "PythonTool"
current_tool = tool_name
tool_start_times[tool_name] = time.time()
tools_called.append(tool_name)
tool_call_details.append(
{
"tool_name": tool_name,
"tool_type": "python",
"code": obj.code,
}
)
elif isinstance(obj, OpenUrlStart):
if current_tool:
_finalize_tool_timing(current_tool)
tool_name = "OpenURLTool"
current_tool = tool_name
tool_start_times[tool_name] = time.time()
tools_called.append(tool_name)
tool_call_details.append(
{
"tool_name": tool_name,
"tool_type": "open_url",
}
)
elif isinstance(obj, CustomToolStart):
if current_tool:
_finalize_tool_timing(current_tool)
tool_name = obj.tool_name
current_tool = tool_name
tool_start_times[tool_name] = time.time()
tools_called.append(tool_name)
tool_call_details.append(
{
"tool_name": tool_name,
"tool_type": "custom",
}
)
elif isinstance(packet, StreamingError):
logger.warning(f"Streaming error during eval: {packet.error}")
error_msg = packet.error
elif isinstance(packet, MessageResponseIDInfo):
message_id = packet.reserved_assistant_message_id
# Finalize any remaining tool timing
if current_tool:
_finalize_tool_timing(current_tool)
def _chat_full_response_to_eval_result(
full: ChatFullResponse,
stream_start_time: float,
) -> ChatFullEvalResult:
"""Map ChatFullResponse from gather_stream_full to eval result components."""
tools_called = [tc.tool_name for tc in full.tool_calls]
tool_call_details: list[dict[str, Any]] = [
{"tool_name": tc.tool_name, "tool_arguments": tc.tool_arguments}
for tc in full.tool_calls
]
stream_end_time = time.time()
if message_id is None:
# If we got a streaming error, include it in the exception
if error_msg:
raise ValueError(f"Message ID is required. Stream error: {error_msg}")
raise ValueError(
f"Message ID is required. No MessageResponseIDInfo received. "
f"Tools called: {tools_called}"
)
# Allow empty answers for tool-only turns (e.g., in multi-turn evals)
# Some turns may only execute tools without generating a text response
if answer is None:
logger.warning(
"No answer content generated. Tools called: %s. "
"This may be expected for tool-only turns.",
tools_called,
)
answer = ""
# Calculate timings
total_ms = (stream_end_time - stream_start_time) * 1000
first_token_ms = (
(first_token_time - stream_start_time) * 1000 if first_token_time else None
)
stream_processing_ms = (stream_end_time - stream_start_time) * 1000
timings = EvalTimings(
total_ms=total_ms,
llm_first_token_ms=first_token_ms,
tool_execution_ms=tool_execution_ms,
stream_processing_ms=stream_processing_ms,
llm_first_token_ms=None,
tool_execution_ms={},
stream_processing_ms=total_ms,
)
return GatherStreamResult(
answer=answer,
answer_citationless=remove_answer_citations(answer),
return ChatFullEvalResult(
answer=full.answer,
tools_called=tools_called,
tool_call_details=tool_call_details,
message_id=message_id,
error_msg=error_msg,
citations=citations,
citations=full.citation_info,
timings=timings,
)
@@ -413,14 +236,17 @@ def _get_answer_with_tools(
),
)
stream_start_time = time.time()
state_container = ChatStateContainer()
packets = handle_stream_message_objects(
new_msg_req=request,
user=user,
db_session=db_session,
external_state_container=state_container,
)
full = gather_stream_full(packets, state_container)
# Gather stream with tool call tracking
result = gather_stream_with_tools(packets)
result = _chat_full_response_to_eval_result(full, stream_start_time)
# Evaluate tool assertions
assertion_passed, assertion_details = evaluate_tool_assertions(
@@ -551,30 +377,30 @@ def _get_multi_turn_answer_with_tools(
),
)
# Create request for this turn
# Create request for this turn using SendMessageRequest (same API as handle_stream_message_objects)
# Use AUTO_PLACE_AFTER_LATEST_MESSAGE to chain messages
request = CreateChatMessageRequest(
forced_tool_id = forced_tool_ids[0] if forced_tool_ids else None
request = SendMessageRequest(
chat_session_id=chat_session_id,
parent_message_id=AUTO_PLACE_AFTER_LATEST_MESSAGE,
message=msg.message,
file_descriptors=[],
search_doc_ids=None,
retrieval_options=RetrievalDetails(),
llm_override=llm_override,
persona_override_config=full_configuration.persona_override_config,
skip_gen_ai_answer_generation=False,
allowed_tool_ids=full_configuration.allowed_tool_ids,
forced_tool_ids=forced_tool_ids or None,
forced_tool_id=forced_tool_id,
)
# Stream and gather results for this turn
packets = stream_chat_message_objects(
# Stream and gather results for this turn via handle_stream_message_objects + gather_stream_full
stream_start_time = time.time()
state_container = ChatStateContainer()
packets = handle_stream_message_objects(
new_msg_req=request,
user=user,
db_session=db_session,
external_state_container=state_container,
)
full = gather_stream_full(packets, state_container)
result = gather_stream_with_tools(packets)
result = _chat_full_response_to_eval_result(full, stream_start_time)
# Evaluate tool assertions for this turn
assertion_passed, assertion_details = evaluate_tool_assertions(

View File

@@ -7,9 +7,6 @@ from pydantic import BaseModel
from pydantic import Field
from sqlalchemy.orm import Session
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import PromptOverrideConfig
from onyx.chat.models import ToolConfig
from onyx.db.tools import get_builtin_tool
from onyx.llm.override_models import LLMOverride
from onyx.server.query_and_chat.streaming_models import CitationInfo
@@ -34,6 +31,16 @@ class EvalTimings(BaseModel):
stream_processing_ms: float | None = None # Time to process the stream
class ChatFullEvalResult(BaseModel):
"""Raw eval components from ChatFullResponse (before tool assertions)."""
answer: str
tools_called: list[str]
tool_call_details: list[dict[str, Any]]
citations: list[CitationInfo]
timings: EvalTimings
class EvalToolResult(BaseModel):
"""Result of a single eval with tool call information."""
@@ -72,8 +79,6 @@ class MultiTurnEvalResult(BaseModel):
class EvalConfiguration(BaseModel):
builtin_tool_types: list[str] = Field(default_factory=list)
persona_override_config: PersonaOverrideConfig | None = None
llm: LLMOverride = Field(default_factory=LLMOverride)
search_permissions_email: str
allowed_tool_ids: list[int]
@@ -81,7 +86,6 @@ class EvalConfiguration(BaseModel):
class EvalConfigurationOptions(BaseModel):
builtin_tool_types: list[str] = list(BUILT_IN_TOOL_MAP.keys())
persona_override_config: PersonaOverrideConfig | None = None
llm: LLMOverride = LLMOverride(
model_provider=None,
model_version="gpt-4o",
@@ -96,26 +100,7 @@ class EvalConfigurationOptions(BaseModel):
experiment_name: str | None = None
def get_configuration(self, db_session: Session) -> EvalConfiguration:
persona_override_config = self.persona_override_config or PersonaOverrideConfig(
name="Eval",
description="A persona for evaluation",
tools=[
ToolConfig(id=get_builtin_tool(db_session, BUILT_IN_TOOL_MAP[tool]).id)
for tool in self.builtin_tool_types
],
prompts=[
PromptOverrideConfig(
name="Default",
description="Default prompt for evaluation",
system_prompt="You are a helpful assistant.",
task_prompt="",
datetime_aware=True,
)
],
)
return EvalConfiguration(
persona_override_config=persona_override_config,
llm=self.llm,
search_permissions_email=self.search_permissions_email,
allowed_tool_ids=[

View File

@@ -2,7 +2,6 @@ from collections.abc import Callable
from typing import Any
from onyx.auth.schemas import UserRole
from onyx.chat.models import PersonaOverrideConfig
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import LLMModelFlowType
@@ -77,7 +76,7 @@ def _build_model_kwargs(
def get_llm_for_persona(
persona: Persona | PersonaOverrideConfig | None,
persona: Persona | None,
user: User,
llm_override: LLMOverride | None = None,
additional_headers: dict[str, str] | None = None,
@@ -102,20 +101,16 @@ def get_llm_for_persona(
if not provider_model:
raise ValueError("No LLM provider found")
# Only check access control for database Persona entities, not PersonaOverrideConfig
# PersonaOverrideConfig is used for temporary overrides and doesn't have access restrictions
persona_model = persona if isinstance(persona, Persona) else None
# Fetch user group IDs for access control check
user_group_ids = fetch_user_group_ids(db_session, user)
if not can_user_access_llm_provider(
provider_model, user_group_ids, persona_model, user.role == UserRole.ADMIN
provider_model, user_group_ids, persona, user.role == UserRole.ADMIN
):
logger.warning(
"User %s with persona %s cannot access provider %s. Falling back to default provider.",
user.id,
getattr(persona_model, "id", None),
persona.id,
provider_model.name,
)
return get_default_llm(

View File

@@ -92,7 +92,7 @@ class CacheableMessage(BaseModel):
class SystemMessage(CacheableMessage):
role: Literal["system"] = "system"
content: str | list[ContentPart]
content: str
class UserMessage(CacheableMessage):

View File

@@ -1,4 +1,8 @@
import os
import threading
from collections.abc import Iterator
from contextlib import contextmanager
from contextlib import nullcontext
from typing import Any
from typing import cast
from typing import TYPE_CHECKING
@@ -49,6 +53,8 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_env_lock = threading.Lock()
if TYPE_CHECKING:
from litellm import CustomStreamWrapper
from litellm import HTTPHandler
@@ -378,23 +384,30 @@ class LitellmLLM(LLM):
if "api_key" not in passthrough_kwargs:
passthrough_kwargs["api_key"] = self._api_key or None
response = litellm.completion(
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
model=model,
base_url=self._api_base or None,
api_version=self._api_version or None,
custom_llm_provider=self._custom_llm_provider or None,
messages=_prompt_to_dicts(prompt),
tools=tools,
tool_choice=tool_choice,
stream=stream,
temperature=temperature,
timeout=timeout_override or self._timeout,
max_tokens=max_tokens,
client=client,
**optional_kwargs,
**passthrough_kwargs,
# We only need to set environment variables if custom config is set
env_ctx = (
temporary_env_and_lock(self._custom_config)
if self._custom_config
else nullcontext()
)
with env_ctx:
response = litellm.completion(
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
model=model,
base_url=self._api_base or None,
api_version=self._api_version or None,
custom_llm_provider=self._custom_llm_provider or None,
messages=_prompt_to_dicts(prompt),
tools=tools,
tool_choice=tool_choice,
stream=stream,
temperature=temperature,
timeout=timeout_override or self._timeout,
max_tokens=max_tokens,
client=client,
**optional_kwargs,
**passthrough_kwargs,
)
return response
except Exception as e:
# for break pointing
@@ -475,13 +488,21 @@ class LitellmLLM(LLM):
client = HTTPHandler(timeout=timeout_override or self._timeout)
try:
response = cast(
LiteLLMModelResponse,
# When custom_config is set, env vars are temporarily injected
# under a global lock. Using stream=True here means the lock is
# only held during connection setup (not the full inference).
# The chunks are then collected outside the lock and reassembled
# into a single ModelResponse via stream_chunk_builder.
from litellm import stream_chunk_builder
from litellm import CustomStreamWrapper as LiteLLMCustomStreamWrapper
stream_response = cast(
LiteLLMCustomStreamWrapper,
self._completion(
prompt=prompt,
tools=tools,
tool_choice=tool_choice,
stream=False,
stream=True,
structured_response_format=structured_response_format,
timeout_override=timeout_override,
max_tokens=max_tokens,
@@ -491,6 +512,11 @@ class LitellmLLM(LLM):
client=client,
),
)
chunks = list(stream_response)
response = cast(
LiteLLMModelResponse,
stream_chunk_builder(chunks),
)
model_response = from_litellm_model_response(response)
@@ -581,3 +607,29 @@ class LitellmLLM(LLM):
finally:
if client is not None:
client.close()
@contextmanager
def temporary_env_and_lock(env_variables: dict[str, str]) -> Iterator[None]:
"""
Temporarily sets the environment variables to the given values.
Code path is locked while the environment variables are set.
Then cleans up the environment and frees the lock.
"""
with _env_lock:
logger.debug("Acquired lock in temporary_env_and_lock")
# Store original values (None if key didn't exist)
original_values: dict[str, str | None] = {
key: os.environ.get(key) for key in env_variables
}
try:
os.environ.update(env_variables)
yield
finally:
for key, original_value in original_values.items():
if original_value is None:
os.environ.pop(key, None) # Remove if it didn't exist before
else:
os.environ[key] = original_value # Restore original value
logger.debug("Released lock in temporary_env_and_lock")

View File

@@ -3,6 +3,7 @@
import uvicorn
from onyx.configs.app_configs import MCP_SERVER_ENABLED
from onyx.configs.app_configs import MCP_SERVER_HOST
from onyx.configs.app_configs import MCP_SERVER_PORT
from onyx.utils.logger import setup_logger
@@ -15,13 +16,13 @@ def main() -> None:
logger.info("MCP server is disabled (MCP_SERVER_ENABLED=false)")
return
logger.info(f"Starting MCP server on 0.0.0.0:{MCP_SERVER_PORT}")
logger.info(f"Starting MCP server on {MCP_SERVER_HOST}:{MCP_SERVER_PORT}")
from onyx.mcp_server.api import mcp_app
uvicorn.run(
mcp_app,
host="0.0.0.0",
host=MCP_SERVER_HOST,
port=MCP_SERVER_PORT,
log_config=None,
)

View File

@@ -1,13 +1,11 @@
# ruff: noqa: E501, W605 start
from onyx.prompts.constants import REMINDER_TAG_DESCRIPTION
from onyx.prompts.constants import REMINDER_TAG_NO_HEADER
DATETIME_REPLACEMENT_PAT = "{{CURRENT_DATETIME}}"
CITATION_GUIDANCE_REPLACEMENT_PAT = "{{CITATION_GUIDANCE}}"
ALT_DATETIME_REPLACEMENT_PAT = "[[CURRENT_DATETIME]]"
ALT_CITATION_GUIDANCE_REPLACEMENT_PAT = "[[CITATION_GUIDANCE]]"
REMINDER_TAG_REPLACEMENT_PAT = "{{REMINDER_TAG_DESCRIPTION}}"
# Note this uses a string pattern replacement so the user can also include it in their custom prompts. Keeps the replacement logic simple
@@ -27,7 +25,7 @@ For code you prefer to use Markdown and specify the language.
You can use horizontal rules (---) to separate sections of your responses.
You can use Markdown tables to format your responses for data, lists, and other structured information.
{REMINDER_TAG_DESCRIPTION}
{REMINDER_TAG_REPLACEMENT_PAT}
""".lstrip()

View File

@@ -1,3 +1,4 @@
# ruff: noqa: E501, W605 start
CODE_BLOCK_PAT = "```\n{}\n```"
TRIPLE_BACKTICK = "```"
SYSTEM_REMINDER_TAG_OPEN = "<system-reminder>"
@@ -5,13 +6,12 @@ SYSTEM_REMINDER_TAG_CLOSE = "</system-reminder>"
# Tags format inspired by Anthropic and OpenCode
REMINDER_TAG_NO_HEADER = f"""
User messages may include {SYSTEM_REMINDER_TAG_OPEN} and {SYSTEM_REMINDER_TAG_CLOSE} tags.
These {SYSTEM_REMINDER_TAG_OPEN} tags contain useful information and reminders. \
They are automatically added by the system and are not actual user inputs.
Behave in accordance to these instructions if relevant, and continue normally if they are not.
User messages may include {SYSTEM_REMINDER_TAG_OPEN} and {SYSTEM_REMINDER_TAG_CLOSE} tags. These {SYSTEM_REMINDER_TAG_OPEN} tags contain useful information and reminders. \
They are automatically added by the system and are not actual user inputs. Behave in accordance to these instructions if relevant, and continue normally if they are not.
""".strip()
REMINDER_TAG_DESCRIPTION = f"""
# System Reminders
{REMINDER_TAG_NO_HEADER}
""".strip()
# ruff: noqa: E501, W605 end

View File

@@ -5,14 +5,14 @@ from langchain_core.messages import BaseMessage
from onyx.configs.constants import DocumentSource
from onyx.prompts.chat_prompts import ADDITIONAL_INFO
from onyx.prompts.chat_prompts import ALT_CITATION_GUIDANCE_REPLACEMENT_PAT
from onyx.prompts.chat_prompts import ALT_DATETIME_REPLACEMENT_PAT
from onyx.prompts.chat_prompts import CITATION_GUIDANCE_REPLACEMENT_PAT
from onyx.prompts.chat_prompts import COMPANY_DESCRIPTION_BLOCK
from onyx.prompts.chat_prompts import COMPANY_NAME_BLOCK
from onyx.prompts.chat_prompts import DATETIME_REPLACEMENT_PAT
from onyx.prompts.chat_prompts import REMINDER_TAG_REPLACEMENT_PAT
from onyx.prompts.chat_prompts import REQUIRE_CITATION_GUIDANCE
from onyx.prompts.constants import CODE_BLOCK_PAT
from onyx.prompts.constants import REMINDER_TAG_DESCRIPTION
from onyx.server.settings.store import load_settings
from onyx.utils.logger import setup_logger
@@ -54,11 +54,8 @@ def replace_current_datetime_tag(
include_day_of_week=include_day_of_week,
)
# Check and replace both patterns: {{CURRENT_DATETIME}} and [[CURRENT_DATETIME]]
if DATETIME_REPLACEMENT_PAT in prompt_str:
prompt_str = prompt_str.replace(DATETIME_REPLACEMENT_PAT, datetime_str)
if ALT_DATETIME_REPLACEMENT_PAT in prompt_str:
prompt_str = prompt_str.replace(ALT_DATETIME_REPLACEMENT_PAT, datetime_str)
return prompt_str
@@ -70,7 +67,7 @@ def replace_citation_guidance_tag(
include_all_guidance: bool = False,
) -> tuple[str, bool]:
"""
Replace {{CITATION_GUIDANCE}} or [[CITATION_GUIDANCE]] placeholder with citation guidance if needed.
Replace {{CITATION_GUIDANCE}} placeholder with citation guidance if needed.
Returns:
tuple[str, bool]: (prompt_with_replacement, should_append_fallback)
@@ -78,10 +75,7 @@ def replace_citation_guidance_tag(
- should_append_fallback: True if citation guidance should be appended
(placeholder is not present and citations are needed)
"""
# Check for both patterns: {{CITATION_GUIDANCE}} and [[CITATION_GUIDANCE]]
has_primary_pattern = CITATION_GUIDANCE_REPLACEMENT_PAT in prompt_str
has_alt_pattern = ALT_CITATION_GUIDANCE_REPLACEMENT_PAT in prompt_str
placeholder_was_present = has_primary_pattern or has_alt_pattern
placeholder_was_present = CITATION_GUIDANCE_REPLACEMENT_PAT in prompt_str
if not placeholder_was_present:
# Placeholder not present - caller should append if citations are needed
@@ -96,30 +90,32 @@ def replace_citation_guidance_tag(
else ""
)
# Replace both patterns if present
if has_primary_pattern:
prompt_str = prompt_str.replace(
CITATION_GUIDANCE_REPLACEMENT_PAT,
citation_guidance,
)
if has_alt_pattern:
prompt_str = prompt_str.replace(
ALT_CITATION_GUIDANCE_REPLACEMENT_PAT,
citation_guidance,
)
prompt_str = prompt_str.replace(
CITATION_GUIDANCE_REPLACEMENT_PAT,
citation_guidance,
)
return prompt_str, False
def replace_reminder_tag(prompt_str: str) -> str:
"""Replace {{REMINDER_TAG_DESCRIPTION}} with the reminder tag content."""
if REMINDER_TAG_REPLACEMENT_PAT in prompt_str:
prompt_str = prompt_str.replace(
REMINDER_TAG_REPLACEMENT_PAT, REMINDER_TAG_DESCRIPTION
)
return prompt_str
def handle_onyx_date_awareness(
prompt_str: str,
# We always replace the pattern {{CURRENT_DATETIME}} or [[CURRENT_DATETIME]] if it shows up
# We always replace the pattern {{CURRENT_DATETIME}} if it shows up
# but if it doesn't show up and the prompt is datetime aware, add it to the prompt at the end.
datetime_aware: bool = False,
) -> str:
"""
If there is a {{CURRENT_DATETIME}} or [[CURRENT_DATETIME]] tag, replace it with the current
date and time no matter what.
If there is a {{CURRENT_DATETIME}} tag, replace it with the current date and time no matter what.
If the prompt is datetime aware, and there are no datetime tags, add it to the prompt.
Do nothing otherwise.
This can later be expanded to support other tags.

View File

@@ -85,7 +85,7 @@ def send_message(
Enforces rate limiting before executing the agent (via dependency).
Returns a Server-Sent Events (SSE) stream with the agent's response.
Follows the same pattern as /chat/send-message for consistency.
Follows the same pattern as /chat/send-chat-message for consistency.
"""
def stream_generator() -> Generator[str, None, None]:

View File

@@ -4,8 +4,9 @@ This client runs `opencode acp` directly in the sandbox pod via kubernetes exec,
using stdin/stdout for JSON-RPC communication. This bypasses the HTTP server
and uses the native ACP subprocess protocol.
This module includes comprehensive logging for debugging ACP communication.
Enable logging by setting LOG_LEVEL=DEBUG or BUILD_PACKET_LOGGING=true.
When multiple API server replicas share the same sandbox pod, this client
uses ACP session resumption (session/list + session/resume) to maintain
conversation context across replicas.
Usage:
client = ACPExecClient(
@@ -100,7 +101,7 @@ class ACPClientState:
"""Internal state for the ACP client."""
initialized: bool = False
current_session: ACPSession | None = None
sessions: dict[str, ACPSession] = field(default_factory=dict)
next_request_id: int = 0
agent_capabilities: dict[str, Any] = field(default_factory=dict)
agent_info: dict[str, Any] = field(default_factory=dict)
@@ -144,6 +145,7 @@ class ACPExecClient:
self._reader_thread: threading.Thread | None = None
self._stop_reader = threading.Event()
self._k8s_client: client.CoreV1Api | None = None
self._prompt_count: int = 0 # Track how many prompts sent on this client
def _get_k8s_client(self) -> client.CoreV1Api:
"""Get or create kubernetes client."""
@@ -155,16 +157,16 @@ class ACPExecClient:
self._k8s_client = client.CoreV1Api()
return self._k8s_client
def start(self, cwd: str = "/workspace", timeout: float = 30.0) -> str:
"""Start the agent process via exec and initialize a session.
def start(self, cwd: str = "/workspace", timeout: float = 30.0) -> None:
"""Start the agent process via exec and initialize the ACP connection.
Only performs the ACP `initialize` handshake. Sessions are created
separately via `create_session()` or `resume_session()`.
Args:
cwd: Working directory for the agent
cwd: Working directory for the `opencode acp` process
timeout: Timeout for initialization
Returns:
The session ID
Raises:
RuntimeError: If startup fails
"""
@@ -176,6 +178,8 @@ class ACPExecClient:
# Start opencode acp via exec
exec_command = ["opencode", "acp", "--cwd", cwd]
logger.info(f"[ACP] Starting client: pod={self._pod_name} cwd={cwd}")
try:
self._ws_client = k8s_stream(
k8s.connect_get_namespaced_pod_exec,
@@ -201,15 +205,13 @@ class ACPExecClient:
# Give process a moment to start
time.sleep(0.5)
# Initialize ACP connection
# Initialize ACP connection (no session creation)
self._initialize(timeout=timeout)
# Create session
session_id = self._create_session(cwd=cwd, timeout=timeout)
return session_id
logger.info(f"[ACP] Client started: pod={self._pod_name}")
except Exception as e:
logger.error(f"[ACP] Client start failed: pod={self._pod_name} error={e}")
self.stop()
raise RuntimeError(f"Failed to start ACP exec client: {e}") from e
@@ -217,63 +219,153 @@ class ACPExecClient:
"""Background thread to read responses from the exec stream."""
buffer = ""
packet_logger = get_packet_logger()
messages_read = 0
# Track how many consecutive read cycles the buffer has had
# unterminated data (no trailing newline) with no new data arriving.
buffer_stale_cycles = 0
# Track empty read cycles for periodic buffer state logging
empty_read_cycles = 0
while not self._stop_reader.is_set():
if self._ws_client is None:
break
logger.debug(f"[ACP] Reader thread started for pod={self._pod_name}")
try:
if self._ws_client.is_open():
# Read available data
self._ws_client.update(timeout=0.1)
# Read stdout (channel 1)
data = self._ws_client.read_stdout(timeout=0.1)
if data:
buffer += data
# Process complete lines
while "\n" in buffer:
line, buffer = buffer.split("\n", 1)
line = line.strip()
if line:
try:
message = json.loads(line)
# Log the raw incoming message
packet_logger.log_jsonrpc_raw_message(
"IN", message, context="k8s"
)
self._response_queue.put(message)
except json.JSONDecodeError:
packet_logger.log_raw(
"JSONRPC-PARSE-ERROR-K8S",
{
"raw_line": line[:500],
"error": "JSON decode failed",
},
)
logger.warning(
f"Invalid JSON from agent: {line[:100]}"
)
else:
packet_logger.log_raw(
"K8S-WEBSOCKET-CLOSED",
{"pod": self._pod_name, "namespace": self._namespace},
)
try:
while not self._stop_reader.is_set():
if self._ws_client is None:
break
except Exception as e:
if not self._stop_reader.is_set():
packet_logger.log_raw(
"K8S-READER-ERROR",
{"error": str(e), "pod": self._pod_name},
try:
if self._ws_client.is_open():
self._ws_client.update(timeout=0.1)
# Read stderr - log any agent errors
stderr_data = self._ws_client.read_stderr(timeout=0.01)
if stderr_data:
logger.warning(
f"[ACP] stderr pod={self._pod_name}: "
f"{stderr_data.strip()[:500]}"
)
# Read stdout
data = self._ws_client.read_stdout(timeout=0.1)
if data:
buffer += data
buffer_stale_cycles = 0
empty_read_cycles = 0
while "\n" in buffer:
line, buffer = buffer.split("\n", 1)
line = line.strip()
if line:
try:
message = json.loads(line)
messages_read += 1
packet_logger.log_jsonrpc_raw_message(
"IN", message, context="k8s"
)
self._response_queue.put(message)
except json.JSONDecodeError:
logger.warning(
f"[ACP] Invalid JSON from agent: "
f"{line[:100]}"
)
else:
empty_read_cycles += 1
# No new data arrived this cycle. If the buffer
# has unterminated content, track how long it's
# been sitting there. After a few cycles (~0.5s)
# try to parse it — the agent may have sent the
# last message without a trailing newline.
if buffer.strip():
buffer_stale_cycles += 1
if buffer_stale_cycles == 1:
logger.info(
f"[ACP] Buffer has unterminated data: "
f"{len(buffer)} bytes, "
f"preview={buffer.strip()[:200]}"
)
if buffer_stale_cycles >= 3:
logger.info(
f"[ACP] Attempting stale buffer parse: "
f"{len(buffer)} bytes, "
f"cycles={buffer_stale_cycles}"
)
try:
message = json.loads(buffer.strip())
messages_read += 1
packet_logger.log_jsonrpc_raw_message(
"IN",
message,
context="k8s-unterminated",
)
self._response_queue.put(message)
buffer = ""
buffer_stale_cycles = 0
logger.info(
"[ACP] Stale buffer parsed successfully"
)
except json.JSONDecodeError:
# Not valid JSON yet, keep waiting
logger.debug(
f"[ACP] Stale buffer not valid JSON: "
f"{buffer.strip()[:100]}"
)
# Periodic log: every ~5s (50 cycles at 0.1s each)
# when we're idle with an empty buffer — helps
# confirm the reader is alive and waiting.
if empty_read_cycles % 50 == 0:
logger.info(
f"[ACP] Reader idle: "
f"empty_cycles={empty_read_cycles} "
f"buffer={len(buffer)} bytes "
f"messages_read={messages_read} "
f"pod={self._pod_name}"
)
else:
logger.warning(
f"[ACP] WebSocket closed: pod={self._pod_name}, "
f"messages_read={messages_read}"
)
break
except Exception as e:
if not self._stop_reader.is_set():
logger.warning(f"[ACP] Reader error: {e}, pod={self._pod_name}")
break
finally:
# Flush any remaining data in buffer
remaining = buffer.strip()
if remaining:
logger.info(
f"[ACP] Flushing buffer on exit: {len(remaining)} bytes, "
f"preview={remaining[:200]}"
)
try:
message = json.loads(remaining)
packet_logger.log_jsonrpc_raw_message(
"IN", message, context="k8s-flush"
)
logger.debug(f"Reader error: {e}")
break
self._response_queue.put(message)
except json.JSONDecodeError:
logger.warning(
f"[ACP] Buffer flush failed (not JSON): " f"{remaining[:200]}"
)
logger.info(
f"[ACP] Reader thread exiting: pod={self._pod_name}, "
f"messages_read={messages_read}, "
f"empty_read_cycles={empty_read_cycles}"
)
def stop(self) -> None:
"""Stop the exec session and clean up."""
session_ids = list(self._state.sessions.keys())
logger.info(
f"[ACP] Stopping client: pod={self._pod_name} "
f"sessions={session_ids} prompts_sent={self._prompt_count}"
)
self._stop_reader.set()
if self._ws_client is not None:
@@ -400,44 +492,215 @@ class ACPExecClient:
if not session_id:
raise RuntimeError("No session ID returned from session/new")
self._state.current_session = ACPSession(session_id=session_id, cwd=cwd)
self._state.sessions[session_id] = ACPSession(session_id=session_id, cwd=cwd)
logger.info(f"[ACP] Created session: acp_session={session_id} cwd={cwd}")
return session_id
def _list_sessions(self, cwd: str, timeout: float = 10.0) -> list[dict[str, Any]]:
"""List available ACP sessions, filtered by working directory.
Returns:
List of session info dicts with keys like 'sessionId', 'cwd', 'title'.
Empty list if session/list is not supported or fails.
"""
try:
request_id = self._send_request("session/list", {"cwd": cwd})
result = self._wait_for_response(request_id, timeout)
sessions = result.get("sessions", [])
logger.info(f"[ACP] session/list: {len(sessions)} sessions for cwd={cwd}")
return sessions
except Exception as e:
logger.info(f"[ACP] session/list unavailable: {e}")
return []
def _resume_session(self, session_id: str, cwd: str, timeout: float = 30.0) -> str:
"""Resume an existing ACP session.
Args:
session_id: The ACP session ID to resume
cwd: Working directory for the session
timeout: Timeout for the resume request
Returns:
The session ID
Raises:
RuntimeError: If resume fails
"""
params = {
"sessionId": session_id,
"cwd": cwd,
"mcpServers": [],
}
request_id = self._send_request("session/resume", params)
result = self._wait_for_response(request_id, timeout)
# The response should contain the session ID
resumed_id = result.get("sessionId", session_id)
self._state.sessions[resumed_id] = ACPSession(session_id=resumed_id, cwd=cwd)
logger.info(f"[ACP] Resumed session: acp_session={resumed_id} cwd={cwd}")
return resumed_id
def _try_resume_existing_session(self, cwd: str, timeout: float) -> str | None:
"""Try to find and resume an existing session for this workspace.
When multiple API server replicas connect to the same sandbox pod,
a previous replica may have already created an ACP session for this
workspace. This method discovers and resumes that session so the
agent retains conversation context.
Args:
cwd: Working directory to search for sessions
timeout: Timeout for ACP requests
Returns:
The resumed session ID, or None if no session could be resumed
"""
# Check if the agent supports session/list + session/resume
session_caps = self._state.agent_capabilities.get("sessionCapabilities", {})
supports_list = session_caps.get("list") is not None
supports_resume = session_caps.get("resume") is not None
if not supports_list or not supports_resume:
logger.debug("[ACP] Agent does not support session resume")
return None
# List sessions for this workspace directory
sessions = self._list_sessions(cwd, timeout=min(timeout, 10.0))
if not sessions:
return None
# Pick the most recent session (first in list, assuming sorted)
target = sessions[0]
target_id = target.get("sessionId")
if not target_id:
logger.warning(
"[ACP-LIFECYCLE] session/list returned session without sessionId"
)
return None
logger.info(
f"[ACP] Resuming existing session: acp_session={target_id} "
f"(found {len(sessions)})"
)
try:
return self._resume_session(target_id, cwd, timeout)
except Exception as e:
logger.warning(
f"[ACP] session/resume failed for {target_id}: {e}, "
f"falling back to session/new"
)
return None
def create_session(self, cwd: str, timeout: float = 30.0) -> str:
"""Create a new ACP session on this connection.
Args:
cwd: Working directory for the session
timeout: Timeout for the request
Returns:
The ACP session ID
"""
if not self._state.initialized:
raise RuntimeError("Client not initialized. Call start() first.")
return self._create_session(cwd=cwd, timeout=timeout)
def resume_session(self, session_id: str, cwd: str, timeout: float = 30.0) -> str:
"""Resume an existing ACP session on this connection.
Args:
session_id: The ACP session ID to resume
cwd: Working directory for the session
timeout: Timeout for the request
Returns:
The ACP session ID
"""
if not self._state.initialized:
raise RuntimeError("Client not initialized. Call start() first.")
return self._resume_session(session_id=session_id, cwd=cwd, timeout=timeout)
def get_or_create_session(self, cwd: str, timeout: float = 30.0) -> str:
"""Get an existing session for this cwd, or create/resume one.
Tries in order:
1. Return an already-tracked session for this cwd
2. Resume an existing session from opencode's storage (multi-replica)
3. Create a new session
Args:
cwd: Working directory for the session
timeout: Timeout for ACP requests
Returns:
The ACP session ID
"""
if not self._state.initialized:
raise RuntimeError("Client not initialized. Call start() first.")
# Check if we already have a session for this cwd
for sid, session in self._state.sessions.items():
if session.cwd == cwd:
logger.info(
f"[ACP] Reusing existing session: " f"acp_session={sid} cwd={cwd}"
)
return sid
# Try to resume from opencode's persisted storage
resumed_id = self._try_resume_existing_session(cwd, timeout)
if resumed_id:
return resumed_id
# Create a new session
return self._create_session(cwd=cwd, timeout=timeout)
def send_message(
self,
message: str,
session_id: str,
timeout: float = ACP_MESSAGE_TIMEOUT,
) -> Generator[ACPEvent, None, None]:
"""Send a message and stream response events.
"""Send a message to a specific session and stream response events.
Args:
message: The message content to send
session_id: The ACP session ID to send the message to
timeout: Maximum time to wait for complete response (defaults to ACP_MESSAGE_TIMEOUT env var)
Yields:
Typed ACP schema event objects
"""
if self._state.current_session is None:
raise RuntimeError("No active session. Call start() first.")
session_id = self._state.current_session.session_id
if session_id not in self._state.sessions:
raise RuntimeError(
f"Unknown session {session_id}. "
f"Known sessions: {list(self._state.sessions.keys())}"
)
packet_logger = get_packet_logger()
self._prompt_count += 1
prompt_num = self._prompt_count
# Log the start of message processing
packet_logger.log_raw(
"ACP-SEND-MESSAGE-START-K8S",
{
"session_id": session_id,
"pod": self._pod_name,
"namespace": self._namespace,
"message_preview": (
message[:200] + "..." if len(message) > 200 else message
),
"timeout": timeout,
},
logger.info(
f"[ACP] Prompt #{prompt_num} start: "
f"acp_session={session_id} pod={self._pod_name}"
)
# Drain leftover messages from the queue (e.g., session_info_update
# that arrived between prompts).
drained_count = 0
while not self._response_queue.empty():
try:
self._response_queue.get_nowait()
drained_count += 1
except Empty:
break
if drained_count > 0:
logger.debug(f"[ACP] Drained {drained_count} stale messages")
prompt_content = [{"type": "text", "text": message}]
params = {
"sessionId": session_id,
@@ -446,44 +709,109 @@ class ACPExecClient:
request_id = self._send_request("session/prompt", params)
start_time = time.time()
last_event_time = time.time() # Track time since last event for keepalive
last_event_time = time.time()
events_yielded = 0
messages_processed = 0
keepalive_count = 0
completion_reason = "unknown"
while True:
remaining = timeout - (time.time() - start_time)
if remaining <= 0:
packet_logger.log_raw(
"ACP-TIMEOUT-K8S",
{
"session_id": session_id,
"elapsed_ms": (time.time() - start_time) * 1000,
},
completion_reason = "timeout"
logger.warning(
f"[ACP] Prompt #{prompt_num} timeout: "
f"acp_session={session_id} events={events_yielded}"
)
yield Error(code=-1, message="Timeout waiting for response")
break
try:
message_data = self._response_queue.get(timeout=min(remaining, 1.0))
last_event_time = time.time() # Reset keepalive timer on event
last_event_time = time.time()
messages_processed += 1
# Log every dequeued message for prompt #2+ to diagnose
# why the response isn't being matched.
if prompt_num >= 2:
msg_id = message_data.get("id")
logger.info(
f"[ACP] Prompt #{prompt_num} dequeued: "
f"id={msg_id} type(id)={type(msg_id).__name__} "
f"method={message_data.get('method')} "
f"keys={list(message_data.keys())} "
f"request_id={request_id}"
)
except Empty:
# Check if we need to send an SSE keepalive
# Check if reader thread is still alive
if (
self._reader_thread is not None
and not self._reader_thread.is_alive()
):
completion_reason = "reader_thread_dead"
# Drain any final messages the reader flushed before dying
while not self._response_queue.empty():
try:
final_msg = self._response_queue.get_nowait()
if final_msg.get("id") == request_id:
if "error" in final_msg:
error_data = final_msg["error"]
yield Error(
code=error_data.get("code", -1),
message=error_data.get(
"message", "Unknown error"
),
)
else:
result = final_msg.get("result", {})
try:
yield PromptResponse.model_validate(result)
except ValidationError:
pass
break
except Empty:
break
logger.warning(
f"[ACP] Reader thread dead: prompt #{prompt_num} "
f"acp_session={session_id} events={events_yielded}"
)
break
# Send SSE keepalive if idle
idle_time = time.time() - last_event_time
if idle_time >= SSE_KEEPALIVE_INTERVAL:
packet_logger.log_raw(
"SSE-KEEPALIVE-YIELD",
{
"session_id": session_id,
"idle_seconds": idle_time,
},
)
keepalive_count += 1
if keepalive_count % 3 == 0:
reader_alive = (
self._reader_thread is not None
and self._reader_thread.is_alive()
)
elapsed_s = time.time() - start_time
logger.info(
f"[ACP] Prompt #{prompt_num} waiting: "
f"keepalives={keepalive_count} "
f"elapsed={elapsed_s:.0f}s "
f"events={events_yielded} "
f"reader_alive={reader_alive} "
f"queue_size={self._response_queue.qsize()}"
)
yield SSEKeepalive()
last_event_time = time.time() # Reset after yielding keepalive
last_event_time = time.time()
continue
# Check for response to our prompt request
if message_data.get("id") == request_id:
# Check for JSON-RPC response to our prompt request.
msg_id = message_data.get("id")
is_response = "method" not in message_data and (
msg_id == request_id
or (msg_id is not None and str(msg_id) == str(request_id))
)
if is_response:
completion_reason = "jsonrpc_response"
if "error" in message_data:
error_data = message_data["error"]
completion_reason = "jsonrpc_error"
logger.warning(f"[ACP] Prompt #{prompt_num} error: {error_data}")
packet_logger.log_jsonrpc_response(
request_id, error=error_data, context="k8s"
)
@@ -498,26 +826,16 @@ class ACPExecClient:
)
try:
prompt_response = PromptResponse.model_validate(result)
packet_logger.log_acp_event_yielded(
"prompt_response", prompt_response
)
events_yielded += 1
yield prompt_response
except ValidationError as e:
packet_logger.log_raw(
"ACP-VALIDATION-ERROR-K8S",
{"type": "prompt_response", "error": str(e)},
)
logger.error(f"[ACP] PromptResponse validation failed: {e}")
# Log completion summary
elapsed_ms = (time.time() - start_time) * 1000
packet_logger.log_raw(
"ACP-SEND-MESSAGE-COMPLETE-K8S",
{
"session_id": session_id,
"events_yielded": events_yielded,
"elapsed_ms": elapsed_ms,
},
logger.info(
f"[ACP] Prompt #{prompt_num} complete: "
f"reason={completion_reason} acp_session={session_id} "
f"events={events_yielded} elapsed={elapsed_ms:.0f}ms"
)
break
@@ -526,25 +844,29 @@ class ACPExecClient:
params_data = message_data.get("params", {})
update = params_data.get("update", {})
# Log the notification
packet_logger.log_jsonrpc_notification(
"session/update",
{"update_type": update.get("sessionUpdate")},
context="k8s",
)
prompt_complete = False
for event in self._process_session_update(update):
events_yielded += 1
# Log each yielded event
event_type = self._get_event_type_name(event)
packet_logger.log_acp_event_yielded(event_type, event)
yield event
if isinstance(event, PromptResponse):
prompt_complete = True
break
if prompt_complete:
completion_reason = "prompt_response_via_notification"
elapsed_ms = (time.time() - start_time) * 1000
logger.info(
f"[ACP] Prompt #{prompt_num} complete: "
f"reason={completion_reason} acp_session={session_id} "
f"events={events_yielded} elapsed={elapsed_ms:.0f}ms"
)
break
# Handle requests from agent - send error response
elif "method" in message_data and "id" in message_data:
packet_logger.log_raw(
"ACP-UNSUPPORTED-REQUEST-K8S",
{"method": message_data["method"], "id": message_data["id"]},
logger.debug(
f"[ACP] Unsupported agent request: "
f"method={message_data['method']}"
)
self._send_error_response(
message_data["id"],
@@ -552,113 +874,50 @@ class ACPExecClient:
f"Method not supported: {message_data['method']}",
)
def _get_event_type_name(self, event: ACPEvent) -> str:
"""Get the type name for an ACP event."""
if isinstance(event, AgentMessageChunk):
return "agent_message_chunk"
elif isinstance(event, AgentThoughtChunk):
return "agent_thought_chunk"
elif isinstance(event, ToolCallStart):
return "tool_call_start"
elif isinstance(event, ToolCallProgress):
return "tool_call_progress"
elif isinstance(event, AgentPlanUpdate):
return "agent_plan_update"
elif isinstance(event, CurrentModeUpdate):
return "current_mode_update"
elif isinstance(event, PromptResponse):
return "prompt_response"
elif isinstance(event, Error):
return "error"
elif isinstance(event, SSEKeepalive):
return "sse_keepalive"
return "unknown"
else:
# Elevate to INFO — if the JSON-RPC response is arriving
# but failing the is_response check, this will reveal it.
logger.info(
f"[ACP] Unhandled message: "
f"id={message_data.get('id')} "
f"type(id)={type(message_data.get('id')).__name__} "
f"method={message_data.get('method')} "
f"keys={list(message_data.keys())} "
f"request_id={request_id} "
f"has_result={'result' in message_data} "
f"has_error={'error' in message_data}"
)
def _process_session_update(
self, update: dict[str, Any]
) -> Generator[ACPEvent, None, None]:
"""Process a session/update notification and yield typed ACP schema objects."""
update_type = update.get("sessionUpdate")
packet_logger = get_packet_logger()
if update_type == "agent_message_chunk":
# Map update types to their ACP schema classes
type_map: dict[str, type] = {
"agent_message_chunk": AgentMessageChunk,
"agent_thought_chunk": AgentThoughtChunk,
"tool_call": ToolCallStart,
"tool_call_update": ToolCallProgress,
"plan": AgentPlanUpdate,
"current_mode_update": CurrentModeUpdate,
"prompt_response": PromptResponse,
}
model_class = type_map.get(update_type) # type: ignore[arg-type]
if model_class is not None:
try:
yield AgentMessageChunk.model_validate(update)
yield model_class.model_validate(update)
except ValidationError as e:
packet_logger.log_raw(
"ACP-VALIDATION-ERROR-K8S",
{"update_type": update_type, "error": str(e), "update": update},
)
elif update_type == "agent_thought_chunk":
try:
yield AgentThoughtChunk.model_validate(update)
except ValidationError as e:
packet_logger.log_raw(
"ACP-VALIDATION-ERROR-K8S",
{"update_type": update_type, "error": str(e), "update": update},
)
elif update_type == "user_message_chunk":
# Echo of user message - skip but log
packet_logger.log_raw(
"ACP-SKIPPED-UPDATE-K8S", {"type": "user_message_chunk"}
)
elif update_type == "tool_call":
try:
yield ToolCallStart.model_validate(update)
except ValidationError as e:
packet_logger.log_raw(
"ACP-VALIDATION-ERROR-K8S",
{"update_type": update_type, "error": str(e), "update": update},
)
elif update_type == "tool_call_update":
try:
yield ToolCallProgress.model_validate(update)
except ValidationError as e:
packet_logger.log_raw(
"ACP-VALIDATION-ERROR-K8S",
{"update_type": update_type, "error": str(e), "update": update},
)
elif update_type == "plan":
try:
yield AgentPlanUpdate.model_validate(update)
except ValidationError as e:
packet_logger.log_raw(
"ACP-VALIDATION-ERROR-K8S",
{"update_type": update_type, "error": str(e), "update": update},
)
elif update_type == "current_mode_update":
try:
yield CurrentModeUpdate.model_validate(update)
except ValidationError as e:
packet_logger.log_raw(
"ACP-VALIDATION-ERROR-K8S",
{"update_type": update_type, "error": str(e), "update": update},
)
elif update_type == "available_commands_update":
# Skip command updates
packet_logger.log_raw(
"ACP-SKIPPED-UPDATE-K8S", {"type": "available_commands_update"}
)
elif update_type == "session_info_update":
# Skip session info updates
packet_logger.log_raw(
"ACP-SKIPPED-UPDATE-K8S", {"type": "session_info_update"}
)
else:
# Unknown update types are logged
packet_logger.log_raw(
"ACP-UNKNOWN-UPDATE-TYPE-K8S",
{"update_type": update_type, "update": update},
)
logger.warning(f"[ACP] Validation error for {update_type}: {e}")
elif update_type not in (
"user_message_chunk",
"available_commands_update",
"session_info_update",
"usage_update",
):
logger.debug(f"[ACP] Unknown update type: {update_type}")
def _send_error_response(self, request_id: int, code: int, message: str) -> None:
"""Send an error response to an agent request."""
@@ -673,15 +932,24 @@ class ACPExecClient:
self._ws_client.write_stdin(json.dumps(response) + "\n")
def cancel(self) -> None:
"""Cancel the current operation."""
if self._state.current_session is None:
return
def cancel(self, session_id: str | None = None) -> None:
"""Cancel the current operation on a session.
self._send_notification(
"session/cancel",
{"sessionId": self._state.current_session.session_id},
)
Args:
session_id: The ACP session ID to cancel. If None, cancels all sessions.
"""
if session_id:
if session_id in self._state.sessions:
self._send_notification(
"session/cancel",
{"sessionId": session_id},
)
else:
for sid in self._state.sessions:
self._send_notification(
"session/cancel",
{"sessionId": sid},
)
def health_check(self, timeout: float = 5.0) -> bool: # noqa: ARG002
"""Check if we can exec into the pod."""
@@ -708,11 +976,9 @@ class ACPExecClient:
return self._ws_client is not None and self._ws_client.is_open()
@property
def session_id(self) -> str | None:
"""Get the current session ID, if any."""
if self._state.current_session:
return self._state.current_session.session_id
return None
def session_ids(self) -> list[str]:
"""Get all tracked session IDs."""
return list(self._state.sessions.keys())
def __enter__(self) -> "ACPExecClient":
"""Context manager entry."""

View File

@@ -50,6 +50,7 @@ from pathlib import Path
from uuid import UUID
from uuid import uuid4
from acp.schema import PromptResponse
from kubernetes import client # type: ignore
from kubernetes import config
from kubernetes.client.rest import ApiException # type: ignore
@@ -97,6 +98,10 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
# API server pod hostname — used to identify which replica is handling a request.
# In K8s, HOSTNAME is set to the pod name (e.g., "api-server-dpgg7").
_API_SERVER_HOSTNAME = os.environ.get("HOSTNAME", "unknown")
# Constants for pod configuration
# Note: Next.js ports are dynamically allocated from SANDBOX_NEXTJS_PORT_START to
# SANDBOX_NEXTJS_PORT_END range, with one port per session.
@@ -348,6 +353,14 @@ class KubernetesSandboxManager(SandboxManager):
self._service_account = SANDBOX_SERVICE_ACCOUNT_NAME
self._file_sync_service_account = SANDBOX_FILE_SYNC_SERVICE_ACCOUNT
# One long-lived ACP client per sandbox (Zed-style architecture).
# Multiple craft sessions share one `opencode acp` process per sandbox.
self._acp_clients: dict[UUID, ACPExecClient] = {}
# Maps (sandbox_id, craft_session_id) → ACP session ID.
# Each craft session has its own ACP session on the shared client.
self._acp_session_ids: dict[tuple[UUID, UUID], str] = {}
# Load AGENTS.md template path
build_dir = Path(__file__).parent.parent.parent # /onyx/server/features/build/
self._agent_instructions_template_path = build_dir / "AGENTS.template.md"
@@ -532,7 +545,7 @@ done
],
resources=client.V1ResourceRequirements(
requests={"cpu": "1000m", "memory": "2Gi"},
limits={"cpu": "4000m", "memory": "8Gi"},
limits={"cpu": "2000m", "memory": "10Gi"},
),
# TODO: Re-enable probes when sandbox container runs actual services.
# Note: Next.js ports are now per-session (dynamic), so container-level
@@ -1156,11 +1169,28 @@ done
def terminate(self, sandbox_id: UUID) -> None:
"""Terminate a sandbox and clean up Kubernetes resources.
Deletes the Service and Pod for the sandbox.
Stops the shared ACP client and removes all session mappings for this
sandbox, then deletes the Service and Pod.
Args:
sandbox_id: The sandbox ID to terminate
"""
# Stop the shared ACP client for this sandbox
acp_client = self._acp_clients.pop(sandbox_id, None)
if acp_client:
try:
acp_client.stop()
except Exception as e:
logger.warning(
f"[SANDBOX-ACP] Failed to stop ACP client for "
f"sandbox {sandbox_id}: {e}"
)
# Remove all session mappings for this sandbox
keys_to_remove = [key for key in self._acp_session_ids if key[0] == sandbox_id]
for key in keys_to_remove:
del self._acp_session_ids[key]
# Clean up Kubernetes resources (needs string for pod/service names)
self._cleanup_kubernetes_resources(str(sandbox_id))
@@ -1395,7 +1425,8 @@ echo "Session workspace setup complete"
) -> None:
"""Clean up a session workspace (on session delete).
Executes kubectl exec to remove the session directory.
Removes the ACP session mapping and executes kubectl exec to remove
the session directory. The shared ACP client persists for other sessions.
Args:
sandbox_id: The sandbox ID
@@ -1403,6 +1434,15 @@ echo "Session workspace setup complete"
nextjs_port: Optional port where Next.js server is running (unused in K8s,
we use PID file instead)
"""
# Remove the ACP session mapping (shared client persists)
session_key = (sandbox_id, session_id)
acp_session_id = self._acp_session_ids.pop(session_key, None)
if acp_session_id:
logger.info(
f"[SANDBOX-ACP] Removed ACP session mapping: "
f"session={session_id} acp_session={acp_session_id}"
)
pod_name = self._get_pod_name(str(sandbox_id))
session_path = f"/workspace/sessions/{session_id}"
@@ -1807,6 +1847,94 @@ echo "Session config regeneration complete"
)
return exec_client.health_check(timeout=timeout)
def _get_or_create_acp_client(self, sandbox_id: UUID) -> ACPExecClient:
"""Get the shared ACP client for a sandbox, creating one if needed.
One long-lived `opencode acp` process per sandbox (Zed-style).
If the existing client's WebSocket has died, replaces it with a new one.
Args:
sandbox_id: The sandbox ID
Returns:
A running ACPExecClient for this sandbox
"""
acp_client = self._acp_clients.get(sandbox_id)
if acp_client is not None and acp_client.is_running:
return acp_client
# Client is dead or doesn't exist — clean up stale one
if acp_client is not None:
logger.warning(
f"[SANDBOX-ACP] Stale ACP client for sandbox {sandbox_id}, "
f"replacing"
)
try:
acp_client.stop()
except Exception:
pass
# Clear session mappings — they're invalid on a new process
keys_to_remove = [
key for key in self._acp_session_ids if key[0] == sandbox_id
]
for key in keys_to_remove:
del self._acp_session_ids[key]
pod_name = self._get_pod_name(str(sandbox_id))
new_client = ACPExecClient(
pod_name=pod_name,
namespace=self._namespace,
container="sandbox",
)
new_client.start(cwd="/workspace")
self._acp_clients[sandbox_id] = new_client
logger.info(
f"[SANDBOX-ACP] Created shared ACP client: "
f"sandbox={sandbox_id} pod={pod_name} "
f"api_pod={_API_SERVER_HOSTNAME}"
)
return new_client
def _get_or_create_acp_session(
self,
sandbox_id: UUID,
session_id: UUID,
acp_client: ACPExecClient,
) -> str:
"""Get the ACP session ID for a craft session, creating one if needed.
Uses the session mapping cache first, then falls back to
`get_or_create_session()` which handles resume from opencode's
persisted storage (multi-replica support).
Args:
sandbox_id: The sandbox ID
session_id: The craft session ID
acp_client: The shared ACP client for this sandbox
Returns:
The ACP session ID
"""
session_key = (sandbox_id, session_id)
acp_session_id = self._acp_session_ids.get(session_key)
if acp_session_id and acp_session_id in acp_client.session_ids:
return acp_session_id
# Session not tracked or was lost — get or create it
session_path = f"/workspace/sessions/{session_id}"
acp_session_id = acp_client.get_or_create_session(cwd=session_path)
self._acp_session_ids[session_key] = acp_session_id
logger.info(
f"[SANDBOX-ACP] Session mapped: "
f"craft_session={session_id} acp_session={acp_session_id}"
)
return acp_session_id
def send_message(
self,
sandbox_id: UUID,
@@ -1815,8 +1943,9 @@ echo "Session config regeneration complete"
) -> Generator[ACPEvent, None, None]:
"""Send a message to the CLI agent and stream ACP events.
Runs `opencode acp` via kubectl exec in the sandbox pod.
The agent runs in the session-specific workspace.
Uses a shared ACP client per sandbox (one `opencode acp` process).
Each craft session has its own ACP session ID on that shared process.
Switching between sessions is client-side — just use the right sessionId.
Args:
sandbox_id: The sandbox ID
@@ -1827,37 +1956,46 @@ echo "Session config regeneration complete"
Typed ACP schema event objects
"""
packet_logger = get_packet_logger()
pod_name = self._get_pod_name(str(sandbox_id))
session_path = f"/workspace/sessions/{session_id}"
# Log ACP client creation
packet_logger.log_acp_client_start(
sandbox_id, session_id, session_path, context="k8s"
# Get or create the shared ACP client for this sandbox
acp_client = self._get_or_create_acp_client(sandbox_id)
# Get or create the ACP session for this craft session
acp_session_id = self._get_or_create_acp_session(
sandbox_id, session_id, acp_client
)
exec_client = ACPExecClient(
pod_name=pod_name,
namespace=self._namespace,
container="sandbox",
logger.info(
f"[SANDBOX-ACP] Sending message: "
f"session={session_id} acp_session={acp_session_id} "
f"api_pod={_API_SERVER_HOSTNAME}"
)
# Log the send_message call at sandbox manager level
packet_logger.log_session_start(session_id, sandbox_id, message)
events_count = 0
got_prompt_response = False
try:
exec_client.start(cwd=session_path)
for event in exec_client.send_message(message):
for event in acp_client.send_message(message, session_id=acp_session_id):
events_count += 1
if isinstance(event, PromptResponse):
got_prompt_response = True
yield event
# Log successful completion
logger.info(
f"[SANDBOX-ACP] send_message completed: "
f"session={session_id} events={events_count} "
f"got_prompt_response={got_prompt_response}"
)
packet_logger.log_session_end(
session_id, success=True, events_count=events_count
)
except GeneratorExit:
# Generator was closed by consumer (client disconnect, timeout, broken pipe)
# This is the most common failure mode for SSE streaming
logger.warning(
f"[SANDBOX-ACP] GeneratorExit: session={session_id} "
f"events={events_count}"
)
packet_logger.log_session_end(
session_id,
success=False,
@@ -1866,7 +2004,10 @@ echo "Session config regeneration complete"
)
raise
except Exception as e:
# Log failure from normal exceptions
logger.error(
f"[SANDBOX-ACP] Exception: session={session_id} "
f"events={events_count} error={e}"
)
packet_logger.log_session_end(
session_id,
success=False,
@@ -1875,19 +2016,16 @@ echo "Session config regeneration complete"
)
raise
except BaseException as e:
# Log failure from other base exceptions (SystemExit, KeyboardInterrupt, etc.)
exception_type = type(e).__name__
logger.error(
f"[SANDBOX-ACP] {type(e).__name__}: session={session_id} " f"error={e}"
)
packet_logger.log_session_end(
session_id,
success=False,
error=f"{exception_type}: {str(e) if str(e) else 'System-level interruption'}",
error=f"{type(e).__name__}: {str(e) if str(e) else 'System-level interruption'}",
events_count=events_count,
)
raise
finally:
exec_client.stop()
# Log client stop
packet_logger.log_acp_client_stop(sandbox_id, session_id, context="k8s")
def list_directory(
self, sandbox_id: UUID, session_id: UUID, path: str

View File

@@ -1 +1,10 @@
"""Celery tasks for sandbox management."""
from onyx.server.features.build.sandbox.tasks.tasks import (
cleanup_idle_sandboxes_task,
) # noqa: F401
from onyx.server.features.build.sandbox.tasks.tasks import (
sync_sandbox_files,
) # noqa: F401
__all__ = ["cleanup_idle_sandboxes_task", "sync_sandbox_files"]

View File

@@ -11,6 +11,8 @@ from onyx.context.search.models import SearchSettingsCreationRequest
from onyx.db.engine.sql_engine import get_session
from onyx.db.index_attempt import expire_index_attempts
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import update_default_contextual_model
from onyx.db.llm import update_no_default_contextual_rag_provider
from onyx.db.models import IndexModelStatus
from onyx.db.models import User
from onyx.db.search_settings import delete_search_settings
@@ -118,7 +120,9 @@ def set_new_search_settings(
# # Ensure Vespa has the new index immediately
# get_multipass_config(search_settings)
# get_multipass_config(new_search_settings)
# document_index = get_default_document_index(search_settings, new_search_settings)
# document_index = get_default_document_index(
# search_settings, new_search_settings, db_session
# )
# document_index.ensure_indices_exist(
# primary_embedding_dim=search_settings.final_embedding_dim,
@@ -252,6 +256,13 @@ def update_saved_search_settings(
search_settings=search_settings, db_session=db_session
)
logger.info(
f"Updated current search settings to {search_settings.model_dump_json()}"
)
# Re-sync default to match PRESENT search settings
_sync_default_contextual_model(db_session)
@router.get("/unstructured-api-key-set")
def unstructured_api_key_set(
@@ -309,3 +320,23 @@ def _validate_contextual_rag_model(
return f"Model {model_name} not found in provider {provider_name}"
return None
def _sync_default_contextual_model(db_session: Session) -> None:
"""Syncs the default CONTEXTUAL_RAG flow to match the PRESENT search settings."""
primary = get_current_search_settings(db_session)
try:
update_default_contextual_model(
db_session=db_session,
enable_contextual_rag=primary.enable_contextual_rag,
contextual_rag_llm_provider=primary.contextual_rag_llm_provider,
contextual_rag_llm_name=primary.contextual_rag_llm_name,
)
except ValueError as e:
logger.error(
f"Error syncing default contextual model, defaulting to no contextual model: {e}"
)
update_no_default_contextual_rag_provider(
db_session=db_session,
)

View File

@@ -30,7 +30,6 @@ from onyx.chat.models import ChatFullResponse
from onyx.chat.models import CreateChatSessionID
from onyx.chat.process_message import gather_stream_full
from onyx.chat.process_message import handle_stream_message_objects
from onyx.chat.process_message import stream_chat_message_objects
from onyx.chat.prompt_utils import get_default_base_system_prompt
from onyx.chat.stop_signal_checker import set_fence
from onyx.configs.app_configs import WEB_DOMAIN
@@ -40,8 +39,6 @@ from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import PUBLIC_API_TAGS
from onyx.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
from onyx.db.chat import add_chats_to_session_from_slack_thread
from onyx.db.chat import create_chat_session
from onyx.db.chat import create_new_chat_message
from onyx.db.chat import delete_all_chat_sessions_for_user
from onyx.db.chat import delete_chat_session
from onyx.db.chat import duplicate_chat_session_for_user_from_slack
@@ -49,7 +46,6 @@ from onyx.db.chat import get_chat_message
from onyx.db.chat import get_chat_messages_by_session
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_chat_sessions_by_user
from onyx.db.chat import get_or_create_root_message
from onyx.db.chat import set_as_latest_chat_message
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import update_chat_session
@@ -71,7 +67,6 @@ from onyx.llm.constants import LlmProviderNames
from onyx.llm.factory import get_default_llm
from onyx.llm.factory import get_llm_for_persona
from onyx.llm.factory import get_llm_token_counter
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.redis.redis_pool import get_redis_client
from onyx.secondary_llm_flows.chat_session_naming import generate_chat_session_name
from onyx.server.api_key_usage import check_api_key_usage
@@ -86,10 +81,7 @@ from onyx.server.query_and_chat.models import ChatSessionGroup
from onyx.server.query_and_chat.models import ChatSessionsResponse
from onyx.server.query_and_chat.models import ChatSessionSummary
from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.server.query_and_chat.models import LLMOverride
from onyx.server.query_and_chat.models import MessageOrigin
from onyx.server.query_and_chat.models import PromptOverride
from onyx.server.query_and_chat.models import RenameChatSessionResponse
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
@@ -503,71 +495,8 @@ def delete_chat_session_by_id(
raise HTTPException(status_code=400, detail=str(e))
# WARNING: this endpoint is deprecated and will be removed soon. Use the new send-chat-message endpoint instead.
@router.post("/send-message")
def handle_new_chat_message(
chat_message_req: CreateChatMessageRequest,
request: Request,
user: User = Depends(current_chat_accessible_user),
_rate_limit_check: None = Depends(check_token_rate_limits),
_api_key_usage_check: None = Depends(check_api_key_usage),
) -> StreamingResponse:
"""
This endpoint is both used for all the following purposes:
- Sending a new message in the session
- Regenerating a message in the session (just send the same one again)
- Editing a message (similar to regenerating but sending a different message)
- Kicking off a seeded chat session (set `use_existing_user_message`)
Assumes that previous messages have been set as the latest to minimize overhead.
Args:
chat_message_req (CreateChatMessageRequest): Details about the new chat message.
request (Request): The current HTTP request context.
user (User): The current user, obtained via dependency injection.
_ (None): Rate limit check is run if user/group/global rate limits are enabled.
Returns:
StreamingResponse: Streams the response to the new chat message.
"""
tenant_id = get_current_tenant_id()
logger.debug(f"Received new chat message: {chat_message_req.message}")
if not chat_message_req.message and not chat_message_req.use_existing_user_message:
raise HTTPException(status_code=400, detail="Empty chat message is invalid")
mt_cloud_telemetry(
tenant_id=tenant_id,
distinct_id=tenant_id if user.is_anonymous else user.email,
event=MilestoneRecordType.RAN_QUERY,
)
def stream_generator() -> Generator[str, None, None]:
try:
with get_session_with_current_tenant() as db_session:
for obj in stream_chat_message_objects(
new_msg_req=chat_message_req,
user=user,
db_session=db_session,
litellm_additional_headers=extract_headers(
request.headers, LITELLM_PASS_THROUGH_HEADERS
),
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
request.headers
),
):
yield get_json_line(obj.model_dump())
except Exception as e:
logger.exception("Error in chat message streaming")
yield json.dumps({"error": str(e)})
finally:
logger.debug("Stream generator finished")
return StreamingResponse(stream_generator(), media_type="text/event-stream")
# NOTE: This endpoint is extremely central to the application, any changes to it should be reviewed and approved by an experienced
# team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
@router.post(
"/send-chat-message",
response_model=ChatFullResponse,
@@ -815,77 +744,6 @@ def get_available_context_tokens_for_session(
"""Endpoints for chat seeding"""
class ChatSeedRequest(BaseModel):
# standard chat session stuff
persona_id: int
# overrides / seeding
llm_override: LLMOverride | None = None
prompt_override: PromptOverride | None = None
description: str | None = None
message: str | None = None
# TODO: support this
# initial_message_retrieval_options: RetrievalDetails | None = None
class ChatSeedResponse(BaseModel):
redirect_url: str
@router.post("/seed-chat-session", tags=PUBLIC_API_TAGS)
def seed_chat(
chat_seed_request: ChatSeedRequest,
# NOTE: This endpoint is designed for programmatic access (API keys, external services)
# rather than authenticated user sessions. The user parameter is used for access control
# but the created chat session is "unassigned" (user_id=None) until a user visits the web UI.
# This allows external systems to pre-seed chat sessions that users can then access.
user: User = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> ChatSeedResponse:
try:
new_chat_session = create_chat_session(
db_session=db_session,
description=chat_seed_request.description or "",
user_id=None, # this chat session is "unassigned" until a user visits the web UI
persona_id=chat_seed_request.persona_id,
llm_override=chat_seed_request.llm_override,
prompt_override=chat_seed_request.prompt_override,
)
except Exception as e:
logger.exception(e)
raise HTTPException(status_code=400, detail="Invalid Persona provided.")
if chat_seed_request.message is not None:
root_message = get_or_create_root_message(
chat_session_id=new_chat_session.id, db_session=db_session
)
llm = get_llm_for_persona(
persona=new_chat_session.persona,
user=user,
)
tokenizer = get_tokenizer(
model_name=llm.config.model_name,
provider_type=llm.config.model_provider,
)
token_count = len(tokenizer.encode(chat_seed_request.message))
create_new_chat_message(
chat_session_id=new_chat_session.id,
parent_message=root_message,
message=chat_seed_request.message,
token_count=token_count,
message_type=MessageType.USER,
db_session=db_session,
)
return ChatSeedResponse(
redirect_url=f"{WEB_DOMAIN}/chat?chatId={new_chat_session.id}&seeded=true"
)
class SeedChatFromSlackRequest(BaseModel):
chat_session_id: UUID

View File

@@ -1,18 +1,15 @@
from datetime import datetime
from enum import Enum
from typing import Any
from typing import TYPE_CHECKING
from uuid import UUID
from pydantic import BaseModel
from pydantic import model_validator
from onyx.chat.models import PersonaOverrideConfig
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MessageType
from onyx.configs.constants import SessionType
from onyx.context.search.models import BaseFilters
from onyx.context.search.models import ChunkContext
from onyx.context.search.models import SavedSearchDoc
from onyx.context.search.models import SearchDoc
from onyx.context.search.models import Tag
@@ -20,7 +17,6 @@ from onyx.db.enums import ChatSessionSharedStatus
from onyx.db.models import ChatSession
from onyx.file_store.models import FileDescriptor
from onyx.llm.override_models import LLMOverride
from onyx.llm.override_models import PromptOverride
from onyx.server.query_and_chat.streaming_models import Packet
@@ -40,8 +36,9 @@ class MessageOrigin(str, Enum):
UNSET = "unset"
if TYPE_CHECKING:
pass
class MessageResponseIDInfo(BaseModel):
user_message_id: int | None
reserved_assistant_message_id: int
class SourceTag(Tag):
@@ -83,6 +80,8 @@ class ChatFeedbackRequest(BaseModel):
return self
# NOTE: This model is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
# experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
class SendMessageRequest(BaseModel):
message: str
@@ -141,115 +140,6 @@ class SendMessageRequest(BaseModel):
return self
class OptionalSearchSetting(str, Enum):
ALWAYS = "always"
NEVER = "never"
# Determine whether to run search based on history and latest query
AUTO = "auto"
class RetrievalDetails(ChunkContext):
# Use LLM to determine whether to do a retrieval or only rely on existing history
# If the Persona is configured to not run search (0 chunks), this is bypassed
# If no Prompt is configured, the only search results are shown, this is bypassed
run_search: OptionalSearchSetting = OptionalSearchSetting.AUTO
# Is this a real-time/streaming call or a question where Onyx can take more time?
# Used to determine reranking flow
real_time: bool = True
# The following have defaults in the Persona settings which can be overridden via
# the query, if None, then use Persona settings
filters: BaseFilters | None = None
enable_auto_detect_filters: bool | None = None
# if None, no offset / limit
offset: int | None = None
limit: int | None = None
# If this is set, only the highest matching chunk (or merged chunks) is returned
dedupe_docs: bool = False
class CreateChatMessageRequest(ChunkContext):
"""Before creating messages, be sure to create a chat_session and get an id"""
chat_session_id: UUID
# This is the primary-key (unique identifier) for the previous message of the tree
parent_message_id: int | None
# New message contents
message: str
# Files that we should attach to this message
file_descriptors: list[FileDescriptor] = []
# Prompts are embedded in personas, so no separate prompt_id needed
# If search_doc_ids provided, it should use those docs explicitly
search_doc_ids: list[int] | None
retrieval_options: RetrievalDetails | None
# allows the caller to specify the exact search query they want to use
# will disable Query Rewording if specified
query_override: str | None = None
# enables additional handling to ensure that we regenerate with a given user message ID
regenerate: bool | None = None
# allows the caller to override the Persona / Prompt
# these do not persist in the chat thread details
llm_override: LLMOverride | None = None
# Test-only override for deterministic LiteLLM mock responses.
mock_llm_response: str | None = None
prompt_override: PromptOverride | None = None
# Allows the caller to override the temperature for the chat session
# this does persist in the chat thread details
temperature_override: float | None = None
# allow user to specify an alternate assistant
alternate_assistant_id: int | None = None
# This takes the priority over the prompt_override
# This won't be a type that's passed in directly from the API
persona_override_config: PersonaOverrideConfig | None = None
# used for seeded chats to kick off the generation of an AI answer
use_existing_user_message: bool = False
# used for "OpenAI Assistants API"
existing_assistant_message_id: int | None = None
# forces the LLM to return a structured response, see
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None
skip_gen_ai_answer_generation: bool = False
# List of allowed tool IDs to restrict tool usage. If not provided, all tools available to the persona will be used.
allowed_tool_ids: list[int] | None = None
# List of tool IDs we MUST use.
# TODO: make this a single one since unclear how to force this for multiple at a time.
forced_tool_ids: list[int] | None = None
deep_research: bool = False
# When True (default), enables citation generation with markers and CitationInfo packets
# When False, disables citations: removes markers like [1], [2] and skips CitationInfo packets
include_citations: bool = True
# Origin of the message for telemetry tracking
origin: MessageOrigin = MessageOrigin.UNKNOWN
@model_validator(mode="after")
def check_search_doc_ids_or_retrieval_options(self) -> "CreateChatMessageRequest":
if self.search_doc_ids is None and self.retrieval_options is None:
raise ValueError(
"Either search_doc_ids or retrieval_options must be provided, but not both or neither."
)
return self
def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
data = super().model_dump(*args, **kwargs)
data["chat_session_id"] = str(data["chat_session_id"])
return data
class ChatMessageIdentifier(BaseModel):
message_id: int
@@ -365,13 +255,3 @@ class ChatSearchResponse(BaseModel):
groups: list[ChatSessionGroup]
has_more: bool
next_page: int | None = None
class ChatSearchRequest(BaseModel):
query: str | None = None
page: int = 1
page_size: int = 10
class CreateChatResponse(BaseModel):
chat_session_id: str

View File

@@ -343,7 +343,13 @@ def run_tool_calls(
raise ValueError("No user message found in message history")
search_memory_context = (
user_memory_context if inject_memories_in_prompt else None
user_memory_context
if inject_memories_in_prompt
else (
user_memory_context.without_memories()
if user_memory_context
else None
)
)
override_kwargs = SearchToolOverrideKwargs(
starting_citation_num=starting_citation_num,

View File

@@ -17,11 +17,12 @@ disallow_untyped_defs = true
warn_unused_ignores = true
enable_error_code = ["possibly-undefined"]
strict_equality = true
# Patterns match paths whether mypy is run from backend/ (CI) or repo root (e.g. VS Code extension with target ./backend)
exclude = [
"^generated/.*",
"^\\.venv/",
"^onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/",
"^onyx/server/features/build/sandbox/kubernetes/docker/templates/venv/",
"(?:^|/)generated/",
"(?:^|/)\\.venv/",
"(?:^|/)onyx/server/features/build/sandbox/kubernetes/docker/skills/",
"(?:^|/)onyx/server/features/build/sandbox/kubernetes/docker/templates/",
]
[[tool.mypy.overrides]]

View File

@@ -23,7 +23,7 @@ def create_new_chat_session(onyx_url: str, api_key: str | None) -> int:
def process_question(onyx_url: str, question: str, api_key: str | None) -> None:
message_endpoint = onyx_url + "/api/chat/send-message"
message_endpoint = onyx_url + "/api/chat/send-chat-message"
chat_session_id = create_new_chat_session(onyx_url, api_key)

View File

@@ -88,7 +88,7 @@ class ChatLoadTester:
token_count = 0
async with session.post(
f"{self.base_url}/chat/send-message",
f"{self.base_url}/chat/send-chat-message",
headers=self.headers,
json={
"chat_session_id": chat_session_id,

View File

@@ -259,6 +259,145 @@ def test_airtable_connector_basic(
compare_documents(doc_batch, expected_docs)
def test_airtable_connector_url(
mock_get_unstructured_api_key: MagicMock, # noqa: ARG001
airtable_config: AirtableConfig,
) -> None:
"""Test that passing an Airtable URL produces the same results as base_id + table_id."""
if not airtable_config.table_identifier.startswith("tbl"):
pytest.skip("URL test requires table ID, not table name")
url = f"https://airtable.com/{airtable_config.base_id}/{airtable_config.table_identifier}/{BASE_VIEW_ID}"
connector = AirtableConnector(
airtable_url=url,
treat_all_non_attachment_fields_as_metadata=False,
)
connector.load_credentials({"airtable_access_token": airtable_config.access_token})
doc_batch_generator = connector.load_from_state()
doc_batch = [
doc for doc in next(doc_batch_generator) if not isinstance(doc, HierarchyNode)
]
with pytest.raises(StopIteration):
next(doc_batch_generator)
assert len(doc_batch) == 2
expected_docs = [
create_test_document(
id="rec8BnxDLyWeegOuO",
title="Slow Internet",
description="The internet connection is very slow.",
priority="Medium",
status="In Progress",
ticket_id="2",
created_time="2024-12-24T21:02:49.000Z",
status_last_changed="2024-12-24T21:02:49.000Z",
days_since_status_change=0,
assignee="Chris Weaver (chris@onyx.app)",
submitted_by="Chris Weaver (chris@onyx.app)",
all_fields_as_metadata=False,
view_id=BASE_VIEW_ID,
),
create_test_document(
id="reccSlIA4pZEFxPBg",
title="Printer Issue",
description="The office printer is not working.",
priority="High",
status="Open",
ticket_id="1",
created_time="2024-12-24T21:02:49.000Z",
status_last_changed="2024-12-24T21:02:49.000Z",
days_since_status_change=0,
assignee="Chris Weaver (chris@onyx.app)",
submitted_by="Chris Weaver (chris@onyx.app)",
attachments=[
(
"Test.pdf:\ntesting!!!",
f"https://airtable.com/{airtable_config.base_id}/{airtable_config.table_identifier}/{BASE_VIEW_ID}/reccSlIA4pZEFxPBg/fld1u21zkJACIvAEF/attlj2UBWNEDZngCc?blocks=hide",
)
],
all_fields_as_metadata=False,
view_id=BASE_VIEW_ID,
),
]
compare_documents(doc_batch, expected_docs)
def test_airtable_connector_index_all(
mock_get_unstructured_api_key: MagicMock, # noqa: ARG001
airtable_config: AirtableConfig,
) -> None:
"""Test index_all mode discovers all bases/tables and returns documents.
The test token has access to one base ("Onyx") with three tables:
- Tickets: 3 records, 2 with content (1 empty record is skipped)
- Support Categories: 4 records, all with Category Name field
- Table 3: 3 records, 1 with content (2 empty records are skipped)
Total expected: 7 documents
"""
connector = AirtableConnector()
connector.load_credentials({"airtable_access_token": airtable_config.access_token})
all_docs: list[Document] = []
for batch in connector.load_from_state():
for item in batch:
if isinstance(item, Document):
all_docs.append(item)
# 2 from Tickets + 4 from Support Categories + 1 from Table 3 = 7
assert len(all_docs) == 7
docs_by_id = {d.id: d for d in all_docs}
# Verify all expected document IDs are present
expected_ids = {
# Tickets
"airtable__rec8BnxDLyWeegOuO",
"airtable__reccSlIA4pZEFxPBg",
# Support Categories
"airtable__rec5SgUDcHXcBc8kS",
"airtable__recD3DQHc0BQkDaqX",
"airtable__recPHdnWu1Q9ZxyTg",
"airtable__recWbIElUDz9HjgMd",
# Table 3
"airtable__recNalBz02QU1LhbM",
}
assert docs_by_id.keys() == expected_ids
# In index_all mode, semantic identifiers include "Base Name > Table Name: Primary Field"
assert (
docs_by_id["airtable__rec8BnxDLyWeegOuO"].semantic_identifier
== "Onyx > Tickets: Slow Internet"
)
assert (
docs_by_id["airtable__rec5SgUDcHXcBc8kS"].semantic_identifier
== "Onyx > Support Categories: Software Development"
)
assert (
docs_by_id["airtable__recNalBz02QU1LhbM"].semantic_identifier
== "Onyx > Table 3: A"
)
# Verify hierarchy metadata on a Tickets doc
tickets_doc = docs_by_id["airtable__rec8BnxDLyWeegOuO"]
assert tickets_doc.doc_metadata is not None
hierarchy = tickets_doc.doc_metadata["hierarchy"]
assert hierarchy["source_path"] == ["Onyx", "Tickets"]
assert hierarchy["base_id"] == airtable_config.base_id
assert hierarchy["base_name"] == "Onyx"
assert hierarchy["table_name"] == "Tickets"
# Verify hierarchy on a Support Categories doc
cat_doc = docs_by_id["airtable__rec5SgUDcHXcBc8kS"]
assert cat_doc.doc_metadata is not None
assert cat_doc.doc_metadata["hierarchy"]["source_path"] == [
"Onyx",
"Support Categories",
]
def test_airtable_connector_all_metadata(
mock_get_unstructured_api_key: MagicMock, # noqa: ARG001
airtable_config: AirtableConfig,

View File

@@ -4,8 +4,8 @@ from typing import cast
from onyx.chat.models import AnswerStreamPart
from onyx.chat.models import CreateChatSessionID
from onyx.chat.models import MessageResponseIDInfo
from onyx.context.search.models import SearchDoc
from onyx.server.query_and_chat.models import MessageResponseIDInfo
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
from onyx.server.query_and_chat.streaming_models import ImageGenerationFinal
from onyx.server.query_and_chat.streaming_models import OpenUrlDocuments

View File

@@ -6,9 +6,8 @@ from uuid import uuid4
from sqlalchemy.orm import Session
from onyx.chat.models import AnswerStreamPart
from onyx.chat.models import MessageResponseIDInfo
from onyx.chat.models import StreamingError
from onyx.chat.process_message import stream_chat_message_objects
from onyx.chat.process_message import handle_stream_message_objects
from onyx.db.chat import create_chat_session
from onyx.db.enums import LLMModelFlowType
from onyx.db.llm import fetch_existing_llm_providers
@@ -18,8 +17,8 @@ from onyx.db.llm import upsert_llm_provider
from onyx.llm.constants import LlmProviderNames
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.server.query_and_chat.models import RetrievalDetails
from onyx.server.query_and_chat.models import MessageResponseIDInfo
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
from onyx.server.query_and_chat.streaming_models import Packet
@@ -70,17 +69,13 @@ def test_answer_with_only_anthropic_provider(
persona_id=0,
)
chat_request = CreateChatMessageRequest(
chat_session_id=chat_session.id,
parent_message_id=None,
chat_request = SendMessageRequest(
message="hello",
file_descriptors=[],
search_doc_ids=None,
retrieval_options=RetrievalDetails(),
chat_session_id=chat_session.id,
)
response_stream: list[AnswerStreamPart] = []
for packet in stream_chat_message_objects(
for packet in handle_stream_message_objects(
new_msg_req=chat_request,
user=test_user,
db_session=db_session,

View File

@@ -4,14 +4,13 @@ from datetime import datetime
from sqlalchemy.orm import Session
from onyx.chat.models import AnswerStreamPart
from onyx.chat.models import MessageResponseIDInfo
from onyx.chat.models import StreamingError
from onyx.chat.process_message import stream_chat_message_objects
from onyx.chat.process_message import handle_stream_message_objects
from onyx.db.chat import create_chat_session
from onyx.db.models import User
from onyx.db.persona import get_persona_by_id
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.server.query_and_chat.models import RetrievalDetails
from onyx.server.query_and_chat.models import MessageResponseIDInfo
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
from tests.external_dependency_unit.answer.conftest import ensure_default_llm_provider
from tests.external_dependency_unit.conftest import create_test_user
@@ -42,18 +41,12 @@ def test_stream_chat_current_date_response(
persona_id=default_persona.id,
)
chat_request = CreateChatMessageRequest(
chat_session_id=chat_session.id,
parent_message_id=None,
chat_request = SendMessageRequest(
message="Please respond only with the current date in the format 'Weekday Month DD, YYYY'.",
file_descriptors=[],
prompt_override=None,
search_doc_ids=None,
retrieval_options=RetrievalDetails(),
query_override=None,
chat_session_id=chat_session.id,
)
gen = stream_chat_message_objects(
gen = handle_stream_message_objects(
new_msg_req=chat_request,
user=test_user,
db_session=db_session,

View File

@@ -7,8 +7,8 @@ import pytest
from sqlalchemy.orm import Session
from onyx.chat.models import CreateChatSessionID
from onyx.chat.models import MessageResponseIDInfo
from onyx.configs.constants import DocumentSource
from onyx.server.query_and_chat.models import MessageResponseIDInfo
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
from onyx.server.query_and_chat.streaming_models import GeneratedImage
from onyx.server.query_and_chat.streaming_models import ImageGenerationFinal

View File

@@ -6,15 +6,14 @@ import pytest
from sqlalchemy.orm import Session
from onyx.chat.models import AnswerStreamPart
from onyx.chat.models import MessageResponseIDInfo
from onyx.chat.models import StreamingError
from onyx.chat.process_message import stream_chat_message_objects
from onyx.chat.process_message import handle_stream_message_objects
from onyx.db.chat import create_chat_session
from onyx.db.models import RecencyBiasSetting
from onyx.db.models import User
from onyx.db.persona import upsert_persona
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.server.query_and_chat.models import RetrievalDetails
from onyx.server.query_and_chat.models import MessageResponseIDInfo
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
from onyx.server.query_and_chat.streaming_models import Packet
from tests.external_dependency_unit.answer.conftest import ensure_default_llm_provider
@@ -100,18 +99,12 @@ def test_stream_chat_message_objects_without_web_search(
persona_id=test_persona.id,
)
# Create the chat message request with a query that attempts to force web search
chat_request = CreateChatMessageRequest(
chat_session_id=chat_session.id,
parent_message_id=None,
chat_request = SendMessageRequest(
message="run a web search for 'Onyx'",
file_descriptors=[],
prompt_override=None,
search_doc_ids=None,
retrieval_options=RetrievalDetails(),
query_override=None,
chat_session_id=chat_session.id,
)
# Call stream_chat_message_objects
response_generator = stream_chat_message_objects(
# Call handle_stream_message_objects
response_generator = handle_stream_message_objects(
new_msg_req=chat_request,
user=test_user,
db_session=db_session,

View File

@@ -5,6 +5,7 @@ These tests verify that:
1. USER_REMINDER messages are wrapped with <system-reminder> tags
2. The wrapped messages are converted to UserMessage type for the LLM
3. The tags are properly applied around the message content
4. CODE_BLOCK_MARKDOWN is prepended to system messages for models that need it
"""
import pytest
@@ -14,7 +15,9 @@ from onyx.chat.models import ChatMessageSimple
from onyx.configs.constants import MessageType
from onyx.llm.interfaces import LLMConfig
from onyx.llm.models import ChatCompletionMessage
from onyx.llm.models import SystemMessage
from onyx.llm.models import UserMessage
from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
from onyx.prompts.constants import SYSTEM_REMINDER_TAG_CLOSE
from onyx.prompts.constants import SYSTEM_REMINDER_TAG_OPEN
@@ -175,3 +178,161 @@ class TestUserReminderMessageType:
assert SYSTEM_REMINDER_TAG_OPEN not in msg.content
assert SYSTEM_REMINDER_TAG_CLOSE not in msg.content
assert msg.content == "This is a normal user message."
def _create_llm_config(model_name: str) -> LLMConfig:
"""Create a LLMConfig with the specified model name."""
return LLMConfig(
model_provider="openai",
model_name=model_name,
temperature=0.7,
api_key="test-key",
api_base=None,
api_version=None,
max_input_tokens=128000,
)
class TestCodeBlockMarkdownFormatting:
"""Tests for CODE_BLOCK_MARKDOWN prefix handling in translate_history_to_llm_format.
OpenAI reasoning models (o1, o3, gpt-5) need a "Formatting re-enabled. " prefix
in their system messages for correct markdown generation.
"""
def test_o1_model_prepends_markdown_to_string(self) -> None:
"""Test that o1 model prepends CODE_BLOCK_MARKDOWN to string system message."""
llm_config = _create_llm_config("o1")
history = [
ChatMessageSimple(
message="You are a helpful assistant.",
token_count=10,
message_type=MessageType.SYSTEM,
)
]
raw_result = translate_history_to_llm_format(history, llm_config)
result = _ensure_list(raw_result)
assert len(result) == 1
msg = result[0]
assert isinstance(msg, SystemMessage)
assert isinstance(msg.content, str)
assert msg.content == CODE_BLOCK_MARKDOWN + "You are a helpful assistant."
def test_o3_model_prepends_markdown(self) -> None:
"""Test that o3 model prepends CODE_BLOCK_MARKDOWN to system message."""
llm_config = _create_llm_config("o3-mini")
history = [
ChatMessageSimple(
message="System prompt here.",
token_count=10,
message_type=MessageType.SYSTEM,
)
]
raw_result = translate_history_to_llm_format(history, llm_config)
result = _ensure_list(raw_result)
assert len(result) == 1
msg = result[0]
assert isinstance(msg, SystemMessage)
assert isinstance(msg.content, str)
assert msg.content.startswith(CODE_BLOCK_MARKDOWN)
def test_gpt5_model_prepends_markdown(self) -> None:
"""Test that gpt-5 model prepends CODE_BLOCK_MARKDOWN to system message."""
llm_config = _create_llm_config("gpt-5")
history = [
ChatMessageSimple(
message="System prompt here.",
token_count=10,
message_type=MessageType.SYSTEM,
)
]
raw_result = translate_history_to_llm_format(history, llm_config)
result = _ensure_list(raw_result)
assert len(result) == 1
msg = result[0]
assert isinstance(msg, SystemMessage)
assert isinstance(msg.content, str)
assert msg.content.startswith(CODE_BLOCK_MARKDOWN)
def test_gpt4o_does_not_prepend(self) -> None:
"""Test that gpt-4o model does NOT prepend CODE_BLOCK_MARKDOWN."""
llm_config = _create_llm_config("gpt-4o")
history = [
ChatMessageSimple(
message="You are a helpful assistant.",
token_count=10,
message_type=MessageType.SYSTEM,
)
]
raw_result = translate_history_to_llm_format(history, llm_config)
result = _ensure_list(raw_result)
assert len(result) == 1
msg = result[0]
assert isinstance(msg, SystemMessage)
assert isinstance(msg.content, str)
# Should NOT have the prefix
assert msg.content == "You are a helpful assistant."
assert not msg.content.startswith(CODE_BLOCK_MARKDOWN)
def test_no_system_message_no_crash(self) -> None:
"""Test that history without system message doesn't crash."""
llm_config = _create_llm_config("o1")
history = [
ChatMessageSimple(
message="Hello!",
token_count=5,
message_type=MessageType.USER,
)
]
raw_result = translate_history_to_llm_format(history, llm_config)
result = _ensure_list(raw_result)
assert len(result) == 1
msg = result[0]
assert isinstance(msg, UserMessage)
assert msg.content == "Hello!"
def test_only_first_system_message_modified(self) -> None:
"""Test that only the first system message gets the prefix."""
llm_config = _create_llm_config("o1")
history = [
ChatMessageSimple(
message="First system prompt.",
token_count=10,
message_type=MessageType.SYSTEM,
),
ChatMessageSimple(
message="Hello!",
token_count=5,
message_type=MessageType.USER,
),
ChatMessageSimple(
message="Second system prompt.",
token_count=10,
message_type=MessageType.SYSTEM,
),
]
raw_result = translate_history_to_llm_format(history, llm_config)
result = _ensure_list(raw_result)
assert len(result) == 3
# First system message should have prefix
first_sys = result[0]
assert isinstance(first_sys, SystemMessage)
assert isinstance(first_sys.content, str)
assert first_sys.content.startswith(CODE_BLOCK_MARKDOWN)
# Second system message should NOT have prefix (only first one is modified)
second_sys = result[2]
assert isinstance(second_sys, SystemMessage)
assert isinstance(second_sys.content, str)
assert not second_sys.content.startswith(CODE_BLOCK_MARKDOWN)

View File

@@ -8,7 +8,6 @@ import pytest
from fastapi_users.password import PasswordHelper
from sqlalchemy.orm import Session
from onyx.chat.models import MessageResponseIDInfo
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import remove_llm_provider
from onyx.db.llm import update_default_provider
@@ -21,6 +20,7 @@ from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from onyx.server.query_and_chat.chat_backend import create_new_chat_session
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
from onyx.server.query_and_chat.models import MessageResponseIDInfo
from tests.external_dependency_unit.answer.stream_test_assertions import (
assert_answer_stream_part_correct,
)

View File

@@ -10,11 +10,12 @@ from sqlalchemy.orm import Session
from onyx.context.search.models import SavedSearchSettings
from onyx.context.search.models import SearchSettingsCreationRequest
from onyx.db.enums import EmbeddingPrecision
from onyx.db.llm import fetch_default_contextual_rag_model
from onyx.db.llm import update_default_contextual_model
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import IndexModelStatus
from onyx.db.search_settings import create_search_settings
from onyx.db.search_settings import get_secondary_search_settings
from onyx.db.search_settings import update_search_settings_status
from onyx.db.swap_index import check_and_perform_index_swap
from onyx.indexing.indexing_pipeline import IndexingPipelineResult
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
@@ -131,26 +132,37 @@ def baseline_search_settings(
) -> None:
"""Ensure a baseline PRESENT search settings row exists in the DB,
which is required before set_new_search_settings can be called."""
baseline = _make_saved_search_settings(enable_contextual_rag=False)
create_search_settings(
search_settings=_make_saved_search_settings(enable_contextual_rag=False),
search_settings=baseline,
db_session=db_session,
status=IndexModelStatus.PRESENT,
)
# Sync default contextual model to match PRESENT (clears any leftover state)
update_default_contextual_model(
db_session=db_session,
enable_contextual_rag=baseline.enable_contextual_rag,
contextual_rag_llm_provider=baseline.contextual_rag_llm_provider,
contextual_rag_llm_name=baseline.contextual_rag_llm_name,
)
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
@patch("onyx.db.swap_index.get_all_document_indices")
@patch("onyx.server.manage.search_settings.get_default_document_index")
@patch("onyx.indexing.indexing_pipeline.get_llm_for_contextual_rag")
@patch("onyx.indexing.indexing_pipeline.index_doc_batch_with_handler")
def test_indexing_pipeline_uses_contextual_rag_settings_from_create(
mock_index_handler: MagicMock,
mock_get_llm: MagicMock,
mock_get_doc_index: MagicMock, # noqa: ARG001
mock_get_all_doc_indices: MagicMock,
baseline_search_settings: None, # noqa: ARG001
db_session: Session,
) -> None:
"""After creating search settings via set_new_search_settings with
contextual RAG enabled, run_indexing_pipeline should call
get_llm_for_contextual_rag with the LLM names from those settings."""
"""After creating FUTURE settings and swapping to PRESENT,
fetch_default_contextual_rag_model should match the PRESENT settings
and run_indexing_pipeline should call get_llm_for_contextual_rag."""
_create_llm_provider_and_model(
db_session=db_session,
provider_name=TEST_CONTEXTUAL_RAG_LLM_PROVIDER,
@@ -163,6 +175,20 @@ def test_indexing_pipeline_uses_contextual_rag_settings_from_create(
db_session=db_session,
)
# PRESENT still has contextual RAG disabled, so default should be None
default_model = fetch_default_contextual_rag_model(db_session)
assert default_model is None
# Swap FUTURE → PRESENT (with 0 cc-pairs, REINDEX swaps immediately)
mock_get_all_doc_indices.return_value = []
old_settings = check_and_perform_index_swap(db_session)
assert old_settings is not None, "Swap should have occurred"
# Now PRESENT has contextual RAG enabled, default should match
default_model = fetch_default_contextual_rag_model(db_session)
assert default_model is not None
assert default_model.name == TEST_CONTEXTUAL_RAG_LLM_NAME
_run_indexing_pipeline_with_mocks(mock_get_llm, mock_index_handler, db_session)
mock_get_llm.assert_called_once_with(
@@ -172,16 +198,21 @@ def test_indexing_pipeline_uses_contextual_rag_settings_from_create(
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
@patch("onyx.db.swap_index.get_all_document_indices")
@patch("onyx.server.manage.search_settings.get_default_document_index")
@patch("onyx.indexing.indexing_pipeline.get_llm_for_contextual_rag")
@patch("onyx.indexing.indexing_pipeline.index_doc_batch_with_handler")
def test_indexing_pipeline_uses_updated_contextual_rag_settings(
mock_index_handler: MagicMock,
mock_get_llm: MagicMock,
tenant_context: None, # noqa: ARG001
mock_get_doc_index: MagicMock, # noqa: ARG001
mock_get_all_doc_indices: MagicMock,
baseline_search_settings: None, # noqa: ARG001
db_session: Session,
) -> None:
"""After updating search settings via update_saved_search_settings,
run_indexing_pipeline should use the updated LLM names."""
"""After creating FUTURE settings, swapping to PRESENT, then updating
via update_saved_search_settings, run_indexing_pipeline should use
the updated LLM names."""
_create_llm_provider_and_model(
db_session=db_session,
provider_name=TEST_CONTEXTUAL_RAG_LLM_PROVIDER,
@@ -193,20 +224,28 @@ def test_indexing_pipeline_uses_updated_contextual_rag_settings(
model_name=UPDATED_CONTEXTUAL_RAG_LLM_NAME,
)
# Create baseline PRESENT settings with contextual RAG already enabled
create_search_settings(
search_settings=_make_saved_search_settings(),
# Create FUTURE settings with contextual RAG enabled
set_new_search_settings(
search_settings_new=_make_creation_request(),
_=MagicMock(),
db_session=db_session,
status=IndexModelStatus.PRESENT,
)
# Retire any FUTURE settings left over from other tests so the
# pipeline uses the PRESENT (primary) settings we just created.
secondary = get_secondary_search_settings(db_session)
if secondary:
update_search_settings_status(secondary, IndexModelStatus.PAST, db_session)
# PRESENT still has contextual RAG disabled, so default should be None
default_model = fetch_default_contextual_rag_model(db_session)
assert default_model is None
# Update LLM names via the endpoint function
# Swap FUTURE → PRESENT (with 0 cc-pairs, REINDEX swaps immediately)
mock_get_all_doc_indices.return_value = []
old_settings = check_and_perform_index_swap(db_session)
assert old_settings is not None, "Swap should have occurred"
# Now PRESENT has contextual RAG enabled, default should match
default_model = fetch_default_contextual_rag_model(db_session)
assert default_model is not None
assert default_model.name == TEST_CONTEXTUAL_RAG_LLM_NAME
# Update the PRESENT LLM names
update_saved_search_settings(
search_settings=_make_saved_search_settings(
llm_name=UPDATED_CONTEXTUAL_RAG_LLM_NAME,
@@ -216,6 +255,10 @@ def test_indexing_pipeline_uses_updated_contextual_rag_settings(
db_session=db_session,
)
default_model = fetch_default_contextual_rag_model(db_session)
assert default_model is not None
assert default_model.name == UPDATED_CONTEXTUAL_RAG_LLM_NAME
_run_indexing_pipeline_with_mocks(mock_get_llm, mock_index_handler, db_session)
mock_get_llm.assert_called_once_with(
@@ -231,6 +274,7 @@ def test_indexing_pipeline_uses_updated_contextual_rag_settings(
def test_indexing_pipeline_skips_llm_when_contextual_rag_disabled(
mock_index_handler: MagicMock,
mock_get_llm: MagicMock,
mock_get_doc_index: MagicMock, # noqa: ARG001
baseline_search_settings: None, # noqa: ARG001
db_session: Session,
) -> None:
@@ -248,6 +292,10 @@ def test_indexing_pipeline_skips_llm_when_contextual_rag_disabled(
db_session=db_session,
)
# PRESENT has contextual RAG disabled, so default should be None
default_model = fetch_default_contextual_rag_model(db_session)
assert default_model is None
_run_indexing_pipeline_with_mocks(mock_get_llm, mock_index_handler, db_session)
mock_get_llm.assert_not_called()

View File

@@ -29,7 +29,7 @@ def test_create_chat_session_and_send_messages() -> None:
# Send first message
first_message = "Hello, this is a test message."
send_message_response = requests.post(
f"{base_url}/chat/send-message",
f"{base_url}/chat/send-chat-message",
json={
"chat_session_id": chat_session_id,
"message": first_message,
@@ -43,7 +43,7 @@ def test_create_chat_session_and_send_messages() -> None:
# Send second message
second_message = "Can you provide more information?"
send_message_response = requests.post(
f"{base_url}/chat/send-message",
f"{base_url}/chat/send-chat-message",
json={
"chat_session_id": chat_session_id,
"message": second_message,

View File

@@ -12,10 +12,9 @@ from onyx.context.search.models import SavedSearchDoc
from onyx.context.search.models import SearchDoc
from onyx.file_store.models import FileDescriptor
from onyx.llm.override_models import LLMOverride
from onyx.llm.override_models import PromptOverride
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.server.query_and_chat.models import RetrievalDetails
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.streaming_models import StreamingType
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
@@ -104,37 +103,27 @@ class ChatSessionManager:
parent_message_id: int | None = None,
user_performing_action: DATestUser | None = None,
file_descriptors: list[FileDescriptor] | None = None,
search_doc_ids: list[int] | None = None,
retrieval_options: RetrievalDetails | None = None,
query_override: str | None = None,
regenerate: bool | None = None,
llm_override: LLMOverride | None = None,
prompt_override: PromptOverride | None = None,
alternate_assistant_id: int | None = None,
use_existing_user_message: bool = False,
allowed_tool_ids: list[int] | None = None,
forced_tool_ids: list[int] | None = None,
chat_session: DATestChatSession | None = None,
mock_llm_response: str | None = None,
deep_research: bool = False,
llm_override: LLMOverride | None = None,
) -> StreamedResponse:
chat_message_req = CreateChatMessageRequest(
chat_session_id=chat_session_id,
parent_message_id=parent_message_id,
chat_message_req = SendMessageRequest(
message=message,
chat_session_id=chat_session_id,
parent_message_id=(
parent_message_id
if parent_message_id is not None
else AUTO_PLACE_AFTER_LATEST_MESSAGE
),
file_descriptors=file_descriptors or [],
search_doc_ids=search_doc_ids or [],
retrieval_options=retrieval_options,
query_override=query_override,
regenerate=regenerate,
llm_override=llm_override,
mock_llm_response=mock_llm_response,
prompt_override=prompt_override,
alternate_assistant_id=alternate_assistant_id,
use_existing_user_message=use_existing_user_message,
allowed_tool_ids=allowed_tool_ids,
forced_tool_ids=forced_tool_ids,
forced_tool_id=forced_tool_ids[0] if forced_tool_ids else None,
mock_llm_response=mock_llm_response,
deep_research=deep_research,
llm_override=llm_override,
)
headers = (
@@ -145,8 +134,8 @@ class ChatSessionManager:
cookies = user_performing_action.cookies if user_performing_action else None
response = requests.post(
f"{API_SERVER_URL}/chat/send-message",
json=chat_message_req.model_dump(),
f"{API_SERVER_URL}/chat/send-chat-message",
json=chat_message_req.model_dump(mode="json"),
headers=headers,
stream=True,
cookies=cookies,
@@ -182,17 +171,11 @@ class ChatSessionManager:
parent_message_id: int | None = None,
user_performing_action: DATestUser | None = None,
file_descriptors: list[FileDescriptor] | None = None,
search_doc_ids: list[int] | None = None,
query_override: str | None = None,
regenerate: bool | None = None,
llm_override: LLMOverride | None = None,
prompt_override: PromptOverride | None = None,
alternate_assistant_id: int | None = None,
use_existing_user_message: bool = False,
allowed_tool_ids: list[int] | None = None,
forced_tool_ids: list[int] | None = None,
mock_llm_response: str | None = None,
deep_research: bool = False,
llm_override: LLMOverride | None = None,
) -> None:
"""
Send a message and simulate client disconnect before stream completes.
@@ -204,33 +187,25 @@ class ChatSessionManager:
chat_session_id: The chat session ID
message: The message to send
disconnect_after_packets: Disconnect after receiving this many packets.
If None, disconnect_after_type must be specified.
disconnect_after_type: Disconnect after receiving a packet of this type
(e.g., "message_start", "search_tool_start"). If None,
disconnect_after_packets must be specified.
... (other standard message parameters)
Returns:
StreamedResponse containing data received before disconnect,
with is_disconnected=True flag set.
None. Caller can verify server-side cleanup via get_chat_history etc.
"""
chat_message_req = CreateChatMessageRequest(
chat_session_id=chat_session_id,
parent_message_id=parent_message_id,
chat_message_req = SendMessageRequest(
message=message,
chat_session_id=chat_session_id,
parent_message_id=(
parent_message_id
if parent_message_id is not None
else AUTO_PLACE_AFTER_LATEST_MESSAGE
),
file_descriptors=file_descriptors or [],
search_doc_ids=search_doc_ids or [],
retrieval_options=RetrievalDetails(), # This will be deprecated soon anyway
query_override=query_override,
regenerate=regenerate,
llm_override=llm_override,
mock_llm_response=mock_llm_response,
prompt_override=prompt_override,
alternate_assistant_id=alternate_assistant_id,
use_existing_user_message=use_existing_user_message,
allowed_tool_ids=allowed_tool_ids,
forced_tool_ids=forced_tool_ids,
forced_tool_id=forced_tool_ids[0] if forced_tool_ids else None,
mock_llm_response=mock_llm_response,
deep_research=deep_research,
llm_override=llm_override,
)
headers = (
@@ -243,8 +218,8 @@ class ChatSessionManager:
packets_received = 0
with requests.post(
f"{API_SERVER_URL}/chat/send-message",
json=chat_message_req.model_dump(),
f"{API_SERVER_URL}/chat/send-chat-message",
json=chat_message_req.model_dump(mode="json"),
headers=headers,
stream=True,
cookies=cookies,

View File

@@ -1,7 +1,5 @@
from onyx.configs import app_configs
from onyx.configs.constants import DocumentSource
from onyx.server.query_and_chat.models import OptionalSearchSetting
from onyx.server.query_and_chat.models import RetrievalDetails
from onyx.tools.constants import SEARCH_TOOL_ID
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.chat import ChatSessionManager
@@ -172,7 +170,7 @@ def test_run_search_always_maps_to_forced_search_tool(admin_user: DATestUser) ->
chat_session_id=chat_session.id,
message="always run search",
user_performing_action=admin_user,
retrieval_options=RetrievalDetails(run_search=OptionalSearchSetting.ALWAYS),
forced_tool_ids=[search_tool_id],
mock_llm_response='{"name":"internal_search","arguments":{"queries":["gamma"]}}',
)

View File

@@ -17,6 +17,7 @@ class TestOnyxWebCrawler:
content from public websites correctly.
"""
@pytest.mark.skip(reason="Temporarily disabled")
def test_fetches_public_url_successfully(self, admin_user: DATestUser) -> None:
"""Test that the crawler can fetch content from a public URL."""
response = requests.post(
@@ -40,6 +41,7 @@ class TestOnyxWebCrawler:
assert "This domain is for use in" in content
assert "documentation" in content or "illustrative" in content
@pytest.mark.skip(reason="Temporarily disabled")
def test_fetches_multiple_urls(self, admin_user: DATestUser) -> None:
"""Test that the crawler can fetch multiple URLs in one request."""
response = requests.post(
@@ -263,6 +265,7 @@ def _activate_exa_provider(admin_user: DATestUser) -> int:
@pytestmark_exa
@pytest.mark.skip(reason="Temporarily disabled")
def test_web_search_endpoints_with_exa(
reset: None, # noqa: ARG001
admin_user: DATestUser,

View File

@@ -0,0 +1,152 @@
"""Tests for chat_utils.py, specifically get_custom_agent_prompt."""
from unittest.mock import MagicMock
from onyx.chat.chat_utils import get_custom_agent_prompt
from onyx.configs.constants import DEFAULT_PERSONA_ID
class TestGetCustomAgentPrompt:
"""Tests for the get_custom_agent_prompt function."""
def _create_mock_persona(
self,
persona_id: int = 1,
system_prompt: str | None = None,
replace_base_system_prompt: bool = False,
) -> MagicMock:
"""Create a mock Persona with the specified attributes."""
persona = MagicMock()
persona.id = persona_id
persona.system_prompt = system_prompt
persona.replace_base_system_prompt = replace_base_system_prompt
return persona
def _create_mock_chat_session(
self,
project: MagicMock | None = None,
) -> MagicMock:
"""Create a mock ChatSession with the specified attributes."""
chat_session = MagicMock()
chat_session.project = project
return chat_session
def _create_mock_project(
self,
instructions: str = "",
) -> MagicMock:
"""Create a mock UserProject with the specified attributes."""
project = MagicMock()
project.instructions = instructions
return project
def test_default_persona_no_project(self) -> None:
"""Test that default persona without a project returns None."""
persona = self._create_mock_persona(persona_id=DEFAULT_PERSONA_ID)
chat_session = self._create_mock_chat_session(project=None)
result = get_custom_agent_prompt(persona, chat_session)
assert result is None
def test_default_persona_with_project_instructions(self) -> None:
"""Test that default persona in a project returns project instructions."""
persona = self._create_mock_persona(persona_id=DEFAULT_PERSONA_ID)
project = self._create_mock_project(instructions="Do X and Y")
chat_session = self._create_mock_chat_session(project=project)
result = get_custom_agent_prompt(persona, chat_session)
assert result == "Do X and Y"
def test_default_persona_with_empty_project_instructions(self) -> None:
"""Test that default persona in a project with empty instructions returns None."""
persona = self._create_mock_persona(persona_id=DEFAULT_PERSONA_ID)
project = self._create_mock_project(instructions="")
chat_session = self._create_mock_chat_session(project=project)
result = get_custom_agent_prompt(persona, chat_session)
assert result is None
def test_custom_persona_replace_base_prompt_true(self) -> None:
"""Test that custom persona with replace_base_system_prompt=True returns None."""
persona = self._create_mock_persona(
persona_id=1,
system_prompt="Custom system prompt",
replace_base_system_prompt=True,
)
chat_session = self._create_mock_chat_session(project=None)
result = get_custom_agent_prompt(persona, chat_session)
assert result is None
def test_custom_persona_with_system_prompt(self) -> None:
"""Test that custom persona with system_prompt returns the system_prompt."""
persona = self._create_mock_persona(
persona_id=1,
system_prompt="Custom system prompt",
replace_base_system_prompt=False,
)
chat_session = self._create_mock_chat_session(project=None)
result = get_custom_agent_prompt(persona, chat_session)
assert result == "Custom system prompt"
def test_custom_persona_empty_string_system_prompt(self) -> None:
"""Test that custom persona with empty string system_prompt returns None."""
persona = self._create_mock_persona(
persona_id=1,
system_prompt="",
replace_base_system_prompt=False,
)
chat_session = self._create_mock_chat_session(project=None)
result = get_custom_agent_prompt(persona, chat_session)
assert result is None
def test_custom_persona_none_system_prompt(self) -> None:
"""Test that custom persona with None system_prompt returns None."""
persona = self._create_mock_persona(
persona_id=1,
system_prompt=None,
replace_base_system_prompt=False,
)
chat_session = self._create_mock_chat_session(project=None)
result = get_custom_agent_prompt(persona, chat_session)
assert result is None
def test_custom_persona_in_project_uses_persona_prompt(self) -> None:
"""Test that custom persona in a project uses persona's system_prompt, not project instructions."""
persona = self._create_mock_persona(
persona_id=1,
system_prompt="Custom system prompt",
replace_base_system_prompt=False,
)
project = self._create_mock_project(instructions="Project instructions")
chat_session = self._create_mock_chat_session(project=project)
result = get_custom_agent_prompt(persona, chat_session)
# Should use persona's system_prompt, NOT project instructions
assert result == "Custom system prompt"
def test_custom_persona_replace_base_in_project(self) -> None:
"""Test that custom persona with replace_base_system_prompt=True in a project still returns None."""
persona = self._create_mock_persona(
persona_id=1,
system_prompt="Custom system prompt",
replace_base_system_prompt=True,
)
project = self._create_mock_project(instructions="Project instructions")
chat_session = self._create_mock_chat_session(project=project)
result = get_custom_agent_prompt(persona, chat_session)
# Should return None because replace_base_system_prompt=True
assert result is None

View File

@@ -0,0 +1,618 @@
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.connectors.airtable.airtable_connector import AirtableConnector
from onyx.connectors.airtable.airtable_connector import parse_airtable_url
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.models import Document
def _make_field_schema(field_id: str, name: str, field_type: str) -> MagicMock:
field = MagicMock()
field.id = field_id
field.name = name
field.type = field_type
return field
def _make_table_schema(
table_id: str,
table_name: str,
primary_field_id: str,
fields: list[MagicMock],
) -> MagicMock:
schema = MagicMock()
schema.id = table_id
schema.name = table_name
schema.primary_field_id = primary_field_id
schema.fields = fields
schema.views = []
return schema
def _make_record(record_id: str, fields: dict[str, Any]) -> dict[str, Any]:
return {"id": record_id, "fields": fields}
def _make_base_info(base_id: str, name: str) -> MagicMock:
info = MagicMock()
info.id = base_id
info.name = name
return info
def _make_table_obj(table_id: str, name: str) -> MagicMock:
obj = MagicMock()
obj.id = table_id
obj.name = name
return obj
def _setup_mock_api(
bases: list[dict[str, Any]],
) -> MagicMock:
"""Set up a mock AirtableApi with bases, tables, records, and schemas.
Args:
bases: List of dicts with keys: id, name, tables.
Each table is a dict with: id, name, primary_field_id, fields, records.
Each field is a dict with: id, name, type.
Each record is a dict with: id, fields.
"""
mock_api = MagicMock()
base_infos = [_make_base_info(b["id"], b["name"]) for b in bases]
mock_api.bases.return_value = base_infos
def base_side_effect(base_id: str) -> MagicMock:
mock_base = MagicMock()
base_data = next((b for b in bases if b["id"] == base_id), None)
if not base_data:
raise ValueError(f"Unknown base: {base_id}")
table_objs = [_make_table_obj(t["id"], t["name"]) for t in base_data["tables"]]
mock_base.tables.return_value = table_objs
return mock_base
mock_api.base.side_effect = base_side_effect
def table_side_effect(base_id: str, table_name_or_id: str) -> MagicMock:
base_data = next((b for b in bases if b["id"] == base_id), None)
if not base_data:
raise ValueError(f"Unknown base: {base_id}")
table_data = next(
(
t
for t in base_data["tables"]
if t["id"] == table_name_or_id or t["name"] == table_name_or_id
),
None,
)
if not table_data:
raise ValueError(f"Unknown table: {table_name_or_id}")
mock_table = MagicMock()
mock_table.name = table_data["name"]
mock_table.all.return_value = [
_make_record(r["id"], r["fields"]) for r in table_data["records"]
]
field_schemas = [
_make_field_schema(f["id"], f["name"], f["type"])
for f in table_data["fields"]
]
schema = _make_table_schema(
table_data["id"],
table_data["name"],
table_data["primary_field_id"],
field_schemas,
)
mock_table.schema.return_value = schema
return mock_table
mock_api.table.side_effect = table_side_effect
return mock_api
SAMPLE_BASES = [
{
"id": "appBASE1",
"name": "Base One",
"tables": [
{
"id": "tblTABLE1",
"name": "Table A",
"primary_field_id": "fld1",
"fields": [
{"id": "fld1", "name": "Name", "type": "singleLineText"},
{"id": "fld2", "name": "Notes", "type": "multilineText"},
],
"records": [
{"id": "recA1", "fields": {"Name": "Alice", "Notes": "Note A"}},
{"id": "recA2", "fields": {"Name": "Bob", "Notes": "Note B"}},
],
},
{
"id": "tblTABLE2",
"name": "Table B",
"primary_field_id": "fld3",
"fields": [
{"id": "fld3", "name": "Title", "type": "singleLineText"},
{"id": "fld4", "name": "Status", "type": "singleSelect"},
],
"records": [
{"id": "recB1", "fields": {"Title": "Task 1", "Status": "Done"}},
],
},
],
},
{
"id": "appBASE2",
"name": "Base Two",
"tables": [
{
"id": "tblTABLE3",
"name": "Table C",
"primary_field_id": "fld5",
"fields": [
{"id": "fld5", "name": "Item", "type": "singleLineText"},
],
"records": [
{"id": "recC1", "fields": {"Item": "Widget"}},
],
},
],
},
]
def _collect_docs(connector: AirtableConnector) -> list[Document]:
docs: list[Document] = []
for batch in connector.load_from_state():
for item in batch:
if isinstance(item, Document):
docs.append(item)
return docs
class TestIndexAll:
@patch("time.sleep")
def test_index_all_discovers_all_bases_and_tables(
self, mock_sleep: MagicMock # noqa: ARG002
) -> None:
connector = AirtableConnector()
mock_api = _setup_mock_api(SAMPLE_BASES)
connector._airtable_client = mock_api
docs = _collect_docs(connector)
# 2 records from Table A + 1 from Table B + 1 from Table C = 4
assert len(docs) == 4
doc_ids = {d.id for d in docs}
assert doc_ids == {
"airtable__recA1",
"airtable__recA2",
"airtable__recB1",
"airtable__recC1",
}
@patch("time.sleep")
def test_index_all_semantic_id_includes_base_name(
self, mock_sleep: MagicMock # noqa: ARG002
) -> None:
connector = AirtableConnector()
mock_api = _setup_mock_api(SAMPLE_BASES)
connector._airtable_client = mock_api
docs = _collect_docs(connector)
docs_by_id = {d.id: d for d in docs}
assert (
docs_by_id["airtable__recA1"].semantic_identifier
== "Base One > Table A: Alice"
)
assert (
docs_by_id["airtable__recB1"].semantic_identifier
== "Base One > Table B: Task 1"
)
assert (
docs_by_id["airtable__recC1"].semantic_identifier
== "Base Two > Table C: Widget"
)
@patch("time.sleep")
def test_index_all_hierarchy_source_path(
self, mock_sleep: MagicMock # noqa: ARG002
) -> None:
"""Verify doc_metadata hierarchy source_path is [base_name, table_name]."""
connector = AirtableConnector()
mock_api = _setup_mock_api(SAMPLE_BASES)
connector._airtable_client = mock_api
docs = _collect_docs(connector)
docs_by_id = {d.id: d for d in docs}
doc_a1 = docs_by_id["airtable__recA1"]
assert doc_a1.doc_metadata is not None
assert doc_a1.doc_metadata["hierarchy"]["source_path"] == [
"Base One",
"Table A",
]
assert doc_a1.doc_metadata["hierarchy"]["base_name"] == "Base One"
assert doc_a1.doc_metadata["hierarchy"]["table_name"] == "Table A"
doc_c1 = docs_by_id["airtable__recC1"]
assert doc_c1.doc_metadata is not None
assert doc_c1.doc_metadata["hierarchy"]["source_path"] == [
"Base Two",
"Table C",
]
@patch("time.sleep")
def test_index_all_empty_account(
self, mock_sleep: MagicMock # noqa: ARG002
) -> None:
connector = AirtableConnector()
mock_api = MagicMock()
mock_api.bases.return_value = []
connector._airtable_client = mock_api
docs = _collect_docs(connector)
assert len(docs) == 0
@patch("time.sleep")
def test_index_all_skips_failing_table(
self, mock_sleep: MagicMock # noqa: ARG002
) -> None:
"""If one table fails, other tables should still be indexed."""
bases = [
{
"id": "appBASE1",
"name": "Base One",
"tables": [
{
"id": "tblGOOD",
"name": "Good Table",
"primary_field_id": "fld1",
"fields": [
{"id": "fld1", "name": "Name", "type": "singleLineText"},
],
"records": [
{"id": "recOK", "fields": {"Name": "Works"}},
],
},
{
"id": "tblBAD",
"name": "Bad Table",
"primary_field_id": "fldX",
"fields": [],
"records": [],
},
],
},
]
mock_api = _setup_mock_api(bases)
# Make the bad table raise an error when fetching records
original_table_side_effect = mock_api.table.side_effect
def table_with_failure(base_id: str, table_name_or_id: str) -> MagicMock:
if table_name_or_id == "tblBAD":
mock_table = MagicMock()
mock_table.all.side_effect = Exception("API Error")
mock_table.schema.side_effect = Exception("API Error")
return mock_table
return original_table_side_effect(base_id, table_name_or_id)
mock_api.table.side_effect = table_with_failure
connector = AirtableConnector()
connector._airtable_client = mock_api
docs = _collect_docs(connector)
# Only the good table's records should come through
assert len(docs) == 1
assert docs[0].id == "airtable__recOK"
@patch("time.sleep")
def test_index_all_skips_failing_base(
self, mock_sleep: MagicMock # noqa: ARG002
) -> None:
"""If listing tables for a base fails, other bases should still be indexed."""
bases_data = [
{
"id": "appGOOD",
"name": "Good Base",
"tables": [
{
"id": "tblOK",
"name": "OK Table",
"primary_field_id": "fld1",
"fields": [
{"id": "fld1", "name": "Name", "type": "singleLineText"},
],
"records": [
{"id": "recOK", "fields": {"Name": "Works"}},
],
},
],
},
]
mock_api = _setup_mock_api(bases_data)
# Add a bad base that fails on tables()
bad_base_info = _make_base_info("appBAD", "Bad Base")
mock_api.bases.return_value = [
bad_base_info,
*mock_api.bases.return_value,
]
original_base_side_effect = mock_api.base.side_effect
def base_with_failure(base_id: str) -> MagicMock:
if base_id == "appBAD":
mock_base = MagicMock()
mock_base.tables.side_effect = Exception("Permission denied")
return mock_base
return original_base_side_effect(base_id)
mock_api.base.side_effect = base_with_failure
connector = AirtableConnector()
connector._airtable_client = mock_api
docs = _collect_docs(connector)
assert len(docs) == 1
assert docs[0].id == "airtable__recOK"
class TestSpecificTableMode:
def test_specific_table_unchanged(self) -> None:
"""Verify the original single-table behavior still works."""
bases = [
{
"id": "appBASE1",
"name": "Base One",
"tables": [
{
"id": "tblTABLE1",
"name": "Table A",
"primary_field_id": "fld1",
"fields": [
{"id": "fld1", "name": "Name", "type": "singleLineText"},
{"id": "fld2", "name": "Notes", "type": "multilineText"},
],
"records": [
{
"id": "recA1",
"fields": {"Name": "Alice", "Notes": "Note"},
},
],
},
],
},
]
mock_api = _setup_mock_api(bases)
connector = AirtableConnector(
base_id="appBASE1",
table_name_or_id="tblTABLE1",
)
connector._airtable_client = mock_api
docs = _collect_docs(connector)
assert len(docs) == 1
assert docs[0].id == "airtable__recA1"
# No base name prefix in specific mode
assert docs[0].semantic_identifier == "Table A: Alice"
def test_specific_table_resolves_base_name_for_hierarchy(self) -> None:
"""In specific mode, bases() is called to resolve the base name for hierarchy."""
bases = [
{
"id": "appBASE1",
"name": "Base One",
"tables": [
{
"id": "tblTABLE1",
"name": "Table A",
"primary_field_id": "fld1",
"fields": [
{"id": "fld1", "name": "Name", "type": "singleLineText"},
],
"records": [
{"id": "recA1", "fields": {"Name": "Test"}},
],
},
],
},
]
mock_api = _setup_mock_api(bases)
connector = AirtableConnector(
base_id="appBASE1",
table_name_or_id="tblTABLE1",
)
connector._airtable_client = mock_api
docs = _collect_docs(connector)
# bases() is called to resolve the base name for hierarchy source_path
mock_api.bases.assert_called_once()
# But base().tables() should NOT be called (no discovery)
mock_api.base.assert_not_called()
# Semantic identifier should NOT include base name in specific mode
assert docs[0].semantic_identifier == "Table A: Test"
# Hierarchy should include base name for Craft file system
assert docs[0].doc_metadata is not None
assert docs[0].doc_metadata["hierarchy"]["source_path"] == [
"Base One",
"Table A",
]
class TestValidateConnectorSettings:
def test_validate_index_all_success(self) -> None:
connector = AirtableConnector()
mock_api = _setup_mock_api(SAMPLE_BASES)
connector._airtable_client = mock_api
# Should not raise
connector.validate_connector_settings()
def test_validate_index_all_no_bases(self) -> None:
connector = AirtableConnector()
mock_api = MagicMock()
mock_api.bases.return_value = []
connector._airtable_client = mock_api
with pytest.raises(ConnectorValidationError, match="No bases found"):
connector.validate_connector_settings()
def test_validate_specific_table_success(self) -> None:
connector = AirtableConnector(
base_id="appBASE1",
table_name_or_id="tblTABLE1",
)
mock_api = _setup_mock_api(SAMPLE_BASES)
connector._airtable_client = mock_api
# Should not raise
connector.validate_connector_settings()
def test_validate_empty_fields_auto_detects_index_all(self) -> None:
"""Empty base_id + table_name_or_id auto-detects as index_all mode."""
connector = AirtableConnector(
base_id="",
table_name_or_id="",
)
assert connector.index_all is True
# Validation should go through the index_all path
mock_api = _setup_mock_api(SAMPLE_BASES)
connector._airtable_client = mock_api
connector.validate_connector_settings()
def test_validate_specific_table_api_error(self) -> None:
connector = AirtableConnector(
base_id="appBAD",
table_name_or_id="tblBAD",
)
mock_api = MagicMock()
mock_table = MagicMock()
mock_table.schema.side_effect = Exception("Not found")
mock_api.table.return_value = mock_table
connector._airtable_client = mock_api
with pytest.raises(ConnectorValidationError, match="Failed to access table"):
connector.validate_connector_settings()
class TestParseAirtableUrl:
def test_full_url_with_view(self) -> None:
base_id, table_id, view_id = parse_airtable_url(
"https://airtable.com/appZqBgQFQ6kWyeZK/tblc9prNLypy7olTV/viwa3yxZvqWnyXftm?blocks=hide"
)
assert base_id == "appZqBgQFQ6kWyeZK"
assert table_id == "tblc9prNLypy7olTV"
assert view_id == "viwa3yxZvqWnyXftm"
def test_url_without_view(self) -> None:
base_id, table_id, view_id = parse_airtable_url(
"https://airtable.com/appZqBgQFQ6kWyeZK/tblc9prNLypy7olTV"
)
assert base_id == "appZqBgQFQ6kWyeZK"
assert table_id == "tblc9prNLypy7olTV"
assert view_id is None
def test_url_without_query_params(self) -> None:
base_id, table_id, view_id = parse_airtable_url(
"https://airtable.com/appABC123/tblDEF456/viwGHI789"
)
assert base_id == "appABC123"
assert table_id == "tblDEF456"
assert view_id == "viwGHI789"
def test_url_with_trailing_whitespace(self) -> None:
base_id, table_id, view_id = parse_airtable_url(
" https://airtable.com/appABC123/tblDEF456 "
)
assert base_id == "appABC123"
assert table_id == "tblDEF456"
def test_invalid_url_raises(self) -> None:
with pytest.raises(ValueError, match="Could not parse"):
parse_airtable_url("https://google.com/something")
def test_missing_table_raises(self) -> None:
with pytest.raises(ValueError, match="Could not parse"):
parse_airtable_url("https://airtable.com/appABC123")
def test_empty_string_raises(self) -> None:
with pytest.raises(ValueError, match="Could not parse"):
parse_airtable_url("")
class TestAirtableUrlConnector:
def test_url_sets_base_and_table_ids(self) -> None:
connector = AirtableConnector(
airtable_url="https://airtable.com/appZqBgQFQ6kWyeZK/tblc9prNLypy7olTV/viwa3yxZvqWnyXftm?blocks=hide"
)
assert connector.base_id == "appZqBgQFQ6kWyeZK"
assert connector.table_name_or_id == "tblc9prNLypy7olTV"
assert connector.view_id == "viwa3yxZvqWnyXftm"
def test_url_without_view_leaves_view_none(self) -> None:
connector = AirtableConnector(airtable_url="https://airtable.com/appABC/tblDEF")
assert connector.base_id == "appABC"
assert connector.table_name_or_id == "tblDEF"
assert connector.view_id is None
def test_url_overrides_explicit_base_and_table(self) -> None:
connector = AirtableConnector(
base_id="appOLD",
table_name_or_id="tblOLD",
airtable_url="https://airtable.com/appNEW/tblNEW",
)
assert connector.base_id == "appNEW"
assert connector.table_name_or_id == "tblNEW"
def test_url_indexes_correctly(self) -> None:
"""End-to-end: URL-configured connector fetches from the right table."""
bases = [
{
"id": "appFromUrl",
"name": "URL Base",
"tables": [
{
"id": "tblFromUrl",
"name": "URL Table",
"primary_field_id": "fld1",
"fields": [
{"id": "fld1", "name": "Name", "type": "singleLineText"},
],
"records": [
{"id": "recURL1", "fields": {"Name": "From URL"}},
],
},
],
},
]
mock_api = _setup_mock_api(bases)
connector = AirtableConnector(
airtable_url="https://airtable.com/appFromUrl/tblFromUrl/viwABC"
)
connector._airtable_client = mock_api
docs = _collect_docs(connector)
assert len(docs) == 1
assert docs[0].id == "airtable__recURL1"
assert docs[0].semantic_identifier == "URL Table: From URL"

View File

@@ -0,0 +1,136 @@
"""Unit tests for SharepointConnector._create_rest_client_context caching."""
from __future__ import annotations
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.connectors.sharepoint.connector import _REST_CTX_MAX_AGE_S
from onyx.connectors.sharepoint.connector import SharepointConnector
SITE_A = "https://tenant.sharepoint.com/sites/SiteA"
SITE_B = "https://tenant.sharepoint.com/sites/SiteB"
FAKE_CREDS = {"sp_client_id": "x", "sp_directory_id": "y"}
def _make_connector() -> SharepointConnector:
"""Return a SharepointConnector with minimal credentials wired up."""
connector = SharepointConnector(sites=[SITE_A])
connector.msal_app = MagicMock()
connector.sp_tenant_domain = "tenant"
connector._credential_json = FAKE_CREDS
return connector
def _noop_load_credentials(connector: SharepointConnector) -> MagicMock:
"""Patch load_credentials to just swap in a fresh MagicMock for msal_app."""
def _fake_load(creds: dict) -> None: # noqa: ARG001, ARG002
connector.msal_app = MagicMock()
mock = MagicMock(side_effect=_fake_load)
connector.load_credentials = mock # type: ignore[method-assign]
return mock
def _fresh_client_context() -> MagicMock:
"""Return a MagicMock for ClientContext that produces a distinct object per call."""
mock_cls = MagicMock()
# Each ClientContext(url).with_access_token(cb) returns a unique sentinel
mock_cls.side_effect = lambda url: MagicMock() # noqa: ARG005
return mock_cls
@patch("onyx.connectors.sharepoint.connector.acquire_token_for_rest")
@patch("onyx.connectors.sharepoint.connector.ClientContext")
def test_returns_cached_context_within_max_age(
mock_client_ctx_cls: MagicMock,
_mock_acquire: MagicMock,
) -> None:
"""Repeated calls with the same site_url within the TTL return the same object."""
mock_client_ctx_cls.side_effect = lambda url: MagicMock() # noqa: ARG005
connector = _make_connector()
_noop_load_credentials(connector)
ctx1 = connector._create_rest_client_context(SITE_A)
ctx2 = connector._create_rest_client_context(SITE_A)
assert ctx1 is ctx2
assert mock_client_ctx_cls.call_count == 1
@patch("onyx.connectors.sharepoint.connector.time")
@patch("onyx.connectors.sharepoint.connector.acquire_token_for_rest")
@patch("onyx.connectors.sharepoint.connector.ClientContext")
def test_rebuilds_context_after_max_age(
mock_client_ctx_cls: MagicMock,
_mock_acquire: MagicMock,
mock_time: MagicMock,
) -> None:
"""After _REST_CTX_MAX_AGE_S the cached context is replaced."""
mock_client_ctx_cls.side_effect = lambda url: MagicMock() # noqa: ARG005
connector = _make_connector()
_noop_load_credentials(connector)
mock_time.monotonic.return_value = 0.0
ctx1 = connector._create_rest_client_context(SITE_A)
# Just past the boundary — should rebuild
mock_time.monotonic.return_value = _REST_CTX_MAX_AGE_S + 1
ctx2 = connector._create_rest_client_context(SITE_A)
assert ctx1 is not ctx2
assert mock_client_ctx_cls.call_count == 2
@patch("onyx.connectors.sharepoint.connector.acquire_token_for_rest")
@patch("onyx.connectors.sharepoint.connector.ClientContext")
def test_rebuilds_context_on_site_change(
mock_client_ctx_cls: MagicMock,
_mock_acquire: MagicMock,
) -> None:
"""Switching to a different site_url forces a new context."""
mock_client_ctx_cls.side_effect = lambda url: MagicMock() # noqa: ARG005
connector = _make_connector()
_noop_load_credentials(connector)
ctx_a = connector._create_rest_client_context(SITE_A)
ctx_b = connector._create_rest_client_context(SITE_B)
assert ctx_a is not ctx_b
assert mock_client_ctx_cls.call_count == 2
@patch("onyx.connectors.sharepoint.connector.time")
@patch("onyx.connectors.sharepoint.connector.acquire_token_for_rest")
@patch("onyx.connectors.sharepoint.connector.ClientContext")
def test_load_credentials_called_on_rebuild(
_mock_client_ctx_cls: MagicMock,
_mock_acquire: MagicMock,
mock_time: MagicMock,
) -> None:
"""load_credentials is called every time the context is rebuilt."""
_mock_client_ctx_cls.side_effect = lambda url: MagicMock() # noqa: ARG005
connector = _make_connector()
mock_load = _noop_load_credentials(connector)
# First call — rebuild (no cache yet)
mock_time.monotonic.return_value = 0.0
connector._create_rest_client_context(SITE_A)
assert mock_load.call_count == 1
# Second call — cache hit, no rebuild
mock_time.monotonic.return_value = 100.0
connector._create_rest_client_context(SITE_A)
assert mock_load.call_count == 1
# Third call — expired, rebuild
mock_time.monotonic.return_value = _REST_CTX_MAX_AGE_S + 1
connector._create_rest_client_context(SITE_A)
assert mock_load.call_count == 2
# Fourth call — site change, rebuild
mock_time.monotonic.return_value = _REST_CTX_MAX_AGE_S + 2
connector._create_rest_client_context(SITE_B)
assert mock_load.call_count == 3

View File

@@ -1,3 +1,7 @@
import os
import threading
import time
from typing import Any
from unittest.mock import ANY
from unittest.mock import patch
@@ -137,42 +141,44 @@ def default_multi_llm() -> LitellmLLM:
def test_multiple_tool_calls(default_multi_llm: LitellmLLM) -> None:
# Mock the litellm.completion function
with patch("litellm.completion") as mock_completion:
# Create a mock response with multiple tool calls using litellm objects
mock_response = litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
finish_reason="tool_calls",
index=0,
message=litellm.Message(
content=None,
role="assistant",
tool_calls=[
litellm.ChatCompletionMessageToolCall(
id="call_1",
function=LiteLLMFunction(
name="get_weather",
arguments='{"location": "New York"}',
# invoke() internally uses stream=True and reassembles via
# stream_chunk_builder, so the mock must return stream chunks.
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(
role="assistant",
tool_calls=[
ChatCompletionDeltaToolCall(
id="call_1",
function=LiteLLMFunction(
name="get_weather",
arguments='{"location": "New York"}',
),
type="function",
index=0,
),
type="function",
),
litellm.ChatCompletionMessageToolCall(
id="call_2",
function=LiteLLMFunction(
name="get_time", arguments='{"timezone": "EST"}'
ChatCompletionDeltaToolCall(
id="call_2",
function=LiteLLMFunction(
name="get_time",
arguments='{"timezone": "EST"}',
),
type="function",
index=1,
),
type="function",
),
],
),
)
],
model="gpt-3.5-turbo",
usage=litellm.Usage(
prompt_tokens=50, completion_tokens=30, total_tokens=80
],
),
finish_reason="tool_calls",
index=0,
)
],
model="gpt-3.5-turbo",
),
)
mock_completion.return_value = mock_response
]
mock_completion.return_value = mock_stream_chunks
# Define input messages
messages: LanguageModelInput = [
@@ -246,11 +252,12 @@ def test_multiple_tool_calls(default_multi_llm: LitellmLLM) -> None:
],
tools=tools,
tool_choice=None,
stream=False,
stream=True,
temperature=0.0, # Default value from GEN_AI_TEMPERATURE
timeout=30,
max_tokens=None,
client=ANY, # HTTPHandler instance created per-request
stream_options={"include_usage": True},
parallel_tool_calls=True,
mock_response=MOCK_LLM_RESPONSE,
allowed_openai_params=["tool_choice"],
@@ -507,21 +514,20 @@ def test_openai_chat_omits_reasoning_params() -> None:
"onyx.llm.multi_llm.is_true_openai_model", return_value=True
) as mock_is_openai,
):
mock_response = litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
finish_reason="stop",
index=0,
message=litellm.Message(
content="Hello",
role="assistant",
),
)
],
model="gpt-5-chat",
)
mock_completion.return_value = mock_response
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(role="assistant", content="Hello"),
finish_reason="stop",
index=0,
)
],
model="gpt-5-chat",
),
]
mock_completion.return_value = mock_stream_chunks
messages: LanguageModelInput = [UserMessage(content="Hi")]
llm.invoke(messages)
@@ -539,21 +545,20 @@ def test_user_identity_metadata_enabled(default_multi_llm: LitellmLLM) -> None:
patch("litellm.completion") as mock_completion,
patch("onyx.llm.utils.SEND_USER_METADATA_TO_LLM_PROVIDER", True),
):
mock_response = litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
finish_reason="stop",
index=0,
message=litellm.Message(
content="Hello",
role="assistant",
),
)
],
model="gpt-3.5-turbo",
)
mock_completion.return_value = mock_response
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(role="assistant", content="Hello"),
finish_reason="stop",
index=0,
)
],
model="gpt-3.5-turbo",
),
]
mock_completion.return_value = mock_stream_chunks
messages: LanguageModelInput = [UserMessage(content="Hi")]
identity = LLMUserIdentity(user_id="user_123", session_id="session_abc")
@@ -573,21 +578,20 @@ def test_user_identity_user_id_truncated_to_64_chars(
patch("litellm.completion") as mock_completion,
patch("onyx.llm.utils.SEND_USER_METADATA_TO_LLM_PROVIDER", True),
):
mock_response = litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
finish_reason="stop",
index=0,
message=litellm.Message(
content="Hello",
role="assistant",
),
)
],
model="gpt-3.5-turbo",
)
mock_completion.return_value = mock_response
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(role="assistant", content="Hello"),
finish_reason="stop",
index=0,
)
],
model="gpt-3.5-turbo",
),
]
mock_completion.return_value = mock_stream_chunks
messages: LanguageModelInput = [UserMessage(content="Hi")]
long_user_id = "u" * 82
@@ -607,21 +611,20 @@ def test_user_identity_metadata_disabled_omits_identity(
patch("litellm.completion") as mock_completion,
patch("onyx.llm.utils.SEND_USER_METADATA_TO_LLM_PROVIDER", False),
):
mock_response = litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
finish_reason="stop",
index=0,
message=litellm.Message(
content="Hello",
role="assistant",
),
)
],
model="gpt-3.5-turbo",
)
mock_completion.return_value = mock_response
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(role="assistant", content="Hello"),
finish_reason="stop",
index=0,
)
],
model="gpt-3.5-turbo",
),
]
mock_completion.return_value = mock_stream_chunks
messages: LanguageModelInput = [UserMessage(content="Hi")]
identity = LLMUserIdentity(user_id="user_123", session_id="session_abc")
@@ -654,21 +657,20 @@ def test_existing_metadata_pass_through_when_identity_disabled() -> None:
patch("litellm.completion") as mock_completion,
patch("onyx.llm.utils.SEND_USER_METADATA_TO_LLM_PROVIDER", False),
):
mock_response = litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
finish_reason="stop",
index=0,
message=litellm.Message(
content="Hello",
role="assistant",
),
)
],
model="gpt-3.5-turbo",
)
mock_completion.return_value = mock_response
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(role="assistant", content="Hello"),
finish_reason="stop",
index=0,
)
],
model="gpt-3.5-turbo",
),
]
mock_completion.return_value = mock_stream_chunks
messages: LanguageModelInput = [UserMessage(content="Hi")]
identity = LLMUserIdentity(user_id="user_123", session_id="session_abc")
@@ -688,18 +690,20 @@ def test_openai_model_invoke_uses_httphandler_client(
from litellm import HTTPHandler
with patch("litellm.completion") as mock_completion:
mock_response = litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
finish_reason="stop",
index=0,
message=litellm.Message(content="Hello", role="assistant"),
)
],
model="gpt-3.5-turbo",
)
mock_completion.return_value = mock_response
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(role="assistant", content="Hello"),
finish_reason="stop",
index=0,
)
],
model="gpt-3.5-turbo",
),
]
mock_completion.return_value = mock_stream_chunks
messages: LanguageModelInput = [UserMessage(content="Hi")]
default_multi_llm.invoke(messages)
@@ -737,18 +741,20 @@ def test_anthropic_model_passes_no_client() -> None:
)
with patch("litellm.completion") as mock_completion:
mock_response = litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
finish_reason="stop",
index=0,
message=litellm.Message(content="Hello", role="assistant"),
)
],
model="claude-3-opus-20240229",
)
mock_completion.return_value = mock_response
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(role="assistant", content="Hello"),
finish_reason="stop",
index=0,
)
],
model="claude-3-opus-20240229",
),
]
mock_completion.return_value = mock_stream_chunks
messages: LanguageModelInput = [UserMessage(content="Hi")]
llm.invoke(messages)
@@ -769,18 +775,20 @@ def test_bedrock_model_passes_no_client() -> None:
)
with patch("litellm.completion") as mock_completion:
mock_response = litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
finish_reason="stop",
index=0,
message=litellm.Message(content="Hello", role="assistant"),
)
],
model="anthropic.claude-3-sonnet-20240229-v1:0",
)
mock_completion.return_value = mock_response
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(role="assistant", content="Hello"),
finish_reason="stop",
index=0,
)
],
model="anthropic.claude-3-sonnet-20240229-v1:0",
),
]
mock_completion.return_value = mock_stream_chunks
messages: LanguageModelInput = [UserMessage(content="Hi")]
llm.invoke(messages)
@@ -809,18 +817,20 @@ def test_azure_openai_model_uses_httphandler_client() -> None:
)
with patch("litellm.completion") as mock_completion:
mock_response = litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
finish_reason="stop",
index=0,
message=litellm.Message(content="Hello", role="assistant"),
)
],
model="gpt-4o",
)
mock_completion.return_value = mock_response
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(role="assistant", content="Hello"),
finish_reason="stop",
index=0,
)
],
model="gpt-4o",
),
]
mock_completion.return_value = mock_stream_chunks
messages: LanguageModelInput = [UserMessage(content="Hi")]
llm.invoke(messages)
@@ -828,3 +838,372 @@ def test_azure_openai_model_uses_httphandler_client() -> None:
mock_completion.assert_called_once()
kwargs = mock_completion.call_args.kwargs
assert isinstance(kwargs["client"], HTTPHandler)
def test_temporary_env_cleanup(monkeypatch: pytest.MonkeyPatch) -> None:
# Assign some environment variables
EXPECTED_ENV_VARS = {
"TEST_ENV_VAR": "test_value",
"ANOTHER_ONE": "1",
"THIRD_ONE": "2",
}
CUSTOM_CONFIG = {
"TEST_ENV_VAR": "fdsfsdf",
"ANOTHER_ONE": "3",
"THIS_IS_RANDOM": "123213",
}
for env_var, value in EXPECTED_ENV_VARS.items():
monkeypatch.setenv(env_var, value)
model_provider = LlmProviderNames.OPENAI
model_name = "gpt-3.5-turbo"
llm = LitellmLLM(
api_key="test_key",
timeout=30,
model_provider=model_provider,
model_name=model_name,
max_input_tokens=get_max_input_tokens(
model_provider=model_provider,
model_name=model_name,
),
model_kwargs={"metadata": {"foo": "bar"}},
custom_config=CUSTOM_CONFIG,
)
# When custom_config is set, invoke() internally uses stream=True and
# reassembles via stream_chunk_builder, so the mock must return stream chunks.
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(role="assistant", content="Hello"),
finish_reason="stop",
index=0,
)
],
model="gpt-3.5-turbo",
),
]
def on_litellm_completion(
**kwargs: dict[str, Any], # noqa: ARG001
) -> list[litellm.ModelResponse]:
# Validate that the environment variables are those in custom config
for env_var, value in CUSTOM_CONFIG.items():
assert env_var in os.environ
assert os.environ[env_var] == value
return mock_stream_chunks
with (
patch("litellm.completion") as mock_completion,
patch("onyx.llm.utils.SEND_USER_METADATA_TO_LLM_PROVIDER", False),
):
mock_completion.side_effect = on_litellm_completion
messages: LanguageModelInput = [UserMessage(content="Hi")]
identity = LLMUserIdentity(user_id="user_123", session_id="session_abc")
llm.invoke(messages, user_identity=identity)
mock_completion.assert_called_once()
kwargs = mock_completion.call_args.kwargs
assert kwargs["stream"] is True
assert "user" not in kwargs
assert kwargs["metadata"]["foo"] == "bar"
# Check that the environment variables are back to the original values
for env_var, value in EXPECTED_ENV_VARS.items():
assert env_var in os.environ
assert os.environ[env_var] == value
# Check that temporary env var from CUSTOM_CONFIG is no longer set
assert "THIS_IS_RANDOM" not in os.environ
def test_temporary_env_cleanup_on_exception(monkeypatch: pytest.MonkeyPatch) -> None:
"""Verify env vars are restored even when an exception occurs during LLM invocation."""
# Assign some environment variables
EXPECTED_ENV_VARS = {
"TEST_ENV_VAR": "test_value",
"ANOTHER_ONE": "1",
"THIRD_ONE": "2",
}
CUSTOM_CONFIG = {
"TEST_ENV_VAR": "fdsfsdf",
"ANOTHER_ONE": "3",
"THIS_IS_RANDOM": "123213",
}
for env_var, value in EXPECTED_ENV_VARS.items():
monkeypatch.setenv(env_var, value)
model_provider = LlmProviderNames.OPENAI
model_name = "gpt-3.5-turbo"
llm = LitellmLLM(
api_key="test_key",
timeout=30,
model_provider=model_provider,
model_name=model_name,
max_input_tokens=get_max_input_tokens(
model_provider=model_provider,
model_name=model_name,
),
model_kwargs={"metadata": {"foo": "bar"}},
custom_config=CUSTOM_CONFIG,
)
def on_litellm_completion_raises(**kwargs: dict[str, Any]) -> None: # noqa: ARG001
# Validate that the environment variables are those in custom config
for env_var, value in CUSTOM_CONFIG.items():
assert env_var in os.environ
assert os.environ[env_var] == value
# Simulate an error during LLM call
raise RuntimeError("Simulated LLM API failure")
with (
patch("litellm.completion") as mock_completion,
patch("onyx.llm.utils.SEND_USER_METADATA_TO_LLM_PROVIDER", False),
):
mock_completion.side_effect = on_litellm_completion_raises
messages: LanguageModelInput = [UserMessage(content="Hi")]
identity = LLMUserIdentity(user_id="user_123", session_id="session_abc")
with pytest.raises(RuntimeError, match="Simulated LLM API failure"):
llm.invoke(messages, user_identity=identity)
mock_completion.assert_called_once()
# Check that the environment variables are back to the original values
for env_var, value in EXPECTED_ENV_VARS.items():
assert env_var in os.environ
assert os.environ[env_var] == value
# Check that temporary env var from CUSTOM_CONFIG is no longer set
assert "THIS_IS_RANDOM" not in os.environ
@pytest.mark.parametrize("use_stream", [False, True], ids=["invoke", "stream"])
def test_multithreaded_custom_config_isolation(
monkeypatch: pytest.MonkeyPatch,
use_stream: bool,
) -> None:
"""Verify the env lock prevents concurrent LLM calls from seeing each other's custom_config.
Two LitellmLLM instances with different custom_config dicts call invoke/stream
concurrently. The _env_lock in temporary_env_and_lock serializes their access so
each call only ever sees its own env vars—never the other's.
"""
# Ensure these keys start unset
monkeypatch.delenv("SHARED_KEY", raising=False)
monkeypatch.delenv("LLM_A_ONLY", raising=False)
monkeypatch.delenv("LLM_B_ONLY", raising=False)
CONFIG_A = {
"SHARED_KEY": "value_from_A",
"LLM_A_ONLY": "a_secret",
}
CONFIG_B = {
"SHARED_KEY": "value_from_B",
"LLM_B_ONLY": "b_secret",
}
all_env_keys = list(set(list(CONFIG_A.keys()) + list(CONFIG_B.keys())))
model_provider = LlmProviderNames.OPENAI
model_name = "gpt-3.5-turbo"
llm_a = LitellmLLM(
api_key="key_a",
timeout=30,
model_provider=model_provider,
model_name=model_name,
max_input_tokens=get_max_input_tokens(
model_provider=model_provider,
model_name=model_name,
),
custom_config=CONFIG_A,
)
llm_b = LitellmLLM(
api_key="key_b",
timeout=30,
model_provider=model_provider,
model_name=model_name,
max_input_tokens=get_max_input_tokens(
model_provider=model_provider,
model_name=model_name,
),
custom_config=CONFIG_B,
)
# Both invoke (with custom_config) and stream use stream=True at the
# litellm level, so the mock must return stream chunks.
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(role="assistant", content="Hi"),
finish_reason="stop",
index=0,
)
],
model=model_name,
),
]
# Track what each call observed inside litellm.completion.
# Keyed by api_key so we can identify which LLM instance made the call.
observed_envs: dict[str, dict[str, str | None]] = {}
def fake_completion(**kwargs: Any) -> list[litellm.ModelResponse]:
time.sleep(0.1) # We expect someone to get caught on the lock
api_key = kwargs.get("api_key", "")
label = "A" if api_key == "key_a" else "B"
snapshot: dict[str, str | None] = {}
for key in all_env_keys:
snapshot[key] = os.environ.get(key)
observed_envs[label] = snapshot
return mock_stream_chunks
errors: list[Exception] = []
def run_llm(llm: LitellmLLM) -> None:
try:
messages: LanguageModelInput = [UserMessage(content="Hi")]
if use_stream:
list(llm.stream(messages))
else:
llm.invoke(messages)
except Exception as e:
errors.append(e)
with patch("litellm.completion", side_effect=fake_completion):
t_a = threading.Thread(target=run_llm, args=(llm_a,))
t_b = threading.Thread(target=run_llm, args=(llm_b,))
t_a.start()
t_b.start()
t_a.join(timeout=10)
t_b.join(timeout=10)
assert not errors, f"Thread errors: {errors}"
assert "A" in observed_envs and "B" in observed_envs
# Thread A must have seen its own config for SHARED_KEY, not B's
assert observed_envs["A"]["SHARED_KEY"] == "value_from_A"
assert observed_envs["A"]["LLM_A_ONLY"] == "a_secret"
# A must NOT see B's exclusive key
assert observed_envs["A"]["LLM_B_ONLY"] is None
# Thread B must have seen its own config for SHARED_KEY, not A's
assert observed_envs["B"]["SHARED_KEY"] == "value_from_B"
assert observed_envs["B"]["LLM_B_ONLY"] == "b_secret"
# B must NOT see A's exclusive key
assert observed_envs["B"]["LLM_A_ONLY"] is None
# After both calls, env should be clean
assert os.environ.get("SHARED_KEY") is None
assert os.environ.get("LLM_A_ONLY") is None
assert os.environ.get("LLM_B_ONLY") is None
def test_multithreaded_invoke_without_custom_config_skips_env_lock() -> None:
"""Verify that invoke() without custom_config does not acquire the env lock.
Two LitellmLLM instances without custom_config call invoke concurrently.
Both should run with stream=False, never touch the env lock, and complete
without blocking each other.
"""
from onyx.llm import multi_llm as multi_llm_module
model_provider = LlmProviderNames.OPENAI
model_name = "gpt-3.5-turbo"
llm_a = LitellmLLM(
api_key="key_a",
timeout=30,
model_provider=model_provider,
model_name=model_name,
max_input_tokens=get_max_input_tokens(
model_provider=model_provider,
model_name=model_name,
),
)
llm_b = LitellmLLM(
api_key="key_b",
timeout=30,
model_provider=model_provider,
model_name=model_name,
max_input_tokens=get_max_input_tokens(
model_provider=model_provider,
model_name=model_name,
),
)
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(role="assistant", content="Hi"),
finish_reason="stop",
index=0,
)
],
model=model_name,
),
]
call_kwargs: dict[str, dict[str, Any]] = {}
def fake_completion(**kwargs: Any) -> list[litellm.ModelResponse]:
api_key = kwargs.get("api_key", "")
label = "A" if api_key == "key_a" else "B"
call_kwargs[label] = kwargs
return mock_stream_chunks
errors: list[Exception] = []
def run_llm(llm: LitellmLLM) -> None:
try:
messages: LanguageModelInput = [UserMessage(content="Hi")]
llm.invoke(messages)
except Exception as e:
errors.append(e)
with (
patch("litellm.completion", side_effect=fake_completion),
patch.object(
multi_llm_module,
"temporary_env_and_lock",
wraps=multi_llm_module.temporary_env_and_lock,
) as mock_env_lock,
):
t_a = threading.Thread(target=run_llm, args=(llm_a,))
t_b = threading.Thread(target=run_llm, args=(llm_b,))
t_a.start()
t_b.start()
t_a.join(timeout=10)
t_b.join(timeout=10)
assert not errors, f"Thread errors: {errors}"
assert "A" in call_kwargs and "B" in call_kwargs
# invoke() always uses stream=True internally (reassembles via stream_chunk_builder)
assert call_kwargs["A"]["stream"] is True
assert call_kwargs["B"]["stream"] is True
# The env lock context manager should never have been called
mock_env_lock.assert_not_called()

View File

@@ -0,0 +1,15 @@
from onyx.prompts.constants import REMINDER_TAG_DESCRIPTION
from onyx.prompts.prompt_utils import replace_reminder_tag
def test_replace_reminder_tag_pattern() -> None:
prompt = "Some text {{REMINDER_TAG_DESCRIPTION}} more text"
result = replace_reminder_tag(prompt)
assert "{{REMINDER_TAG_DESCRIPTION}}" not in result
assert REMINDER_TAG_DESCRIPTION in result
def test_replace_reminder_tag_no_pattern() -> None:
prompt = "Some text without any pattern"
result = replace_reminder_tag(prompt)
assert result == prompt

View File

@@ -0,0 +1,374 @@
"""Unit tests for Zed-style ACP session management in KubernetesSandboxManager.
These tests verify that the KubernetesSandboxManager correctly:
- Maintains one shared ACPExecClient per sandbox
- Maps craft sessions to ACP sessions on the shared client
- Replaces dead clients and re-creates sessions
- Cleans up on terminate/cleanup
All external dependencies (K8s, WebSockets, packet logging) are mocked.
"""
from collections.abc import Generator
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import UUID
from uuid import uuid4
import pytest
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
# The fully-qualified path to the module under test, used for patching
_K8S_MODULE = "onyx.server.features.build.sandbox.kubernetes.kubernetes_sandbox_manager"
_ACP_CLIENT_CLASS = f"{_K8S_MODULE}.ACPExecClient"
_GET_PACKET_LOGGER = f"{_K8S_MODULE}.get_packet_logger"
def _make_mock_event() -> MagicMock:
"""Create a mock ACP event."""
return MagicMock(name="mock_acp_event")
def _make_mock_client(
is_running: bool = True,
session_ids: list[str] | None = None,
) -> MagicMock:
"""Create a mock ACPExecClient with configurable state.
Args:
is_running: Whether the client appears running
session_ids: List of ACP session IDs the client tracks
"""
mock_client = MagicMock()
type(mock_client).is_running = property(lambda _self: is_running)
type(mock_client).session_ids = property(
lambda _self: session_ids if session_ids is not None else []
)
mock_client.start.return_value = None
mock_client.stop.return_value = None
# get_or_create_session returns a unique ACP session ID
mock_client.get_or_create_session.return_value = f"acp-session-{uuid4().hex[:8]}"
mock_event = _make_mock_event()
mock_client.send_message.return_value = iter([mock_event])
return mock_client
def _drain_generator(gen: Generator[Any, None, None]) -> list[Any]:
"""Consume a generator and return all yielded values as a list."""
return list(gen)
# ---------------------------------------------------------------------------
# Fixture: fresh KubernetesSandboxManager instance
# ---------------------------------------------------------------------------
@pytest.fixture()
def manager() -> Generator[Any, None, None]:
"""Create a fresh KubernetesSandboxManager instance with all externals mocked."""
with (
patch(f"{_K8S_MODULE}.config") as _mock_config,
patch(f"{_K8S_MODULE}.client") as _mock_k8s_client,
patch(f"{_K8S_MODULE}.k8s_stream"),
patch(_GET_PACKET_LOGGER) as mock_get_logger,
):
mock_packet_logger = MagicMock()
mock_get_logger.return_value = mock_packet_logger
_mock_config.load_incluster_config.return_value = None
_mock_config.ConfigException = Exception
_mock_k8s_client.ApiClient.return_value = MagicMock()
_mock_k8s_client.CoreV1Api.return_value = MagicMock()
_mock_k8s_client.BatchV1Api.return_value = MagicMock()
_mock_k8s_client.NetworkingV1Api.return_value = MagicMock()
from onyx.server.features.build.sandbox.kubernetes.kubernetes_sandbox_manager import (
KubernetesSandboxManager,
)
KubernetesSandboxManager._instance = None
mgr = KubernetesSandboxManager()
yield mgr
KubernetesSandboxManager._instance = None
# ---------------------------------------------------------------------------
# Tests: Shared client lifecycle
# ---------------------------------------------------------------------------
def test_send_message_creates_shared_client_on_first_call(manager: Any) -> None:
"""First call to send_message() should create one shared ACPExecClient
for the sandbox, create an ACP session, and yield events."""
sandbox_id: UUID = uuid4()
session_id: UUID = uuid4()
message = "hello world"
mock_event = _make_mock_event()
mock_client = _make_mock_client(is_running=True)
acp_session_id = "acp-session-abc"
mock_client.get_or_create_session.return_value = acp_session_id
# session_ids must include the created session for validation
type(mock_client).session_ids = property(lambda _: [acp_session_id])
mock_client.send_message.return_value = iter([mock_event])
with patch(_ACP_CLIENT_CLASS, return_value=mock_client) as MockClass:
events = _drain_generator(manager.send_message(sandbox_id, session_id, message))
# Verify shared client was constructed once
MockClass.assert_called_once()
# Verify start() was called with /workspace (not session-specific path)
mock_client.start.assert_called_once_with(cwd="/workspace")
# Verify get_or_create_session was called with the session path
expected_cwd = f"/workspace/sessions/{session_id}"
mock_client.get_or_create_session.assert_called_once_with(cwd=expected_cwd)
# Verify send_message was called with correct args
mock_client.send_message.assert_called_once_with(message, session_id=acp_session_id)
# Verify we got the event
assert mock_event in events
# Verify shared client is cached by sandbox_id
assert sandbox_id in manager._acp_clients
assert manager._acp_clients[sandbox_id] is mock_client
# Verify session mapping exists
assert (sandbox_id, session_id) in manager._acp_session_ids
assert manager._acp_session_ids[(sandbox_id, session_id)] == acp_session_id
def test_send_message_reuses_shared_client_for_same_session(manager: Any) -> None:
"""Second call with the same session should reuse the shared client
and the same ACP session ID."""
sandbox_id: UUID = uuid4()
session_id: UUID = uuid4()
mock_event_1 = _make_mock_event()
mock_event_2 = _make_mock_event()
mock_client = _make_mock_client(is_running=True)
acp_session_id = "acp-session-reuse"
mock_client.get_or_create_session.return_value = acp_session_id
type(mock_client).session_ids = property(lambda _: [acp_session_id])
mock_client.send_message.side_effect = [
iter([mock_event_1]),
iter([mock_event_2]),
]
with patch(_ACP_CLIENT_CLASS, return_value=mock_client) as MockClass:
events_1 = _drain_generator(
manager.send_message(sandbox_id, session_id, "first")
)
events_2 = _drain_generator(
manager.send_message(sandbox_id, session_id, "second")
)
# Constructor called only ONCE (shared client)
MockClass.assert_called_once()
# start() called only once
mock_client.start.assert_called_once()
# get_or_create_session called only once (second call uses cached mapping)
mock_client.get_or_create_session.assert_called_once()
# send_message called twice with same ACP session ID
assert mock_client.send_message.call_count == 2
assert mock_event_1 in events_1
assert mock_event_2 in events_2
def test_send_message_different_sessions_share_client(manager: Any) -> None:
"""Two different craft sessions on the same sandbox should share the
same ACPExecClient but have different ACP sessions."""
sandbox_id: UUID = uuid4()
session_id_a: UUID = uuid4()
session_id_b: UUID = uuid4()
mock_client = _make_mock_client(is_running=True)
acp_session_a = "acp-session-a"
acp_session_b = "acp-session-b"
mock_client.get_or_create_session.side_effect = [acp_session_a, acp_session_b]
type(mock_client).session_ids = property(lambda _: [acp_session_a, acp_session_b])
mock_event_a = _make_mock_event()
mock_event_b = _make_mock_event()
mock_client.send_message.side_effect = [
iter([mock_event_a]),
iter([mock_event_b]),
]
with patch(_ACP_CLIENT_CLASS, return_value=mock_client) as MockClass:
events_a = _drain_generator(
manager.send_message(sandbox_id, session_id_a, "msg a")
)
events_b = _drain_generator(
manager.send_message(sandbox_id, session_id_b, "msg b")
)
# Only ONE shared client was created
MockClass.assert_called_once()
# get_or_create_session called twice (once per craft session)
assert mock_client.get_or_create_session.call_count == 2
# send_message called with different ACP session IDs
mock_client.send_message.assert_any_call("msg a", session_id=acp_session_a)
mock_client.send_message.assert_any_call("msg b", session_id=acp_session_b)
# Both session mappings exist
assert manager._acp_session_ids[(sandbox_id, session_id_a)] == acp_session_a
assert manager._acp_session_ids[(sandbox_id, session_id_b)] == acp_session_b
assert mock_event_a in events_a
assert mock_event_b in events_b
def test_send_message_replaces_dead_client(manager: Any) -> None:
"""If the shared client has is_running == False, should replace it and
re-create sessions."""
sandbox_id: UUID = uuid4()
session_id: UUID = uuid4()
# Place a dead client in the cache
dead_client = _make_mock_client(is_running=False)
manager._acp_clients[sandbox_id] = dead_client
manager._acp_session_ids[(sandbox_id, session_id)] = "old-acp-session"
# Create the replacement client
new_event = _make_mock_event()
new_client = _make_mock_client(is_running=True)
new_acp_session = "new-acp-session"
new_client.get_or_create_session.return_value = new_acp_session
type(new_client).session_ids = property(lambda _: [new_acp_session])
new_client.send_message.return_value = iter([new_event])
with patch(_ACP_CLIENT_CLASS, return_value=new_client):
events = _drain_generator(manager.send_message(sandbox_id, session_id, "test"))
# Dead client was stopped during replacement
dead_client.stop.assert_called_once()
# New client was started
new_client.start.assert_called_once()
# Old session mapping was cleared, new one created
assert manager._acp_session_ids[(sandbox_id, session_id)] == new_acp_session
# Cache holds the new client
assert manager._acp_clients[sandbox_id] is new_client
assert new_event in events
# ---------------------------------------------------------------------------
# Tests: Cleanup
# ---------------------------------------------------------------------------
def test_terminate_stops_shared_client(manager: Any) -> None:
"""terminate(sandbox_id) should stop the shared client and clear
all session mappings for that sandbox."""
sandbox_id: UUID = uuid4()
session_id_1: UUID = uuid4()
session_id_2: UUID = uuid4()
mock_client = _make_mock_client(is_running=True)
manager._acp_clients[sandbox_id] = mock_client
manager._acp_session_ids[(sandbox_id, session_id_1)] = "acp-1"
manager._acp_session_ids[(sandbox_id, session_id_2)] = "acp-2"
with patch.object(manager, "_cleanup_kubernetes_resources"):
manager.terminate(sandbox_id)
# Shared client was stopped
mock_client.stop.assert_called_once()
# Client removed from cache
assert sandbox_id not in manager._acp_clients
# Session mappings removed
assert (sandbox_id, session_id_1) not in manager._acp_session_ids
assert (sandbox_id, session_id_2) not in manager._acp_session_ids
def test_terminate_leaves_other_sandbox_untouched(manager: Any) -> None:
"""terminate(sandbox_A) should NOT affect sandbox_B's client or sessions."""
sandbox_a: UUID = uuid4()
sandbox_b: UUID = uuid4()
session_a: UUID = uuid4()
session_b: UUID = uuid4()
client_a = _make_mock_client(is_running=True)
client_b = _make_mock_client(is_running=True)
manager._acp_clients[sandbox_a] = client_a
manager._acp_clients[sandbox_b] = client_b
manager._acp_session_ids[(sandbox_a, session_a)] = "acp-a"
manager._acp_session_ids[(sandbox_b, session_b)] = "acp-b"
with patch.object(manager, "_cleanup_kubernetes_resources"):
manager.terminate(sandbox_a)
# sandbox_a cleaned up
client_a.stop.assert_called_once()
assert sandbox_a not in manager._acp_clients
assert (sandbox_a, session_a) not in manager._acp_session_ids
# sandbox_b untouched
client_b.stop.assert_not_called()
assert sandbox_b in manager._acp_clients
assert manager._acp_session_ids[(sandbox_b, session_b)] == "acp-b"
def test_cleanup_session_removes_session_mapping(manager: Any) -> None:
"""cleanup_session_workspace() should remove the session mapping but
leave the shared client alive for other sessions."""
sandbox_id: UUID = uuid4()
session_id: UUID = uuid4()
mock_client = _make_mock_client(is_running=True)
manager._acp_clients[sandbox_id] = mock_client
manager._acp_session_ids[(sandbox_id, session_id)] = "acp-session-xyz"
with patch.object(manager, "_stream_core_api") as mock_stream_api:
mock_stream_api.connect_get_namespaced_pod_exec = MagicMock()
with patch(f"{_K8S_MODULE}.k8s_stream", return_value="cleanup ok"):
manager.cleanup_session_workspace(sandbox_id, session_id)
# Session mapping removed
assert (sandbox_id, session_id) not in manager._acp_session_ids
# Shared client is NOT stopped (other sessions may use it)
mock_client.stop.assert_not_called()
assert sandbox_id in manager._acp_clients
def test_cleanup_session_handles_no_mapping(manager: Any) -> None:
"""cleanup_session_workspace() should not error when there's no
session mapping."""
sandbox_id: UUID = uuid4()
session_id: UUID = uuid4()
assert (sandbox_id, session_id) not in manager._acp_session_ids
with patch.object(manager, "_stream_core_api") as mock_stream_api:
mock_stream_api.connect_get_namespaced_pod_exec = MagicMock()
with patch(f"{_K8S_MODULE}.k8s_stream", return_value="cleanup ok"):
manager.cleanup_session_workspace(sandbox_id, session_id)
assert (sandbox_id, session_id) not in manager._acp_session_ids

View File

@@ -0,0 +1,93 @@
import pytest
from ee.onyx.server.scim.filtering import parse_scim_filter
from ee.onyx.server.scim.filtering import ScimFilter
from ee.onyx.server.scim.filtering import ScimFilterOperator
class TestParseScimFilter:
"""Tests for SCIM filter expression parsing."""
def test_eq_filter_double_quoted(self) -> None:
result = parse_scim_filter('userName eq "john@example.com"')
assert result == ScimFilter(
attribute="userName",
operator=ScimFilterOperator.EQUAL,
value="john@example.com",
)
def test_eq_filter_single_quoted(self) -> None:
result = parse_scim_filter("userName eq 'john@example.com'")
assert result == ScimFilter(
attribute="userName",
operator=ScimFilterOperator.EQUAL,
value="john@example.com",
)
def test_co_filter(self) -> None:
result = parse_scim_filter('displayName co "Engineering"')
assert result == ScimFilter(
attribute="displayName",
operator=ScimFilterOperator.CONTAINS,
value="Engineering",
)
def test_sw_filter(self) -> None:
result = parse_scim_filter('userName sw "admin"')
assert result == ScimFilter(
attribute="userName",
operator=ScimFilterOperator.STARTS_WITH,
value="admin",
)
def test_case_insensitive_operator(self) -> None:
result = parse_scim_filter('userName EQ "test@example.com"')
assert result is not None
assert result.operator == ScimFilterOperator.EQUAL
def test_external_id_filter(self) -> None:
result = parse_scim_filter('externalId eq "abc-123"')
assert result == ScimFilter(
attribute="externalId",
operator=ScimFilterOperator.EQUAL,
value="abc-123",
)
def test_empty_value(self) -> None:
result = parse_scim_filter('userName eq ""')
assert result == ScimFilter(
attribute="userName",
operator=ScimFilterOperator.EQUAL,
value="",
)
def test_whitespace_trimming(self) -> None:
result = parse_scim_filter(' userName eq "test" ')
assert result is not None
assert result.value == "test"
@pytest.mark.parametrize(
"filter_string",
[
None,
"",
" ",
],
)
def test_empty_input_returns_none(self, filter_string: str | None) -> None:
assert parse_scim_filter(filter_string) is None
@pytest.mark.parametrize(
"filter_string",
[
"userName", # missing operator and value
"userName eq", # missing value
'userName gt "5"', # unsupported operator
'userName ne "test"', # unsupported operator
"userName eq unquoted", # unquoted value
'a eq "x" and b eq "y"', # compound filter not supported
],
)
def test_malformed_input_raises_value_error(self, filter_string: str) -> None:
with pytest.raises(ValueError, match="Unsupported or malformed"):
parse_scim_filter(filter_string)

View File

@@ -0,0 +1,258 @@
import pytest
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimPatchOperation
from ee.onyx.server.scim.models import ScimPatchOperationType
from ee.onyx.server.scim.models import ScimUserResource
from ee.onyx.server.scim.patch import apply_group_patch
from ee.onyx.server.scim.patch import apply_user_patch
from ee.onyx.server.scim.patch import ScimPatchError
def _make_user(**kwargs: object) -> ScimUserResource:
defaults: dict = {
"userName": "test@example.com",
"active": True,
"name": ScimName(givenName="Test", familyName="User"),
}
defaults.update(kwargs)
return ScimUserResource(**defaults)
def _make_group(**kwargs: object) -> ScimGroupResource:
defaults: dict = {"displayName": "Engineering"}
defaults.update(kwargs)
return ScimGroupResource(**defaults)
def _replace_op(
path: str | None = None,
value: str | bool | dict | list | None = None,
) -> ScimPatchOperation:
return ScimPatchOperation(op=ScimPatchOperationType.REPLACE, path=path, value=value)
def _add_op(
path: str | None = None,
value: str | bool | dict | list | None = None,
) -> ScimPatchOperation:
return ScimPatchOperation(op=ScimPatchOperationType.ADD, path=path, value=value)
def _remove_op(path: str) -> ScimPatchOperation:
return ScimPatchOperation(op=ScimPatchOperationType.REMOVE, path=path)
class TestApplyUserPatch:
"""Tests for SCIM user PATCH operations."""
def test_deactivate_user(self) -> None:
user = _make_user()
result = apply_user_patch([_replace_op("active", False)], user)
assert result.active is False
assert result.userName == "test@example.com"
def test_activate_user(self) -> None:
user = _make_user(active=False)
result = apply_user_patch([_replace_op("active", True)], user)
assert result.active is True
def test_replace_given_name(self) -> None:
user = _make_user()
result = apply_user_patch([_replace_op("name.givenName", "NewFirst")], user)
assert result.name is not None
assert result.name.givenName == "NewFirst"
assert result.name.familyName == "User"
def test_replace_family_name(self) -> None:
user = _make_user()
result = apply_user_patch([_replace_op("name.familyName", "NewLast")], user)
assert result.name is not None
assert result.name.familyName == "NewLast"
def test_replace_username(self) -> None:
user = _make_user()
result = apply_user_patch([_replace_op("userName", "new@example.com")], user)
assert result.userName == "new@example.com"
def test_replace_without_path_uses_dict(self) -> None:
user = _make_user()
result = apply_user_patch(
[_replace_op(None, {"active": False, "userName": "new@example.com"})],
user,
)
assert result.active is False
assert result.userName == "new@example.com"
def test_multiple_operations(self) -> None:
user = _make_user()
result = apply_user_patch(
[
_replace_op("active", False),
_replace_op("name.givenName", "Updated"),
],
user,
)
assert result.active is False
assert result.name is not None
assert result.name.givenName == "Updated"
def test_case_insensitive_path(self) -> None:
user = _make_user()
result = apply_user_patch([_replace_op("Active", False)], user)
assert result.active is False
def test_original_not_mutated(self) -> None:
user = _make_user()
apply_user_patch([_replace_op("active", False)], user)
assert user.active is True
def test_unsupported_path_raises(self) -> None:
user = _make_user()
with pytest.raises(ScimPatchError, match="Unsupported path"):
apply_user_patch([_replace_op("unknownField", "value")], user)
def test_remove_op_on_user_raises(self) -> None:
user = _make_user()
with pytest.raises(ScimPatchError, match="Unsupported operation"):
apply_user_patch([_remove_op("active")], user)
class TestApplyGroupPatch:
"""Tests for SCIM group PATCH operations."""
def test_replace_display_name(self) -> None:
group = _make_group()
result, added, removed = apply_group_patch(
[_replace_op("displayName", "New Name")], group
)
assert result.displayName == "New Name"
assert added == []
assert removed == []
def test_add_members(self) -> None:
group = _make_group()
result, added, removed = apply_group_patch(
[_add_op("members", [{"value": "user-1"}, {"value": "user-2"}])],
group,
)
assert len(result.members) == 2
assert added == ["user-1", "user-2"]
assert removed == []
def test_add_members_without_path(self) -> None:
group = _make_group()
result, added, _ = apply_group_patch(
[_add_op(None, [{"value": "user-1"}])],
group,
)
assert len(result.members) == 1
assert added == ["user-1"]
def test_add_duplicate_member_skipped(self) -> None:
group = _make_group(members=[ScimGroupMember(value="user-1")])
result, added, _ = apply_group_patch(
[_add_op("members", [{"value": "user-1"}, {"value": "user-2"}])],
group,
)
assert len(result.members) == 2
assert added == ["user-2"]
def test_remove_member(self) -> None:
group = _make_group(
members=[
ScimGroupMember(value="user-1"),
ScimGroupMember(value="user-2"),
]
)
result, added, removed = apply_group_patch(
[_remove_op('members[value eq "user-1"]')],
group,
)
assert len(result.members) == 1
assert result.members[0].value == "user-2"
assert removed == ["user-1"]
assert added == []
def test_remove_nonexistent_member(self) -> None:
group = _make_group(members=[ScimGroupMember(value="user-1")])
result, _, removed = apply_group_patch(
[_remove_op('members[value eq "user-999"]')],
group,
)
assert len(result.members) == 1
assert removed == []
def test_mixed_operations(self) -> None:
group = _make_group(members=[ScimGroupMember(value="user-1")])
result, added, removed = apply_group_patch(
[
_replace_op("displayName", "Renamed"),
_add_op("members", [{"value": "user-2"}]),
_remove_op('members[value eq "user-1"]'),
],
group,
)
assert result.displayName == "Renamed"
assert added == ["user-2"]
assert removed == ["user-1"]
assert len(result.members) == 1
def test_remove_without_path_raises(self) -> None:
group = _make_group()
with pytest.raises(ScimPatchError, match="requires a path"):
apply_group_patch(
[ScimPatchOperation(op=ScimPatchOperationType.REMOVE, path=None)],
group,
)
def test_remove_invalid_path_raises(self) -> None:
group = _make_group()
with pytest.raises(ScimPatchError, match="Unsupported remove path"):
apply_group_patch([_remove_op("displayName")], group)
def test_replace_members_with_path(self) -> None:
group = _make_group(
members=[
ScimGroupMember(value="user-1"),
ScimGroupMember(value="user-2"),
]
)
result, added, removed = apply_group_patch(
[_replace_op("members", [{"value": "user-2"}, {"value": "user-3"}])],
group,
)
assert len(result.members) == 2
member_ids = {m.value for m in result.members}
assert member_ids == {"user-2", "user-3"}
assert "user-3" in added
assert "user-1" in removed
assert "user-2" not in added
assert "user-2" not in removed
def test_replace_members_empty_list_clears(self) -> None:
group = _make_group(
members=[
ScimGroupMember(value="user-1"),
ScimGroupMember(value="user-2"),
]
)
result, added, removed = apply_group_patch(
[_replace_op("members", [])],
group,
)
assert len(result.members) == 0
assert added == []
assert set(removed) == {"user-1", "user-2"}
def test_unsupported_replace_path_raises(self) -> None:
group = _make_group()
with pytest.raises(ScimPatchError, match="Unsupported path"):
apply_group_patch([_replace_op("unknownField", "val")], group)
def test_original_not_mutated(self) -> None:
group = _make_group()
apply_group_patch([_replace_op("displayName", "Changed")], group)
assert group.displayName == "Engineering"

View File

@@ -367,7 +367,6 @@ webserver:
# KEDA specific configurations (only used when autoscaling.engine is set to 'keda')
pollingInterval: 30 # seconds
cooldownPeriod: 300 # seconds
idleReplicaCount: 1 # minimum replicas when idle
failureThreshold: 3 # number of failures before fallback
fallbackReplicas: 1 # replicas to maintain on failure
# Custom triggers for advanced KEDA configurations
@@ -457,7 +456,6 @@ api:
# KEDA specific configurations (only used when autoscaling.engine is set to 'keda')
pollingInterval: 30 # seconds
cooldownPeriod: 300 # seconds
idleReplicaCount: 1 # minimum replicas when idle
failureThreshold: 3 # number of failures before fallback
fallbackReplicas: 1 # replicas to maintain on failure
# Custom triggers for advanced KEDA configurations
@@ -553,7 +551,6 @@ celery_worker_heavy:
# KEDA specific configurations (only used when autoscaling.engine is set to 'keda')
pollingInterval: 30 # seconds
cooldownPeriod: 300 # seconds
idleReplicaCount: 1 # minimum replicas when idle
failureThreshold: 3 # number of failures before fallback
fallbackReplicas: 1 # replicas to maintain on failure
# Custom triggers for advanced KEDA configurations
@@ -587,7 +584,6 @@ celery_worker_docprocessing:
# KEDA specific configurations (only used when autoscaling.engine is set to 'keda')
pollingInterval: 30 # seconds
cooldownPeriod: 300 # seconds
idleReplicaCount: 1 # minimum replicas when idle
failureThreshold: 3 # number of failures before fallback
fallbackReplicas: 1 # replicas to maintain on failure
# Custom triggers for advanced KEDA configurations
@@ -621,7 +617,6 @@ celery_worker_light:
# KEDA specific configurations (only used when autoscaling.engine is set to 'keda')
pollingInterval: 30 # seconds
cooldownPeriod: 300 # seconds
idleReplicaCount: 1 # minimum replicas when idle
failureThreshold: 3 # number of failures before fallback
fallbackReplicas: 1 # replicas to maintain on failure
# Custom triggers for advanced KEDA configurations
@@ -655,7 +650,6 @@ celery_worker_monitoring:
# KEDA specific configurations (only used when autoscaling.engine is set to 'keda')
pollingInterval: 30 # seconds
cooldownPeriod: 300 # seconds
idleReplicaCount: 1 # minimum replicas when idle
failureThreshold: 3 # number of failures before fallback
fallbackReplicas: 1 # replicas to maintain on failure
# Custom triggers for advanced KEDA configurations
@@ -689,7 +683,6 @@ celery_worker_primary:
# KEDA specific configurations (only used when autoscaling.engine is set to 'keda')
pollingInterval: 30 # seconds
cooldownPeriod: 300 # seconds
idleReplicaCount: 1 # minimum replicas when idle
failureThreshold: 3 # number of failures before fallback
fallbackReplicas: 1 # replicas to maintain on failure
# Custom triggers for advanced KEDA configurations
@@ -723,7 +716,6 @@ celery_worker_user_file_processing:
# KEDA specific configurations
pollingInterval: 30 # seconds
cooldownPeriod: 300 # seconds
idleReplicaCount: 1 # minimum replicas when idle
failureThreshold: 3 # number of failures before fallback
fallbackReplicas: 1 # replicas to maintain on failure
# Custom triggers for advanced KEDA configurations
@@ -868,7 +860,6 @@ celery_worker_docfetching:
# KEDA specific configurations (only used when autoscaling.engine is set to 'keda')
pollingInterval: 30 # seconds
cooldownPeriod: 300 # seconds
idleReplicaCount: 1 # minimum replicas when idle
failureThreshold: 3 # number of failures before fallback
fallbackReplicas: 1 # replicas to maintain on failure
# Custom triggers for advanced KEDA configurations

View File

@@ -196,7 +196,7 @@ members = ["backend", "tools/ods"]
[tool.basedpyright]
include = ["backend"]
exclude = ["backend/generated"]
exclude = ["backend/generated", "backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx", "backend/onyx/server/features/build/sandbox/kubernetes/docker/templates/venv"]
typeCheckingMode = "off"
[tool.ruff]

View File

@@ -70,6 +70,9 @@ ENV NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=${NEXT_PUBLIC_INCLUDE_ERROR_POP
ARG NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY
ENV NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY}
ARG NEXT_PUBLIC_RECAPTCHA_SITE_KEY
ENV NEXT_PUBLIC_RECAPTCHA_SITE_KEY=${NEXT_PUBLIC_RECAPTCHA_SITE_KEY}
# Add NODE_OPTIONS argument
ARG NODE_OPTIONS
@@ -144,6 +147,9 @@ ENV NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=${NEXT_PUBLIC_INCLUDE_ERROR_POP
ARG NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY
ENV NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY}
ARG NEXT_PUBLIC_RECAPTCHA_SITE_KEY
ENV NEXT_PUBLIC_RECAPTCHA_SITE_KEY=${NEXT_PUBLIC_RECAPTCHA_SITE_KEY}
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
ARG ONYX_VERSION=0.0.0-dev
ENV ONYX_VERSION=${ONYX_VERSION}

View File

@@ -1,22 +1,20 @@
import type { IconProps } from "@opal/types";
const SvgMaximize2 = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 14 14"
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
strokeWidth={2.5}
{...props}
>
<path
d="M9 1H13M13 1V5M13 1L8.33333 5.66667M5 13H1M1 13V9M1 13L5.66667 8.33333"
d="M10 2H14M14 2V6M14 2L9.33333 6.66667M6 14H2M2 14V10M2 14L6.66667 9.33333"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
export default SvgMaximize2;

6
web/package-lock.json generated
View File

@@ -13748,9 +13748,9 @@
"license": "MIT"
},
"node_modules/qs": {
"version": "6.14.1",
"resolved": "https://registry.npmjs.org/qs/-/qs-6.14.1.tgz",
"integrity": "sha512-4EK3+xJl8Ts67nLYNwqw/dsFVnCf+qR7RgXSK9jEEm9unao3njwMDdmsdvoKBKHzxd7tCYz5e5M+SnMjdtXGQQ==",
"version": "6.14.2",
"resolved": "https://registry.npmjs.org/qs/-/qs-6.14.2.tgz",
"integrity": "sha512-V/yCWTTF7VJ9hIh18Ugr2zhJMP01MY7c5kh4J870L7imm6/DIzBsNLTXzMwUA3yZ5b/KBqLx8Kp3uRvd7xSe3Q==",
"license": "BSD-3-Clause",
"dependencies": {
"side-channel": "^1.1.0"

View File

@@ -1,5 +1,5 @@
import { Form, Formik } from "formik";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import { toast } from "@/hooks/useToast";
import { SelectorFormField, TextFormField } from "@/components/Field";
import { createApiKey, updateApiKey } from "./lib";
import Modal from "@/refresh-components/Modal";
@@ -10,14 +10,12 @@ import { APIKey } from "./types";
import { SvgKey } from "@opal/icons";
export interface OnyxApiKeyFormProps {
onClose: () => void;
setPopup: (popupSpec: PopupSpec | null) => void;
onCreateApiKey: (apiKey: APIKey) => void;
apiKey?: APIKey;
}
export default function OnyxApiKeyForm({
onClose,
setPopup,
onCreateApiKey,
apiKey,
}: OnyxApiKeyFormProps) {
@@ -54,12 +52,11 @@ export default function OnyxApiKeyForm({
}
formikHelpers.setSubmitting(false);
if (response.ok) {
setPopup({
message: isUpdate
toast.success(
isUpdate
? "Successfully updated API key!"
: "Successfully created API key!",
type: "success",
});
: "Successfully created API key!"
);
if (!isUpdate) {
onCreateApiKey(await response.json());
}
@@ -67,12 +64,11 @@ export default function OnyxApiKeyForm({
} else {
const responseJson = await response.json();
const errorMsg = responseJson.detail || responseJson.message;
setPopup({
message: isUpdate
toast.error(
isUpdate
? `Error updating API key - ${errorMsg}`
: `Error creating API key - ${errorMsg}`,
type: "error",
});
: `Error creating API key - ${errorMsg}`
);
}
}}
>

View File

@@ -15,7 +15,7 @@ import {
Table,
} from "@/components/ui/table";
import Title from "@/components/ui/title";
import { usePopup } from "@/components/admin/connectors/Popup";
import { toast } from "@/hooks/useToast";
import { useState } from "react";
import { DeleteButton } from "@/components/DeleteButton";
import Modal from "@/refresh-components/Modal";
@@ -33,8 +33,6 @@ import Text from "@/refresh-components/texts/Text";
import { SvgEdit, SvgKey, SvgRefreshCw } from "@opal/icons";
function Main() {
const { popup, setPopup } = usePopup();
const {
data: apiKeys,
isLoading,
@@ -84,7 +82,6 @@ function Main() {
if (filteredApiKeys.length === 0) {
return (
<div>
{popup}
{introSection}
{showCreateUpdateForm && (
@@ -97,7 +94,6 @@ function Main() {
setSelectedApiKey(undefined);
mutate("/api/admin/api-key");
}}
setPopup={setPopup}
apiKey={selectedApiKey}
/>
)}
@@ -107,8 +103,6 @@ function Main() {
return (
<>
{popup}
<Modal open={!!fullApiKey}>
<Modal.Content width="sm" height="sm">
<Modal.Header
@@ -171,10 +165,7 @@ function Main() {
setKeyIsGenerating(false);
if (!response.ok) {
const errorMsg = await response.text();
setPopup({
type: "error",
message: `Failed to regenerate API Key: ${errorMsg}`,
});
toast.error(`Failed to regenerate API Key: ${errorMsg}`);
return;
}
const newKey = (await response.json()) as APIKey;
@@ -191,10 +182,7 @@ function Main() {
const response = await deleteApiKey(apiKey.api_key_id);
if (!response.ok) {
const errorMsg = await response.text();
setPopup({
type: "error",
message: `Failed to delete API Key: ${errorMsg}`,
});
toast.error(`Failed to delete API Key: ${errorMsg}`);
return;
}
mutate("/api/admin/api-key");
@@ -216,7 +204,6 @@ function Main() {
setSelectedApiKey(undefined);
mutate("/api/admin/api-key");
}}
setPopup={setPopup}
apiKey={selectedApiKey}
/>
)}

View File

@@ -4,7 +4,7 @@ import Text from "@/refresh-components/texts/Text";
import { Persona } from "./interfaces";
import { useRouter } from "next/navigation";
import Checkbox from "@/refresh-components/inputs/Checkbox";
import { usePopup } from "@/components/admin/connectors/Popup";
import { toast } from "@/hooks/useToast";
import { useState, useMemo, useEffect } from "react";
import { UniqueIdentifier } from "@dnd-kit/core";
import { DraggableTable } from "@/components/table/DraggableTable";
@@ -56,7 +56,6 @@ export function PersonasTable({
pageSize: number;
}) {
const router = useRouter();
const { popup, setPopup } = usePopup();
const { refreshUser, isAdmin } = useUser();
const editablePersonas = useMemo(() => {
@@ -109,10 +108,7 @@ export function PersonasTable({
});
if (!response.ok) {
setPopup({
type: "error",
message: `Failed to update persona order - ${await response.text()}`,
});
toast.error(`Failed to update persona order - ${await response.text()}`);
setFinalPersonas(personas);
await refreshPersonas();
return;
@@ -139,10 +135,7 @@ export function PersonasTable({
refreshPersonas();
closeDeleteModal();
} else {
setPopup({
type: "error",
message: `Failed to delete persona - ${await response.text()}`,
});
toast.error(`Failed to delete persona - ${await response.text()}`);
}
}
};
@@ -167,17 +160,13 @@ export function PersonasTable({
refreshPersonas();
closeDefaultModal();
} else {
setPopup({
type: "error",
message: `Failed to update persona - ${await response.text()}`,
});
toast.error(`Failed to update persona - ${await response.text()}`);
}
}
};
return (
<div>
{popup}
{deleteModalOpen && personaToDelete && (
<ConfirmationModalLayout
icon={SvgAlertCircle}
@@ -290,10 +279,9 @@ export function PersonasTable({
if (response.ok) {
refreshPersonas();
} else {
setPopup({
type: "error",
message: `Failed to update persona - ${await response.text()}`,
});
toast.error(
`Failed to update persona - ${await response.text()}`
);
}
}}
className={`

View File

@@ -1,7 +1,6 @@
"use client";
import CardSection from "@/components/admin/CardSection";
import { usePopup } from "@/components/admin/connectors/Popup";
import { useRouter } from "next/navigation";
import { useState } from "react";
import { SlackTokensForm } from "./SlackTokensForm";
@@ -17,7 +16,6 @@ export const NewSlackBotForm = () => {
app_token: "",
user_token: "",
});
const { popup, setPopup } = usePopup();
const router = useRouter();
return (
@@ -27,12 +25,10 @@ export const NewSlackBotForm = () => {
title="New Slack Bot"
/>
<CardSection>
{popup}
<div className="p-4">
<SlackTokensForm
isUpdate={false}
initialValues={formValues}
setPopup={setPopup}
router={router}
/>
</div>

View File

@@ -1,6 +1,6 @@
"use client";
import { usePopup } from "@/components/admin/connectors/Popup";
import { toast } from "@/hooks/useToast";
import { SlackBot, ValidSources } from "@/lib/types";
import { useRouter } from "next/navigation";
import { useState, useEffect, useRef } from "react";
@@ -24,7 +24,6 @@ export const ExistingSlackBotForm = ({
}) => {
const [isExpanded, setIsExpanded] = useState(false);
const [formValues, setFormValues] = useState(existingSlackBot);
const { popup, setPopup } = usePopup();
const router = useRouter();
const dropdownRef = useRef<HTMLDivElement>(null);
const [showDeleteModal, setShowDeleteModal] = useState(false);
@@ -42,15 +41,9 @@ export const ExistingSlackBotForm = ({
if (!response.ok) {
throw new Error(await response.text());
}
setPopup({
message: `Connector ${field} updated successfully`,
type: "success",
});
toast.success(`Connector ${field} updated successfully`);
} catch (error) {
setPopup({
message: `Failed to update connector ${field}`,
type: "error",
});
toast.error(`Failed to update connector ${field}`);
}
setFormValues((prev) => ({ ...prev, [field]: value }));
};
@@ -74,7 +67,6 @@ export const ExistingSlackBotForm = ({
return (
<div>
{popup}
<div className="flex items-center justify-between h-14">
<div className="flex items-center gap-2">
<div className="my-auto">
@@ -120,7 +112,6 @@ export const ExistingSlackBotForm = ({
initialValues={formValues}
existingSlackBotId={existingSlackBot.id}
refreshSlackBot={refreshSlackBot}
setPopup={setPopup}
router={router}
onValuesChange={(values) => setFormValues(values)}
/>
@@ -149,16 +140,10 @@ export const ExistingSlackBotForm = ({
if (!response.ok) {
throw new Error(await response.text());
}
setPopup({
message: "Slack bot deleted successfully",
type: "success",
});
toast.success("Slack bot deleted successfully");
router.push("/admin/bots");
} catch (error) {
setPopup({
message: "Failed to delete Slack bot",
type: "error",
});
toast.error("Failed to delete Slack bot");
}
setShowDeleteModal(false);
}}

View File

@@ -8,13 +8,13 @@ import Button from "@/refresh-components/buttons/Button";
import Separator from "@/refresh-components/Separator";
import { useEffect } from "react";
import { DOCS_ADMINS_PATH } from "@/lib/constants";
import { toast } from "@/hooks/useToast";
export const SlackTokensForm = ({
isUpdate,
initialValues,
existingSlackBotId,
refreshSlackBot,
setPopup,
router,
onValuesChange,
}: {
@@ -22,7 +22,6 @@ export const SlackTokensForm = ({
initialValues: any;
existingSlackBotId?: number;
refreshSlackBot?: () => void;
setPopup: (popup: { message: string; type: "error" | "success" }) => void;
router: any;
onValuesChange?: (values: any) => void;
}) => {
@@ -59,12 +58,11 @@ export const SlackTokensForm = ({
}
const responseJson = await response.json();
const botId = isUpdate ? existingSlackBotId : responseJson.id;
setPopup({
message: isUpdate
toast.success(
isUpdate
? "Successfully updated Slack Bot!"
: "Successfully created Slack Bot!",
type: "success",
});
: "Successfully created Slack Bot!"
);
router.push(`/admin/bots/${encodeURIComponent(botId)}`);
} else {
const responseJson = await response.json();
@@ -75,12 +73,11 @@ export const SlackTokensForm = ({
} else if (errorMsg.includes("Invalid app token:")) {
errorMsg = "Slack App Token is invalid";
}
setPopup({
message: isUpdate
toast.error(
isUpdate
? `Error updating Slack Bot - ${errorMsg}`
: `Error creating Slack Bot - ${errorMsg}`,
type: "error",
});
: `Error creating Slack Bot - ${errorMsg}`
);
}
}}
enableReinitialize={true}

View File

@@ -1,7 +1,7 @@
"use client";
import { PageSelector } from "@/components/PageSelector";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import { toast } from "@/hooks/useToast";
import { EditIcon } from "@/components/icons/icons";
import { SlackChannelConfig } from "@/lib/types";
import {
@@ -27,14 +27,12 @@ export interface SlackChannelConfigsTableProps {
slackBotId: number;
slackChannelConfigs: SlackChannelConfig[];
refresh: () => void;
setPopup: (popupSpec: PopupSpec | null) => void;
}
export default function SlackChannelConfigsTable({
slackBotId,
slackChannelConfigs,
refresh,
setPopup,
}: SlackChannelConfigsTableProps) {
const [page, setPage] = useState(1);
@@ -130,16 +128,14 @@ export default function SlackChannelConfigsTable({
slackChannelConfig.id
);
if (response.ok) {
setPopup({
message: `Slack bot config "${slackChannelConfig.id}" deleted`,
type: "success",
});
toast.success(
`Slack bot config "${slackChannelConfig.id}" deleted`
);
} else {
const errorMsg = await response.text();
setPopup({
message: `Failed to delete Slack bot config - ${errorMsg}`,
type: "error",
});
toast.error(
`Failed to delete Slack bot config - ${errorMsg}`
);
}
refresh();
}}

View File

@@ -3,7 +3,7 @@
import React, { useMemo } from "react";
import { Formik, Form } from "formik";
import * as Yup from "yup";
import { usePopup } from "@/components/admin/connectors/Popup";
import { toast } from "@/hooks/useToast";
import {
DocumentSetSummary,
SlackChannelConfig,
@@ -34,7 +34,6 @@ export const SlackChannelConfigCreationForm = ({
standardAnswerCategoryResponse: StandardAnswerCategoryResponse;
existingSlackChannelConfig?: SlackChannelConfig;
}) => {
const { popup, setPopup } = usePopup();
const router = useRouter();
const isUpdate = Boolean(existingSlackChannelConfig);
const isDefault = existingSlackChannelConfig?.is_default || false;
@@ -65,8 +64,6 @@ export const SlackChannelConfigCreationForm = ({
return (
<CardSection className="!px-12 max-w-4xl">
{popup}
<Formik
initialValues={{
slack_bot_id: slack_bot_id,
@@ -221,12 +218,11 @@ export const SlackChannelConfigCreationForm = ({
} else {
const responseJson = await response.json();
const errorMsg = responseJson.detail || responseJson.message;
setPopup({
message: `Error ${
toast.error(
`Error ${
isUpdate ? "updating" : "creating"
} OnyxBot config - ${errorMsg}`,
type: "error",
});
} OnyxBot config - ${errorMsg}`
);
}
}}
>
@@ -241,7 +237,6 @@ export const SlackChannelConfigCreationForm = ({
searchEnabledAssistants={searchEnabledAssistants}
nonSearchAssistants={nonSearchAssistants}
standardAnswerCategoryResponse={standardAnswerCategoryResponse}
setPopup={setPopup}
slack_bot_id={slack_bot_id}
formikProps={formikProps}
/>

View File

@@ -3,6 +3,7 @@
import { useState, useEffect, useMemo } from "react";
import { FieldArray, useFormikContext, ErrorMessage } from "formik";
import { DocumentSetSummary } from "@/lib/types";
import { toast } from "@/hooks/useToast";
import {
Label,
SelectorFormField,
@@ -47,10 +48,6 @@ export interface SlackChannelConfigFormFieldsProps {
searchEnabledAssistants: MinimalPersonaSnapshot[];
nonSearchAssistants: MinimalPersonaSnapshot[];
standardAnswerCategoryResponse: StandardAnswerCategoryResponse;
setPopup: (popup: {
message: string;
type: "error" | "success" | "warning";
}) => void;
slack_bot_id: number;
formikProps: any;
}
@@ -62,7 +59,6 @@ export function SlackChannelConfigFormFields({
searchEnabledAssistants,
nonSearchAssistants,
standardAnswerCategoryResponse,
setPopup,
slack_bot_id,
formikProps,
}: SlackChannelConfigFormFieldsProps) {
@@ -142,13 +138,11 @@ export function SlackChannelConfigFormFields({
(dsId: number) => !invalidSelected.includes(dsId)
)
);
setPopup({
message:
"We removed one or more document sets from your selection because they are no longer valid. Please review and update your configuration.",
type: "warning",
});
toast.warning(
"We removed one or more document sets from your selection because they are no longer valid. Please review and update your configuration."
);
}
}, [unselectableSets, values.document_sets, setFieldValue, setPopup]);
}, [unselectableSets, values.document_sets, setFieldValue]);
const shouldShowPrivacyAlert = useMemo(() => {
if (values.knowledge_source === "document_sets") {

Some files were not shown because too many files have changed in this diff Show More